# Read dataset and create data loaders

In [36]:
# Import torch and CIFAR dataset
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.optim as optim
import torch.nn.functional as F

# Import matplotlib and numpy for graphs
import matplotlib.pyplot as plt
import numpy as np


In [37]:
'''
Import CIFAR dataset, define labbels and load training and validation dataset
Reference for loading dataset: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
Reference for augmentation: https://pytorch.org/vision/stable/transforms.html
'''
batch_size=64 
print('Batch size:', batch_size)

# Normalisation and std values for RGB in dataset
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

# Data augmentation for training set
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # Randomly crop the image with padding
    transforms.RandomHorizontalFlip(),    # Randomly flip the image horizontally
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Adjust brightness, contrast, etc.
    transforms.RandomRotation(15),        # Randomly rotate the image by up to 15 degrees
    transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)),  # Randomly translate the image
    transforms.ToTensor(),                # Convert image to tensor
    transforms.Normalize(mean=mean, std=std),  # Normalize with mean and std
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.3))  # Randomly erase a portion of the image (optional)
])

# No augmentation for test set (only normalization)
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)  # Normalize with mean and std
])

# Load training and testing datasets
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

# Define labels
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'lorry')

Batch size: 64
Files already downloaded and verified
Files already downloaded and verified


In [38]:
# # From the PyTorch's tutorial on image classification
# import matplotlib.pyplot as plt
# import numpy as np

# def imshow(img):
#     '''
#     Show an image
#     Input: image file to show
#     Output: image
#     '''
#     img = img / 2 + 0.5     # unnormalize
#     npimg = img.numpy()
#     plt.imshow(np.transpose(npimg, (1, 2, 0)))
#     plt.show()

# # Get random training images
# dataiter = iter(trainloader)
# images, labels = next(dataiter)

# # Show images
# imshow(torchvision.utils.make_grid(images))
# # Print labels
# print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))

# Main model
Divided as such:


*   **Stem**: takes the images as inputs, extracts features from them
*   **Backbone**: made up of *K* branches, made up of an expert branch
*   **Classifier**: takes input from the last block
*   **Model**: wraps all together







## Stem
*   Takes images as inputs
*   Extracts a feature representation from them

In [39]:
class Stem(nn.Module):
  '''
  Extract features using a Resnet-18 stem
  Reference: Week 09 Lab
  '''
  def __init__(self, input_channels, middle_channels, output_channels):
     super(Stem,self).__init__()
     # Default parameters
     kernel_size=3
     stride=1
     padding=1
     
     # Combine multiple layers
     self.stem = nn.Sequential(
       nn.Conv2d(input_channels, middle_channels, kernel_size = kernel_size, stride = stride, padding = padding),
       nn.BatchNorm2d(middle_channels), 
       nn.ReLU(inplace=True),
       nn.Conv2d(middle_channels, middle_channels,kernel_size = kernel_size, stride = stride, padding = padding),
       nn.BatchNorm2d(middle_channels),
       nn.ReLU(inplace=True),
       nn.MaxPool2d(2), # Half the size of the image
       nn.Conv2d(middle_channels, output_channels, kernel_size = kernel_size, stride = stride, padding = padding),
       nn.BatchNorm2d(output_channels),
       nn.ReLU(inplace=True),
       nn.MaxPool2d(2) # Half the size of the image
       )

  def forward(self,x):
    x = self.stem(x)
    return x

## Block

