In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import random_split, Subset
# from S6_model import get_model, save_model
from Newmodel import get_model, save_model
from tqdm import tqdm
import os
import argparse

In [2]:
def get_transform(apply_augmentation=False):
    """Get data transformation pipeline with enhanced augmentation"""
    transforms_list = [
        transforms.ToTensor(),  # Convert to tensor first
        transforms.Normalize((0.1307,), (0.3081,))
    ]
    
    if apply_augmentation:
        transforms_list = [
            transforms.RandomAffine(
                degrees=10,
                translate=(0.1, 0.1),
                scale=(0.9, 1.1),
                fill=0,
            ),
            transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
        ] + transforms_list  # Add base transforms after augmentation
        
        # Add RandomErasing after converting to tensor
        transforms_list.append(transforms.RandomErasing(p=0.1))
    
    return transforms.Compose(transforms_list)

In [3]:
def setup_directories():
    """Create necessary directories"""
    if not os.path.exists('models'):
        os.makedirs('models')
        print("Created 'models' directory for saving checkpoints")


In [4]:
def setup_device():
    """Set up and return the device to use"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    return device

In [5]:
def load_data(use_augmentation, batch_size=128):
    """Load and prepare data loaders"""
    train_transform = get_transform(apply_augmentation=use_augmentation)
    test_transform = get_transform(apply_augmentation=False)
    
    full_train_dataset = datasets.MNIST('./data', train=True, download=True, transform=train_transform)
    test_dataset = datasets.MNIST('./data', train=False, transform=test_transform)
    
    train_indices = list(range(50000))
    val_indices = list(range(50000, 60000))
    
    train_dataset = Subset(full_train_dataset, train_indices)
    val_dataset = Subset(
        datasets.MNIST('./data', train=True, download=False, transform=test_transform),
        val_indices
    )
    
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader, test_loader, len(train_dataset), len(val_dataset), len(test_dataset)

In [6]:
def print_training_config(use_augmentation, initial_lr):
    """Print training configuration details"""
    print("\n=== Training Configuration ===")
    print(f"Initial Learning Rate: {initial_lr}")
    print(f"Optimizer: Adam (betas=(0.9, 0.999), eps=1e-08)")
    print(f"Learning Rate Scheduler: ReduceLROnPlateau")
    print(f" - mode: max (tracking validation accuracy)")
    print(f" - factor: 0.1")
    print(f" - patience: 3 epochs")
    print(f" - min_lr: 1e-6")
    
    print("\n=== Data Augmentation Settings ===")
    if use_augmentation:
        print("Data Augmentation: Enabled for training")
        print(" - Random rotation: ±10 degrees")
        print(" - Random zoom: ±10%")
        print(" - Random shift: ±10% horizontal and vertical")
    else:
        print("Data Augmentation: Disabled")

In [7]:
def train_epoch(model, train_loader, optimizer, criterion, device, scheduler):
    """Train for one epoch"""
    model.train()
    train_correct = 0
    train_total = 0
    train_loss = 0.0
    
    train_pbar = tqdm(train_loader, desc=f'Training (lr={scheduler.get_last_lr()[0]:.6f})', leave=False)
    for batch_idx, (data, target) in enumerate(train_pbar):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        train_total += target.size(0)
        train_correct += (predicted == target).sum().item()
        
        train_pbar.set_postfix({
            'loss': f'{train_loss/(batch_idx+1):.4f}',
            'acc': f'{100.*train_correct/train_total:.2f}%',
            'lr': f'{scheduler.get_last_lr()[0]:.6f}'
        })
    
    return train_loss/len(train_loader), 100 * train_correct / train_total

In [8]:
def test(model, test_loader, criterion, device):
    """Test the model"""
    model.eval()
    test_correct = 0
    test_total = 0
    test_loss = 0.0
    
    with torch.no_grad():
        test_pbar = tqdm(test_loader, desc='Testing')
        for batch_idx, (data, target) in enumerate(test_pbar):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            test_loss += loss.item()
            
            _, predicted = torch.max(output.data, 1)
            test_total += target.size(0)
            test_correct += (predicted == target).sum().item()
            
            test_pbar.set_postfix({
                'loss': f'{test_loss/(batch_idx+1):.4f}',
                'acc': f'{100.*test_correct/test_total:.2f}%'
            })
    
    return test_loss/len(test_loader), 100 * test_correct / test_total

In [9]:
def validate(model, val_loader, criterion, device):
    """Validate the model"""
    model.eval()
    val_correct = 0
    val_total = 0
    val_loss = 0.0
    
    with torch.no_grad():
        val_pbar = tqdm(val_loader, desc='Validation', leave=False)
        for batch_idx, (data, target) in enumerate(val_pbar):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            val_loss += loss.item()
            
            _, predicted = torch.max(output.data, 1)
            val_total += target.size(0)
            val_correct += (predicted == target).sum().item()
            
            val_pbar.set_postfix({
                'loss': f'{val_loss/(batch_idx+1):.4f}',
                'acc': f'{100.*val_correct/val_total:.2f}%'
            })
    
    return val_loss/len(val_loader), 100 * val_correct / val_total

In [10]:
def get_lr(epoch, initial_lr=0.003):
    """Custom learning rate scheduler"""
    return round(initial_lr * 1/(1 + 0.319 * epoch), 10)

In [11]:
def train(use_augmentation=True):
    """Main training function"""
    BATCH_SIZE = 128
    initial_lr = 0.001  # Initial learning rate, 0.003 for SGD
    
    print("\n=== Initializing Training Pipeline ===")
    setup_directories()
    device = setup_device()
    
    print("\n=== Preparing Data ===")
    train_loader, val_loader, test_loader, train_size, val_size, test_size = load_data(
        use_augmentation, BATCH_SIZE
    )
    
    print("\n=== Dataset Statistics ===")
    print(f"Training samples: {train_size}")
    print(f"Validation samples: {val_size}")
    print(f"Test samples: {test_size}")
    
    print_training_config(use_augmentation, initial_lr)
    
    model = get_model().to(device)
    #optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9)
    optimizer = optim.Adam(
        model.parameters(),
        lr=initial_lr,
        betas=(0.9, 0.999),  # default Adam parameters
        eps=1e-08,           # default numerical stability constant
        weight_decay=0       # L2 penalty (if needed)
    )
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    
    # ReduceLROnPlateau scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='max',           # Since we're tracking validation accuracy
        factor=0.1,          # Reduce LR by factor of 10
        patience=3,          # Number of epochs with no improvement after which LR will be reduced
        #verbose=True,        # Print message when LR is reduced
        min_lr=1e-6,        # Minimum LR
        threshold=0.001,     # Minimum change to qualify as an improvement
        threshold_mode='rel' # Relative change
    )
    
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}")
    
    print("\n=== Starting Training ===")
    best_accuracy = 0.0
    
    for epoch in range(20):
        current_lr = optimizer.param_groups[0]['lr']
        print(f"\nEpoch Summary:")
        print(f"Epoch {epoch+1}/20 (LR={current_lr:.6f}):")
        
        # Create a simple scheduler wrapper for train_epoch function
        class SimpleScheduler:
            def get_last_lr(self):
                return [current_lr]
        
        temp_scheduler = SimpleScheduler()
        
        train_loss, train_accuracy = train_epoch(model, train_loader, optimizer, criterion, device, temp_scheduler)
        val_loss, val_accuracy = validate(model, val_loader, criterion, device)
        
        print(f"Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.2f}% | "
              f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")
        
        # Step the scheduler with validation accuracy
        scheduler.step(val_accuracy)
        
        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            model_filename = save_model(model, val_accuracy)
            print(f"✓ New best model saved as {model_filename}")
        
        if val_accuracy >= 99.4:
            print("\n🎉 Reached target validation accuracy of 99.4%!")
            break
    
    print("\n=== Final Evaluation ===")
    test_loss, test_accuracy = test(model, test_loader, criterion, device)
    print(f"\nFinal Test Results:")
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.2f}%")
    print("\n=== Training Complete ===")

In [13]:
train(use_augmentation=False)


=== Initializing Training Pipeline ===
Using device: cuda

=== Preparing Data ===

=== Dataset Statistics ===
Training samples: 50000
Validation samples: 10000
Test samples: 10000

=== Training Configuration ===
Initial Learning Rate: 0.001
Optimizer: Adam (betas=(0.9, 0.999), eps=1e-08)
Learning Rate Scheduler: ReduceLROnPlateau
 - mode: max (tracking validation accuracy)
 - factor: 0.1
 - patience: 3 epochs
 - min_lr: 1e-6

=== Data Augmentation Settings ===
Data Augmentation: Disabled
Total parameters: 10,624

=== Starting Training ===

Epoch Summary:
Epoch 1/20 (LR=0.001000):


                                                                                                                       

Training Loss: 0.9077, Training Accuracy: 90.88% | Validation Loss: 0.7657, Validation Accuracy: 96.46%
✓ New best model saved as mnist_model_96.46_20241129_092814.pth

Epoch Summary:
Epoch 2/20 (LR=0.001000):


                                                                                                                       

Training Loss: 0.6478, Training Accuracy: 97.41% | Validation Loss: 0.6656, Validation Accuracy: 97.60%
✓ New best model saved as mnist_model_97.60_20241129_092832.pth

Epoch Summary:
Epoch 3/20 (LR=0.001000):


                                                                                                                       

Training Loss: 0.6029, Training Accuracy: 98.04% | Validation Loss: 0.6158, Validation Accuracy: 97.98%
✓ New best model saved as mnist_model_97.98_20241129_092850.pth

Epoch Summary:
Epoch 4/20 (LR=0.001000):


                                                                                                                       

Training Loss: 0.5856, Training Accuracy: 98.36% | Validation Loss: 0.5849, Validation Accuracy: 98.71%
✓ New best model saved as mnist_model_98.71_20241129_092908.pth

Epoch Summary:
Epoch 5/20 (LR=0.001000):


                                                                                                                       

Training Loss: 0.5769, Training Accuracy: 98.64% | Validation Loss: 0.5634, Validation Accuracy: 98.87%
✓ New best model saved as mnist_model_98.87_20241129_092925.pth

Epoch Summary:
Epoch 6/20 (LR=0.001000):


                                                                                                                       

Training Loss: 0.5712, Training Accuracy: 98.75% | Validation Loss: 0.5612, Validation Accuracy: 99.00%
✓ New best model saved as mnist_model_99.00_20241129_092943.pth

Epoch Summary:
Epoch 7/20 (LR=0.001000):


                                                                                                                       

Training Loss: 0.5682, Training Accuracy: 98.88% | Validation Loss: 0.5635, Validation Accuracy: 98.97%

Epoch Summary:
Epoch 8/20 (LR=0.001000):


                                                                                                                       

Training Loss: 0.5659, Training Accuracy: 98.88% | Validation Loss: 0.5573, Validation Accuracy: 99.00%

Epoch Summary:
Epoch 9/20 (LR=0.001000):


                                                                                                                       

Training Loss: 0.5628, Training Accuracy: 99.00% | Validation Loss: 0.5537, Validation Accuracy: 99.03%
✓ New best model saved as mnist_model_99.03_20241129_093036.pth

Epoch Summary:
Epoch 10/20 (LR=0.001000):


                                                                                                                       

Training Loss: 0.5610, Training Accuracy: 99.00% | Validation Loss: 0.5492, Validation Accuracy: 99.23%
✓ New best model saved as mnist_model_99.23_20241129_093054.pth

Epoch Summary:
Epoch 11/20 (LR=0.001000):


                                                                                                                       

Training Loss: 0.5588, Training Accuracy: 99.05% | Validation Loss: 0.5518, Validation Accuracy: 99.16%

Epoch Summary:
Epoch 12/20 (LR=0.001000):


                                                                                                                       

Training Loss: 0.5577, Training Accuracy: 99.07% | Validation Loss: 0.5464, Validation Accuracy: 99.20%

Epoch Summary:
Epoch 13/20 (LR=0.001000):


                                                                                                                       

Training Loss: 0.5566, Training Accuracy: 99.12% | Validation Loss: 0.5506, Validation Accuracy: 99.14%

Epoch Summary:
Epoch 14/20 (LR=0.001000):


                                                                                                                       

Training Loss: 0.5547, Training Accuracy: 99.12% | Validation Loss: 0.5456, Validation Accuracy: 99.22%

Epoch Summary:
Epoch 15/20 (LR=0.000100):


                                                                                                                       

Training Loss: 0.5507, Training Accuracy: 99.27% | Validation Loss: 0.5453, Validation Accuracy: 99.25%
✓ New best model saved as mnist_model_99.25_20241129_093223.pth

Epoch Summary:
Epoch 16/20 (LR=0.000100):


                                                                                                                       

Training Loss: 0.5495, Training Accuracy: 99.35% | Validation Loss: 0.5420, Validation Accuracy: 99.30%
✓ New best model saved as mnist_model_99.30_20241129_093241.pth

Epoch Summary:
Epoch 17/20 (LR=0.000100):


                                                                                                                       

Training Loss: 0.5499, Training Accuracy: 99.32% | Validation Loss: 0.5434, Validation Accuracy: 99.28%

Epoch Summary:
Epoch 18/20 (LR=0.000100):


                                                                                                                       

Training Loss: 0.5492, Training Accuracy: 99.28% | Validation Loss: 0.5440, Validation Accuracy: 99.29%

Epoch Summary:
Epoch 19/20 (LR=0.000010):


                                                                                                                       

Training Loss: 0.5476, Training Accuracy: 99.38% | Validation Loss: 0.5429, Validation Accuracy: 99.29%

Epoch Summary:
Epoch 20/20 (LR=0.000010):


                                                                                                                       

Training Loss: 0.5487, Training Accuracy: 99.36% | Validation Loss: 0.5416, Validation Accuracy: 99.29%

=== Final Evaluation ===


Testing: 100%|████████████████████████████████████████████████| 79/79 [00:02<00:00, 30.62it/s, loss=0.5415, acc=99.41%]



Final Test Results:
Test Loss: 0.5415
Test Accuracy: 99.41%

=== Training Complete ===
