In [None]:
# imports
import torch
import math
import copy
import torch.optim as optim
import torch.nn as nn

import numpy as np
import matplotlib.pyplot as plt

from dataset import TrafficSignDataset
from torch.utils.data import Subset, DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from matplotlib.patches import Rectangle
from tqdm import tqdm
from datetime import datetime
from torchvision import transforms

from networks import TrafficSignsClassifier


In [None]:
img_size=32
dataset_mean = [86.72383685, 79.56345902, 81.93326525]
dataset_std= [51.48834219, 50.93286751, 53.30977311]

transform = transforms.Compose(
        [
            transforms.Resize((img_size, img_size)),
            transforms.Normalize(mean=dataset_mean, std=dataset_std)
        ]
)

# augmentation transforms
augmentation_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    # transforms.GaussianBlur(kernel_size=5, sigma=(0.01, 0.6)),
    transforms.RandomPerspective(distortion_scale=0.1, p=0.5),
    # transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.Normalize(mean=dataset_mean, std=dataset_std)
])

In [None]:
dataset = TrafficSignDataset("..\data\Train.csv", "..\data", transform=transform)

# randomly split into train and validation, regardless of frame sequences
train, val = random_split(dataset, [math.floor(len(dataset) * 0.8), math.ceil(len(dataset) * 0.2)], generator=torch.Generator().manual_seed(42))

train.dataset.transform = augmentation_transform
print(len(train), len(val))

In [None]:
# calculate mean and variance
# from preprocessing import calculate_mean_std

# calculate_mean_std(dataset)

In [None]:
cols, rows = 5, 6
idx = 0 # index of frame sequence

figure = plt.figure(figsize=(20, 24))

for i in range(idx * 30, idx * 30 + 30):
    img, target = train[i]

    figure.add_subplot(cols, rows, i - (idx * 30) + 1)
    # plt.imshow(img.byte().permute(1,2,0), cmap="gray")
    plt.imshow(img.byte().permute(1,2,0))
    ax = plt.gca()

    bbox = target["bbox"]
    x = bbox[2]
    y = bbox[3]
    box_width = bbox[4] - x
    box_height = bbox[5] - y

    rect = Rectangle((x,y), box_width, box_height, linewidth=1, edgecolor='r',facecolor='none')

    plt.axis("off")
    ax.add_patch(rect)

plt.show()

In [None]:
fig, ax = plt.subplots(3,1, figsize=(15,15))

# training data
train_labels, train_counts = np.unique((train.dataset.targets["ClassId"])[train.indices], return_counts=True)
ax[0].bar(train_labels, train_counts / len(train) * 100)

# validation data
val_labels, val_counts = np.unique((val.dataset.targets["ClassId"])[val.indices], return_counts=True)
ax[1].bar(val_labels, val_counts, color="orange")


# val - train
diff = val_counts / len(val) - train_counts / len(train)

ax[2].bar(train_labels, diff * 100, color=np.where(diff >= 0, "orange","C0"))


In [None]:
batch_size = 256

# setup dataloaders
train_dataloader = DataLoader(train, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val, batch_size=batch_size, shuffle=False)

# TODO: test_dataloader

# to test training setup for errors
sample_subset = Subset(dataset, np.arange(batch_size))
sample_dataloader =  DataLoader(sample_subset, batch_size=batch_size, shuffle=True)

In [None]:
# setup device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

In [None]:
hparams = {
    "lr": 1e-3,
    "bn_momentum": 0.1,
    "momentum": 0.9, 
    "weight_decay": 0,
    "dampening": 0
}

In [None]:
# init model
model = TrafficSignsClassifier(hparams, input_size=32, num_classes=len(train_labels))

model = model.to(device)

# calculate model params
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model parameters: {total_params}")

In [None]:
# define optimizer
criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=hparams["lr"])
optimizer = optim.SGD(
    model.parameters(),
    lr=hparams["lr"],
    momentum=hparams["momentum"],
    weight_decay=hparams["weight_decay"],
    dampening=hparams["dampening"],
    nesterov=True
)