In [40]:
class ExpertBranch(nn.Module):
  '''
  Expert branch predicting vector a with K elements from input tensor X
  '''
  def __init__(self, input_channels, k, r):
    super(ExpertBranch,self).__init__()
    # Spatially pool x
    self.pool= nn.AdaptiveAvgPool2d(1)
    #Forward through fc1, reducing by r
    self.fc1= nn.Linear(input_channels, input_channels//r)
    # Activation function ReLu
    self.relu= nn.ReLU()
    # Forward through fc2
    self.fc2= nn.Linear(input_channels//r,k)

  def forward(self,x):
    # Spatially pool X
    x = self.pool(x)
    # Forward through fc1, reducing by r
    x= x.squeeze(-1).squeeze(-1)
    x = self.fc1(x)
    # Processed through non-linear activation g
    x = F.relu(x)
    # Pass through fc2
    x = self.fc2(x)
    # Forward with softmax
    x = F.softmax(x,dim=1)
    return x

In [41]:
class Block(nn.Module):
  '''
  Block
  '''
  def __init__(self, input_channels, output_channels, k, r):
    super(Block, self).__init__()
    # Default parameters
    kernel_size=3
    stride=1
    padding=1
    # Set parameters
    self.k= k
    self.expertBranch = ExpertBranch(input_channels, k=k, r=r)
    # Input from first block
    # Input from previous block for rest
    # Generate vector a with K elements from X as a= E(X)
    # Create K convolutional layers
    self.convs= nn.ModuleList([
        nn.Conv2d(input_channels, output_channels, kernel_size=kernel_size, stride= stride, padding=padding)
        for _ in range(k)
    ])

  def forward(self,x):
    identity= x
    # Vector a from expert branch
    a = self.expertBranch(x)
    # Convolutional layers 
    conv_outputs = [conv(x) for conv in self.convs]
    stacked = torch.stack(conv_outputs, dim=1)
    # Create vector O
    a= a.view(a.size(0), self.k, 1,1,1)

    out = (a* stacked).sum(dim=1)
    # Skip connection to stablise gradient descent
    out += identity
    out = F.relu(out) # activation after skip

    return out

## Backbone

In [42]:
class Backbone(nn.Module):
  '''
  N blocks
  '''
  def __init__(self, input_channels, hidden_channels, num_blocks, k, r):
    super(Backbone, self).__init__()
    self.blocks= nn.ModuleList()

    # First block takes input from stem
    self.blocks.append(Block(input_channels, hidden_channels, k=k, r=r))

    # Rest of blocks take input form previous block
    for _ in range(1, num_blocks):
      self.blocks.append(Block(hidden_channels, hidden_channels, k=k, r=r))

  def forward(self, x):
    for idx, block in enumerate(self.blocks):
      x = block(x)
    return x

## Classifier

In [43]:
class Classifier(nn.Module):
  def __init__(self, input_channels, num_classes, use_mlp):
    super(Classifier,self).__init__()
    # Default parameters
    dropout_rate=0.25
    # Spatially pool
    self.pool = nn.AdaptiveAvgPool2d(1)
    self.use_mlp= use_mlp

    if use_mlp:
      self.classifier= nn.Sequential(
          nn.Linear(input_channels, input_channels*2),
          nn.ReLU(),
          nn.Dropout(dropout_rate), # Deeper network with 3 layers
          nn.Linear(input_channels*2, input_channels),
          nn.ReLU(),
          nn.Dropout(dropout_rate),
          nn.Linear(input_channels, num_classes)
      )
    else:
      self.classifier= nn.Linear(input_channels, num_classes)

  def forward(self, x):
    x = self.pool(x).squeeze(-1).squeeze(-1)
    out = self.classifier(x)
    return out


# Model

In [44]:
class Model(nn.Module):
  def __init__(self, input_channels, output_channels, middle_channels, hidden_channels, num_blocks, k, r, num_classes, use_mlp):
    super(Model, self).__init__()
    # Call stem
    self.stem= Stem(
      input_channels=input_channels,
      middle_channels=middle_channels,
      output_channels=output_channels
    )
    # Call backbone
    self.backbone= Backbone(
      input_channels=output_channels, 
      hidden_channels= hidden_channels, 
      num_blocks=num_blocks,
      k=k, 
      r=r)
    # Call classifier
    self.classifier= Classifier(
      input_channels=hidden_channels, 
      num_classes=num_classes,
      use_mlp= use_mlp)

  def forward(self,x):
    x= self.stem(x)
    x= self.backbone(x)
    x= self.classifier(x)
    return x

# Create the loss and optmiser


In [None]:
model = Model(
    input_channels=3,
    output_channels=256,
    middle_channels=64,
    hidden_channels=256,
    num_blocks=6,
    k=4,
    r=4,
    num_classes=10,
    use_mlp=True
)

def init_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

model.apply(init_weights)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
# optimizer = optim.SGD(model.parameters(), lr=0.0001, weight_decay=1e-4, momentum=0.9)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

# Training & Testing

In [46]:
# Set up device 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Save model
model.to(device)

# Log training 
train_losses, val_losses = [], []
train_accuracies = []
val_accuracies = []

# Training and Validation Loops 
def train(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in tqdm(loader, desc="Training", leave=False):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    return running_loss / len(loader), 100 * correct / total

def evaluate(model, loader, criterion, device):
    model.eval()
    total = 0
    correct = 0
    loss = 0.0

    with torch.no_grad():
        for inputs, labels in tqdm(loader, desc="Validating", leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss += criterion(outputs, labels).item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    return loss / len(loader), 100 * correct / total

# Main Loop 
# patience = 20  # Number of epochs to wait for improvement
early_stop_counter = 0 # Counter for early stopping
epochs = 200
best_acc = 0.0

for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}/{epochs}")
    train_loss, train_acc = train(model, trainloader, criterion, optimizer, device)
    val_loss, val_acc = evaluate(model, testloader, criterion, device)


    # Log metrics
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accuracies.append(train_acc)
    val_accuracies.append(val_acc)

    print(f"Train Loss: {train_loss:.4f} | Accuracy: {train_acc:.2f}%")
    print(f"Val   Loss: {val_loss:.4f} | Accuracy: {val_acc:.2f}%")

    # Save best model
    if val_acc > best_acc:
        best_acc = val_acc
        early_stop_counter=0
        torch.save(model.state_dict(), "best_model.pth")
        print("Saved best model.")
    else:
        early_stop_counter += 1
        print(f"No improvement for {early_stop_counter} epochs.")

    # if early_stop_counter >= patience:
    #     print(f"Early stopping triggered after {epoch+1} epochs.")
    #     break
print("\nTraining Complete")

# Print Final Averages 
avg_train_loss = sum(train_losses) / len(train_losses)
avg_val_loss = sum(val_losses) / len(val_losses)
avg_train_acc = sum(train_accuracies) / len(train_accuracies)
avg_val_acc = sum(val_accuracies) / len(val_accuracies)

print("\nFinal Averages Over All Epochs")
print(f"Average Train Loss: {avg_train_loss:.4f}")
print(f"Average Train Accuracy: {avg_train_acc:.2f}%")
print(f"Average Val   Loss: {avg_val_loss:.4f}")
print(f"Average Val   Accuracy: {avg_val_acc:.2f}%")


# Plot results

# Plot Loss
plt.figure()
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title("Loss Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid()
plt.savefig("loss_curve.png")

# Plot Accuracy
plt.figure()
plt.plot(train_accuracies, label='Train Accuracy')
plt.plot(val_accuracies, label='Validation Accuracy')
plt.title("Accuracy Curve")
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.legend()
plt.grid()
plt.savefig("accuracy_curve.png")

print("Plots saved: loss_curve.png and accuracy_curve.png")



Epoch 1/200


                                                              

Train Loss: 2.0396 | Accuracy: 24.85%
Val   Loss: 1.7676 | Accuracy: 33.54%
Saved best model.

Epoch 2/200


                                                              

Train Loss: 1.7498 | Accuracy: 34.85%
Val   Loss: 1.5071 | Accuracy: 43.76%
Saved best model.

Epoch 3/200


                                                             

Train Loss: 1.6050 | Accuracy: 41.10%
Val   Loss: 1.3659 | Accuracy: 49.36%
Saved best model.

Epoch 4/200


                                                             

Train Loss: 1.4891 | Accuracy: 45.83%
Val   Loss: 1.2985 | Accuracy: 52.42%
Saved best model.

Epoch 5/200


                                                             

Train Loss: 1.3958 | Accuracy: 49.56%
Val   Loss: 1.2659 | Accuracy: 54.90%
Saved best model.

Epoch 6/200


                                                             

Train Loss: 1.3235 | Accuracy: 52.52%
Val   Loss: 1.1097 | Accuracy: 59.46%
Saved best model.

Epoch 7/200


                                                             

Train Loss: 1.2693 | Accuracy: 55.04%
Val   Loss: 1.0885 | Accuracy: 60.87%
Saved best model.

Epoch 8/200


                                                             

Train Loss: 1.2190 | Accuracy: 56.45%
Val   Loss: 1.0369 | Accuracy: 64.03%
Saved best model.

Epoch 9/200


                                                             

Train Loss: 1.1686 | Accuracy: 58.73%
Val   Loss: 0.9775 | Accuracy: 65.44%
Saved best model.

Epoch 10/200


                                                             

Train Loss: 1.1354 | Accuracy: 59.97%
Val   Loss: 0.9117 | Accuracy: 67.75%
Saved best model.

Epoch 11/200


                                                             

Train Loss: 1.1005 | Accuracy: 61.45%
Val   Loss: 0.8212 | Accuracy: 71.11%
Saved best model.

Epoch 12/200


                                                             

Train Loss: 1.0675 | Accuracy: 62.52%
Val   Loss: 0.8320 | Accuracy: 71.11%
No improvement for 1 epochs.

Epoch 13/200


                                                             

Train Loss: 1.0302 | Accuracy: 63.94%
Val   Loss: 0.8116 | Accuracy: 71.63%
Saved best model.

Epoch 14/200


                                                             

Train Loss: 0.9998 | Accuracy: 64.94%
Val   Loss: 0.7568 | Accuracy: 73.23%
Saved best model.

Epoch 15/200


                                                             

Train Loss: 0.9727 | Accuracy: 66.35%
Val   Loss: 0.7922 | Accuracy: 72.43%
No improvement for 1 epochs.

Epoch 16/200


                                                             

Train Loss: 0.9527 | Accuracy: 66.85%
Val   Loss: 0.7160 | Accuracy: 74.93%
Saved best model.

Epoch 17/200


                                                             

Train Loss: 0.9272 | Accuracy: 67.66%
Val   Loss: 0.6931 | Accuracy: 75.38%
Saved best model.

Epoch 18/200


                                                             

Train Loss: 0.9078 | Accuracy: 68.69%
Val   Loss: 0.6726 | Accuracy: 76.60%
Saved best model.

Epoch 19/200


                                                             

Train Loss: 0.8890 | Accuracy: 69.10%
Val   Loss: 0.7249 | Accuracy: 74.64%
No improvement for 1 epochs.

Epoch 20/200


                                                             

Train Loss: 0.8717 | Accuracy: 69.88%
Val   Loss: 0.6726 | Accuracy: 76.49%
No improvement for 2 epochs.

Epoch 21/200


                                                             

Train Loss: 0.8444 | Accuracy: 70.94%
Val   Loss: 0.6674 | Accuracy: 76.26%
No improvement for 3 epochs.

Epoch 22/200


                                                             

Train Loss: 0.8352 | Accuracy: 71.20%
Val   Loss: 0.6455 | Accuracy: 77.55%
Saved best model.

Epoch 23/200


                                                             

Train Loss: 0.8150 | Accuracy: 71.79%
Val   Loss: 0.6004 | Accuracy: 79.30%
Saved best model.

Epoch 24/200


                                                             

Train Loss: 0.8039 | Accuracy: 72.37%
Val   Loss: 0.5660 | Accuracy: 80.37%
Saved best model.

Epoch 25/200


                                                             

Train Loss: 0.7857 | Accuracy: 72.93%
Val   Loss: 0.5718 | Accuracy: 79.92%
No improvement for 1 epochs.

Epoch 26/200


                                                             

Train Loss: 0.7729 | Accuracy: 73.54%
Val   Loss: 0.5579 | Accuracy: 80.54%
Saved best model.

Epoch 27/200


                                                             

Train Loss: 0.7604 | Accuracy: 73.76%
Val   Loss: 0.5560 | Accuracy: 80.92%
Saved best model.

Epoch 28/200


                                                             

Train Loss: 0.7509 | Accuracy: 74.09%
Val   Loss: 0.5573 | Accuracy: 80.51%
No improvement for 1 epochs.

Epoch 29/200


                                                             

Train Loss: 0.7451 | Accuracy: 74.06%
Val   Loss: 0.5207 | Accuracy: 81.97%
Saved best model.

Epoch 30/200


                                                             

Train Loss: 0.7309 | Accuracy: 74.90%
Val   Loss: 0.5177 | Accuracy: 81.97%
No improvement for 1 epochs.

Epoch 31/200


                                                             

Train Loss: 0.7194 | Accuracy: 75.27%
Val   Loss: 0.5105 | Accuracy: 82.49%
Saved best model.

Epoch 32/200


                                                             

Train Loss: 0.7066 | Accuracy: 75.69%
Val   Loss: 0.5379 | Accuracy: 81.12%
No improvement for 1 epochs.

Epoch 33/200


                                                             

Train Loss: 0.7066 | Accuracy: 75.33%
Val   Loss: 0.5102 | Accuracy: 82.45%
No improvement for 2 epochs.

Epoch 34/200


                                                             

Train Loss: 0.6918 | Accuracy: 76.09%
Val   Loss: 0.4966 | Accuracy: 82.57%
Saved best model.

Epoch 35/200


                                                             

Train Loss: 0.6804 | Accuracy: 76.59%
Val   Loss: 0.5447 | Accuracy: 81.53%
No improvement for 1 epochs.

Epoch 36/200


                                                             

Train Loss: 0.6735 | Accuracy: 76.86%
Val   Loss: 0.4970 | Accuracy: 83.18%
Saved best model.

Epoch 37/200


                                                             

Train Loss: 0.6647 | Accuracy: 77.09%
Val   Loss: 0.5388 | Accuracy: 81.58%
No improvement for 1 epochs.

Epoch 38/200


                                                             

Train Loss: 0.6572 | Accuracy: 77.34%
Val   Loss: 0.4913 | Accuracy: 83.38%
Saved best model.

Epoch 39/200


                                                             

Train Loss: 0.6509 | Accuracy: 77.58%
Val   Loss: 0.5078 | Accuracy: 82.60%
No improvement for 1 epochs.

Epoch 40/200


                                                             

Train Loss: 0.6420 | Accuracy: 77.99%
Val   Loss: 0.4850 | Accuracy: 84.19%
Saved best model.

Epoch 41/200


                                                             

Train Loss: 0.6329 | Accuracy: 78.26%
Val   Loss: 0.4594 | Accuracy: 84.62%
Saved best model.

Epoch 42/200


                                                             

Train Loss: 0.6264 | Accuracy: 78.36%
Val   Loss: 0.4754 | Accuracy: 83.82%
No improvement for 1 epochs.

Epoch 43/200


                                                             

Train Loss: 0.6115 | Accuracy: 78.84%
Val   Loss: 0.4493 | Accuracy: 84.70%
Saved best model.

Epoch 44/200


                                                             

Train Loss: 0.6130 | Accuracy: 78.93%
Val   Loss: 0.4410 | Accuracy: 85.10%
Saved best model.

Epoch 45/200


                                                             

Train Loss: 0.6064 | Accuracy: 79.14%
Val   Loss: 0.4460 | Accuracy: 85.11%
Saved best model.

Epoch 46/200


                                                             

Train Loss: 0.5917 | Accuracy: 79.61%
Val   Loss: 0.4363 | Accuracy: 85.36%
Saved best model.

Epoch 47/200


                                                             

Train Loss: 0.5932 | Accuracy: 79.56%
Val   Loss: 0.4329 | Accuracy: 85.47%
Saved best model.

Epoch 48/200


                                                             

Train Loss: 0.5811 | Accuracy: 79.97%
Val   Loss: 0.4406 | Accuracy: 84.94%
No improvement for 1 epochs.

Epoch 49/200


                                                             

Train Loss: 0.5815 | Accuracy: 80.04%
Val   Loss: 0.4623 | Accuracy: 84.30%
No improvement for 2 epochs.

Epoch 50/200


                                                             

Train Loss: 0.5752 | Accuracy: 79.99%
Val   Loss: 0.4414 | Accuracy: 85.29%
No improvement for 3 epochs.

Epoch 51/200


                                                             

Train Loss: 0.5763 | Accuracy: 80.18%
Val   Loss: 0.4434 | Accuracy: 85.32%
No improvement for 4 epochs.

Epoch 52/200


                                                             

Train Loss: 0.5621 | Accuracy: 80.84%
Val   Loss: 0.4481 | Accuracy: 84.76%
No improvement for 5 epochs.

Epoch 53/200


                                                             

Train Loss: 0.5614 | Accuracy: 80.65%
Val   Loss: 0.4367 | Accuracy: 85.06%
No improvement for 6 epochs.

Epoch 54/200


                                                             

Train Loss: 0.5584 | Accuracy: 80.72%
Val   Loss: 0.4282 | Accuracy: 85.71%
Saved best model.

Epoch 55/200


                                                             

Train Loss: 0.5472 | Accuracy: 81.24%
Val   Loss: 0.4186 | Accuracy: 86.04%
Saved best model.

Epoch 56/200


                                                             

Train Loss: 0.5426 | Accuracy: 81.23%
Val   Loss: 0.4106 | Accuracy: 86.31%
Saved best model.

Epoch 57/200


                                                             

Train Loss: 0.5330 | Accuracy: 81.57%
Val   Loss: 0.4336 | Accuracy: 85.75%
No improvement for 1 epochs.

Epoch 58/200


                                                             

Train Loss: 0.5289 | Accuracy: 81.66%
Val   Loss: 0.4403 | Accuracy: 85.71%
No improvement for 2 epochs.

Epoch 59/200


                                                             

Train Loss: 0.5323 | Accuracy: 81.70%
Val   Loss: 0.4251 | Accuracy: 85.83%
No improvement for 3 epochs.

Epoch 60/200


                                                             

Train Loss: 0.5193 | Accuracy: 82.10%
Val   Loss: 0.4121 | Accuracy: 86.27%
No improvement for 4 epochs.

Epoch 61/200


                                                             

Train Loss: 0.5187 | Accuracy: 82.07%
Val   Loss: 0.4057 | Accuracy: 86.43%
Saved best model.

Epoch 62/200


                                                             

Train Loss: 0.5160 | Accuracy: 82.26%
Val   Loss: 0.4370 | Accuracy: 85.56%
No improvement for 1 epochs.

Epoch 63/200


                                                             

Train Loss: 0.5058 | Accuracy: 82.53%
Val   Loss: 0.4128 | Accuracy: 86.45%
Saved best model.

Epoch 64/200


                                                             

Train Loss: 0.5066 | Accuracy: 82.45%
Val   Loss: 0.3984 | Accuracy: 86.47%
Saved best model.

Epoch 65/200


                                                             

Train Loss: 0.4972 | Accuracy: 82.71%
Val   Loss: 0.4265 | Accuracy: 86.09%
No improvement for 1 epochs.

Epoch 66/200


                                                             

Train Loss: 0.4978 | Accuracy: 82.80%
Val   Loss: 0.4118 | Accuracy: 86.62%
Saved best model.

Epoch 67/200


                                                             

Train Loss: 0.4951 | Accuracy: 83.00%
Val   Loss: 0.3950 | Accuracy: 87.11%
Saved best model.

Epoch 68/200


                                                             

Train Loss: 0.4857 | Accuracy: 83.34%
Val   Loss: 0.4017 | Accuracy: 86.89%
No improvement for 1 epochs.

Epoch 69/200


                                                             

Train Loss: 0.4825 | Accuracy: 83.32%
Val   Loss: 0.4368 | Accuracy: 85.74%
No improvement for 2 epochs.

Epoch 70/200


                                                             

Train Loss: 0.4834 | Accuracy: 83.33%
Val   Loss: 0.4104 | Accuracy: 86.34%
No improvement for 3 epochs.

Epoch 71/200


                                                             

Train Loss: 0.4744 | Accuracy: 83.53%
Val   Loss: 0.3958 | Accuracy: 86.98%
No improvement for 4 epochs.

Epoch 72/200


                                                             

Train Loss: 0.4705 | Accuracy: 83.84%
Val   Loss: 0.4179 | Accuracy: 86.19%
No improvement for 5 epochs.

Epoch 73/200


                                                             

Train Loss: 0.4643 | Accuracy: 83.88%
Val   Loss: 0.3882 | Accuracy: 87.10%
No improvement for 6 epochs.

Epoch 74/200


                                                             

Train Loss: 0.4537 | Accuracy: 84.41%
Val   Loss: 0.3931 | Accuracy: 87.51%
Saved best model.

Epoch 75/200


                                                             

Train Loss: 0.4543 | Accuracy: 84.33%
Val   Loss: 0.4146 | Accuracy: 86.53%
No improvement for 1 epochs.

Epoch 76/200


                                                             

Train Loss: 0.4577 | Accuracy: 84.15%
Val   Loss: 0.4236 | Accuracy: 86.49%
No improvement for 2 epochs.

Epoch 77/200


                                                             

Train Loss: 0.4540 | Accuracy: 84.34%
Val   Loss: 0.3996 | Accuracy: 87.10%
No improvement for 3 epochs.

Epoch 78/200


                                                             

Train Loss: 0.4479 | Accuracy: 84.65%
Val   Loss: 0.3851 | Accuracy: 87.56%
Saved best model.

Epoch 79/200


                                                             

Train Loss: 0.4464 | Accuracy: 84.67%
Val   Loss: 0.4050 | Accuracy: 86.78%
No improvement for 1 epochs.

Epoch 80/200


                                                             

Train Loss: 0.4362 | Accuracy: 84.84%
Val   Loss: 0.3920 | Accuracy: 87.22%
No improvement for 2 epochs.

Epoch 81/200


                                                             

Train Loss: 0.4471 | Accuracy: 84.60%
Val   Loss: 0.4086 | Accuracy: 87.00%
No improvement for 3 epochs.

Epoch 82/200


                                                             

Train Loss: 0.4356 | Accuracy: 84.89%
Val   Loss: 0.3891 | Accuracy: 87.34%
No improvement for 4 epochs.

Epoch 83/200


                                                             

Train Loss: 0.4362 | Accuracy: 84.97%
Val   Loss: 0.4045 | Accuracy: 87.20%
No improvement for 5 epochs.

Epoch 84/200


                                                             

Train Loss: 0.4271 | Accuracy: 85.36%
Val   Loss: 0.3995 | Accuracy: 87.18%
No improvement for 6 epochs.

Epoch 85/200


                                                             

Train Loss: 0.4246 | Accuracy: 85.30%
Val   Loss: 0.3947 | Accuracy: 87.64%
Saved best model.

Epoch 86/200


                                                             

Train Loss: 0.4250 | Accuracy: 85.30%
Val   Loss: 0.3919 | Accuracy: 87.52%
No improvement for 1 epochs.

Epoch 87/200


                                                             

Train Loss: 0.4152 | Accuracy: 85.78%
Val   Loss: 0.3901 | Accuracy: 87.69%
Saved best model.

Epoch 88/200


                                                             

Train Loss: 0.4188 | Accuracy: 85.58%
Val   Loss: 0.3867 | Accuracy: 87.60%
No improvement for 1 epochs.

Epoch 89/200


                                                             

Train Loss: 0.4105 | Accuracy: 85.96%
Val   Loss: 0.4204 | Accuracy: 86.85%
No improvement for 2 epochs.

Epoch 90/200


                                                             

Train Loss: 0.4080 | Accuracy: 85.90%
Val   Loss: 0.3898 | Accuracy: 87.83%
Saved best model.

Epoch 91/200


                                                             

Train Loss: 0.4082 | Accuracy: 85.95%
Val   Loss: 0.4036 | Accuracy: 87.41%
No improvement for 1 epochs.

Epoch 92/200


                                                             

Train Loss: 0.4065 | Accuracy: 85.82%
Val   Loss: 0.3891 | Accuracy: 86.95%
No improvement for 2 epochs.

Epoch 93/200


                                                             

Train Loss: 0.4037 | Accuracy: 86.15%
Val   Loss: 0.3740 | Accuracy: 88.08%
Saved best model.

Epoch 94/200


                                                             

Train Loss: 0.3993 | Accuracy: 85.92%
Val   Loss: 0.3759 | Accuracy: 87.74%
No improvement for 1 epochs.

Epoch 95/200


                                                             

Train Loss: 0.3974 | Accuracy: 86.42%
Val   Loss: 0.3935 | Accuracy: 87.64%
No improvement for 2 epochs.

Epoch 96/200


                                                             

Train Loss: 0.3934 | Accuracy: 86.18%
Val   Loss: 0.3925 | Accuracy: 87.77%
No improvement for 3 epochs.

Epoch 97/200


                                                             

Train Loss: 0.3885 | Accuracy: 86.63%
Val   Loss: 0.3891 | Accuracy: 87.76%
No improvement for 4 epochs.

Epoch 98/200


                                                             

Train Loss: 0.3885 | Accuracy: 86.58%
Val   Loss: 0.4115 | Accuracy: 87.44%
No improvement for 5 epochs.

Epoch 99/200


                                                             

Train Loss: 0.3833 | Accuracy: 86.72%
Val   Loss: 0.4027 | Accuracy: 87.24%
No improvement for 6 epochs.

Epoch 100/200


                                                             

Train Loss: 0.3854 | Accuracy: 86.62%
Val   Loss: 0.3967 | Accuracy: 88.07%
No improvement for 7 epochs.

Epoch 101/200


                                                             

Train Loss: 0.3811 | Accuracy: 86.72%
Val   Loss: 0.4222 | Accuracy: 87.03%
No improvement for 8 epochs.

Epoch 102/200


                                                             

Train Loss: 0.3767 | Accuracy: 86.82%
Val   Loss: 0.4062 | Accuracy: 87.86%
No improvement for 9 epochs.

Epoch 103/200


                                                             

Train Loss: 0.3800 | Accuracy: 86.83%
Val   Loss: 0.4049 | Accuracy: 87.62%
No improvement for 10 epochs.

Epoch 104/200


                                                             

Train Loss: 0.3738 | Accuracy: 87.04%
Val   Loss: 0.3906 | Accuracy: 87.83%
No improvement for 11 epochs.

Epoch 105/200


                                                             

Train Loss: 0.3764 | Accuracy: 86.93%
Val   Loss: 0.3974 | Accuracy: 87.64%
No improvement for 12 epochs.

Epoch 106/200


                                                             

Train Loss: 0.3704 | Accuracy: 87.21%
Val   Loss: 0.3942 | Accuracy: 88.37%
Saved best model.

Epoch 107/200


                                                             

Train Loss: 0.3705 | Accuracy: 87.15%
Val   Loss: 0.3893 | Accuracy: 88.10%
No improvement for 1 epochs.

Epoch 108/200


                                                             

Train Loss: 0.3698 | Accuracy: 87.17%
Val   Loss: 0.3681 | Accuracy: 88.52%
Saved best model.

Epoch 109/200


                                                             

Train Loss: 0.3610 | Accuracy: 87.36%
Val   Loss: 0.4587 | Accuracy: 86.36%
No improvement for 1 epochs.

Epoch 110/200


                                                             

Train Loss: 0.3568 | Accuracy: 87.60%
Val   Loss: 0.3810 | Accuracy: 88.08%
No improvement for 2 epochs.

Epoch 111/200


                                                             

Train Loss: 0.3621 | Accuracy: 87.44%
Val   Loss: 0.4158 | Accuracy: 87.17%
No improvement for 3 epochs.

Epoch 112/200


                                                             

Train Loss: 0.3620 | Accuracy: 87.45%
Val   Loss: 0.3891 | Accuracy: 87.97%
No improvement for 4 epochs.

Epoch 113/200


                                                             

Train Loss: 0.3562 | Accuracy: 87.73%
Val   Loss: 0.3886 | Accuracy: 87.48%
No improvement for 5 epochs.

Epoch 114/200


                                                             

Train Loss: 0.3544 | Accuracy: 87.83%
Val   Loss: 0.3772 | Accuracy: 88.59%
Saved best model.

Epoch 115/200


                                                             

Train Loss: 0.3520 | Accuracy: 87.83%
Val   Loss: 0.3900 | Accuracy: 88.30%
No improvement for 1 epochs.

Epoch 116/200


                                                             

Train Loss: 0.3461 | Accuracy: 88.06%
Val   Loss: 0.3860 | Accuracy: 88.11%
No improvement for 2 epochs.

Epoch 117/200


                                                             

Train Loss: 0.3499 | Accuracy: 87.81%
Val   Loss: 0.4000 | Accuracy: 87.97%
No improvement for 3 epochs.

Epoch 118/200


                                                             

Train Loss: 0.3460 | Accuracy: 88.13%
Val   Loss: 0.4039 | Accuracy: 87.81%
No improvement for 4 epochs.

Epoch 119/200


                                                             

Train Loss: 0.3402 | Accuracy: 88.30%
Val   Loss: 0.4034 | Accuracy: 88.25%
No improvement for 5 epochs.

Epoch 120/200


                                                             

Train Loss: 0.3438 | Accuracy: 88.05%
Val   Loss: 0.4144 | Accuracy: 87.73%
No improvement for 6 epochs.

Epoch 121/200


                                                             

Train Loss: 0.3395 | Accuracy: 88.31%
Val   Loss: 0.4120 | Accuracy: 88.13%
No improvement for 7 epochs.

Epoch 122/200


                                                             

Train Loss: 0.3368 | Accuracy: 88.37%
Val   Loss: 0.4112 | Accuracy: 88.01%
No improvement for 8 epochs.

Epoch 123/200


                                                             

Train Loss: 0.3368 | Accuracy: 88.37%
Val   Loss: 0.4167 | Accuracy: 87.55%
No improvement for 9 epochs.

Epoch 124/200


                                                             

Train Loss: 0.3331 | Accuracy: 88.47%
Val   Loss: 0.3635 | Accuracy: 89.04%
Saved best model.

Epoch 125/200


                                                             

Train Loss: 0.3319 | Accuracy: 88.60%
Val   Loss: 0.4032 | Accuracy: 88.19%
No improvement for 1 epochs.

Epoch 126/200


                                                             

Train Loss: 0.3292 | Accuracy: 88.64%
Val   Loss: 0.4086 | Accuracy: 87.77%
No improvement for 2 epochs.

Epoch 127/200


                                                             

Train Loss: 0.3278 | Accuracy: 88.78%
Val   Loss: 0.3850 | Accuracy: 87.71%
No improvement for 3 epochs.

Epoch 128/200


                                                             

Train Loss: 0.3248 | Accuracy: 88.78%
Val   Loss: 0.4119 | Accuracy: 87.44%
No improvement for 4 epochs.

Epoch 129/200


                                                             

Train Loss: 0.3250 | Accuracy: 88.76%
Val   Loss: 0.3939 | Accuracy: 88.00%
No improvement for 5 epochs.

Epoch 130/200


                                                             

Train Loss: 0.3214 | Accuracy: 88.84%
Val   Loss: 0.4451 | Accuracy: 87.35%
No improvement for 6 epochs.

Epoch 131/200


                                                             

Train Loss: 0.3189 | Accuracy: 88.81%
Val   Loss: 0.4017 | Accuracy: 88.47%
No improvement for 7 epochs.

Epoch 132/200


                                                             

Train Loss: 0.3200 | Accuracy: 88.85%
Val   Loss: 0.4103 | Accuracy: 87.74%
No improvement for 8 epochs.

Epoch 133/200


                                                             

Train Loss: 0.3154 | Accuracy: 88.96%
Val   Loss: 0.3952 | Accuracy: 88.20%
No improvement for 9 epochs.

Epoch 134/200


                                                             

Train Loss: 0.3186 | Accuracy: 89.18%
Val   Loss: 0.4246 | Accuracy: 87.90%
No improvement for 10 epochs.

Epoch 135/200


                                                             

Train Loss: 0.3156 | Accuracy: 89.08%
Val   Loss: 0.4143 | Accuracy: 87.96%
No improvement for 11 epochs.

Epoch 136/200


                                                             

Train Loss: 0.3114 | Accuracy: 89.14%
Val   Loss: 0.3982 | Accuracy: 88.17%
No improvement for 12 epochs.

Epoch 137/200


                                                             

Train Loss: 0.3061 | Accuracy: 89.45%
Val   Loss: 0.4223 | Accuracy: 88.21%
No improvement for 13 epochs.

Epoch 138/200


                                                             

Train Loss: 0.3138 | Accuracy: 89.18%
Val   Loss: 0.3913 | Accuracy: 88.41%
No improvement for 14 epochs.

Epoch 139/200


                                                             

Train Loss: 0.3106 | Accuracy: 89.22%
Val   Loss: 0.3988 | Accuracy: 88.41%
No improvement for 15 epochs.

Epoch 140/200


                                                             

Train Loss: 0.3074 | Accuracy: 89.28%
Val   Loss: 0.3878 | Accuracy: 88.85%
No improvement for 16 epochs.

Epoch 141/200


                                                           

KeyboardInterrupt: 

In [None]:
# ### Data loading and augmentation from test_train.py ###
# # Added Normalize with the standard CIFAR-10 statistics
# transform_train = transforms.Compose([
#     transforms.RandomCrop(32, padding=4),
#     transforms.RandomHorizontalFlip(),
#     transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
# ])
# transform_test = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
# ])

# # Downloading and creating the Datasets here
# train_dataset = torchvision.datasets.CIFAR10(
#     root='./data', train=True, download=True, transform=transform_train
# )
# test_dataset = torchvision.datasets.CIFAR10(
#     root='./data', train=False, download=True, transform=transform_test
# )

# # Creating DataLoaders here
# batch_size = 128
# train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
# test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
# ### Training utilities from test_train.py (with fixed method names) ###
# class Accumulator:
#     """For accumulating sums over n variables."""
#     def __init__(self, n):
#         self.data = [0.0] * n
#     def add(self, *args):
#         self.data = [a + float(b) for a, b in zip(self.data, args)]
#     def reset(self):
#         self.data = [0.0] * len(self.data)
#     def __getitem__(self, idx):
#         return self.data[idx]

# def accuracy(y_hat, y):
#     """Compute the number of correct predictions."""
#     if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
#         y_hat = y_hat.argmax(axis=1)
#     cmp = (y_hat.type(y.dtype) == y)
#     return float(torch.sum(cmp))

# def evaluate_accuracy(net, data_iter, device): 
#     """Compute the accuracy for a model on a dataset."""
#     net.eval()
#     metric = Accumulator(2)  # No. of correct predictions, no. of predictions
#     with torch.no_grad():
#         for X, y in data_iter:
#             X, y = X.to(device), y.to(device)
#             metric.add(accuracy(net(X), y), y.numel())
#     return metric[0] / metric[1]

# def train_epoch(net, train_iter, loss, optimizer, device):
#     """Training function for one epoch."""
#     net.train()
#     metric = Accumulator(3)  # train_loss, train_acc, num_examples
#     for X, y in train_iter:
#         X, y = X.to(device), y.to(device)
#         optimizer.zero_grad()
#         y_hat = net(X)
#         l = loss(y_hat, y)
#         l.backward()
#         optimizer.step()
#         metric.add(float(l) * len(y), accuracy(y_hat, y), y.numel())
#     return metric[0] / metric[2], metric[1] / metric[2]

# def train_model(net, train_iter, test_iter, loss, optimizer, num_epochs, device):
#     """Train and evaluate a model."""
#     print('-' * 50)
#     print('Starting training...')
    
#     train_losses = []
#     train_accs = []
#     test_accs = []
    
#     for epoch in range(num_epochs):
#         train_metrics = train_epoch(net, train_iter, loss, optimizer, device)
#         test_acc = evaluate_accuracy(net, test_iter, device)
#         train_loss, train_acc = train_metrics
        
#         train_losses.append(train_loss)
#         train_accs.append(train_acc)
#         test_accs.append(test_acc)
        
#         print(f'Epoch {epoch + 1}:')
#         print(f'  Train loss: {train_loss:.3f}')
#         print(f'  Train accuracy: {train_acc:.3f} ({train_acc*100:.1f}%)')
#         print(f'  Test accuracy:  {test_acc:.3f} ({test_acc*100:.1f}%)')
    
#     # Plot metrics
#     plt.figure(figsize=(12, 4))
#     plt.subplot(1, 2, 1)
#     plt.plot(train_losses, label='train loss')
#     plt.xlabel('epoch')
#     plt.ylabel('loss')
#     plt.legend()
    
#     plt.subplot(1, 2, 2)
#     plt.plot([x*100 for x in train_accs], label='train acc (%)')
#     plt.plot([x*100 for x in test_accs], label='test acc (%)')
#     plt.xlabel('epoch')
#     plt.ylabel('accuracy (%)')
#     plt.legend()
#     plt.savefig('training_results.png')
#     plt.show()
    
#     return train_losses, train_accs, test_accs

# ### Main execution block ###
# if __name__ == '__main__':
#     # Device configuration
#     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#     if device.type == 'cuda':
#         print('GPU training enabled')  # Simplified device info
    
#     # Create your model from mymodel.py
#     model = Model(
#         stem_channels=128,
#         hidden_channels=128,
#         num_blocks=3,
#         k=4,
#         r=4,
#         num_classes=10,
#         use_mlp=True
#     ).to(device)
    
#     # Define loss function and optimizer
#     criterion = nn.CrossEntropyLoss()
#     optimizer = optim.Adam(model.parameters(), lr=0.001)
    
#     # Train the model
#     train_losses, train_accs, test_accs = train_model(
#         model, train_loader, test_loader, criterion, optimizer, num_epochs=25, device=device
#     )
    
#     # Save model
#     torch.save(model.state_dict(), "best_model.pth")
#     print("Model saved as best_model.pth")
    
#     # Print final metrics
#     print("\nFinal Metrics:")
#     print(f"Final train loss: {train_losses[-1]:.4f}")
#     print(f"Final train accuracy: {train_accs[-1]*100:.2f}%")
#     print(f"Final test accuracy: {test_accs[-1]*100:.2f}%")

Averages:


*   Train Loss: 1.7223, Accuracy: 38.21%, Validation Loss: 1.7194, Accuracy: 38.25%
*   Train Loss: 1.7106, Accuracy: 34.80%, Validation Loss: 1.7984, Accuracy: 35.98%

*   Train Loss: 1.8150, Accuracy: 34.54%, Val   Loss: 1.7848 Accuracy: 36.13%

*   Train Loss: 1.9579, Accuracy: 28.84%, Val   Loss: 1.8691, Accuracy: 32.51%
*   Train Loss: 1.9712, Accuracy: 27.54%, Val   Loss: 1.9107 ,Accuracy: 30.24%
*   Train Loss: 2.1609, Accuracy: 16.97%, Val   Loss: 2.1343, Accuracy: 18.28%
*   Train Loss: 1.9798, Accuracy: 27.24%, Val   Loss: 1.9312, Accuracy: 29.65%
*   Train Loss: 1.4970, Accuracy: 44.11%, Val   Loss: 1.3675, Accuracy: 48.65%
*   Train Loss: 1.3648, Accuracy: 51.66%, Val   Loss: 1.2319, Val   Accuracy: 55.56%
*   Train Loss: 0.7390, Accuracy: 74.50%, Val   Loss: 0.8193, Val   Accuracy: 72.47%
*  Train Loss: 0.7262, Accuracy: 74.99%, Val   Loss: 0.8539, Accuracy: 71.57%
*   Train Loss: 0.6575, Accuracy: 76.88%, Val   Loss: 0.7876, Accuracy: 73.31%
*   Train Loss: 0.6564, Accuracy: 76.91%, Val   Loss: 0.7731, Accuracy: 73.89%
*   Train Loss: 0.6747, Accuracy: 76.24%, Val   Loss: 0.7645, Accuracy: 73.83%
*   Train Loss: 0.7119, Accuracy: 74.75%, Val   Loss: 0.8092, Accuracy: 72.01%
* Train Loss: 1.0820, Accuracy: 61.24%,  Val   Loss: 0.9241, Accuracy: 66.86%
* Train Accuracy: 62.14%, Val   Loss: 0.8826, Accuracy: 68.16%
* Train Loss: 1.1482, Accuracy: 59.56%, Val   Loss: 0.9446, Accuracy: 66.12%
* Train Loss: 1.0876, Accuracy: 61.30%, Val   Loss: 0.9327,Accuracy: 66.59%
* Train Loss: 1.1938, Accuracy: 57.16%, Val   Loss: 1.0265, Accuracy: 62.66%
* Train Loss: 1.3364, Accuracy: 51.05%, Val   Loss: 1.1814, Accuracy: 56.93%