In [2]:
# And now we do our imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from miniminiai import *
from torcheval.metrics import MulticlassAccuracy, Metric, Mean

In [3]:
from datasets import load_dataset, DatasetDict

# We'll use the MNIST dataset for this tutorial
dataset = load_dataset('mnist')

# # Only the first 400 samples
# dataset = DatasetDict({
#     'train':dataset['train'].select(range(3000)),
#     'test':dataset['test'].select(range(100))
# })

Found cached dataset mnist (/Users/tcapelle/.cache/huggingface/datasets/mnist/mnist/1.0.0/fda16c03c4ecfb13f165ba7e29cf38129ce035011519968cdaf74894ce91c9d4)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 81.42it/s]


In [4]:

# Turn it into dataloaders
import torchvision.transforms.functional as TF

def transforms(b):
    b['image'] = [TF.to_tensor(o) for o in b['image']]
    return b
     
dataset = dataset.with_transform(transforms)
dls = DataLoaders.from_dd(dataset,batch_size=64)

xb, yb = next(iter(dls.train))
xb.shape, yb.shape, yb[:5]

(torch.Size([64, 1, 28, 28]), torch.Size([64]), tensor([3, 7, 3, 1, 9]))

In [5]:
from torch import nn

model = nn.Sequential(
    nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    nn.AdaptiveAvgPool2d(1),
    nn.Flatten()
)

model(xb).shape

torch.Size([64, 10])

In [6]:
l = nn.CrossEntropyLoss()
def complicated_loss_function(inp, target, return_multi=False):
    a = l(inp, target)
    b = inp.pow(2).mean()
    loss = a + b
    if return_multi: return loss, a, b
    return loss

In [7]:
class MultiMetric(Mean):
    def __init__(self, *, device=None, idx=0):
        super().__init__(device=device)
        self.idx=idx
    def update(self, inp, targets): 
        self.weighted_sum += complicated_loss_function(inp, targets, return_multi=True)[self.idx]
        self.weights += 1

In [8]:
cbs = [MetricsCB(a=MultiMetric(idx=1),
                 b=MultiMetric(idx=2)), 
       DeviceCB(), ProgressCB()]
learn = TrainLearner(model, dls, complicated_loss_function, lr=0.1, cbs=cbs)
learn.fit(3)

a,b,loss,epoch,train
2.087,0.064,2.151,0,train
1.711,0.197,1.909,0,eval
1.563,0.242,1.805,1,train
1.394,0.316,1.711,1,eval
1.359,0.312,1.671,2,train
1.266,0.358,1.625,2,eval


In [None]:
class ReconLoss(Mean):
    def __init__(self, *, device=None): super().__init__(device=device)
    def update(self, inp, targets): 
        self.weighted_sum += complicated_loss_function(inp, targets, return_multi=True)[1]
        self.weights += 1 #len(targets)
        

class KLDLoss(Mean):
    def __init__(self, *, device=None): super().__init__(device=device)
    def update(self, inp, targets): 
        self.weighted_sum += complicated_loss_function(inp, targets, return_multi=True)[2]
        self.weights += 1 #len(targets)
        

In [None]:
cbs = [MetricsCB(ReconLoss(), KLDLoss(), MulticlassAccuracy()), 
       DeviceCB(), ProgressCB()]
learn = TrainLearner(model, dls, complicated_loss_function, lr=0.1, cbs=cbs)
learn.fit(3)

ReconLoss,KLDLoss,MulticlassAccuracy,loss,epoch,train
2.208,0.025,0.259,2.233,0,train
1.909,0.119,0.564,2.028,0,eval
1.645,0.215,0.659,1.86,1,train
1.518,0.257,0.701,1.775,1,eval
1.427,0.291,0.725,1.718,2,train
1.36,0.318,0.751,1.678,2,eval
