Load FashionMNIST Datasets

In [1]:
import torch
from torchvision import datasets, transforms

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

train_data = datasets.FashionMNIST(root='data', train=True, download=True, transform=
    transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop((28,28)),
        transforms.ToTensor()
    ])
)
test_data = datasets.FashionMNIST(root='data', train=False, download=True, transform=transforms.ToTensor())

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}

Using cuda device


Define CNN Model

In [2]:
import torchsummary

class BypassConvBlock(torch.nn.Module):
  def __init__(self, channels):
    super(BypassConvBlock, self).__init__()

    self.conv_stack = torch.nn.Sequential(
      torch.nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1,1)),
      torch.nn.BatchNorm2d(channels),
      torch.nn.ReLU(),
      torch.nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(3,3), padding=(1,1)),
      torch.nn.BatchNorm2d(channels),
      torch.nn.ReLU(),
      torch.nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1,1)),
      torch.nn.BatchNorm2d(channels),
      torch.nn.ReLU(),
    )

    self.conv_bypass = torch.nn.Sequential(
      torch.nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1,1)),
      torch.nn.BatchNorm2d(channels),
    )

    self.act = torch.nn.ReLU()

  def forward(self, x):
    a = self.conv_stack(x)
    b = self.conv_bypass(x)
    return self.act(a + b)


class CNN(torch.nn.Module):
  def __init__(self):
    super(CNN, self).__init__()

    self.conv0_stack = torch.nn.Sequential(
      torch.nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(3,3), padding=(1,1)),
      torch.nn.ReLU(),
      torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3,3), padding=(1,1)),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(kernel_size=(2,2), stride=2),
      torch.nn.BatchNorm2d(64),
      torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3,3), padding=(1,1)),
      torch.nn.ReLU(),
      torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3,3), padding=(1,1)),
      torch.nn.ReLU(),
      torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3,3), padding=(1,1)),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(kernel_size=(2,2), stride=2),
      torch.nn.BatchNorm2d(64),
    )

    self.flatten_stack = torch.nn.Flatten()

    output_activation = torch.nn.Identity() if self.training else torch.nn.LogSoftmax()

    self.fc_stack = torch.nn.Sequential(
      torch.nn.Dropout(),
      torch.nn.Linear(in_features=64*7*7, out_features=1028),
      torch.nn.ReLU(),
      torch.nn.Dropout(),
      torch.nn.Linear(in_features=1028, out_features=1028),
      torch.nn.ReLU(),
      torch.nn.Dropout(),
      torch.nn.Linear(in_features=1028, out_features=512),
      torch.nn.ReLU(),
      torch.nn.Dropout(),
      torch.nn.Linear(in_features=512, out_features=10),
      output_activation
    )

  def forward(self, x):
    conv_y = self.conv0_stack(x)

    return self.fc_stack(self.flatten_stack(conv_y))


model = CNN().to(device)
print(torchsummary.summary(model=model, input_size=(1,28,28), batch_size=128))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [128, 64, 28, 28]             640
              ReLU-2          [128, 64, 28, 28]               0
            Conv2d-3          [128, 64, 28, 28]          36,928
              ReLU-4          [128, 64, 28, 28]               0
         MaxPool2d-5          [128, 64, 14, 14]               0
       BatchNorm2d-6          [128, 64, 14, 14]             128
            Conv2d-7          [128, 64, 14, 14]          36,928
              ReLU-8          [128, 64, 14, 14]               0
            Conv2d-9          [128, 64, 14, 14]          36,928
             ReLU-10          [128, 64, 14, 14]               0
           Conv2d-11          [128, 64, 14, 14]          36,928
             ReLU-12          [128, 64, 14, 14]               0
        MaxPool2d-13            [128, 64, 7, 7]               0
      BatchNorm2d-14            [128, 6

Training

In [3]:
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from tqdm.notebook import tqdm
from src.common.plot import plotImagesProbBar
import matplotlib.pyplot as plt


train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=False)

optimizer = torch.optim.Adam(model.parameters())
loss_function = torch.nn.CrossEntropyLoss()

writer = SummaryWriter()

def trainEpoch():
  epoch_loss, epoch_acc = 0, 0

  with tqdm(train_loader, desc="Epoch", unit="batch", position=1, leave=False) as bar:
    for X, y in bar:
      X = X.to(device)
      y = y.to(device)
      pred = model(X)
      loss = loss_function(pred, y)
      epoch_loss += loss.mean().item()
      epoch_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

  return epoch_loss / len(train_loader.dataset), epoch_acc / len(train_loader.dataset)


def evalEpoch():
  batch_labels, batch_probs, batch_preds, batch_loss = [], [], [], []

  with torch.no_grad():
    for X, y in tqdm(test_loader, desc="Test", unit="batch", position=1, leave=False):
      X = X.to(device)
      y = y.to(device)
      out = model(X)
      batch_pred = out.argmax(1)
      batch_probs.append(F.softmax(out, dim=1))
      batch_labels.append(y)
      batch_preds.append(batch_pred)
      batch_loss.append(F.cross_entropy(F.log_softmax(out, dim=1), y, reduction='none'))

  test_probs = torch.cat(batch_probs, dim=0)
  test_loss = torch.cat(batch_loss)
  test_labels = torch.cat(batch_labels)
  test_preds = torch.cat(batch_preds)
  return test_loss, test_probs, test_preds, test_labels

