In [1]:
import torch
import torch.nn as nn
import torchvision
from tqdm import tqdm

In [2]:
class VGG16_A(nn.Module):
  def __init__(self, in_channels = 3, classes = 10):
    super().__init__()
    self.feature = nn.Sequential(
        #conv 1
        nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),

        #conv 2
        nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),

        #conv 3
        nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),

        #conv 4
        nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),

        #conv 5
        nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(512),
        nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(512),
        nn.MaxPool2d(kernel_size=2, stride=2),
    )
    self.classifier = nn.Sequential(
        #input size is 224x224, I have 5 MaxPool2d which halves the dims -> 224/2^5 -> 224/32 = 7 --> last conv layer will output a 7x7 feature
        nn.Linear(in_features=512*7*7, out_features=4096),
        nn.ReLU(),
        nn.Dropout(),
        nn.Linear(in_features=4096, out_features=4096),
        nn.ReLU(),
        nn.Dropout(),
        nn.Linear(in_features=4096, out_features=classes),
        nn.ReLU(),
    )

    def init_weights(m):
      if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, mean=0, std=0.01)  # Initialize weights from a normal distribution
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)  # Initialize biases to zero

    self.feature.apply(init_weights)
    self.classifier.apply(init_weights)


  def forward(self, x):
    x = self.feature(x)
    x = torch.flatten(x,1)
    x = self.classifier(x)
    return x

### Training

In [None]:
# Load CIFAR-10 train dataset
data_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224,224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
cifar10_train_dataset = torchvision.datasets.CIFAR10(root='./data', download=True, train=True, transform=data_transforms)
train_dataloader = torch.utils.data.DataLoader(cifar10_train_dataset,
                                         batch_size = 32,
                                         shuffle=True)

# Load AlexNet
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = VGG16_A().to(device)
model.train()

# Define Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)

epoch_nums = 3 # small value, just for a working demonstration
# Train the network
for epoch in range(epoch_nums):

  current_loss = 0.0
  for i, data in tqdm(enumerate(train_dataloader, start = 0), unit="batch", total=len(train_dataloader), desc=f"Epoch {epoch}"):
    # get inputs -> data is a list of [inputs, labels]
    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device)

    # reset gradients
    optimizer.zero_grad()

    # forward pass + backward + optimize
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    # compute statistics
    current_loss += loss.item()
    if i % 100 == 99:    # print every 100 mini-batches
            print(f'current loss: {current_loss / 100:.3f}')
            current_loss = 0.0

print("Training done! Saving trained model to: './trained_model.pth'")
torch.save(model.state_dict(), './trained_model.pth')
print("Saved.")






### Testing

In [None]:
# Load CIFAR-10 test dataset
data_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224,224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
cifar10_test_dataset = torchvision.datasets.CIFAR10(root='./data', download=True, train=False, transform=data_transforms)
test_dataloader = torch.utils.data.DataLoader(cifar10_test_dataset,
                                         batch_size = 128,
                                         shuffle=False)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = VGG16_A().to(device)
model.load_state_dict(torch.load("./trained_model.pth"))

correct = 0
total = 0
with torch.no_grad():
  for data in tqdm(test_dataloader):
    images, labels = data
    images, labels = images.to(device), labels.to(device)

    outputs = model(images)

    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

print(f'\nAccuracy of AlexNet on CIFAR10: {100 * correct // total} %')