## Settings

In [9]:
import torch

PARTS = 100
IMAGE_SIZE = (40, 40)
BATCH_SIZE = 32
assert torch.cuda.is_available()
DEVICE = torch.device('cuda')


## Load the Dataset

In [10]:
import torchvision.transforms
to_tensor = torchvision.transforms.ToTensor()
upscaling = torchvision.transforms.Resize(size=IMAGE_SIZE, antialias=True)
dataset_transform = torchvision.transforms.Compose([to_tensor, upscaling])


In [11]:
import torch.utils.data
import torchvision.datasets

train = torchvision.datasets.MNIST(
    root = "./data",
    train = True,
    download = True,
    transform = dataset_transform
)

train_parts = torch.utils.data.random_split(train, lengths=[1 / PARTS - 1e-12]*PARTS)


## Build the Model


In [12]:
import torch.nn


class SimpleCNN(torch.nn.Module):
    def __init__(self):
        super().__init__()

        #  1x40x40
        self.conv1 = torch.nn.Conv2d(
            in_channels=1, out_channels=32, kernel_size=3, padding='same')
        # 32x40x40
        self.maxpl = torch.nn.MaxPool2d(kernel_size=2)
        # 16x20x20
        self.conv2 = torch.nn.Conv2d(
            in_channels=32, out_channels=64, kernel_size=3, padding='same')
        # 64x20x20
        # self.avgpl = torch.nn.AvgPool2d(kernel_size=4)
        # 64x5x5
        self.flatt = torch.nn.Flatten()
        # 1600
        self.line1 = torch.nn.Linear(in_features=64*20*20, out_features=128)
        # 128
        self.activ = torch.nn.ReLU()
        # 128
        self.feats = torch.nn.Linear(in_features=128, out_features=10)
        # 10

    def forward(self, x: torch.Tensor):
        x = self.conv1(x)
        x = self.maxpl(x)
        x = self.conv2(x)
        # x = self.avgpl(x)
        x = self.flatt(x)
        x = self.line1(x)
        x = self.activ(x)
        x = self.feats(x)
        return x


model = SimpleCNN().to(DEVICE)

## Train the Model


In [13]:
import torch.nn

criterion = torch.nn.CrossEntropyLoss().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters())


In [14]:
from torch.utils.data import DataLoader, ConcatDataset

def generate_loaders(vpart: int) -> tuple[DataLoader, DataLoader]:
    train_dataset = ConcatDataset(train_parts[:vpart] + train_parts[vpart + 1:])
    train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle=True)

    validation_dataset = train_parts[vpart]
    validation_loader = DataLoader(validation_dataset, batch_size = BATCH_SIZE, shuffle=False)
    return train_loader, validation_loader


In [15]:
from tqdm.notebook import tqdm_notebook
import torch
import torch.utils.data

def train_epoch(loader: torch.utils.data.DataLoader, train: bool) -> tuple[int, float]:
    total_acc = 0
    total_loss = 0

    for x, y in tqdm_notebook(loader):
        x, y = x.to(DEVICE), y.to(DEVICE)

        output = model(x)
        batch_loss = criterion(output, y)
        total_loss += batch_loss
        
        if train:
            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()

        total_acc += (output.argmax(dim=1) == y).sum()
    return total_acc.item()/len(loader.dataset), total_loss.item()/len(loader.dataset)


