## Settings

In [1]:
import torch

PARTS = 10
IMAGE_SIZE = (40, 40)
BATCH_SIZE = 32
assert torch.cuda.is_available()
DEVICE = torch.device('cuda')


## Load the Dataset

In [2]:
import torchvision.transforms
to_tensor = torchvision.transforms.ToTensor()
upscaling = torchvision.transforms.Resize(size=IMAGE_SIZE, antialias=True)
dataset_transform = torchvision.transforms.Compose([to_tensor, upscaling])


In [3]:
import torch.utils.data
import torchvision.datasets

train = torchvision.datasets.MNIST(
    root = "./data",
    train = True,
    download = True,
    transform = dataset_transform
)

train_parts = torch.utils.data.random_split(train, lengths=[1/PARTS - 1e-12]*PARTS)


## Build the Model


In [4]:
import torch.nn


class SimpleCNN(torch.nn.Module):
    def __init__(self):
        super().__init__()

        #  1x40x40
        self.conv1 = torch.nn.Conv2d(
            in_channels=1, out_channels=32, kernel_size=3, padding='same')
        # 32x40x40
        self.maxpl = torch.nn.MaxPool2d(kernel_size=2)
        # 16x20x20
        self.conv2 = torch.nn.Conv2d(
            in_channels=32, out_channels=64, kernel_size=3, padding='same')
        # 64x20x20
        self.avgpl = torch.nn.AvgPool2d(kernel_size=4)
        # 64x5x5
        self.flatt = torch.nn.Flatten()
        # 1600
        self.line1 = torch.nn.Linear(in_features=1600, out_features=128)
        # 128
        self.activ = torch.nn.ReLU()
        # 128
        self.feats = torch.nn.Linear(in_features=128, out_features=10)
        # 10

    def forward(self, x: torch.Tensor):
        x = self.conv1(x)
        x = self.maxpl(x)
        x = self.conv2(x)
        x = self.avgpl(x)
        x = self.flatt(x)
        x = self.line1(x)
        x = self.activ(x)
        x = self.feats(x)
        return x


model = SimpleCNN().to(DEVICE)

## Train the Model


In [5]:
import torch.nn

criterion = torch.nn.CrossEntropyLoss().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters())


In [6]:
from torch.utils.data import DataLoader, ConcatDataset

def generate_loaders(vpart: int) -> tuple[DataLoader, DataLoader]:
    train_dataset = ConcatDataset(train_parts[:vpart] + train_parts[vpart + 1:])
    train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle=True)

    validation_dataset = train_parts[vpart]
    validation_loader = DataLoader(validation_dataset, batch_size = BATCH_SIZE, shuffle=False)
    return train_loader, validation_loader


In [7]:
from tqdm.notebook import tqdm_notebook
import torch
import torch.utils.data

def train_epoch(loader: torch.utils.data.DataLoader, train: bool) -> tuple[int, float]:
    total_acc = 0
    total_loss = 0

    for x, y in tqdm_notebook(loader):
        x, y = x.to(DEVICE), y.to(DEVICE)

        output = model(x)
        batch_loss = criterion(output, y)
        total_loss += batch_loss
        
        if train:
            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()

        total_acc += (output.argmax(dim=1) == y).sum()
    return total_acc.item()/len(loader.dataset), total_loss.item()/len(loader.dataset)


In [8]:
best_loss_val = 1e18
for epoch in range(PARTS):
    train_loader, validation_loader = generate_loaders(epoch % PARTS)
    acc_train, loss_train = train_epoch(train_loader, True)
    with torch.no_grad():
        acc_val, loss_val = train_epoch(validation_loader, False)
    print(f"Epoch {epoch+1}: {acc_train = } {loss_train = } {acc_val = } {loss_val = }")
    if best_loss_val > loss_val:
        print(f"Saving model: {best_loss_val = } > {loss_val = }")
        best_loss_val = loss_val
        torch.save(model.state_dict(), './data/model_dict.pt')


  0%|          | 0/1688 [00:00<?, ?it/s]

  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 1: acc_train = 0.9105925925925926 loss_train = 0.008874165287724247 acc_val = 0.9523333333333334 loss_val = 0.004857874234517416
Saving model: best_loss_val = 1e+18 > loss_val = 0.004857874234517416


  0%|          | 0/1688 [00:00<?, ?it/s]

  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 2: acc_train = 0.9611296296296297 loss_train = 0.003899815594708478 acc_val = 0.9631666666666666 loss_val = 0.003648871421813965
Saving model: best_loss_val = 0.004857874234517416 > loss_val = 0.003648871421813965


  0%|          | 0/1688 [00:00<?, ?it/s]

  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 3: acc_train = 0.9738518518518519 loss_train = 0.002611234311704282 acc_val = 0.9661666666666666 loss_val = 0.0032255420684814454
Saving model: best_loss_val = 0.003648871421813965 > loss_val = 0.0032255420684814454


  0%|          | 0/1688 [00:00<?, ?it/s]

  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 4: acc_train = 0.98 loss_train = 0.001980682796902127 acc_val = 0.9848333333333333 loss_val = 0.0016717012723286946
Saving model: best_loss_val = 0.0032255420684814454 > loss_val = 0.0016717012723286946


  0%|          | 0/1688 [00:00<?, ?it/s]

  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 5: acc_train = 0.9826481481481482 loss_train = 0.00164574220445421 acc_val = 0.9805 loss_val = 0.0019105979601542155


  0%|          | 0/1688 [00:00<?, ?it/s]

  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 6: acc_train = 0.9854814814814815 loss_train = 0.001408216264512804 acc_val = 0.9865 loss_val = 0.001198611815770467
Saving model: best_loss_val = 0.0016717012723286946 > loss_val = 0.001198611815770467


  0%|          | 0/1688 [00:00<?, ?it/s]

  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 7: acc_train = 0.9874444444444445 loss_train = 0.001227132726598669 acc_val = 0.9846666666666667 loss_val = 0.0012881425221761068


  0%|          | 0/1688 [00:00<?, ?it/s]

  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 8: acc_train = 0.9897592592592592 loss_train = 0.0010173548945674189 acc_val = 0.988 loss_val = 0.0011150782108306884
Saving model: best_loss_val = 0.001198611815770467 > loss_val = 0.0011150782108306884


  0%|          | 0/1688 [00:00<?, ?it/s]

  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 9: acc_train = 0.9900555555555556 loss_train = 0.0009392291881419995 acc_val = 0.9885 loss_val = 0.0013371709187825521


  0%|          | 0/1688 [00:00<?, ?it/s]

  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 10: acc_train = 0.9917222222222222 loss_train = 0.000815513257627134 acc_val = 0.9833333333333333 loss_val = 0.0017046488126118977


## Testing

In [9]:
import torch.utils.data
import torchvision.datasets

test = torchvision.datasets.MNIST(
    root = "./data",
    train = False,
    download = True,
    transform = dataset_transform
)


In [10]:
model.load_state_dict(torch.load('./data/model_dict.pt'))
test_loader = DataLoader(test, batch_size = BATCH_SIZE, shuffle=False)
with torch.no_grad():
    acc_test, loss_test = train_epoch(test_loader, False)

print(f"{acc_test = }, {loss_test = }")


  0%|          | 0/313 [00:00<?, ?it/s]

acc_test = 0.9869, loss_test = 0.0013912675857543945


## Export


In [11]:
import torch

model.load_state_dict(torch.load('./data/model_dict.pt'))

random_input = torch.rand(1, 1, *IMAGE_SIZE).to(DEVICE)
script_module = torch.jit.trace(model, random_input)

script_module.save('./data/model_exported.pt')