writer.add_graph(model, next(iter(test_loader))[0].to(device))

with tqdm(range(1,21), desc="Training Epochs", unit="epoch", position=0) as bar:
  for i in bar:
    train_loss, train_acc = trainEpoch()
    test_loss,  test_probs, test_preds, test_labels = evalEpoch()
    test_acc = (test_preds == test_labels).type(torch.float).mean().item()
    writer.add_scalars("Loss", {'train_loss': train_loss, 'test_loss' : test_loss.mean().item()}, i)
    writer.add_scalars("Accuracy", {'train_acc': train_acc, 'test_acc' : test_acc}, i)

    # add pr curve
    for c_i in range(10):
      labels = test_labels == c_i
      probs = test_probs[:, c_i]
      writer.add_pr_curve(labels_map[c_i], labels, probs, i)

    # add worst classifications
    worst_ix = torch.topk(test_loss, k=15)[1]
    fig = plotImagesProbBar(
      test_data.data[worst_ix],
      test_labels[worst_ix],
      test_preds[worst_ix],
      test_probs[worst_ix],
      labels_map, 5, 3
    )
    writer.add_figure("Worst Predictions", fig, i)

    # add hightest uncertaincy classifications
    test_entropy = torch.sum(test_probs * torch.log(test_probs), dim=1)
    le_ix = torch.topk(test_entropy, k=15, largest=False)[1]
    fig = plotImagesProbBar(
      test_data.data[le_ix],
      test_labels[le_ix],
      test_preds[le_ix],
      test_probs[le_ix],
      labels_map, 5, 3
    )
    writer.add_figure("Lowest entropy predictions", fig, i)

    bar.set_postfix_str(f"TrainLoss={train_loss:.03e} TestAcc={test_acc*100:.02f}%")

writer.close()

Training Epochs:   0%|          | 0/20 [00:00<?, ?epoch/s]

Epoch:   0%|          | 0/469 [00:00<?, ?batch/s]

Test:   0%|          | 0/79 [00:00<?, ?batch/s]

Epoch:   0%|          | 0/469 [00:00<?, ?batch/s]

Test:   0%|          | 0/79 [00:00<?, ?batch/s]

Epoch:   0%|          | 0/469 [00:00<?, ?batch/s]

Test:   0%|          | 0/79 [00:00<?, ?batch/s]

Epoch:   0%|          | 0/469 [00:00<?, ?batch/s]

Test:   0%|          | 0/79 [00:00<?, ?batch/s]

Epoch:   0%|          | 0/469 [00:00<?, ?batch/s]

Test:   0%|          | 0/79 [00:00<?, ?batch/s]

Epoch:   0%|          | 0/469 [00:00<?, ?batch/s]

Test:   0%|          | 0/79 [00:00<?, ?batch/s]

Epoch:   0%|          | 0/469 [00:00<?, ?batch/s]

Test:   0%|          | 0/79 [00:00<?, ?batch/s]

Epoch:   0%|          | 0/469 [00:00<?, ?batch/s]

Test:   0%|          | 0/79 [00:00<?, ?batch/s]

Epoch:   0%|          | 0/469 [00:00<?, ?batch/s]

Test:   0%|          | 0/79 [00:00<?, ?batch/s]

Epoch:   0%|          | 0/469 [00:00<?, ?batch/s]

Test:   0%|          | 0/79 [00:00<?, ?batch/s]

Epoch:   0%|          | 0/469 [00:00<?, ?batch/s]

Test:   0%|          | 0/79 [00:00<?, ?batch/s]

Epoch:   0%|          | 0/469 [00:00<?, ?batch/s]

Test:   0%|          | 0/79 [00:00<?, ?batch/s]

Epoch:   0%|          | 0/469 [00:00<?, ?batch/s]

Test:   0%|          | 0/79 [00:00<?, ?batch/s]

Epoch:   0%|          | 0/469 [00:00<?, ?batch/s]

Test:   0%|          | 0/79 [00:00<?, ?batch/s]

Epoch:   0%|          | 0/469 [00:00<?, ?batch/s]

Test:   0%|          | 0/79 [00:00<?, ?batch/s]

Epoch:   0%|          | 0/469 [00:00<?, ?batch/s]

Test:   0%|          | 0/79 [00:00<?, ?batch/s]

Epoch:   0%|          | 0/469 [00:00<?, ?batch/s]

Test:   0%|          | 0/79 [00:00<?, ?batch/s]

Epoch:   0%|          | 0/469 [00:00<?, ?batch/s]

Test:   0%|          | 0/79 [00:00<?, ?batch/s]

Epoch:   0%|          | 0/469 [00:00<?, ?batch/s]

Test:   0%|          | 0/79 [00:00<?, ?batch/s]

Epoch:   0%|          | 0/469 [00:00<?, ?batch/s]

Test:   0%|          | 0/79 [00:00<?, ?batch/s]