In [16]:
best_loss_val = 1e18
for epoch in range(PARTS):
    train_loader, validation_loader = generate_loaders(epoch % PARTS)
    acc_train, loss_train = train_epoch(train_loader, True)
    with torch.no_grad():
        acc_val, loss_val = train_epoch(validation_loader, False)
    print(f"Epoch {epoch+1}: {acc_train = } {loss_train = } {acc_val = } {loss_val = }")
    if best_loss_val > loss_val:
        print(f"Saving model: {best_loss_val = } > {loss_val = }")
        best_loss_val = loss_val
        torch.save(model.state_dict(), './data/model_dict.pt')


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 1: acc_train = 0.9471043771043771 loss_train = 0.00544374613649516 acc_val = 0.98 loss_val = 0.0022764186064402263
Saving model: best_loss_val = 1e+18 > loss_val = 0.0022764186064402263


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 2: acc_train = 0.9793939393939394 loss_train = 0.0020711227057357427 acc_val = 0.9916666666666667 loss_val = 0.0006517524023850759
Saving model: best_loss_val = 0.0022764186064402263 > loss_val = 0.0006517524023850759


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 3: acc_train = 0.9845286195286195 loss_train = 0.0015946264299078019 acc_val = 0.9816666666666667 loss_val = 0.002106059392293294


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 4: acc_train = 0.9869360269360269 loss_train = 0.0013666558827615347 acc_val = 0.985 loss_val = 0.0012057883540789287


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 5: acc_train = 0.9891919191919192 loss_train = 0.0011357132112136995 acc_val = 0.99 loss_val = 0.0008812300364176432


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 6: acc_train = 0.991010101010101 loss_train = 0.0009872267623541732 acc_val = 0.9866666666666667 loss_val = 0.0011345015962918599


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 7: acc_train = 0.9913636363636363 loss_train = 0.000973891890811599 acc_val = 0.9866666666666667 loss_val = 0.00216267466545105


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 8: acc_train = 0.9921717171717171 loss_train = 0.0008877343200272582 acc_val = 0.9883333333333333 loss_val = 0.0013707684477170308


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 9: acc_train = 0.9924915824915825 loss_train = 0.0008039060265126854 acc_val = 0.9916666666666667 loss_val = 0.0014043123523394267


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 10: acc_train = 0.9937205387205387 loss_train = 0.0007939244279957781 acc_val = 0.9966666666666667 loss_val = 0.0005626798669497172
Saving model: best_loss_val = 0.0006517524023850759 > loss_val = 0.0005626798669497172


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 11: acc_train = 0.9941919191919192 loss_train = 0.0007963159991434528 acc_val = 0.99 loss_val = 0.00064613143603007


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 12: acc_train = 0.9949663299663299 loss_train = 0.0006640388668586911 acc_val = 0.9933333333333333 loss_val = 0.0008923548460006714


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 13: acc_train = 0.9942087542087542 loss_train = 0.0008756008212413853 acc_val = 1.0 loss_val = 9.419346228241921e-05
Saving model: best_loss_val = 0.0005626798669497172 > loss_val = 9.419346228241921e-05


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 14: acc_train = 0.9946801346801347 loss_train = 0.0007405406939060198 acc_val = 0.9933333333333333 loss_val = 0.0010505211353302002


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 15: acc_train = 0.9955723905723906 loss_train = 0.0006632086885497225 acc_val = 0.995 loss_val = 0.0005079635481039683


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 16: acc_train = 0.9953367003367003 loss_train = 0.0007545250915116333 acc_val = 0.995 loss_val = 0.0004916038612524669


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 17: acc_train = 0.9957575757575757 loss_train = 0.0007010465037541758 acc_val = 0.9866666666666667 loss_val = 0.00246644655863444


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 18: acc_train = 0.9958080808080808 loss_train = 0.0006300118314698088 acc_val = 0.9916666666666667 loss_val = 0.0021887675921122233


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 19: acc_train = 0.9958585858585859 loss_train = 0.0007901972632616859 acc_val = 0.9966666666666667 loss_val = 0.0008542049924532573


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 20: acc_train = 0.9958585858585859 loss_train = 0.0007176171890412918 acc_val = 0.9966666666666667 loss_val = 0.0015613104899724324


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 21: acc_train = 0.9965319865319865 loss_train = 0.0007205956712716356 acc_val = 0.9883333333333333 loss_val = 0.005819808642069498


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 22: acc_train = 0.9962289562289562 loss_train = 0.000795787528709129 acc_val = 0.9983333333333333 loss_val = 8.705512310067813e-05
Saving model: best_loss_val = 9.419346228241921e-05 > loss_val = 8.705512310067813e-05


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 23: acc_train = 0.9961952861952862 loss_train = 0.0007411605501014375 acc_val = 0.9916666666666667 loss_val = 0.0007455598811308543


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 24: acc_train = 0.9971548821548821 loss_train = 0.0005934929767441669 acc_val = 0.9983333333333333 loss_val = 0.0015149610241254172


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 25: acc_train = 0.9962962962962963 loss_train = 0.0007736962572091356 acc_val = 0.9983333333333333 loss_val = 0.0004964552819728852


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 26: acc_train = 0.9965151515151515 loss_train = 0.0008213550554782854 acc_val = 0.995 loss_val = 0.0009253327051798502


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 27: acc_train = 0.9970707070707071 loss_train = 0.0007461629411588213 acc_val = 0.9983333333333333 loss_val = 0.00034904027978579203


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 28: acc_train = 0.9971548821548821 loss_train = 0.0006445038920701152 acc_val = 1.0 loss_val = 1.9187815487384795e-05
Saving model: best_loss_val = 8.705512310067813e-05 > loss_val = 1.9187815487384795e-05


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 29: acc_train = 0.9971043771043772 loss_train = 0.0008052454091081716 acc_val = 0.99 loss_val = 0.0027979818979899087


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 30: acc_train = 0.9975757575757576 loss_train = 0.0005531338649968105 acc_val = 0.9916666666666667 loss_val = 0.001980628768603007


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 31: acc_train = 0.9973737373737374 loss_train = 0.0008491796756834294 acc_val = 0.9933333333333333 loss_val = 0.0021211689710617064


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 32: acc_train = 0.9973737373737374 loss_train = 0.0008146420873776831 acc_val = 0.9983333333333333 loss_val = 0.0014318133393923442


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 33: acc_train = 0.9976094276094276 loss_train = 0.0006875808150679977 acc_val = 0.995 loss_val = 0.0014083263278007506


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 34: acc_train = 0.9967508417508417 loss_train = 0.0011071964867588647 acc_val = 0.9983333333333333 loss_val = 0.0007237737874190012


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 35: acc_train = 0.997979797979798 loss_train = 0.0005917170232393926 acc_val = 0.9966666666666667 loss_val = 0.00034768857061862946


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 36: acc_train = 0.9976599326599327 loss_train = 0.0007146165507409709 acc_val = 0.9966666666666667 loss_val = 0.001474528710047404


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 37: acc_train = 0.9976430976430977 loss_train = 0.0008721763919098209 acc_val = 0.9983333333333333 loss_val = 6.829700122276942e-05


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 38: acc_train = 0.9978451178451179 loss_train = 0.0008067174712415496 acc_val = 0.9983333333333333 loss_val = 0.0010648043950398764


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 39: acc_train = 0.9977609427609427 loss_train = 0.0008996350837476326 acc_val = 0.9983333333333333 loss_val = 0.0001773823673526446


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 40: acc_train = 0.9974915824915825 loss_train = 0.000874313007701527 acc_val = 0.995 loss_val = 0.0016895014047622681


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 41: acc_train = 0.9979124579124579 loss_train = 0.0008778577220158947 acc_val = 1.0 loss_val = 3.042311883897734e-10
Saving model: best_loss_val = 1.9187815487384795e-05 > loss_val = 3.042311883897734e-10


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 42: acc_train = 0.9981313131313131 loss_train = 0.0007682246070116859 acc_val = 1.0 loss_val = 2.781517309813353e-09


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 43: acc_train = 0.9977272727272727 loss_train = 0.0009166519730179399 acc_val = 0.9933333333333333 loss_val = 0.0014323778947194416


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 44: acc_train = 0.9981313131313131 loss_train = 0.0007665430576311618 acc_val = 0.9983333333333333 loss_val = 0.00028869556883970895


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 45: acc_train = 0.9978114478114478 loss_train = 0.0011345986003426188 acc_val = 0.9983333333333333 loss_val = 3.7871692329645155e-05


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 46: acc_train = 0.9981481481481481 loss_train = 0.0008945364261717106 acc_val = 0.995 loss_val = 0.004016694227854411


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 47: acc_train = 0.9975757575757576 loss_train = 0.0009972991044272478 acc_val = 0.995 loss_val = 0.003440189758936564


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 48: acc_train = 0.9983838383838384 loss_train = 0.0006767435427065248 acc_val = 1.0 loss_val = 2.955568682712813e-07


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 49: acc_train = 0.9980471380471381 loss_train = 0.0010975143483993581 acc_val = 0.9933333333333333 loss_val = 0.005298977295557658


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 50: acc_train = 0.9983164983164983 loss_train = 0.0010598150567976312 acc_val = 1.0 loss_val = 6.950622579703728e-06


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 51: acc_train = 0.9986531986531987 loss_train = 0.0007388212544348103 acc_val = 0.9983333333333333 loss_val = 0.0008457465966542561


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 52: acc_train = 0.9986363636363637 loss_train = 0.0007035741581258549 acc_val = 0.995 loss_val = 0.00858082930246989


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 53: acc_train = 0.9981481481481481 loss_train = 0.0011272576842645203 acc_val = 0.9983333333333333 loss_val = 0.0017222978671391806


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 54: acc_train = 0.9983501683501683 loss_train = 0.0008530936578307489 acc_val = 1.0 loss_val = 5.525828328245552e-10


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 55: acc_train = 0.9983670033670033 loss_train = 0.0011106758246116767 acc_val = 0.9966666666666667 loss_val = 0.0034439233938852944


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 56: acc_train = 0.9987878787878788 loss_train = 0.000908779182819405 acc_val = 1.0 loss_val = 0.0
Saving model: best_loss_val = 3.042311883897734e-10 > loss_val = 0.0


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 57: acc_train = 0.9984343434343435 loss_train = 0.0008935362562185988 acc_val = 0.9983333333333333 loss_val = 0.00035733431577682494


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 58: acc_train = 0.9986531986531987 loss_train = 0.0009145300637190591 acc_val = 1.0 loss_val = 1.9247297681583102e-10


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 59: acc_train = 0.9984511784511785 loss_train = 0.000919802887271149 acc_val = 0.995 loss_val = 0.0024951553344726564


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 60: acc_train = 0.9987037037037036 loss_train = 0.0009402060589003644 acc_val = 1.0 loss_val = 6.425265382858924e-08


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 61: acc_train = 0.9987037037037036 loss_train = 0.000936611573704164 acc_val = 0.9983333333333333 loss_val = 0.0013759112358093261


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 62: acc_train = 0.9986026936026936 loss_train = 0.0009619721178253893 acc_val = 0.9983333333333333 loss_val = 0.005543512503306071


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 63: acc_train = 0.9983333333333333 loss_train = 0.0009693924104324495 acc_val = 0.9966666666666667 loss_val = 0.0007266632715861003


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 64: acc_train = 0.9983501683501683 loss_train = 0.0009901527764419914 acc_val = 0.9983333333333333 loss_val = 0.0013416463136672974


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 65: acc_train = 0.9986868686868687 loss_train = 0.0011045299234615032 acc_val = 1.0 loss_val = 0.0


  0%|          | 0/1857 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 66: acc_train = 0.9988552188552189 loss_train = 0.0008018417229957452 acc_val = 0.9983333333333333 loss_val = 0.0003016261011362076


  0%|          | 0/1857 [00:00<?, ?it/s]

## Testing

In [None]:
import torch.utils.data
import torchvision.datasets

test = torchvision.datasets.MNIST(
    root = "./data",
    train = False,
    download = True,
    transform = dataset_transform
)


In [None]:
model.load_state_dict(torch.load('./data/model_dict.pt'))
test_loader = DataLoader(test, batch_size = BATCH_SIZE, shuffle=False)
with torch.no_grad():
    acc_test, loss_test = train_epoch(test_loader, False)

print(f"{acc_test = }, {loss_test = }")


## Export


In [None]:
import torch

model.load_state_dict(torch.load('./data/model_dict.pt'))

random_input = torch.rand(1, 1, *IMAGE_SIZE).to(DEVICE)
script_module = torch.jit.trace(model, random_input)

script_module.save('./data/model_exported.pt')
