In [None]:
import torch
from torch import nn

import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor

from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import random

from torchinfo import summary
from tqdm.auto import tqdm

print(f"PyTorch version: {torch.__version__}\ntorchvision version: {torchvision.__version__}")

In [None]:
train_data = datasets.FashionMNIST('data', True, ToTensor(), download=True)
test_data = datasets.FashionMNIST('data', False, ToTensor(), download=True)

In [None]:
BATCH_SIZE = 32
train_dataloader = DataLoader(train_data, BATCH_SIZE, True)
test_dataloader = DataLoader(test_data, BATCH_SIZE, False)

In [None]:
train_data.classes

In [None]:
def visualizeData(data):
    idx = random.randint(0, len(data))
    img, label = data[idx]
    plt.imshow(img.squeeze())
    plt.title(data.classes[label])
    plt.show()

In [None]:
train_data[0][0].shape

In [None]:
class SampleModel(nn.Module):
    def __init__(self, scale=2, use_gradient_checkpoint=False):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(1, 128*scale, 3, 1, 1),
            nn.GELU(),
            nn.BatchNorm2d(128*scale),
            nn.Conv2d(128*scale, 256*scale, 3, 1, 1),
            nn.GELU(),
            nn.BatchNorm2d(256*scale),
            nn.MaxPool2d(2),
            nn.Conv2d(256*scale, 256*scale, 3, 1, 1),
            nn.GELU(),
            nn.BatchNorm2d(256*scale),
            nn.Conv2d(256*scale, 512*scale, 3, 1, 1),
            nn.GELU(),
            nn.BatchNorm2d(512*scale),
            nn.MaxPool2d(2),
            nn.Conv2d(512*scale, 512*scale, 3, 1, 1),
            nn.GELU(),
            nn.BatchNorm2d(512*scale),
            nn.Conv2d(512*scale, 10, 3, 1, 1),
            nn.GELU(),
            nn.BatchNorm2d(10),
            nn.Flatten(),
            nn.Linear(490, 10),
        )
        self.use_gradient_checkpoint = use_gradient_checkpoint
    def forward(self, x):
        if not self.use_gradient_checkpoint:
            return self.block1(x)
        
        x = torch.utils.checkpoint.checkpoint_sequential(self.block1, 10, x, use_reentrant=False)
        return x
model = SampleModel(use_gradient_checkpoint=True)
#summary(model, (BATCH_SIZE, 1, 28, 28))

In [None]:
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
model.to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), 0.001)
loss_fn = torch.nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', 0.1, 5)
def accuracy_fn(y_pred, y_true):
    num = 0
    for i in range(len(y_true)):
        num += (y_true[i] == y_pred[i])
    return 100.0*num/len(y_true)

In [None]:
def train(model, train_dataloader, test_dataloader, device, epochs=1, save=False, savefilename='model.pth'):
    best_test_loss = 1e9
    for epoch in range(epochs):
        print(f'Epoch: {epoch}')
        model.train()
        train_loss = 0
        for (X, y) in tqdm(train_dataloader):
            X = X.to(device)
            y = y.to(device)

            y_pred = model(X)
            loss = loss_fn(y_pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss
        train_loss /= len(train_dataloader)
        print(f'Train loss: {train_loss}')
        scheduler.step(train_loss)
        test_acc = 0
        test_loss = 0
        model.eval()
        with torch.inference_mode():
            for (X, y) in test_dataloader:
                X = X.to(device)
                y = y.to(device)

                y_pred = model(X)
                test_acc += accuracy_fn(torch.argmax(y_pred, dim=1), y)
                test_loss += loss_fn(y_pred, y)
            test_acc /= len(test_dataloader)
            test_loss /= len(test_dataloader)
            if save and test_loss < best_test_loss:
                best_test_loss = test_loss
                torch.save(model.state_dict(), savefilename)
                print("Saved")
            print(f'Test loss: {test_loss}, Test acc: {test_acc:.2f}%')


In [None]:
train(model, train_dataloader, test_dataloader, device, save=True, epochs=1)

In [None]:
def visualizeModelPredict(model, data):
    fig, axs = plt.subplots(2, 5, figsize=(15, 6))
    for i in range(2):
        for j in range(5):
            idx = random.randint(0, len(data))
            img, label = data[idx]
            model.eval()
            with torch.inference_mode():
                pred = data.classes[torch.argmax(model(data[idx][0].unsqueeze(dim=0).to(device)), dim=1)]
                truth = data.classes[label]
                axs[i, j].imshow(img.squeeze(), cmap='gray')
                if pred == truth:
                    axs[i, j].set_title('Pred:'+pred+', Truth:'+truth, color='green')
                else:
                    axs[i, j].set_title('Pred:'+pred+', Truth:'+truth, color='red')
                axs[i, j].axis('off')
    plt.show()

In [None]:
visualizeModelPredict(model, test_data)