In [None]:
import torch as pt
pt.manual_seed(42)

import torchvision
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

device = pt.device("cuda:0" if pt.cuda.is_available() else "cpu")
device

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.to(device))])


train_ds = torchvision.datasets.FashionMNIST('./fmnist', 
                                              download = True, 
                                              train = True,
                                              transform = transform)

train_dl = pt.utils.data.DataLoader(train_ds,
                                          batch_size=4)
                                          # shuffle=True,
                                          # num_workers=4)

In [None]:
test_ds = torchvision.datasets.FashionMNIST('./fmnist', 
                                              download = True, 
                                              train = False,
                                              transform = transform)

test_dl = pt.utils.data.DataLoader(test_ds, batch_size=4)
                                          # shuffle=True,
                                          # num_workers=4)

In [None]:
CLASSES = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress',
                   'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

idx, (image, label) = next(enumerate(train_ds))
idx, image.shape, label

In [None]:
plt.imshow(image.cpu().squeeze().numpy(), cmap = 'binary')
plt.xlabel(CLASSES[label]);

In [None]:
import numpy as np

plt.figure(figsize=(10, 10))
for i in range(25):
    plt.subplot(5, 5, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid('off')
    idx = np.random.randint(0, len(train_ds))
    image, label = train_ds[idx]
    plt.imshow(image.cpu().squeeze().numpy(), cmap='binary')
    plt.xlabel(CLASSES[label])

In [None]:
from torch import nn

class Lambda(nn.Module):
  def __init__(self, fn):
    super(Lambda, self).__init__()
    self.fn = fn

  def forward(self, x):
    return self.fn(x)

model = nn.Sequential(
  nn.Conv2d(1, 32, 3), # 28x28x32 -> 26x26x32
  nn.ReLU(),
  nn.Conv2d(32, 64, 3), # 26x26x64 -> 24x24x64
  nn.ReLU(),
  nn.MaxPool2d(2, 2), # 24x24x64 -> 12x12x64
  nn.Dropout2d(),
  Lambda(lambda x: x.view(-1, 12 * 12 * 64)),
  nn.Linear(12 * 12 * 64, 128),
  nn.ReLU(),
  nn.Dropout2d(),
  nn.Linear(128, 10)
).to(device)

In [None]:
# model = Net().to(device)
def forward(X):
  return model(X)

def loss(y_pred, y):
  return pt.nn.functional.cross_entropy(y_pred, y)

optimizer = pt.optim.AdamW(model.parameters())

In [None]:
class ConfusionMatrix():
  def __init__(self, classes):
    self.classes = classes
    self.side = len(classes)
    self.values = np.zeros( (self.side, self.side) )
    self.total = int(0)

  def average_accuracy(self):
    return (self.values.diagonal() / self.values.sum(axis = 1).astype(float)).mean()

  def __call__(self, y_pred, y):
    for row, col in zip(y_pred, y):
      self.values[row, col] += 1      
    self.total += len(y_pred)
    return self

  def update(self, y_pred, y):
    return self.__call__(y_pred, y)

  def __repr__(self):
    msg = ""
    for i in range(self.side):
      msg += "{}: {:.2f}% ".format(self.classes[i], 100.0 * (self.values[i, i] / self.values[i,:].sum()))
    return msg


TOY_CLASSES = ['Hot dog', 'Not hot dog']
confusion_matrix = ConfusionMatrix(TOY_CLASSES)
confusion_matrix([0,1], [0,1])

In [None]:
plt.imshow(confusion_matrix.values, cmap = 'binary')
plt.xticks(np.arange(len(TOY_CLASSES)), TOY_CLASSES, rotation = 90)
plt.yticks(np.arange(len(TOY_CLASSES)), TOY_CLASSES);
plt.colorbar();
confusion_matrix.average_accuracy()

In [None]:
training_cm = ConfusionMatrix(CLASSES)
EPOCHS = 5
for epoch in range(EPOCHS):
  running_loss = 0.0
  for batch_idx, (X_batch, y_batch) in enumerate(train_dl):

      y_pred_batch = forward(X_batch)

      xe = loss(y_pred_batch, y_batch.to(device))
      
      training_cm.update(y_pred_batch.argmax(dim = 1).cpu().detach().numpy(), 
                          y_batch.cpu().detach().numpy())
      if batch_idx % 1000 == 0:
        print("Loss: ", xe.data, " Metric: ", training_cm.average_accuracy())

      xe.backward()

      optimizer.step()
      optimizer.zero_grad()                

  print(training_cm)

In [None]:
plt.imshow(training_cm.values, cmap = 'binary')
plt.xticks(np.arange(len(CLASSES)), CLASSES, rotation = 90)
plt.yticks(np.arange(len(CLASSES)), CLASSES);
plt.colorbar();

Copyright 2021 CounterFactual.AI LLC. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.