## Load Dataset

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib
import matplotlib.pyplot as plt
import gc

from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader
from torch.optim import SGD, RMSprop, Adam
from torch import autograd
from sklearn.metrics import accuracy_score
from tqdm import tqdm

In [2]:
# Define the transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5))
])

# Load Dataset
train_dataset = FashionMNIST(root = './data', train = True, transform  = transform, download = True)
test_dataset = FashionMNIST(root = './data', train = True, transform = transform, download = True)

In [3]:
# DataLoaders
train_loader = DataLoader(train_dataset, batch_size = 128, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = 128, shuffle = False)

In [4]:
# Mapping
mapping = {0: 'T-shirt/top',
 1: 'Trouser',
 2: 'Pullover',
 3: 'Dress',
 4: 'Coat',
 5: 'Sandal',
 6: 'Shirt',
 7: 'Sneaker',
 8: 'Bag',
 9: 'Ankle boot'}

## Modeling

In [5]:
class ConvNet(nn.Module):
  def __init__(self):
    super(ConvNet, self).__init__()
    self.seq = nn.Sequential(
        nn.Conv2d(in_channels = 1, out_channels = 48, kernel_size = (3, 3), padding = 'same'),
        nn.ReLU(),
        nn.Conv2d(in_channels = 48, out_channels = 32, kernel_size = (3, 3), padding = 'same'),
        nn.ReLU(),
        nn.Conv2d(in_channels = 32, out_channels = 16, kernel_size = (3, 3), padding = 'same'),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(16 * 28 * 28, 10),
    )

  def forward(self, x_batch):
    preds = self.seq(x_batch)

    return preds

conv_net = ConvNet()

In [6]:
conv_net

ConvNet(
  (seq): Sequential(
    (0): Conv2d(1, 48, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): ReLU()
    (2): Conv2d(48, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (3): ReLU()
    (4): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (5): ReLU()
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=12544, out_features=10, bias=True)
  )
)

## Training

In [7]:
def CalcValLoss(model, loss_func, val_loader):
  with torch.no_grad():
    val_losses = []
    for X_batch, Y_batch in val_loader:
      preds = model(X_batch)
      loss = loss_func(preds, Y_batch)
      val_losses.append(loss)
    print('Valid Categorical CrossEntropy : {:.4f}', format(torch.tensor(val_losses).mean()))

def MakePredictions(model, loader):
  preds, Y_shuffled = [], []
  for X_batch, Y_batch in loader:
    preds.append(model(X_batch))
    Y_shuffled.append(Y_batch)

  preds = torch.cat(preds).argmax(axis = -1)
  Y_shuffled = torch.act(Y_shuffled)

  return Y_shuffled, preds

def TrainModelInBatchesV1(model, loss_func, optimizer, train_loader, val_loader, epochs = 5):
  for i in range(epochs):
    losses = []
    for X_batch, Y_batch in tqdm(train_loader):
      preds = model(X_batch) # Network를 통해 순방향 전파하여 예측 수행
      loss = loss_func(preds, Y_batch) # Loss 계산
      losses.append(loss) # Loss 기록
      optimizer.zero_grad() # Gradients를 계산하기 전에 weight를 0으로 초기화
      loss.backward() # Gradients 계산
      optimizer.step() # Weight 업뎃

    print('Train Categorical CrossEntropy : {:.4f}'.format(torch.tensor(losses).mean()))
    CalcValLoss(model, loss_func, val_loader)

    Y_test_shuffled, test_preds = MakePredictions(model, val_loader)
    val_acc = accuracy_score(Y_test_shuffled, test_preds)
    print('Val ACC : {:.4f}'.format(val_acc))
    gc.collect()

In [None]:
torch.manual_seed(42)
epochs = 5
learning_rate = torch.tensor(1e-3)

conv_net = ConvNet()
cel = nn.CrossEntropyLoss()
optimizer = Adam(params = conv_net.parameters(), lr = learning_rate)

TrainModelInBatchesV1(conv_net, cel, optimizer, train_loader, test_loader, epochs)

 69%|██████▉   | 323/469 [02:41<00:46,  3.14it/s]

In [None]:
# # 저장하기
# torch.save(conv_net, 'trained_model.pth')

# # 불러오기
# conv_net = torch.load('trained_model.pth')

## Grad-CAM

### 1. Capture Output of Last Convolution Layer

In [None]:
list(conv_net.children())[0]

In [None]:
class LastConvLayerModel(nn.Module):
  def __init__(self):
    super(LastConvLayerModel, self).__init__()
    self.layers = list(list(conv_net.children())[0].children())

  def forward(self, X_batch):
    x = self.layers[0](X_batch)
    conv_layer_output = None
    for i, layer in enumerate(self.layers[1:]):
      x = layer(x)
      if i == 3: # 3층 이후의 층
        self.conv_layer_output = x

    return x

In [None]:
X_test = test_dataset.data.resape(-1, 1, 28, 28).type(torch.float32)
Y_test = test_dataset.targets

conv_model = LastConvLayerModel()
idx = np.random.choice(range(10000))
pred = conv_model(X_test[idx:idx+1])

F.softmax(pred, dim=-1).argmax(), F.softmax(pred, dim=-1).max()

In [None]:
conv_model.conv_layer_output.shape

In [None]:
print('Actual Target : {}'.format(mapping[Y_test[idx].item()]))
print('Predicted Target : {}'.format(mapping[pred.argmax(dim = -1).item()]))

### 2. Take Gradients of Last Convolution Layer Output with Respect to Prediction

In [None]:
grads = autograd.grad(pred[:, pred.argmax().item()], conv_model.conv_layer_output)

grads[0].shape

### 3. Average Gradients

In [None]:
pooled_grads = grads[0].mean((0, 2, 3))

pooled_grads

### 4. Multiply Convolution Layer Output with Averaed Gradients

In [None]:
conv_output = conv_model.conv_layer_output.squeeze()

conv_output = F.relu(conv_output)

conv_output.shape

In [None]:
for i in range(len(pooled_grads)):
  conv_output[i, :, :] *= pooled_grads[i]

conv_output.shape

### 5. Average Output at Channel axis to Create Heatmap

In [None]:
heatmap = conv_output.mean(dim = 0).squeeze()

# heatmap normalize
heatmap = F.relu(heatmap) / torch.max(heatmap)

heatmap.shape

### 6. Visualize Original Image and Heatmap

In [None]:
def plot_actual_and_heatmap(idx, heatmap):
  cmap = matplotlib.cm.get_cmap('Reds')

  fig = plt.figure(figsize = (10, 10))
  ax1 = fig.add_subplot(121)
  ax1.imshow(X_test[idx].numpy().squeeze(), cmap = 'gray')
  ax1.set_title('Actual')
  ax1.set_xticks([], []);ax1.set_yticks([], [])

  ax2 = fig.add_subplot(122)
  ax2.imshow(heatmap, cmap = 'Reds')
  ax2.set_title('Gradients')
  ax2.set_xticks([], []);ax2.set_yticks([], [])

plot_actual_and_heatmap(idx, heatmap.detach())