In [13]:

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets

In [4]:
# locate the MNIST data
def load_fashion_dataloaders():
    transform = transforms.Compose([
        transforms.ToTensor(),
        # transforms.Resize(X_DIM),
        transforms.Normalize((0.5,), (0.5,))
    ])

    train_set = torchvision.datasets.FashionMNIST(
        root='./data',
        train=True,
        download=True,
        transform=transform
    )
    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=8,
        shuffle=False,
        num_workers=2
    )

    test_set = torchvision.datasets.FashionMNIST(
        root='./data',
        train=False,
        download=True,
        transform=transform
    )
    test_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=1,
        shuffle=False
    )

    return train_loader, test_loader

In [5]:
train_loader, test_loader = load_fashion_dataloaders()
print(len(train_loader))
print(len(test_loader))

7500
10000


In [20]:
img_show, label_show = next(iter(test_loader))
label_show.item()

9

In [8]:
class CNN(nn.Module):

	def __init__(self):
		super(CNN, self).__init__()
		self.main = nn.Sequential(

			# input is Z, going into a convolution
			nn.Conv2d(1, 8, kernel_size=5, stride=1, padding=2),
			nn.BatchNorm2d(8),
			nn.ReLU(True),
			nn.Dropout2d(p=0.1),

			nn.Conv2d(8, 16, kernel_size=5, stride=2, padding=2),
			nn.BatchNorm2d(16),
			nn.ReLU(True),
			nn.Dropout2d(p=0.1),

			nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
			nn.BatchNorm2d(32),
			nn.ReLU(True),
			nn.Dropout2d(p=0.1),

			nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2),
			nn.BatchNorm2d(64),
			nn.ReLU(True),
			nn.Dropout2d(p=0.2),

			nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
			nn.BatchNorm2d(128),
			nn.ReLU(True)
		)

		self.classifier = nn.Sequential(
			nn.Linear(128, 10),
		)

	def forward(self, x):
		x = self.main(x) #
		x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2)  # GAP Layer
		logits = self.classifier(x)
		pred = F.softmax(logits, dim=1)

		return pred, logits, x


In [10]:
model = CNN()
model_checkpoint = torch.load('cnn_lr_2e_4_epoch50.pth', map_location='cpu')
model.load_state_dict(model_checkpoint)

model

CNN(
  (main): Sequential(
    (0): Conv2d(1, 8, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Dropout2d(p=0.1, inplace=False)
    (4): Conv2d(8, 16, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): Dropout2d(p=0.1, inplace=False)
    (8): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (9): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): Dropout2d(p=0.1, inplace=False)
    (12): Conv2d(32, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (13): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): ReLU(inplace=True)
    (15): Dropout2d(p=0.2, inplace=False)
    (16): Conv2d(64, 128, kernel_size=(

In [26]:
# 10 types of clothes in FashionMNIST dataset
def output_label(label):
  output_mapping = {
      0: "T-shirt",
      1: "Trouser",
      2: "Pullover",
      3: "Dress",
      4: "Coat",
      5: "Sandal",
      6: "Shirt",
      7: "Sneaker",
      8: "Bag",
      9: "Ankle Boot"
  }

  input = (label.item() if type(label) == torch.Tensor else label)

  return output_mapping[input]

In [27]:
correct_predictions = 0
total_samples = 0
count = 0
incorrect_indices = []

# Use the model for inference on the test set
with torch.no_grad():  # Disable gradient calculation during inference
    for batch_idx, (images, labels) in enumerate(test_loader):
        outputs, _, _ = model(images)
        predicted = torch.argmax(outputs, 1)
        correct_predictions += (predicted == labels).sum().item()
        total_samples += labels.size(0)
        
        if predicted != labels:
            
            incorrect_indices.append(batch_idx)
            
            image = images.squeeze()
            image = transforms.functional.to_pil_image(image)
            filename = f"true_label_{output_label(labels)}_predicted_label_{output_label(predicted)}.png"
            image.save(os.path.join('./yifan_data/incorrect_pred_images', filename))
            count += 1

# Calculate accuracy
accuracy = correct_predictions / total_samples
print(f'Test Accuracy: {accuracy * 100:.2f}%')
print("Number of incorrect predicted images: ", count)

Test Accuracy: 85.84%
Number of incorrect predicted images:  1416
