In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import numpy as np
import torch.backends.cudnn as cudnn
import copy
import math
from torch.autograd import Function
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split
import os

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 256

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

# Split training set for training and validation
train_size = int(0.8 * len(train_set))
val_size = len(train_set) - train_size
train_set, val_set = random_split(train_set, [train_size, val_size])

# DataLoader for validation set
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, drop_last = True)
train_loader =  DataLoader(train_set, batch_size=256, shuffle=False, drop_last = True)
test_loader =  DataLoader(test_set, batch_size=256, shuffle=False, drop_last = True)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
        )

        self.classifier = nn.Sequential(
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, 10),
            nn.Softmax(dim=1)
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                m.bias.data.zero_()

    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 512)
        x = self.classifier(x)
        return x

In [5]:
def test(model, testloader):
    correct = 0
    total = 0

    for images, label in testloader:
        images, label = images.to(device), label.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += label.size(0)
        correct += (predicted == label).sum().item()

    return correct / total

In [6]:
def train(model, epoches, criterion, optimizer):
    best_model_wts = None
    best_loss = float('inf')
    batch_num = 0
    warm_up_batch = 3

    for inputs, labels in train_loader:
        if (best_model_wts):
            model.load_state_dict(best_model_wts)

        inputs, labels = inputs.to(device), labels.to(device)
        prev_loss = float('inf')
        for epoch in range(epoches):
            model.train()
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for val_inputs, val_labels in val_loader:
                    val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
                    outputs = model(val_inputs)
                    batch_loss = criterion(outputs, val_labels)
                    val_loss += batch_loss.item()

                if val_loss < best_loss:
                    best_loss = val_loss
                    best_model_wts = copy.deepcopy(model.state_dict())

            print(f"Batch: {batch_num}, epoch: {epoch}, Train Loss: {loss.item()}, Val Loss: {val_loss}")
            if (prev_loss < val_loss and warm_up_batch < batch_num):
                break
            prev_loss = val_loss
        with torch.no_grad():
            test_acc = test(model, test_loader)
            
        print(f"epoch: {epoch}, Test Acc: {test_acc}")

        batch_num += 1

    return best_model_wts

In [7]:
model = VGG().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)

best_model_wts = train(model, 20, criterion, optimizer)

Batch: 0, epoch: 0, Train Loss: 2.300583839416504, Val Loss: 89.79979133605957
Batch: 0, epoch: 1, Train Loss: 2.2727200984954834, Val Loss: 89.8000717163086
Batch: 0, epoch: 2, Train Loss: 2.244594097137451, Val Loss: 89.79906725883484
Batch: 0, epoch: 3, Train Loss: 2.2109744548797607, Val Loss: 89.79881739616394
Batch: 0, epoch: 4, Train Loss: 2.187385320663452, Val Loss: 89.80052065849304
Batch: 0, epoch: 5, Train Loss: 2.1453003883361816, Val Loss: 89.80090165138245
Batch: 0, epoch: 6, Train Loss: 2.1085588932037354, Val Loss: 89.7977991104126
Batch: 0, epoch: 7, Train Loss: 2.0575807094573975, Val Loss: 89.79571390151978
Batch: 0, epoch: 8, Train Loss: 1.9971327781677246, Val Loss: 89.79208707809448
Batch: 0, epoch: 9, Train Loss: 1.9513640403747559, Val Loss: 89.78596496582031
Batch: 0, epoch: 10, Train Loss: 1.890354037284851, Val Loss: 89.77544331550598
Batch: 0, epoch: 11, Train Loss: 1.8344842195510864, Val Loss: 89.76152038574219
Batch: 0, epoch: 12, Train Loss: 1.783623576

In [9]:
with torch.no_grad():
    torch.save(model.state_dict(), 'cifar-10-baseline.pth')

In [11]:
model = VGG().to(device)
model.load_state_dict(torch.load('cifar-10-baseline.pth'))

test_acc = test(model, test_loader)

In [12]:
test_acc

0.5663060897435898