## Settings

In [1]:
import torch

PARTS = 10
IMAGE_SIZE = (40, 40)
BATCH_SIZE = 32
assert torch.cuda.is_available()
DEVICE = torch.device('cpu')


AssertionError: 

: 

## Load the Dataset

In [2]:
import torchvision.transforms
to_tensor = torchvision.transforms.ToTensor()
upscaling = torchvision.transforms.Resize(size=IMAGE_SIZE, antialias=True)
dataset_transform = torchvision.transforms.Compose([to_tensor, upscaling])


In [3]:
import torch.utils.data
import torchvision.datasets

train = torchvision.datasets.MNIST(
    root = "./data",
    train = True,
    download = True,
    transform = dataset_transform
)

train_parts = torch.utils.data.random_split(train, lengths=[1/PARTS - 1e-12]*PARTS)


## Build the Model


In [4]:
import torch.nn


class SimpleCNN(torch.nn.Module):
    def __init__(self):
        super().__init__()

        #  1x40x40
        self.conv1 = torch.nn.Conv2d(
            in_channels=1, out_channels=32, kernel_size=3, padding='same')
        # 32x40x40
        self.maxpl = torch.nn.MaxPool2d(kernel_size=2)
        # 16x20x20
        self.conv2 = torch.nn.Conv2d(
            in_channels=32, out_channels=64, kernel_size=3, padding='same')
        # 64x20x20
        self.avgpl = torch.nn.AvgPool2d(kernel_size=4)
        # 64x5x5
        self.flatt = torch.nn.Flatten()
        # 1600
        self.line1 = torch.nn.Linear(in_features=1600, out_features=128)
        # 128
        self.activ = torch.nn.ReLU()
        # 128
        self.feats = torch.nn.Linear(in_features=128, out_features=10)
        # 10

    def forward(self, x: torch.Tensor):
        x = self.conv1(x)
        x = self.maxpl(x)
        x = self.conv2(x)
        x = self.avgpl(x)
        x = self.flatt(x)
        x = self.line1(x)
        x = self.activ(x)
        x = self.feats(x)
        return x


model = SimpleCNN().to(DEVICE)

## Train the Model


In [5]:
import torch.nn

criterion = torch.nn.CrossEntropyLoss().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters())


In [6]:
from torch.utils.data import DataLoader, ConcatDataset

def generate_loaders(vpart: int) -> tuple[DataLoader, DataLoader]:
    train_dataset = ConcatDataset(train_parts[:vpart] + train_parts[vpart + 1:])
    train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle=True)

    validation_dataset = train_parts[vpart]
    validation_loader = DataLoader(validation_dataset, batch_size = BATCH_SIZE, shuffle=False)
    return train_loader, validation_loader


In [7]:
from tqdm.notebook import tqdm_notebook
import torch
import torch.utils.data

def train_epoch(loader: torch.utils.data.DataLoader, train: bool) -> tuple[int, float]:
    total_acc = 0
    total_loss = 0

    for x, y in tqdm_notebook(loader):
        x, y = x.to(DEVICE), y.to(DEVICE)

        output = model(x)
        batch_loss = criterion(output, y)
        total_loss += batch_loss
        
        if train:
            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()

        total_acc += (output.argmax(dim=1) == y).sum()
    return total_acc.item()/len(loader.dataset), total_loss.item()/len(loader.dataset)


In [8]:
best_loss_val = 1e18
for epoch in range(PARTS):
    train_loader, validation_loader = generate_loaders(epoch % PARTS)
    acc_train, loss_train = train_epoch(train_loader, True)
    with torch.no_grad():
        acc_val, loss_val = train_epoch(validation_loader, False)
    print(f"Epoch {epoch+1}: {acc_train = } {loss_train = } {acc_val = } {loss_val = }")
    if best_loss_val > loss_val:
        print(f"Saving model: {best_loss_val = } > {loss_val = }")
        best_loss_val = loss_val
        torch.save(model.state_dict(), './data/model_dict.pt')


  0%|          | 0/1688 [00:00<?, ?it/s]

  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 1: acc_train = 0.9084444444444445 loss_train = 0.009213874534324363 acc_val = 0.9433333333333334 loss_val = 0.005899653116861979
