## Settings

In [None]:
import torch

PARTS = 20
IMAGE_SIZE = (40, 40)
BATCH_SIZE = 66
assert torch.cuda.is_available()
DEVICE = torch.device('cuda')


## Load the Dataset

In [None]:
import torchvision.transforms
# import random
# random.seed(69420)
stx, sty = 0, 0
to_tensor = torchvision.transforms.ToTensor()
upscaling = torchvision.transforms.Resize(size=IMAGE_SIZE, antialias=True)
def CartoonNetwork(x: torch.tensor):
    return torchvision.transforms.functional.erase(img = x, i = stx, j = sty, h = 10, w = 10, v = 0)
dataset_transform = torchvision.transforms.Compose([to_tensor, upscaling, CartoonNetwork])


In [None]:
import torch.utils.data
import torchvision.datasets

train = torchvision.datasets.MNIST(
    root = "./data",
    train = True,
    download = True,
    transform = dataset_transform
)

train_parts = torch.utils.data.random_split(train, lengths=[1 / PARTS - 1e-12]*PARTS)
print(len(train_parts))
print(len(train_parts[0]))
print(len(train_parts[0][0]))
train_parts[0][0][0].shape
from matplotlib import pyplot as plt
plt.imshow(train_parts[0][0][0][0], cmap = 'gray')

## Build the Model


In [None]:
import torch.nn


class SimpleCNN(torch.nn.Module):
    def __init__(self):
        super().__init__()

        #  1x40x40
        self.conv1 = torch.nn.Conv2d(
            in_channels=1, out_channels=32, kernel_size=3, padding='same')
        # 32x40x40
        self.maxpl = torch.nn.MaxPool2d(kernel_size=2)
        # 16x20x20
        self.conv2 = torch.nn.Conv2d(
            in_channels=32, out_channels=64, kernel_size=3, padding='same')
        # 64x20x20
        self.avgpl = torch.nn.AvgPool2d(kernel_size=4)
        # 64x5x5
        self.flatt = torch.nn.Flatten()
        # 1600
        self.line1 = torch.nn.Linear(in_features=1600, out_features=128)
        # 128
        self.activ = torch.nn.ReLU()
        # 128
        self.feats = torch.nn.Linear(in_features=128, out_features=10)
        # 10
        # self.softx = torch.nn.Softmax(dim = 1)

    def forward(self, x: torch.Tensor):
        x = self.conv1(x)
        x = self.activ(x)
        x = self.maxpl(x)
        x = self.conv2(x)
        x = self.activ(x)
        x = self.avgpl(x)
        x = self.flatt(x)
        x = self.line1(x)
        x = self.activ(x)
        x = self.feats(x)
        # x = self.softx(x)
        return x


model = SimpleCNN().to(DEVICE)
# model.load_state_dict(torch.load('./data/model_dict.pt'))

## Train the Model


In [None]:
import torch.nn

criterion = torch.nn.CrossEntropyLoss().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters())


In [None]:
from torch.utils.data import DataLoader, ConcatDataset

def generate_loaders(vpart: int) -> tuple[DataLoader, DataLoader]:
    train_dataset = ConcatDataset(train_parts[:vpart] + train_parts[vpart + 1:])
    train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle=True)

    validation_dataset = train_parts[vpart]
    validation_loader = DataLoader(validation_dataset, batch_size = BATCH_SIZE, shuffle=False)
    return train_loader, validation_loader


In [None]:
from tqdm.notebook import tqdm_notebook
import torch
import torch.utils.data

def train_epoch(loader: torch.utils.data.DataLoader, train: bool) -> tuple[int, float]:
    total_acc = 0
    total_loss = 0

    for x, y in tqdm_notebook(loader):
        x, y = x.to(DEVICE), y.to(DEVICE)

        output = model(x)
        batch_loss = criterion(output, y)
        total_loss += batch_loss
        
        if train:
            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()

        total_acc += (output.argmax(dim = 1) == y).sum()
    return total_acc.item()/len(loader.dataset), total_loss.item()/len(loader.dataset)


In [None]:
best_loss_val = 1e18
for epoch in range(PARTS * 121):
    if epoch % 20 == 0:
        best_loss_val = 1e18
    copium = epoch // 20
    stx = (copium // 11) * 3
    sty = (copium % 11) * 3
    print(stx, sty)
    train_loader, validation_loader = generate_loaders(epoch % PARTS)
    acc_train, loss_train = train_epoch(train_loader, True)
    with torch.no_grad():
        acc_val, loss_val = train_epoch(validation_loader, False)
    print(f"Epoch {epoch+1}: {acc_train = } {loss_train = } {acc_val = } {loss_val = }")
    if best_loss_val > loss_val:
        print(f"Saving model: {best_loss_val = } > {loss_val = }")
        best_loss_val = loss_val
        torch.save(model.state_dict(), './data/' + f'{copium}' 'model_dict.pt')


## Testing

In [None]:
import torch.utils.data
import torchvision.datasets

test = torchvision.datasets.MNIST(
    root = "./data",
    train = False,
    download = True,
    transform = dataset_transform
)


In [None]:
 
test_loader = DataLoader(test, batch_size = BATCH_SIZE, shuffle=False)
with torch.no_grad():
    acc_test, loss_test = train_epoch(test_loader, False)

print(f"{acc_test = }, {loss_test = }")


## Export


In [None]:
import torch

model.load_state_dict(torch.load('./data/model_dict.pt'))

random_input = torch.rand(1, 1, *IMAGE_SIZE).to(DEVICE)
script_module = torch.jit.trace(model, random_input)

script_module.save('./data/model_exported.pt')
