In [143]:
import torch
import torchvision
from torchvision.transforms.functional import to_tensor
import torch.utils.data as D
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

DEVICE = "cuda:0"
DTYPE = torch.bfloat16

In [144]:
def transform(x):
    x = to_tensor(x)
    x = x.to(device=DEVICE, dtype=DTYPE)
    return x


In [165]:
mnist_train = torchvision.datasets.MNIST("./MNIST",
                                        transform=transform,
                                        train=True,
                                        download=True)

mnist_test = torchvision.datasets.MNIST("./MNIST",
                                        transform=transform,
                                        train=False,
                                        download=True)
                                        

combined = mnist_train + mnist_test


In [146]:
rand_gen = torch.Generator().manual_seed(192)
train, val, test = D.random_split(combined, [0.8, 0.1, 0.1], generator=rand_gen)

In [147]:
train_loader = D.DataLoader(train, 128, shuffle=True)
val_loader = D.DataLoader(val, 128)
test_loader = D.DataLoader(test, 64)

In [148]:
INPUT_SHAPE = torch.prod(torch.tensor(combined[0][0].shape)).item()

In [158]:
class Model(nn.Module):
    def __init__(self,
                 input_dim: int,
                 output_dim: int):
        super().__init__()
        self._input_dim = input_dim
        self._output_dim = output_dim
        
        self._net = nn.Sequential(
            nn.Flatten(start_dim=1, end_dim=-1),
            nn.Linear(input_dim, 200),
            nn.ReLU(False),
            #nn.BatchNorm1d(200),
            nn.Linear(200, 100),
            nn.ReLU(False),
            #nn.BatchNorm1d(100),
            nn.Linear(100, 50),
            nn.ReLU(False),
            #nn.BatchNorm1d(50),
            nn.Linear(50, output_dim),
            nn.Softmax(dim=-1)
        )
        
    def forward(self, x):
        return self._net(x)
    
    #@torch.inference_mode()
    #def predict(self, x):
        #return self._net(x)

In [163]:
from math import ceil
from IPython.display import clear_output
from torchmetrics.classification import MulticlassAccuracy

def train_model(
    model: nn.Module,
    train_loader: D.DataLoader,
    val_loader: D.DataLoader,
    optimizer,
    criterion,
    epochs=100):
    
    mca = MulticlassAccuracy(10).to(device=DEVICE)
    
    history = {"train":[], "val":[]}
    num_train_batches = ceil(len(train_loader.dataset) / train_loader.batch_size)
    num_val_batches = ceil(len(val_loader.dataset) / val_loader.batch_size)
    
    for epoch in range(epochs):
        avg_train_loss = 0
        avg_train_acc = 0
        
        for step, (train_batch, train_labels) in enumerate(train_loader):
            # Forward
            preds = model(train_batch)
            targets = F.one_hot(train_labels).to(device=DEVICE, dtype=DTYPE)
            
            # Compute Metrics
            train_loss = criterion(preds, targets)
            print(train_labels.device)
            with torch.inference_mode():
                train_acc = mca(preds, train_labels)
                avg_train_loss += train_loss.item()
                avg_train_acc += train_acc.item()

            #Make output
            #
            #num_bars = int(((step / num_train_batches) * 20)) + 1
            #completion_string = "="*num_bars
            #completion_string += "-"*(20 - num_bars)
            completion_string = "$"
            print(
                "Epoch: {} \t [{}] \t Train loss: {} \t Train acc: {}"\
                .format(epoch,
                        completion_string,
                        train_loss,
                        train_acc
                        ),
                end="\r"
                )
            
            # Backpropogate
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
 
        # Validation
        with torch.inference_mode():
            avg_val_loss = 0
            for _, (val_batch, val_labels) in enumerate(val_loader):
                val_loss = criterion(
                    model.forward(val_batch),
                    F.one_hot(val_labels, model._output_dim).to(device="cuda:0", dtype=DTYPE)
                    )
                avg_val_loss += val_loss
            
            avg_val_loss /= num_val_batches
            print("\nAvg Val loss: {:.3f}".format(avg_val_loss))

        # Record
        avg_train_loss /= num_train_batches
        avg_train_acc /= num_train_batches
        history["train"].append(avg_train_loss)
        history["val"].append(avg_val_loss)
        
    return history

In [164]:
model = Model(INPUT_SHAPE, 10)
model.to(device=DEVICE, dtype=DTYPE)

optimizer = torch.optim.SGD(model.parameters(), lr=1e-1)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)
criterion = torch.nn.CrossEntropyLoss(reduction="mean")

hist = train_model(model,
                   train_loader,
                   val_loader,
                   optimizer,
                   criterion,
                   epochs=50)


cpu


RuntimeError: Encountered different devices in metric calculation (see stacktrace for details). This could be due to the metric class not being on the same device as input. Instead of `metric=MulticlassAccuracy(...)` try to do `metric=MulticlassAccuracy(...).to(device)` where device corresponds to the device of the input.

In [131]:
from sklearn.metrics import accuracy_score
mca = MulticlassAccuracy(10).to(device=DEVICE)

for i, (img, label) in enumerate(train_loader):
    pred = model(img)
    print(mca(pred, label.to(device=DEVICE)))
    
    

tensor(0.7239, device='cuda:0')
tensor(0.7289, device='cuda:0')
tensor(0.6800, device='cuda:0')
tensor(0.7679, device='cuda:0')
tensor(0.6963, device='cuda:0')
tensor(0.7662, device='cuda:0')
tensor(0.7713, device='cuda:0')
tensor(0.6979, device='cuda:0')
tensor(0.7642, device='cuda:0')
tensor(0.7395, device='cuda:0')
tensor(0.7269, device='cuda:0')
tensor(0.7515, device='cuda:0')
tensor(0.7302, device='cuda:0')
tensor(0.7048, device='cuda:0')
tensor(0.7585, device='cuda:0')
tensor(0.7291, device='cuda:0')
tensor(0.7104, device='cuda:0')
tensor(0.7161, device='cuda:0')
tensor(0.7446, device='cuda:0')
tensor(0.6651, device='cuda:0')
tensor(0.7314, device='cuda:0')
tensor(0.7222, device='cuda:0')
tensor(0.7402, device='cuda:0')
tensor(0.7448, device='cuda:0')
tensor(0.7256, device='cuda:0')
tensor(0.6812, device='cuda:0')
tensor(0.6981, device='cuda:0')
tensor(0.7429, device='cuda:0')
tensor(0.6393, device='cuda:0')
tensor(0.7467, device='cuda:0')
tensor(0.7316, device='cuda:0')
tensor(0

KeyboardInterrupt: 

In [None]:
F.cross_entropy()