In [None]:
# new tensorboard log
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
run_name = f"2_conv_2_k3_linear_relu_bn_maxp_b256_randp01_1e-3_val20_SGD{timestamp}"
writer = SummaryWriter(f"../runs/{run_name}")

correct = 0.0   
total = 0
max_epochs = 30
validate_every = 1

min_val_loss = None
patience = 0
stopping_threshold = 5

for epoch in range(0, max_epochs):
    running_loss = 0.0

    for i, data in enumerate(tqdm(train_dataloader)):
        # move data to device
        x, y = data
        x = x.to(device)
        y = y["label"].to(device)

        # zero param gradients
        optimizer.zero_grad()

        # forward + backward + optimizer
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        
        # calc stats
        running_loss += loss.item()
        _, preds = torch.max(logits, 1)
        correct += preds.eq(y).sum().item()
        total += y.size(0)
        
    # print stats to console
    running_loss /= len(train_dataloader)
    correct /= total
    print("[Epoch %d/%d] loss: %.3f acc: %.2f %%" % (epoch+1, max_epochs, running_loss, 100*correct))

    # log to tensorboard
    writer.add_scalar('Training loss', running_loss, epoch * len(train_dataloader) * (batch_size/32) + i)
    writer.add_scalar('Training acc', correct, epoch * len(train_dataloader) * (batch_size/32))

    # reset stats after epoch
    running_loss = 0.0
    correct = 0.0
    total = 0

    # validation loop
    if epoch % validate_every == (validate_every - 1):

        val_running_loss = 0.0
        val_acc = 0.0
        val_total = 0

        with torch.no_grad():
            for i, data in enumerate(val_dataloader):
                # move data to device
                x, y = data
                x, y = x.to(device), y["label"].to(device)
                
                # prediction + loss
                logits = model(x)
                val_loss = criterion(logits, y)
                val_running_loss += val_loss.item()
                
                # statistics
                preds = torch.argmax(logits, dim=1)
                val_acc += torch.sum(preds == y).item()
                val_total += y.size(0)

        # normalize loss and acc
        val_running_loss /= len(val_dataloader)
        val_acc /= val_total
        print("Validation loss: %.3f acc: %.2f %%" % (val_running_loss, val_acc * 100))

        # log to tensorboard
        writer.add_scalar('Validation loss', val_running_loss, epoch * len(train_dataloader) * (batch_size/32) + i)
        writer.add_scalar('Validation acc', val_acc, epoch * len(train_dataloader) * (batch_size/32))
    
        # save best model and early stopping
        if min_val_loss is None or val_running_loss < min_val_loss:
            patience = 0
            min_val_loss = val_running_loss
            best_acc = val_acc

            # save params
            best_model_weights = copy.deepcopy(model.state_dict())
        elif patience >= stopping_threshold:
            print("Early stopping, best model at epoch ", epoch - stopping_threshold)
            break
        else:
            patience += 1
            continue

if patience < stopping_threshold:
    print("Best model at epoch ", max_epochs - patience)
    
print("Val loss: %.3f acc: %.2f %%" % (min_val_loss, best_acc * 100))

In [None]:
# load best weights
model.load_state_dict(best_model_weights)

In [None]:
print(model)

In [None]:
# save best model 
torch.save(model.state_dict(), f"..\models\{run_name}.pth")

### Confusion matrix

In [None]:
num_classes = len(train_labels)

confusion_matrix = torch.zeros(num_classes, num_classes)
with torch.no_grad():
    for i, (x, y) in enumerate(val_dataloader):
        x = x.to(device)
        y = y["label"].to(device)
        outputs = model(x)
        _, preds = torch.max(outputs, 1)
        for t, p in zip(y.view(-1), preds.view(-1)):
                confusion_matrix[t, p] += 1

print(confusion_matrix)

In [None]:
import seaborn as sns

rows_sum = confusion_matrix.sum(axis=1)

norm_confusion_matrix = confusion_matrix / rows_sum[:, np.newaxis]

sns.set(rc={'figure.figsize':(24,20)})
ax = sns.heatmap(norm_confusion_matrix, annot=True, cmap='Blues', fmt='.2f')

plt.show()