In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import math

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # First convolutional layer: 3 input channels (RGB) -> 32 output channels
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        # Second convolutional layer: 32 -> 64 channels
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        # Max pooling layer to reduce spatial dimensions by a factor of 2
        self.pool = nn.MaxPool2d(2, 2)
        # Fully connected layers
        # After two pooling operations, the 32x32 image becomes 8x8 (32 -> 16 -> 8)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)  # CIFAR10 has 10 classes

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # Apply conv1, relu, then pool
        x = self.pool(F.relu(self.conv2(x)))  # Apply conv2, relu, then pool
        x = x.view(-1, 64 * 8 * 8)            # Flatten the tensor for the fully connected layers
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
model = CNN().to(device)
for name, parameter in model.named_parameters():
  print(name)

In [None]:
class CustomLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super(CustomLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        # Create weight and bias as learnable parameters
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.bias = nn.Parameter(torch.Tensor(out_features))
        self.reset_parameters()

    def reset_parameters(self):
        # Initialize weights with Kaiming uniform initialization
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        # Calculate bias bounds and initialize uniformly
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        # Perform the linear transformation: input * weight^T + bias
        return F.linear(input, self.weight, self.bias)

In [None]:
class CustomCNN(nn.Module):
    def __init__(self):
        super(CustomCNN, self).__init__()
        # Convolutional layers
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)  # RGB -> 32 channels
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # 32 -> 64 channels
        # Max pooling layer to downsample by a factor of 2
        self.pool = nn.MaxPool2d(2, 2)
        # Fully connected layers using our CustomLinear layer.
        # After two poolings, a 32x32 image becomes 8x8.
        self.fc1 = CustomLinear(64 * 8 * 8, 512)
        self.fc2 = CustomLinear(512, 10)  # CIFAR10 has 10 classes

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # Conv1 -> ReLU -> Pool
        x = self.pool(F.relu(self.conv2(x)))  # Conv2 -> ReLU -> Pool
        x = x.view(-1, 64 * 8 * 8)            # Flatten the tensor for fully connected layers
        x = F.relu(self.fc1(x))               # CustomLinear layer 1 with ReLU activation
        x = self.fc2(x)                       # CustomLinear layer 2 (output layer)
        return x

In [None]:
custom_model = CustomCNN().to(device)
for name, module in custom_model.named_modules():
  print(module)

In [None]:
def get_cifar10_dataloaders(batch_size=64):
  # Data transforms including normalization
  transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  ])

  # Download and load the training and test datasets
  trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
  testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

  trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
  testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

  return trainloader, testloader

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
def train(model, trainloader, criterion, optimizer, device):
  model.train()  # Set model to training mode
  for inputs, labels in trainloader:
      inputs, labels = inputs.to(device), labels.to(device)
      optimizer.zero_grad()       # Zero the gradients
      outputs = model(inputs)     # Forward pass
      loss = criterion(outputs, labels)
      loss.backward()             # Backward pass
      optimizer.step()            # Update parameters

In [None]:
def eval(model, testloader, criterion, device):
  model.eval()  # Set model to evaluation mode
  total_loss = 0.0
  correct = 0
  total = 0
  with torch.no_grad():
      for inputs, labels in testloader:
          inputs, labels = inputs.to(device), labels.to(device)
          outputs = model(inputs)
          loss = criterion(outputs, labels)
          total_loss += loss.item() * inputs.size(0)  # Accumulate loss over batch
          _, predicted = torch.max(outputs.data, 1)
          total += labels.size(0)
          correct += (predicted == labels).sum().item()
  avg_loss = total_loss / total
  accuracy = 100 * correct / total
  print(f'Evaluation loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')
  return avg_loss, accuracy

In [None]:
trainloader, testloader = get_cifar10_dataloaders()
num_epochs = 10
for epoch in range(num_epochs):
    print(f'Epoch {epoch + 1}/{num_epochs}')
    train(model, trainloader, criterion, optimizer, device)
    eval(model, testloader, criterion, device)

In [None]:
transform_vit = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

def get_cifar10_dataloaders_vit(batch_size=64):
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_vit)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_vit)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
    return trainloader, testloader

# Load a pretrained ViT model (e.g., vit_b_16) and replace its head for 10 CIFAR10 classes.
vit_model = models.vit_b_16(pretrained=True)
# vit_model.head = nn.Linear(vit_model.head.in_features, 10)
vit_model = vit_model.to(device)

"""
print("ViT Model parameters:")
for name, parameter in vit_model.named_parameters():
    print(name)
"""

vit_criterion = nn.CrossEntropyLoss()
vit_optimizer = optim.Adam(vit_model.parameters(), lr=0.001)

vit_trainloader, vit_testloader = get_cifar10_dataloaders_vit()

print("Finetuning ViT model on CIFAR10")
for epoch in range(num_epochs):
    print(f'ViT Epoch {epoch + 1}/{num_epochs}')
    train(vit_model, vit_trainloader, vit_criterion, vit_optimizer, device)
    eval(vit_model, vit_testloader, vit_criterion, device)