In [1]:
import torch
import torchvision
from torchvision.transforms.functional import to_tensor
import torch.utils.data as D
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

DEVICE = "cuda:0"
DTYPE = torch.bfloat16

  from .autonotebook import tqdm as notebook_tqdm


In [15]:
def transform(x):
    x = to_tensor(x)
    x = x.to(device=DEVICE, dtype=DTYPE)
    return x

def target_transform(y):
    y = torch.tensor(y).long()
    y = F.one_hot(y, num_classes=10)
    y = y.to(device=DEVICE, dtype=DTYPE)
    return y


In [16]:
mnist_train = torchvision.datasets.MNIST("./MNIST",
                                        transform=transform,
                                        target_transform=target_transform,
                                        train=True,
                                        download=True)

mnist_test = torchvision.datasets.MNIST("./MNIST",
                                        transform=transform,
                                        target_transform=target_transform,
                                        train=False,
                                        download=True)
                                        

combined = mnist_train + mnist_test


In [17]:
rand_gen = torch.Generator().manual_seed(192)
train, val, test = D.random_split(combined, [0.8, 0.1, 0.1], generator=rand_gen)

In [155]:
train_loader = D.DataLoader(train, 512, shuffle=True)
val_loader = D.DataLoader(val, 512)
test_loader = D.DataLoader(test, 64)

In [156]:
INPUT_SHAPE = torch.prod(torch.tensor(combined[0][0].shape)).item()

In [157]:
class Model(nn.Module):
    def __init__(self,
                 input_dim: int,
                 output_dim: int):
        super().__init__()
        self._input_dim = input_dim
        self._output_dim = output_dim
        self._flatten = nn.Flatten(start_dim=1, end_dim=-1)
    
        self._linear_stack = nn.Sequential(
            nn.Linear(input_dim, 200),
            nn.ReLU(False),
            nn.BatchNorm1d(200),
            nn.Linear(200, 100),
            nn.ReLU(False),
            nn.BatchNorm1d(100),
            nn.Linear(100, 50),
            nn.ReLU(False),
            nn.BatchNorm1d(50),
            nn.Linear(50, output_dim),
            nn.Softmax(dim=-1)
        )
        
    def forward(self, x):
        x_ = self._flatten(x)
        return self._linear_stack(x_)
    
    #@torch.inference_mode()
    #def predict(self, x):
        #return self._net(x)

In [172]:
from math import ceil
from IPython.display import clear_output

def compute_acc(pred, target) -> float:
    pred_= pred.clone()
    target_ = target.clone()
    
    if pred_.shape[1] > 1:
        pred_ = torch.argmax(pred_, dim=1)
    if target_.shape[1] > 1:
        target_ = torch.argmax(target_, dim=1)
        
    diff = pred_ - target_
    missclass = torch.nonzero(diff)
    
    return 1 - (len(missclass) / len(pred_))

def train_model(
    model: nn.Module,
    train_loader: D.DataLoader,
    val_loader: D.DataLoader,
    optimizer,
    criterion,
    epochs=100):
    
    history = {"train":[], "val":[]}
    num_train_batches = ceil(len(train_loader.dataset) / train_loader.batch_size)
    num_val_batches = ceil(len(val_loader.dataset) / val_loader.batch_size)
    
    for epoch in range(epochs):
        avg_train_loss = 0
        avg_train_acc = 0
        for step, (train_batch, train_labels) in enumerate(train_loader):
            ###########
            # Forward #
            ###########
            model.train()
            preds = model(train_batch)
            
            ###################          
            # Compute Metrics #
            ###################
            train_loss = criterion(preds, train_labels)
            model.eval()
            with torch.inference_mode():
                train_acc = compute_acc(preds, train_labels)
                avg_train_loss += train_loss.item()
                avg_train_acc += train_acc
                
            ###############
            # Make output #
            ###############
            num_bars = int(((step / num_train_batches) * 20)) + 1
            completion_string = "="*num_bars
            completion_string += "-"*(20 - num_bars)
            output = "Epoch: {} \t [{}] \t Train loss: {} \t Train acc: {}"\
                .format(epoch,
                        completion_string,
                        train_loss,
                        train_acc
                        )
            print(output, end="\r")
            
            #################
            # Backpropogate #
            #################
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
        
        ##############
        # Validation #
        ##############
        model.eval()
        with torch.inference_mode():
            avg_val_loss = 0
            avg_val_acc = 0 
            for _, (val_batch, val_labels) in enumerate(val_loader):
                preds = model(val_batch)
                targets = val_labels
                
                val_loss = criterion(preds, targets)
                avg_val_loss += val_loss
                
                val_acc = compute_acc(preds, targets)
                avg_val_acc += val_acc
            
            avg_train_loss /= num_train_batches
            avg_val_loss /= num_val_batches
            print("\nAvg Train Loss: {:.3f} \t Avg Val loss: {:.3f}".format(avg_train_loss, avg_val_loss))
            print("-"*len(output), "\n")
            
        ##########
        # Record #
        ##########
        history["train"].append(avg_train_loss)
        history["val"].append(avg_val_loss)
        
    return history