Saving model: best_loss_val = 1e+18 > loss_val = 0.005899653116861979


  0%|          | 0/1688 [00:00<?, ?it/s]

  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 2: acc_train = 0.9597777777777777 loss_train = 0.004212257667824074 acc_val = 0.9663333333333334 loss_val = 0.0035088698069254557
Saving model: best_loss_val = 0.005899653116861979 > loss_val = 0.0035088698069254557


  0%|          | 0/1688 [00:00<?, ?it/s]

  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 3: acc_train = 0.9713703703703703 loss_train = 0.0029046777795862268 acc_val = 0.9721666666666666 loss_val = 0.0026678361892700196
Saving model: best_loss_val = 0.0035088698069254557 > loss_val = 0.0026678361892700196


  0%|          | 0/1688 [00:00<?, ?it/s]

  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 4: acc_train = 0.978 loss_train = 0.0022422088340476706 acc_val = 0.9793333333333333 loss_val = 0.00213171116511027
Saving model: best_loss_val = 0.0026678361892700196 > loss_val = 0.00213171116511027


  0%|          | 0/1688 [00:00<?, ?it/s]

  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 5: acc_train = 0.9817037037037037 loss_train = 0.0017918551409686052 acc_val = 0.9821666666666666 loss_val = 0.0017542096773783366
Saving model: best_loss_val = 0.00213171116511027 > loss_val = 0.0017542096773783366


  0%|          | 0/1688 [00:00<?, ?it/s]

  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 6: acc_train = 0.9850185185185185 loss_train = 0.0015186908863208913 acc_val = 0.9853333333333333 loss_val = 0.0014068929354349772
Saving model: best_loss_val = 0.0017542096773783366 > loss_val = 0.0014068929354349772


  0%|          | 0/1688 [00:00<?, ?it/s]

  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 7: acc_train = 0.9867222222222222 loss_train = 0.0012797833195439092 acc_val = 0.9868333333333333 loss_val = 0.001229344367980957
Saving model: best_loss_val = 0.0014068929354349772 > loss_val = 0.001229344367980957


  0%|          | 0/1688 [00:00<?, ?it/s]

  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 8: acc_train = 0.9887592592592592 loss_train = 0.0010895832909478083 acc_val = 0.9906666666666667 loss_val = 0.0008855327765146891
Saving model: best_loss_val = 0.001229344367980957 > loss_val = 0.0008855327765146891


  0%|          | 0/1688 [00:00<?, ?it/s]

  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 9: acc_train = 0.9897962962962963 loss_train = 0.0009706606688322844 acc_val = 0.9863333333333333 loss_val = 0.0014676952362060547


  0%|          | 0/1688 [00:00<?, ?it/s]

  0%|          | 0/188 [00:00<?, ?it/s]

Epoch 10: acc_train = 0.9909814814814815 loss_train = 0.0008677931185121889 acc_val = 0.9911666666666666 loss_val = 0.0008827939033508301
Saving model: best_loss_val = 0.0008855327765146891 > loss_val = 0.0008827939033508301


## Testing

In [9]:
import torch.utils.data
import torchvision.datasets

test = torchvision.datasets.MNIST(
    root = "./data",
    train = False,
    download = True,
    transform = dataset_transform
)


In [10]:
model.load_state_dict(torch.load('./data/model_dict.pt'))
test_loader = DataLoader(test, batch_size = BATCH_SIZE, shuffle=False)
with torch.no_grad():
    acc_test, loss_test = train_epoch(test_loader, False)

print(f"{acc_test = }, {loss_test = }")


  0%|          | 0/313 [00:00<?, ?it/s]

acc_test = 0.9865, loss_test = 0.0015363811492919923


## Export


In [11]:
import torch

model.load_state_dict(torch.load('./data/model_dict.pt'))

random_input = torch.rand(1, 1, *IMAGE_SIZE).to(DEVICE)
script_module = torch.jit.trace(model, random_input)

script_module.save('./data/model_exported.pt')


[W NNPACK.cpp:64] Could not initialize NNPACK! Reason: Unsupported hardware.
