In [40]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import transforms, utils
import torchvision.datasets

from torch.utils.data import DataLoader,Subset

from tqdm import tqdm
import os

In [26]:
from models.mnist_classifier import MNISTClassifier

In [29]:
data =  torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
data_test =  torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

eval_idx = torch.randperm(len(data))[:5000]

data_eval = Subset(data, eval_idx)
train_idx = [i for i in range(len(data)) if i not in eval_idx]

data_train = Subset(data, train_idx)

print(len(data_eval), len(data_train))

train_loader = DataLoader(data_train, batch_size=256, shuffle=True)
test_loader = DataLoader(data_test, batch_size = 256, shuffle=True)
eval_loader = DataLoader(data_eval, batch_size=256, shuffle=True)

5000 55000


In [34]:
model = MNISTClassifier()
device = "cuda"
epochs = 15
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
model.to(device)
# loop over the dataset multiple times
for epoch in tqdm(range(epochs), desc="Train the classifier..."):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)


        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    #eval on the val set
    model.eval()
    with torch.no_grad():
        correct = 0
        # loop over the dataset multiple times
        for i, data in tqdm(enumerate(eval_loader, 0), desc="Eval the classifier"):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)



            # forward + backward + optimize
            outputs = model.get_probs(inputs) #returns softmaxed output

            #get the predicdtion 
            predictions = torch.argmax(outputs, dim=-1)
            correct += (predictions == labels).sum().item()

        print('Finished Evaluating')
        print(f"Got : {correct / len(data_eval)} correct predictions on Eval")

        print('Loss: {}'.format(running_loss))

    model.train()

print('Finished Training')

Eval the classifier: 20it [00:00, 26.72it/s]15 [00:00<?, ?it/s]
Train the classifier...:   7%|▋         | 1/15 [00:09<02:12,  9.47s/it]

Finished Evaluating
Got : 0.9642 correct predictions on Eval
Loss: 85.4907333701849


Eval the classifier: 20it [00:00, 26.99it/s]
Train the classifier...:  13%|█▎        | 2/15 [00:18<02:03,  9.47s/it]

Finished Evaluating
Got : 0.9772 correct predictions on Eval
Loss: 18.48391891643405


Eval the classifier: 20it [00:00, 27.12it/s]
Train the classifier...:  20%|██        | 3/15 [00:28<01:53,  9.47s/it]

Finished Evaluating
Got : 0.9832 correct predictions on Eval
Loss: 12.697551483288407


Eval the classifier: 20it [00:00, 26.87it/s]
Train the classifier...:  27%|██▋       | 4/15 [00:37<01:44,  9.50s/it]

Finished Evaluating
Got : 0.9842 correct predictions on Eval
Loss: 10.052294757217169


Eval the classifier: 20it [00:00, 26.50it/s]
Train the classifier...:  33%|███▎      | 5/15 [00:47<01:35,  9.51s/it]

Finished Evaluating
Got : 0.9836 correct predictions on Eval
Loss: 8.143664833158255


Eval the classifier: 20it [00:00, 26.76it/s]
Train the classifier...:  40%|████      | 6/15 [00:57<01:25,  9.52s/it]

Finished Evaluating
Got : 0.9854 correct predictions on Eval
Loss: 6.820592971984297


Eval the classifier: 20it [00:00, 27.17it/s]
Train the classifier...:  47%|████▋     | 7/15 [01:06<01:16,  9.52s/it]

Finished Evaluating
Got : 0.9876 correct predictions on Eval
Loss: 6.0679896841757


Eval the classifier: 20it [00:00, 27.04it/s]
Train the classifier...:  53%|█████▎    | 8/15 [01:16<01:06,  9.51s/it]

Finished Evaluating
Got : 0.986 correct predictions on Eval
Loss: 5.367546751629561


Eval the classifier: 20it [00:00, 26.95it/s]
Train the classifier...:  60%|██████    | 9/15 [01:25<00:57,  9.52s/it]

Finished Evaluating
Got : 0.988 correct predictions on Eval
Loss: 4.76151297101751


Eval the classifier: 20it [00:00, 26.83it/s]
Train the classifier...:  67%|██████▋   | 10/15 [01:35<00:47,  9.51s/it]

Finished Evaluating
Got : 0.9868 correct predictions on Eval
Loss: 4.078572123777121


Eval the classifier: 20it [00:00, 26.64it/s]
Train the classifier...:  73%|███████▎  | 11/15 [01:44<00:38,  9.55s/it]

Finished Evaluating
Got : 0.9888 correct predictions on Eval
Loss: 3.507115508429706


Eval the classifier: 20it [00:00, 26.66it/s]
Train the classifier...:  80%|████████  | 12/15 [01:54<00:28,  9.60s/it]

Finished Evaluating
Got : 0.9852 correct predictions on Eval
Loss: 3.149536758195609


Eval the classifier: 20it [00:00, 26.61it/s]
Train the classifier...:  87%|████████▋ | 13/15 [02:04<00:19,  9.63s/it]

Finished Evaluating
Got : 0.9894 correct predictions on Eval
Loss: 2.7378638543887064


Eval the classifier: 20it [00:00, 26.59it/s]
Train the classifier...:  93%|█████████▎| 14/15 [02:13<00:09,  9.65s/it]

Finished Evaluating
Got : 0.99 correct predictions on Eval
Loss: 2.590119845001027


Eval the classifier: 20it [00:00, 26.68it/s]
Train the classifier...: 100%|██████████| 15/15 [02:23<00:00,  9.57s/it]

Finished Evaluating
Got : 0.9892 correct predictions on Eval
Loss: 2.031988272909075
Finished Training





In [37]:

device = "cuda"
model.to(device)
correct = 0
# loop over the dataset multiple times
for i, data in tqdm(enumerate(test_loader, 0), desc="Test the classifier"):
    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device)


    # forward + backward + optimize
    outputs = model.get_probs(inputs) #returns softmaxed output

    #get the predicdtion 
    predictions = torch.argmax(outputs, dim=-1)
    correct += (predictions == labels).sum().item()

print('Finished Testing')
print(f"Got : {correct / len(data_test)} correct predictions on Test")

Test the classifier: 40it [00:01, 27.59it/s]

Finished Testing
Got : 0.9913 correct predictions on Test





In [42]:
#save the weights 
os.mkdir("evaluation")
torch.save(model.state_dict(), "evaluation/trained_classifier")