In [1]:
# (Bronte) Sihan Li, Cole Crescas 2023

In [2]:
import sys, os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torch.utils.data import random_split
from torch.autograd import Variable
import wandb

In [3]:
nclasses = 43 # GTSRB as 43 classes


# This model architecture is based on https://github.com/mmoraes-rafael/gtsrb_resnet/blob/master/code/model48.py
class ResNet(nn.Module):
    def __init__(self):
        super(ResNet, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv1_bn = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.conv2_bn = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.conv3_drop = nn.Dropout2d(p=0.2)
        self.conv3_bn = nn.BatchNorm2d(32)

        self.conv4 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv4_bn = nn.BatchNorm2d(64)
        self.conv5 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv5_bn = nn.BatchNorm2d(64)
        self.conv6 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv6_drop = nn.Dropout2d(p=0.2)
        self.conv6_bn = nn.BatchNorm2d(64)

        self.conv7 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv7_bn = nn.BatchNorm2d(128)
        self.conv8 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.conv8_bn = nn.BatchNorm2d(128)
        self.conv9 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.conv9_drop = nn.Dropout2d(p=0.2)
        self.conv9_bn = nn.BatchNorm2d(128)

        self.conv10 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv10_bn = nn.BatchNorm2d(256)
        self.conv11 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv11_bn = nn.BatchNorm2d(256)
        self.conv12 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv12_drop = nn.Dropout2d(p=0.2)
        self.conv12_bn = nn.BatchNorm2d(256)

        self.fc1 = nn.Linear(256 * 6 * 6 , 2048)
        self.fc1_bn = nn.BatchNorm1d(2048)
        self.fc2 = nn.Linear(2048, nclasses)

    def forward(self, inp):
        res = F.relu(self.conv1_bn(self.conv1(inp)))
        x =   F.relu(self.conv2_bn(self.conv2(res)))
        x =   self.conv3_drop(self.conv3(x))
        block1_out = F.relu(self.conv3_bn(F.max_pool2d(x + res, 2)))    

        res = F.relu(self.conv4_bn(self.conv4(block1_out)))
        x =   F.relu(self.conv5_bn(self.conv5(res)))
        x = self.conv6_drop(self.conv6(x))
        block2_out = F.relu(self.conv6_bn(F.max_pool2d(x + res, 2)))

        res = F.relu(self.conv7_bn(self.conv7(block2_out)))
        x =   F.relu(self.conv8_bn(self.conv8(res)))
        x = self.conv9_drop(self.conv9(x))
        block3_out = F.relu(self.conv9_bn(F.max_pool2d(x + res, 2)))

        res = F.relu(self.conv10_bn(self.conv10(block3_out)))
        x =   F.relu(self.conv11_bn(self.conv11(res)))
        x = F.relu(self.conv12_bn(self.conv12_drop(self.conv12(x + res))))
   
        x = x.view(-1, 256 * 6 * 6)
        x = F.relu(self.fc1_bn(self.fc1(x)))
        x = F.dropout(x, training=self.training, p=0.2)
        x = self.fc2(x)
        return F.log_softmax(x)


In [4]:
# Define datasets

batch_size = 32

data_transforms = transforms.Compose([
	transforms.Resize((48, 48)),    
    transforms.ToTensor(),
    transforms.Normalize((0.3337, 0.3064, 0.3171), ( 0.2672, 0.2564, 0.2629))
])

# Split training data into training and validation sets
data_dir = 'data/GTSRB_train/Final_Training/Images'
dataset = ImageFolder(data_dir, transform=data_transforms)

# Define the indices for the training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

# Randomly split the dataset into training and validation sets
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Define the data loaders for the training and validation sets
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=True)

In [5]:
# Verify the train loader and val loader
print("Train loader length: ", len(train_loader))
print("Val loader length: ", len(val_loader))
for i, (data, target) in enumerate(train_loader):
    print("Train data shape: ", data.shape)
    print("Train target shape: ", target.shape)
    break


Train loader length:  981
Val loader length:  246
Train data shape:  torch.Size([32, 3, 48, 48])
Train target shape:  torch.Size([32])


In [6]:
# Define parameters
first_epoch = 0
nepochs = 50
lr = 0.001
momentum = 0.9
wd = 1e-6
log_interval = 10

# Define the checkpoint directory
checkpoint_dir = 'checkpoints/'
os.makedirs(checkpoint_dir, exist_ok=True)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define the model
model = ResNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

# initialize wandb
experiment = wandb.init(project='gtsrb-resnet', resume='allow', anonymous='must')
experiment.config.update(
    dict(
        epochs=nepochs,
        batch_size=batch_size,
        learning_rate=lr,
        momentum=momentum,)
)


[34m[1mwandb[0m: Currently logged in as: [33mbronte[0m ([33mfire-dream[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
# Define training and validation functions
def train(epoch):
    avg_loss = 0
    steps = 0
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = Variable(data).to(device), Variable(target).to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

        steps += 1
        avg_loss += loss.item()
        if batch_idx % log_interval == 0:
            experiment.log(
                {'training loss' : loss.item(),
                 'steps' : steps,
                 'epoch' : epoch
                 }
            )
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

    return float(avg_loss) / steps

def validation(epoch):
    model.eval()
    validation_loss = 0
    correct = 0
    for data, target in val_loader:
        data, target = Variable(data, volatile=True).to(device), Variable(target).to(device)
        output = model(data)
        validation_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    validation_loss /= len(val_loader.dataset)
    experiment.log(
        {'validation loss' : validation_loss,
         'validation accuracy' : 100. * correct / len(val_loader.dataset),
         'epoch' : epoch,
         }
    )
    print('\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        validation_loss, correct, len(val_loader.dataset),
        100. * correct / len(val_loader.dataset)))

    return validation_loss

In [8]:
losses = []
for epoch in range(first_epoch, nepochs + 1):
    train_loss = train(epoch)
    val_loss = validation(epoch)
    losses.append((epoch, train_loss, val_loss))
    model_file = checkpoint_dir + 'model_' + str(epoch) + '.pth'
    torch.save(model.state_dict(), model_file)
    print('\nSaved model to ' + model_file)

  return F.log_softmax(x)




  data, target = Variable(data, volatile=True).to(device), Variable(target).to(device)



Validation set: Average loss: 0.1193, Accuracy: 7550/7842 (96%)


Saved model to checkpoints/model_0.pth

Validation set: Average loss: 0.0490, Accuracy: 7732/7842 (99%)


Saved model to checkpoints/model_1.pth

Validation set: Average loss: 0.0497, Accuracy: 7717/7842 (98%)


Saved model to checkpoints/model_2.pth

Validation set: Average loss: 0.0277, Accuracy: 7785/7842 (99%)


Saved model to checkpoints/model_3.pth

Validation set: Average loss: 0.0615, Accuracy: 7724/7842 (98%)


Saved model to checkpoints/model_4.pth

Validation set: Average loss: 0.0626, Accuracy: 7723/7842 (98%)


Saved model to checkpoints/model_5.pth

Validation set: Average loss: 0.0125, Accuracy: 7816/7842 (100%)


Saved model to checkpoints/model_6.pth

Validation set: Average loss: 0.0191, Accuracy: 7803/7842 (100%)


Saved model to checkpoints/model_7.pth

Validation set: Average loss: 0.0251, Accuracy: 7803/7842 (100%)


Saved model to checkpoints/model_8.pth

Validation set: Average loss: 0.0164, Accu