In [None]:
import torch
import torchvision
from torchvision.transforms.functional import to_tensor
import torch.utils.data as D
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

DEVICE = "cuda:0"
DTYPE = torch.bfloat16

In [None]:
def transform(x):
    x = x.to(device=DEVICE, dtype=DTYPE) 
    x = torch.flatten(x, start_dim=1) / 255.
    return x

def target_transform(y):
    y = F.one_hot(y, num_classes=10)
    y = y.to(device=DEVICE, dtype=DTYPE)
    return y


In [None]:
mnist_train = torchvision.datasets.FashionMNIST("./FMNIST",
                                        train=True,
                                        download=True)

mnist_test = torchvision.datasets.FashionMNIST("./FMNIST",
                                        train=False,
                                        download=True)


In [None]:
mnist_data = torch.concat(
    (mnist_train.data,
     mnist_test.data),
     dim=0)

mnist_labels = torch.concat(
    (mnist_train.targets,
     mnist_test.targets)
    )

In [None]:
class MNISTDataset(D.Dataset):
    def __init__(self,
                 data,
                 targets,
                 transform=None,
                 target_transform=None):
        
        super().__init__()
        self.data = data.clone()
        self.targets = targets.clone()
        self.num_classes = len(torch.unique(targets))
        self.data = transform(self.data) if transform else data
        self.targets = target_transform(self.targets) if target_transform else targets
        self.shape = self.data.shape

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return (self.data[idx], self.targets[idx])

MNIST_data = MNISTDataset(mnist_data,
                    mnist_labels,
                    transform,
                    target_transform)

In [None]:
rand_gen = torch.Generator().manual_seed(192)
train, val, test = D.random_split(MNIST_data, [0.8, 0.1, 0.1], generator=rand_gen)

In [None]:
train_loader = D.DataLoader(train, 1024, shuffle=True)
val_loader = D.DataLoader(val, 1024)
test_loader = D.DataLoader(test, 1_000_000)

In [None]:
INPUT_SHAPE = MNIST_data.shape[1]
NUM_CLASSES = MNIST_data.num_classes

In [None]:
class Model(nn.Module):
    def __init__(self,
                 input_dim: int,
                 output_dim: int):
        super().__init__()
        self._input_dim = input_dim
        self._output_dim = output_dim
        self._flatten = nn.Flatten(start_dim=1, end_dim=-1)
    
        self._linear_stack = nn.Sequential(
            #nn.Linear(input_dim, 200),
            #nn.ReLU(True),
            #nn.BatchNorm1d(200),
            #nn.Linear(200, 100),
            #nn.ReLU(True),
            #nn.BatchNorm1d(100),
            #nn.Linear(100, 50),
            #nn.ReLU(True),
            #nn.BatchNorm1d(50),
            nn.Dropout(p=0.1),
            nn.Linear(input_dim, output_dim),
            nn.Softmax(dim=-1)
        )
        
    def forward(self, x):
        x_ = self._flatten(x)
        return self._linear_stack(x_)
    
    #@torch.inference_mode()
    #def predict(self, x):
        #return self._net(x)

In [None]:
from math import ceil
from IPython.display import clear_output
from torchmetrics import Accuracy

def train_model(
    model: nn.Module,
    train_loader: D.DataLoader,
    val_loader: D.DataLoader,
    optimizer,
    criterion,
    epochs=100):
    
    history = {"train":[], "val":[]}
    num_train_batches = ceil(len(train_loader.dataset) / train_loader.batch_size)
    num_val_batches = ceil(len(val_loader.dataset) / val_loader.batch_size)
    mca_train = Accuracy(task="multiclass", num_classes=NUM_CLASSES).to(device=DEVICE)
    mca_val = mca_train.clone()
    
    for epoch in range(epochs):
        sum_train_loss = 0
        for step, (train_batch, train_labels) in enumerate(train_loader):
            ###########
            # Forward #
            ###########
            model.train()
            preds = model(train_batch)
            
            ###################          
            # Compute Metrics #
            ###################
            train_loss = criterion(preds, train_labels)
            model.eval()
            with torch.inference_mode():
                #train_acc = compute_acc(preds, train_labels)
                train_acc = mca_train.forward(
                    preds.to(torch.float16),
                    torch.argmax(train_labels, dim=1)
                )
                
                sum_train_loss += train_loss.item()
                
            ###############
            # Make output #
            ###############
            num_bars = int(((step / num_train_batches) * 20)) + 1
            completion_string = "="*num_bars
            completion_string += "-"*(20 - num_bars)
            output = "Epoch: {} \t [{}] \t Train loss: {:.3f} \t Train acc: {:.3f}"\
                .format(epoch,
                        completion_string,
                        sum_train_loss/(step+1),
                        mca_train.compute().item()
                        )
            print(output, end="\r")
            
            #################
            # Backpropogate #
            #################
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
        
        avg_train_loss = sum_train_loss / num_train_batches
        
        ##############
        # Validation #
        ##############
        model.eval() # Turn off batch norm
        with torch.inference_mode():
            sum_val_loss = 0
            for _, (val_batch, val_labels) in enumerate(val_loader):
                preds = model(val_batch)
                targets = val_labels
                
                val_loss = criterion(preds, targets)
                sum_val_loss += val_loss.item()
                
                mca_val.update(
                    preds.to(torch.float16),
                    torch.argmax(targets, dim=1)
                    )
            
            avg_val_loss = sum_val_loss / num_val_batches
            
            print("\nAvg Train Loss: {:.3f}\
                Avg Val Loss: {:.3f}\
                Avg Train Acc: {:.3f}\
                Avg Val Acc: {:.3f}"\
                .format(avg_train_loss,
                        avg_val_loss,
                        mca_train.compute().item(),
                        mca_val.compute().item()))
            
            print("-"*140, "\n")
        
        mca_train.reset()
        mca_val.reset()
            
        ##########
        # Record #
        ##########
        history["train"].append(avg_train_loss)
        history["val"].append(avg_val_loss)
        
    return history

In [None]:
model = Model(INPUT_SHAPE, NUM_CLASSES)
model.to(device=DEVICE, dtype=DTYPE)
torch.manual_seed(192)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)
criterion = torch.nn.CrossEntropyLoss(reduction="mean")

hist = train_model(model,
                   train_loader,
                   val_loader,
                   optimizer,
                   criterion,
                   epochs=50)


In [None]:
mca_test = Accuracy(
    task="multiclass",
    num_classes=NUM_CLASSES).to(device=DEVICE)

for i, (img, label) in enumerate(test_loader):
    pred = model(img)
    mca_test.update(pred.to(torch.float16),
                    torch.argmax(label, dim=1)),


mca_test.compute()