In [1]:
%reload_ext autoreload
%autoreload 2

In [5]:
# ## Train dataset on MNIST :(

from torch.utils.data import DataLoader

import torch.optim as optim
import torch.nn as nn

import torchvision.transforms as transforms

import torch

from tqdm import tqdm

In [10]:
from phd_school.dataset import PrintedMNIST, AddGaussianNoise, AddSPNoise  # noqa

In [11]:
from phd_school.models import get_model

In [14]:
def main(lr, batch_size, n_epochs, model):    

    train_transform = transforms.Compose([
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        # AddGaussianNoise(0, 1.0),
        AddSPNoise(0.1),
    ])

    val_transforms = transforms.Compose([transforms.ToTensor()])

    train_set = PrintedMNIST(50000, -666, train_transform)
    val_set = PrintedMNIST(5000, 33, val_transforms)

    train_loader = DataLoader(train_set, batch_size=batch_size)
    val_loader = DataLoader(val_set, batch_size=batch_size)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    print(f"Device: {device}")

    # Get network
    net = get_model(model)

    net = net.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=lr)

    counter = 0

    for epoch in range(n_epochs):

        print(f"Running epoch {epoch+1}")

        running_loss = 0.0

        total_correct = 0
        total_processed = 0

        for i, data in tqdm(enumerate(train_loader), total=len(train_loader)):

            net.train()

            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs.to(device))
            loss = criterion(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # Compute actual predictions
            sm = torch.nn.functional.softmax(outputs, dim=1)
            _, y_hat = torch.max(sm, 1)

            total_correct += (labels == y_hat.to("cpu")).sum().item()
            total_processed += len(labels)

            counter += 1

            # print statistics
            running_loss += loss.item()
            if i % 100 == 99:  # print every 100 mini-batches
                
                running_loss = 0.0

                # Check validation set
                net.eval()

                with torch.no_grad():

                    total_processed_val = 0
                    total_correct_val = 0

                    for data in val_loader:
                        # get the inputs; data is a list of [inputs, labels]
                        inputs, labels = data

                        # forward + backward + optimize
                        outputs = net(inputs.to(device))
                        loss = criterion(outputs, labels.to(device))

                        # Compute actual predictions
                        sm = torch.nn.functional.softmax(outputs, dim=1)
                        _, y_hat = torch.max(sm, 1)

                        total_processed_val += len(labels)
                        total_correct_val += (
                            (labels == y_hat.to("cpu")).sum().item()
                        )

                        running_loss += loss.item()

                    running_loss = 0.0

        # Save model at the end of every epoch
        torch.save(net.state_dict(), f"../models/{model}_mnist.pth")

    print("Finished Training")

In [17]:
main(1e-4, 256, 2, "resnet50")

Device: cpu
Running epoch 1



00%|███████████████████████████████████████████████████████████████████████████████| 196/196 [06:21<00:00,  1.95s/it]

Running epoch 2



28%|██████████████████████▍                                                         | 55/196 [01:48<04:38,  1.97s/it]

KeyboardInterrupt: 