In [173]:
model = Model(INPUT_SHAPE, 10)
model.to(device=DEVICE, dtype=DTYPE)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)
criterion = torch.nn.CrossEntropyLoss(reduction="mean")

hist = train_model(model,
                   train_loader,
                   val_loader,
                   optimizer,
                   criterion,
                   epochs=50)


Avg Train Loss: 1.556 	 Avg Val loss: 1.508
----------------------------------------------------------------------------------- 

Avg Train Loss: 1.502 	 Avg Val loss: 1.500
----------------------------------------------------------------------------------- 

Avg Train Loss: 1.492 	 Avg Val loss: 1.500
------------------------------------------------------------------------ 

Avg Train Loss: 1.486 	 Avg Val loss: 1.500
--------------------------------------------------------------------------------------- 

Avg Train Loss: 1.485 	 Avg Val loss: 1.492
------------------------------------------------------------------------------- 

Avg Train Loss: 1.484 	 Avg Val loss: 1.492
----------------------------------------------------------------------------------------- 

Avg Train Loss: 1.481 	 Avg Val loss: 1.492
----------------------------------------------------------------------------------------- 

Avg Train Loss: 1.478 	 Avg Val loss: 1.492
---------------------------------------------

KeyboardInterrupt: 

Probably makes more sense to update the average with each batch, and then output that instead. Otherwise it doesn't even look like we're learning anything.

In [147]:
from sklearn.metrics import accuracy_score
from torchmetrics import Accuracy
mca = Accuracy(task="multiclass", num_classes=10, ignore_index=True).to(device=DEVICE)

torch.manual_seed(192)
for i, (img, label) in enumerate(train_loader):
    pred = model(img)
    print(mca.forward(pred.to(torch.float16),
                      label.to(torch.float16)))
    
    print(accuracy_score(torch.argmax(pred.cpu().detach().to(torch.float16), dim=1).numpy(),
                         torch.argmax(label.cpu().detach().to(torch.float16), dim=1).numpy()))
    
    print(compute_acc(pred, label))

    
    

tensor(0.9766, device='cuda:0')
0.9375
0.9375
tensor(0.9852, device='cuda:0')
0.9140625
0.9140625
tensor(0.9852, device='cuda:0')
0.9296875
0.9296875
tensor(0.9844, device='cuda:0')
0.8984375
0.8984375
tensor(0.9792, device='cuda:0')
0.890625
0.890625
tensor(0.9896, device='cuda:0')
0.9296875
0.9296875
tensor(0.9861, device='cuda:0')
0.9609375
0.9609375
tensor(0.9896, device='cuda:0')
0.9453125
0.9453125
tensor(0.9809, device='cuda:0')
0.9375
0.9375
tensor(0.9905, device='cuda:0')
0.9609375
0.9609375
tensor(0.9887, device='cuda:0')
0.9296875
0.9296875
tensor(0.9931, device='cuda:0')
0.9609375
0.9609375
tensor(0.9809, device='cuda:0')
0.9296875
0.9296875
tensor(0.9887, device='cuda:0')
0.96875
0.96875
tensor(0.9887, device='cuda:0')
0.9453125
0.9453125
tensor(0.9922, device='cuda:0')
0.9609375
0.9609375
tensor(0.9740, device='cuda:0')
0.890625
0.890625
tensor(0.9878, device='cuda:0')
0.9296875
0.9296875
tensor(0.9931, device='cuda:0')
0.9609375
0.9609375
tensor(0.9852, device='cuda:0')


KeyboardInterrupt: 

0.9921875

In [141]:
torch.nonzero(torch.tensor([1, 0, 1]))

tensor([[0],
        [2]])