In [4]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [5]:
from AIToolbox.torchtrain.train_loop import TrainLoop, TrainLoopModelCheckpointEndSave
from AIToolbox.torchtrain.model import ModelWrap
from AIToolbox.torchtrain.data.batch_model_feed_defs import AbstractModelFeedDefinition
from AIToolbox.torchtrain.callbacks.performance_eval_callbacks import ModelPerformanceEvaluation, ModelPerformancePrintReport
from AIToolbox.experiment.result_package.basic_packages import ClassificationResultPackage

## Define the model

In [6]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

## Define the batch feed def

In [24]:
class MNISTModelFeedDefinition(AbstractModelFeedDefinition):
    def get_loss(self, model, batch_data, criterion, device):
        data, target = batch_data
        data, target = data.to(device), target.to(device)

        output = model(data)
        loss = criterion(output, target)

        return loss
    
    def get_loss_eval(self, model, batch_data, criterion, device):
        return self.get_loss(model, batch_data, criterion, device)

    def get_predictions(self, model, batch_data, device):
        data, y_test = batch_data
        data = data.to(device)

        output = model(data)
        y_pred = output.argmax(dim=1, keepdim=False)  # get the index of the max log-probability

        return y_test, y_pred.cpu(), {}

In [10]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

In [12]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=100, shuffle=True, **kwargs)

In [13]:
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=100, shuffle=True, **kwargs)

In [15]:
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.5)
criterion = F.nll_loss

In [26]:
callbacks = [ModelPerformanceEvaluation(ClassificationResultPackage(), {'batch_size': 100, 'lr': 0.001},
                                        on_train_data=True, on_val_data=True),
             ModelPerformancePrintReport(['train_Accuracy', 'val_Accuracy'], strict_metric_reporting=True)]

In [27]:
TrainLoop(ModelWrap(model, MNISTModelFeedDefinition()),
          train_loader, test_loader, None,
          optimizer, criterion)(num_epoch=3, callbacks=callbacks)



  0%|          | 0/600 [00:00<?, ?it/s][A[A

  0%|          | 2/600 [00:00<00:41, 14.50it/s][A[A



Epoch: 0




  1%|          | 4/600 [00:00<00:40, 14.68it/s][A[A

  1%|          | 6/600 [00:00<00:39, 15.05it/s][A[A

  1%|▏         | 8/600 [00:00<00:38, 15.48it/s][A[A

  2%|▏         | 10/600 [00:00<00:37, 15.66it/s][A[A

  2%|▏         | 12/600 [00:00<00:38, 15.46it/s][A[A

  2%|▏         | 14/600 [00:00<00:37, 15.65it/s][A[A

  3%|▎         | 16/600 [00:01<00:37, 15.77it/s][A[A

  3%|▎         | 18/600 [00:01<00:36, 15.80it/s][A[A

  3%|▎         | 20/600 [00:01<00:37, 15.67it/s][A[A

  4%|▎         | 22/600 [00:01<00:36, 15.73it/s][A[A

  4%|▍         | 24/600 [00:01<00:36, 15.81it/s][A[A

  4%|▍         | 26/600 [00:01<00:36, 15.92it/s][A[A

  5%|▍         | 28/600 [00:01<00:35, 16.04it/s][A[A

  5%|▌         | 30/600 [00:01<00:36, 15.42it/s][A[A

  5%|▌         | 32/600 [00:02<00:36, 15.52it/s][A[A

  6%|▌         | 34/600 [00:02<00:36, 15.68it/s][A[A

  6%|▌         | 36/600 [00:02<00:35, 15.82it/s][A[A

  6%|▋         | 38/600 [00:02<00:36, 15.59it/s]

 49%|████▊     | 292/600 [00:19<00:20, 15.30it/s][A[A

 49%|████▉     | 294/600 [00:19<00:25, 12.07it/s][A[A

 49%|████▉     | 296/600 [00:19<00:26, 11.43it/s][A[A

 50%|████▉     | 298/600 [00:19<00:25, 11.88it/s][A[A

 50%|█████     | 300/600 [00:20<00:28, 10.43it/s][A[A

 50%|█████     | 302/600 [00:20<00:27, 10.87it/s][A[A

 51%|█████     | 304/600 [00:20<00:26, 11.06it/s][A[A

 51%|█████     | 306/600 [00:20<00:24, 11.80it/s][A[A

 51%|█████▏    | 308/600 [00:20<00:22, 12.94it/s][A[A

 52%|█████▏    | 310/600 [00:20<00:21, 13.56it/s][A[A

 52%|█████▏    | 312/600 [00:20<00:20, 14.05it/s][A[A

 52%|█████▏    | 314/600 [00:21<00:19, 14.74it/s][A[A

 53%|█████▎    | 316/600 [00:21<00:18, 15.52it/s][A[A

 53%|█████▎    | 318/600 [00:21<00:17, 15.72it/s][A[A

 53%|█████▎    | 320/600 [00:21<00:17, 15.90it/s][A[A

 54%|█████▎    | 322/600 [00:21<00:17, 16.23it/s][A[A

 54%|█████▍    | 324/600 [00:21<00:16, 16.40it/s][A[A

 54%|█████▍    | 326/600 [00:21

 96%|█████████▋| 578/600 [00:36<00:01, 16.85it/s][A[A

 97%|█████████▋| 580/600 [00:37<00:01, 16.85it/s][A[A

 97%|█████████▋| 582/600 [00:37<00:01, 16.70it/s][A[A

 97%|█████████▋| 584/600 [00:37<00:00, 16.87it/s][A[A

 98%|█████████▊| 586/600 [00:37<00:00, 16.91it/s][A[A

 98%|█████████▊| 588/600 [00:37<00:00, 16.89it/s][A[A

 98%|█████████▊| 590/600 [00:37<00:00, 17.00it/s][A[A

 99%|█████████▊| 592/600 [00:37<00:00, 16.93it/s][A[A

 99%|█████████▉| 594/600 [00:37<00:00, 17.02it/s][A[A

 99%|█████████▉| 596/600 [00:37<00:00, 17.07it/s][A[A

100%|█████████▉| 598/600 [00:38<00:00, 17.03it/s][A[A

100%|██████████| 600/600 [00:38<00:00, 16.96it/s][A[A

  0%|          | 0/600 [00:00<?, ?it/s][A[A

  1%|          | 4/600 [00:00<00:16, 35.57it/s][A[A

AVG BATCH ACCUMULATED TRAIN LOSS: 0.2084879569336772




  1%|▏         | 8/600 [00:00<00:16, 34.84it/s][A[A

  2%|▏         | 12/600 [00:00<00:16, 35.47it/s][A[A

  3%|▎         | 16/600 [00:00<00:16, 36.03it/s][A[A

  3%|▎         | 20/600 [00:00<00:15, 36.58it/s][A[A

  4%|▍         | 24/600 [00:00<00:15, 36.74it/s][A[A

  5%|▍         | 28/600 [00:00<00:15, 37.24it/s][A[A

  5%|▌         | 32/600 [00:00<00:15, 37.53it/s][A[A

  6%|▌         | 36/600 [00:00<00:14, 38.12it/s][A[A

  7%|▋         | 40/600 [00:01<00:14, 38.64it/s][A[A

  7%|▋         | 44/600 [00:01<00:15, 36.95it/s][A[A

  8%|▊         | 48/600 [00:01<00:14, 36.91it/s][A[A

  9%|▊         | 52/600 [00:01<00:14, 37.02it/s][A[A

 10%|▉         | 57/600 [00:01<00:14, 38.45it/s][A[A

 10%|█         | 62/600 [00:01<00:13, 39.24it/s][A[A

 11%|█         | 66/600 [00:01<00:13, 39.24it/s][A[A

 12%|█▏        | 71/600 [00:01<00:13, 39.74it/s][A[A

 13%|█▎        | 76/600 [00:01<00:12, 40.99it/s][A[A

 14%|█▎        | 81/600 [00:02<00:12, 41.71it/s

TRAIN LOSS: 0.19601456328605613




  9%|▉         | 9/100 [00:00<00:02, 39.64it/s][A[A

 13%|█▎        | 13/100 [00:00<00:02, 38.66it/s][A[A

 17%|█▋        | 17/100 [00:00<00:02, 38.52it/s][A[A

 21%|██        | 21/100 [00:00<00:02, 38.14it/s][A[A

 25%|██▌       | 25/100 [00:00<00:01, 37.61it/s][A[A

 29%|██▉       | 29/100 [00:00<00:01, 37.43it/s][A[A

 33%|███▎      | 33/100 [00:00<00:01, 37.32it/s][A[A

 37%|███▋      | 37/100 [00:00<00:01, 37.38it/s][A[A

 41%|████      | 41/100 [00:01<00:01, 37.21it/s][A[A

 46%|████▌     | 46/100 [00:01<00:01, 38.50it/s][A[A

 51%|█████     | 51/100 [00:01<00:01, 39.21it/s][A[A

 55%|█████▌    | 55/100 [00:01<00:01, 39.39it/s][A[A

 59%|█████▉    | 59/100 [00:01<00:01, 39.48it/s][A[A

 64%|██████▍   | 64/100 [00:01<00:00, 39.54it/s][A[A

 68%|██████▊   | 68/100 [00:01<00:00, 39.04it/s][A[A

 72%|███████▏  | 72/100 [00:01<00:00, 38.37it/s][A[A

 76%|███████▌  | 76/100 [00:01<00:00, 38.81it/s][A[A

 80%|████████  | 80/100 [00:02<00:00, 36.94it/s

VAL LOSS: 0.18336668550968171




  1%|▏         | 8/600 [00:00<00:15, 38.58it/s][A[A

  2%|▏         | 12/600 [00:00<00:15, 38.34it/s][A[A

  3%|▎         | 16/600 [00:00<00:15, 38.27it/s][A[A

  3%|▎         | 20/600 [00:00<00:14, 38.67it/s][A[A

  4%|▍         | 25/600 [00:00<00:14, 39.46it/s][A[A

  5%|▌         | 30/600 [00:00<00:14, 39.91it/s][A[A

  6%|▌         | 35/600 [00:00<00:13, 40.49it/s][A[A

  6%|▋         | 39/600 [00:00<00:14, 39.80it/s][A[A

  7%|▋         | 44/600 [00:01<00:13, 41.03it/s][A[A

  8%|▊         | 49/600 [00:01<00:13, 40.78it/s][A[A

  9%|▉         | 54/600 [00:01<00:13, 41.18it/s][A[A

 10%|▉         | 59/600 [00:01<00:13, 40.58it/s][A[A

 11%|█         | 64/600 [00:01<00:13, 40.76it/s][A[A

 12%|█▏        | 69/600 [00:01<00:12, 41.46it/s][A[A

 12%|█▏        | 74/600 [00:01<00:12, 41.72it/s][A[A

 13%|█▎        | 79/600 [00:01<00:12, 42.16it/s][A[A

 14%|█▍        | 84/600 [00:02<00:12, 41.55it/s][A[A

 15%|█▍        | 89/600 [00:02<00:12, 41.92it/s

100%|██████████| 600/600 [00:16<00:00, 36.58it/s][A[A

  0%|          | 0/100 [00:00<?, ?it/s][A[A

  4%|▍         | 4/100 [00:00<00:02, 33.69it/s][A[A

  8%|▊         | 8/100 [00:00<00:02, 32.97it/s][A[A

 11%|█         | 11/100 [00:00<00:02, 31.89it/s][A[A

 15%|█▌        | 15/100 [00:00<00:02, 32.56it/s][A[A

 19%|█▉        | 19/100 [00:00<00:02, 32.95it/s][A[A

 23%|██▎       | 23/100 [00:00<00:02, 33.76it/s][A[A

 27%|██▋       | 27/100 [00:00<00:02, 34.41it/s][A[A

 31%|███       | 31/100 [00:00<00:01, 34.95it/s][A[A

 35%|███▌      | 35/100 [00:01<00:01, 35.77it/s][A[A

 39%|███▉      | 39/100 [00:01<00:01, 36.37it/s][A[A

 43%|████▎     | 43/100 [00:01<00:01, 36.40it/s][A[A

 47%|████▋     | 47/100 [00:01<00:01, 36.70it/s][A[A

 51%|█████     | 51/100 [00:01<00:01, 36.81it/s][A[A

 55%|█████▌    | 55/100 [00:01<00:01, 37.22it/s][A[A

 59%|█████▉    | 59/100 [00:01<00:01, 37.30it/s][A[A

 63%|██████▎   | 63/100 [00:01<00:00, 37.39it/s][A[A

 6

------------------  End of epoch performance report  -------------------
train_Accuracy: 0.9429333333333333
val_Accuracy: 0.9482


Epoch: 1




  1%|          | 4/600 [00:00<00:40, 14.89it/s][A[A

  1%|          | 6/600 [00:00<00:39, 14.99it/s][A[A

  1%|▏         | 8/600 [00:00<00:39, 15.16it/s][A[A

  2%|▏         | 10/600 [00:00<00:38, 15.27it/s][A[A

  2%|▏         | 12/600 [00:00<00:38, 15.40it/s][A[A

  2%|▏         | 14/600 [00:00<00:38, 15.31it/s][A[A

  3%|▎         | 16/600 [00:01<00:38, 15.14it/s][A[A

  3%|▎         | 18/600 [00:01<00:38, 15.13it/s][A[A

  3%|▎         | 20/600 [00:01<00:38, 15.17it/s][A[A

  4%|▎         | 22/600 [00:01<00:37, 15.24it/s][A[A

  4%|▍         | 24/600 [00:01<00:37, 15.38it/s][A[A

  4%|▍         | 26/600 [00:01<00:37, 15.30it/s][A[A

  5%|▍         | 28/600 [00:01<00:37, 15.33it/s][A[A

  5%|▌         | 30/600 [00:01<00:37, 15.30it/s][A[A

  5%|▌         | 32/600 [00:02<00:36, 15.46it/s][A[A

  6%|▌         | 34/600 [00:02<00:36, 15.52it/s][A[A

  6%|▌         | 36/600 [00:02<00:36, 15.53it/s][A[A

  6%|▋         | 38/600 [00:02<00:35, 15.62it/s]

 49%|████▊     | 292/600 [00:19<00:20, 14.70it/s][A[A

 49%|████▉     | 294/600 [00:19<00:20, 14.70it/s][A[A

 49%|████▉     | 296/600 [00:19<00:20, 14.60it/s][A[A

 50%|████▉     | 298/600 [00:19<00:20, 14.70it/s][A[A

 50%|█████     | 300/600 [00:19<00:20, 14.84it/s][A[A

 50%|█████     | 302/600 [00:19<00:19, 14.95it/s][A[A

 51%|█████     | 304/600 [00:20<00:19, 14.98it/s][A[A

 51%|█████     | 306/600 [00:20<00:19, 14.94it/s][A[A

 51%|█████▏    | 308/600 [00:20<00:19, 14.93it/s][A[A

 52%|█████▏    | 310/600 [00:20<00:19, 14.77it/s][A[A

 52%|█████▏    | 312/600 [00:20<00:19, 14.94it/s][A[A

 52%|█████▏    | 314/600 [00:20<00:18, 15.13it/s][A[A

 53%|█████▎    | 316/600 [00:20<00:18, 15.15it/s][A[A

 53%|█████▎    | 318/600 [00:21<00:18, 15.04it/s][A[A

 53%|█████▎    | 320/600 [00:21<00:18, 14.79it/s][A[A

 54%|█████▎    | 322/600 [00:21<00:18, 14.82it/s][A[A

 54%|█████▍    | 324/600 [00:21<00:18, 14.99it/s][A[A

 54%|█████▍    | 326/600 [00:21

 96%|█████████▋| 578/600 [00:38<00:01, 15.18it/s][A[A

 97%|█████████▋| 580/600 [00:38<00:01, 15.14it/s][A[A

 97%|█████████▋| 582/600 [00:38<00:01, 15.23it/s][A[A

 97%|█████████▋| 584/600 [00:38<00:01, 15.41it/s][A[A

 98%|█████████▊| 586/600 [00:38<00:00, 15.45it/s][A[A

 98%|█████████▊| 588/600 [00:38<00:00, 15.41it/s][A[A

 98%|█████████▊| 590/600 [00:38<00:00, 15.52it/s][A[A

 99%|█████████▊| 592/600 [00:38<00:00, 15.65it/s][A[A

 99%|█████████▉| 594/600 [00:39<00:00, 15.67it/s][A[A

 99%|█████████▉| 596/600 [00:39<00:00, 15.66it/s][A[A

100%|█████████▉| 598/600 [00:39<00:00, 15.63it/s][A[A

100%|██████████| 600/600 [00:39<00:00, 15.76it/s][A[A

  0%|          | 0/600 [00:00<?, ?it/s][A[A

  1%|          | 4/600 [00:00<00:17, 34.22it/s][A[A

AVG BATCH ACCUMULATED TRAIN LOSS: 0.1857283940538764




  1%|▏         | 8/600 [00:00<00:17, 34.51it/s][A[A

  2%|▏         | 12/600 [00:00<00:16, 35.75it/s][A[A

  3%|▎         | 16/600 [00:00<00:15, 36.52it/s][A[A

  3%|▎         | 20/600 [00:00<00:15, 37.28it/s][A[A

  4%|▍         | 24/600 [00:00<00:15, 37.82it/s][A[A

  5%|▍         | 28/600 [00:00<00:15, 38.08it/s][A[A

  5%|▌         | 32/600 [00:00<00:15, 37.69it/s][A[A

  6%|▌         | 36/600 [00:00<00:15, 37.29it/s][A[A

  7%|▋         | 40/600 [00:01<00:15, 37.19it/s][A[A

  7%|▋         | 44/600 [00:01<00:14, 37.11it/s][A[A

  8%|▊         | 48/600 [00:01<00:14, 37.09it/s][A[A

  9%|▊         | 52/600 [00:01<00:14, 37.24it/s][A[A

  9%|▉         | 56/600 [00:01<00:14, 36.95it/s][A[A

 10%|█         | 60/600 [00:01<00:14, 36.28it/s][A[A

 11%|█         | 64/600 [00:01<00:14, 36.37it/s][A[A

 11%|█▏        | 68/600 [00:01<00:14, 36.09it/s][A[A

 12%|█▏        | 72/600 [00:01<00:14, 36.09it/s][A[A

 13%|█▎        | 76/600 [00:02<00:14, 36.52it/s

 97%|█████████▋| 584/600 [00:15<00:00, 37.94it/s][A[A

 98%|█████████▊| 588/600 [00:15<00:00, 37.48it/s][A[A

 99%|█████████▊| 592/600 [00:15<00:00, 37.33it/s][A[A

 99%|█████████▉| 596/600 [00:15<00:00, 37.34it/s][A[A

100%|██████████| 600/600 [00:16<00:00, 37.06it/s][A[A

  0%|          | 0/100 [00:00<?, ?it/s][A[A

  4%|▍         | 4/100 [00:00<00:02, 37.20it/s][A[A

TRAIN LOSS: 0.1803850051512321




  8%|▊         | 8/100 [00:00<00:02, 37.20it/s][A[A

 12%|█▏        | 12/100 [00:00<00:02, 37.08it/s][A[A

 16%|█▌        | 16/100 [00:00<00:02, 37.26it/s][A[A

 20%|██        | 20/100 [00:00<00:02, 37.18it/s][A[A

 24%|██▍       | 24/100 [00:00<00:02, 37.49it/s][A[A

 28%|██▊       | 28/100 [00:00<00:01, 37.84it/s][A[A

 32%|███▏      | 32/100 [00:00<00:01, 36.77it/s][A[A

 36%|███▌      | 36/100 [00:00<00:01, 36.98it/s][A[A

 40%|████      | 40/100 [00:01<00:01, 37.44it/s][A[A

 44%|████▍     | 44/100 [00:01<00:01, 37.26it/s][A[A

 48%|████▊     | 48/100 [00:01<00:01, 37.21it/s][A[A

 52%|█████▏    | 52/100 [00:01<00:01, 37.75it/s][A[A

 56%|█████▌    | 56/100 [00:01<00:01, 37.34it/s][A[A

 60%|██████    | 60/100 [00:01<00:01, 36.36it/s][A[A

 64%|██████▍   | 64/100 [00:01<00:01, 35.91it/s][A[A

 68%|██████▊   | 68/100 [00:01<00:00, 35.71it/s][A[A

 72%|███████▏  | 72/100 [00:01<00:00, 36.12it/s][A[A

 76%|███████▌  | 76/100 [00:02<00:00, 36.36it/s

VAL LOSS: 0.16884166572242976




  1%|▏         | 8/600 [00:00<00:16, 35.33it/s][A[A

  2%|▏         | 12/600 [00:00<00:16, 35.45it/s][A[A

  3%|▎         | 16/600 [00:00<00:16, 35.84it/s][A[A

  3%|▎         | 20/600 [00:00<00:15, 36.31it/s][A[A

  4%|▍         | 24/600 [00:00<00:15, 36.60it/s][A[A

  5%|▍         | 28/600 [00:00<00:15, 36.57it/s][A[A

  5%|▌         | 32/600 [00:00<00:15, 36.53it/s][A[A

  6%|▌         | 36/600 [00:00<00:15, 36.42it/s][A[A

  7%|▋         | 40/600 [00:01<00:15, 36.44it/s][A[A

  7%|▋         | 44/600 [00:01<00:15, 36.50it/s][A[A

  8%|▊         | 48/600 [00:01<00:15, 35.28it/s][A[A

  9%|▊         | 52/600 [00:01<00:15, 35.66it/s][A[A

  9%|▉         | 56/600 [00:01<00:14, 36.30it/s][A[A

 10%|█         | 60/600 [00:01<00:14, 36.67it/s][A[A

 11%|█         | 64/600 [00:01<00:14, 36.18it/s][A[A

 11%|█▏        | 68/600 [00:01<00:15, 35.39it/s][A[A

 12%|█▏        | 72/600 [00:01<00:14, 35.65it/s][A[A

 13%|█▎        | 76/600 [00:02<00:14, 35.66it/s

 97%|█████████▋| 584/600 [00:16<00:00, 35.27it/s][A[A

 98%|█████████▊| 588/600 [00:16<00:00, 35.49it/s][A[A

 99%|█████████▊| 592/600 [00:16<00:00, 35.89it/s][A[A

 99%|█████████▉| 596/600 [00:16<00:00, 36.27it/s][A[A

100%|██████████| 600/600 [00:16<00:00, 36.52it/s][A[A

  0%|          | 0/100 [00:00<?, ?it/s][A[A

  4%|▍         | 4/100 [00:00<00:02, 35.70it/s][A[A

Auto purging prediction store




  8%|▊         | 8/100 [00:00<00:02, 36.28it/s][A[A

 12%|█▏        | 12/100 [00:00<00:02, 35.90it/s][A[A

 16%|█▌        | 16/100 [00:00<00:02, 34.41it/s][A[A

 20%|██        | 20/100 [00:00<00:02, 33.53it/s][A[A

 24%|██▍       | 24/100 [00:00<00:02, 33.88it/s][A[A

 28%|██▊       | 28/100 [00:00<00:02, 34.32it/s][A[A

 32%|███▏      | 32/100 [00:00<00:01, 34.66it/s][A[A

 36%|███▌      | 36/100 [00:01<00:01, 35.26it/s][A[A

 40%|████      | 40/100 [00:01<00:01, 35.37it/s][A[A

 44%|████▍     | 44/100 [00:01<00:01, 36.02it/s][A[A

 48%|████▊     | 48/100 [00:01<00:01, 36.20it/s][A[A

 52%|█████▏    | 52/100 [00:01<00:01, 36.08it/s][A[A

 56%|█████▌    | 56/100 [00:01<00:01, 35.16it/s][A[A

 60%|██████    | 60/100 [00:01<00:01, 35.00it/s][A[A

 64%|██████▍   | 64/100 [00:01<00:01, 32.48it/s][A[A

 68%|██████▊   | 68/100 [00:01<00:00, 33.12it/s][A[A

 72%|███████▏  | 72/100 [00:02<00:00, 34.23it/s][A[A

 76%|███████▌  | 76/100 [00:02<00:00, 33.62it/s

------------------  End of epoch performance report  -------------------
train_Accuracy: 0.9482666666666667
val_Accuracy: 0.9532


Epoch: 2




  1%|          | 4/600 [00:00<00:40, 14.57it/s][A[A

  1%|          | 6/600 [00:00<00:40, 14.56it/s][A[A

  1%|▏         | 8/600 [00:00<00:41, 14.11it/s][A[A

  2%|▏         | 10/600 [00:00<00:42, 13.73it/s][A[A

  2%|▏         | 12/600 [00:00<00:44, 13.10it/s][A[A

  2%|▏         | 14/600 [00:01<00:43, 13.42it/s][A[A

  3%|▎         | 16/600 [00:01<00:44, 13.25it/s][A[A

  3%|▎         | 18/600 [00:01<00:43, 13.29it/s][A[A

  3%|▎         | 20/600 [00:01<00:42, 13.69it/s][A[A

  4%|▎         | 22/600 [00:01<00:42, 13.69it/s][A[A

  4%|▍         | 24/600 [00:01<00:41, 13.81it/s][A[A

  4%|▍         | 26/600 [00:01<00:43, 13.22it/s][A[A

  5%|▍         | 28/600 [00:02<00:43, 13.07it/s][A[A

  5%|▌         | 30/600 [00:02<00:42, 13.43it/s][A[A

  5%|▌         | 32/600 [00:02<00:41, 13.74it/s][A[A

  6%|▌         | 34/600 [00:02<00:40, 13.96it/s][A[A

  6%|▌         | 36/600 [00:02<00:39, 14.11it/s][A[A

  6%|▋         | 38/600 [00:02<00:40, 14.05it/s]

 49%|████▊     | 292/600 [00:20<00:20, 14.81it/s][A[A

 49%|████▉     | 294/600 [00:20<00:20, 14.92it/s][A[A

 49%|████▉     | 296/600 [00:20<00:20, 14.95it/s][A[A

 50%|████▉     | 298/600 [00:20<00:20, 14.91it/s][A[A

 50%|█████     | 300/600 [00:21<00:20, 14.84it/s][A[A

 50%|█████     | 302/600 [00:21<00:20, 14.80it/s][A[A

 51%|█████     | 304/600 [00:21<00:20, 14.70it/s][A[A

 51%|█████     | 306/600 [00:21<00:19, 14.85it/s][A[A

 51%|█████▏    | 308/600 [00:21<00:19, 14.84it/s][A[A

 52%|█████▏    | 310/600 [00:21<00:19, 14.65it/s][A[A

 52%|█████▏    | 312/600 [00:21<00:19, 14.68it/s][A[A

 52%|█████▏    | 314/600 [00:22<00:19, 14.86it/s][A[A

 53%|█████▎    | 316/600 [00:22<00:19, 14.89it/s][A[A

 53%|█████▎    | 318/600 [00:22<00:18, 15.03it/s][A[A

 53%|█████▎    | 320/600 [00:22<00:18, 15.15it/s][A[A

 54%|█████▎    | 322/600 [00:22<00:18, 15.15it/s][A[A

 54%|█████▍    | 324/600 [00:22<00:18, 15.14it/s][A[A

 54%|█████▍    | 326/600 [00:22

 96%|█████████▋| 578/600 [00:41<00:01, 15.58it/s][A[A

 97%|█████████▋| 580/600 [00:41<00:01, 15.23it/s][A[A

 97%|█████████▋| 582/600 [00:41<00:01, 15.27it/s][A[A

 97%|█████████▋| 584/600 [00:41<00:01, 15.12it/s][A[A

 98%|█████████▊| 586/600 [00:41<00:00, 15.04it/s][A[A

 98%|█████████▊| 588/600 [00:41<00:00, 15.13it/s][A[A

 98%|█████████▊| 590/600 [00:41<00:00, 14.89it/s][A[A

 99%|█████████▊| 592/600 [00:41<00:00, 15.03it/s][A[A

 99%|█████████▉| 594/600 [00:42<00:00, 15.03it/s][A[A

 99%|█████████▉| 596/600 [00:42<00:00, 14.67it/s][A[A

100%|█████████▉| 598/600 [00:42<00:00, 13.23it/s][A[A

100%|██████████| 600/600 [00:42<00:00, 12.04it/s][A[A

  0%|          | 0/600 [00:00<?, ?it/s][A[A

  0%|          | 2/600 [00:00<00:30, 19.40it/s][A[A

AVG BATCH ACCUMULATED TRAIN LOSS: 0.16800105535735688




  1%|          | 5/600 [00:00<00:27, 21.42it/s][A[A

  1%|▏         | 8/600 [00:00<00:26, 22.25it/s][A[A

  2%|▏         | 11/600 [00:00<00:26, 22.44it/s][A[A

  2%|▏         | 14/600 [00:00<00:24, 23.64it/s][A[A

  3%|▎         | 17/600 [00:00<00:24, 23.84it/s][A[A

  3%|▎         | 20/600 [00:00<00:25, 22.79it/s][A[A

  4%|▍         | 23/600 [00:00<00:25, 22.44it/s][A[A

  4%|▍         | 26/600 [00:01<00:25, 22.74it/s][A[A

  5%|▍         | 29/600 [00:01<00:23, 23.98it/s][A[A

  5%|▌         | 32/600 [00:01<00:22, 24.81it/s][A[A

  6%|▌         | 35/600 [00:01<00:22, 25.65it/s][A[A

  6%|▋         | 38/600 [00:01<00:21, 26.45it/s][A[A

  7%|▋         | 41/600 [00:01<00:20, 27.12it/s][A[A

  7%|▋         | 44/600 [00:01<00:20, 27.04it/s][A[A

  8%|▊         | 48/600 [00:01<00:19, 27.98it/s][A[A

  8%|▊         | 51/600 [00:01<00:19, 28.38it/s][A[A

  9%|▉         | 55/600 [00:02<00:19, 28.59it/s][A[A

 10%|▉         | 58/600 [00:02<00:19, 28.40it/s]

 93%|█████████▎| 558/600 [00:17<00:01, 33.49it/s][A[A

 94%|█████████▎| 562/600 [00:17<00:01, 31.57it/s][A[A

 94%|█████████▍| 566/600 [00:17<00:01, 31.28it/s][A[A

 95%|█████████▌| 570/600 [00:17<00:00, 31.23it/s][A[A

 96%|█████████▌| 574/600 [00:18<00:00, 32.17it/s][A[A

 96%|█████████▋| 578/600 [00:18<00:00, 32.65it/s][A[A

 97%|█████████▋| 582/600 [00:18<00:00, 32.40it/s][A[A

 98%|█████████▊| 586/600 [00:18<00:00, 32.55it/s][A[A

 98%|█████████▊| 590/600 [00:18<00:00, 32.46it/s][A[A

 99%|█████████▉| 594/600 [00:18<00:00, 31.52it/s][A[A

100%|█████████▉| 598/600 [00:18<00:00, 30.92it/s][A[A

100%|██████████| 600/600 [00:18<00:00, 31.77it/s][A[A

  0%|          | 0/100 [00:00<?, ?it/s][A[A

  4%|▍         | 4/100 [00:00<00:02, 33.24it/s][A[A

TRAIN LOSS: 0.16086503631745774




  8%|▊         | 8/100 [00:00<00:02, 34.24it/s][A[A

 12%|█▏        | 12/100 [00:00<00:02, 34.98it/s][A[A

 16%|█▌        | 16/100 [00:00<00:02, 35.46it/s][A[A

 20%|██        | 20/100 [00:00<00:02, 35.84it/s][A[A

 24%|██▍       | 24/100 [00:00<00:02, 36.03it/s][A[A

 28%|██▊       | 28/100 [00:00<00:02, 35.72it/s][A[A

 32%|███▏      | 32/100 [00:00<00:01, 35.65it/s][A[A

 36%|███▌      | 36/100 [00:01<00:01, 33.33it/s][A[A

 40%|████      | 40/100 [00:01<00:02, 29.70it/s][A[A

 43%|████▎     | 43/100 [00:01<00:01, 29.65it/s][A[A

 46%|████▌     | 46/100 [00:01<00:03, 16.15it/s][A[A

 49%|████▉     | 49/100 [00:01<00:02, 17.24it/s][A[A

 52%|█████▏    | 52/100 [00:01<00:02, 18.76it/s][A[A

 55%|█████▌    | 55/100 [00:02<00:02, 20.19it/s][A[A

 59%|█████▉    | 59/100 [00:02<00:01, 22.59it/s][A[A

 63%|██████▎   | 63/100 [00:02<00:01, 24.96it/s][A[A

 67%|██████▋   | 67/100 [00:02<00:01, 26.62it/s][A[A

 71%|███████   | 71/100 [00:02<00:01, 27.95it/s

VAL LOSS: 0.1496959038823843




  1%|▏         | 8/600 [00:00<00:17, 33.77it/s][A[A

  2%|▏         | 12/600 [00:00<00:16, 34.81it/s][A[A

  3%|▎         | 16/600 [00:00<00:16, 35.30it/s][A[A

  3%|▎         | 20/600 [00:00<00:16, 35.79it/s][A[A

  4%|▍         | 24/600 [00:00<00:15, 36.52it/s][A[A

  5%|▍         | 28/600 [00:00<00:15, 36.59it/s][A[A

  5%|▌         | 32/600 [00:00<00:15, 36.40it/s][A[A

  6%|▌         | 36/600 [00:00<00:15, 36.55it/s][A[A

  7%|▋         | 40/600 [00:01<00:15, 35.72it/s][A[A

  7%|▋         | 44/600 [00:01<00:15, 35.49it/s][A[A

  8%|▊         | 48/600 [00:01<00:15, 35.76it/s][A[A

  9%|▊         | 52/600 [00:01<00:15, 34.36it/s][A[A

  9%|▉         | 56/600 [00:01<00:15, 34.56it/s][A[A

 10%|█         | 60/600 [00:01<00:15, 35.34it/s][A[A

 11%|█         | 64/600 [00:01<00:15, 35.41it/s][A[A

 11%|█▏        | 68/600 [00:01<00:15, 35.35it/s][A[A

 12%|█▏        | 72/600 [00:02<00:14, 35.70it/s][A[A

 13%|█▎        | 76/600 [00:02<00:14, 35.87it/s

 96%|█████████▌| 576/600 [00:17<00:00, 34.36it/s][A[A

 97%|█████████▋| 580/600 [00:17<00:00, 34.47it/s][A[A

 97%|█████████▋| 584/600 [00:17<00:00, 34.71it/s][A[A

 98%|█████████▊| 588/600 [00:17<00:00, 34.19it/s][A[A

 99%|█████████▊| 592/600 [00:17<00:00, 34.62it/s][A[A

 99%|█████████▉| 596/600 [00:17<00:00, 34.86it/s][A[A

100%|██████████| 600/600 [00:17<00:00, 34.87it/s][A[A

  0%|          | 0/100 [00:00<?, ?it/s][A[A

  4%|▍         | 4/100 [00:00<00:02, 34.30it/s][A[A

Auto purging prediction store




  8%|▊         | 8/100 [00:00<00:02, 34.41it/s][A[A

 12%|█▏        | 12/100 [00:00<00:02, 34.35it/s][A[A

 16%|█▌        | 16/100 [00:00<00:02, 34.28it/s][A[A

 20%|██        | 20/100 [00:00<00:02, 34.13it/s][A[A

 24%|██▍       | 24/100 [00:00<00:02, 33.61it/s][A[A

 28%|██▊       | 28/100 [00:00<00:02, 32.84it/s][A[A

 32%|███▏      | 32/100 [00:00<00:02, 32.59it/s][A[A

 36%|███▌      | 36/100 [00:01<00:01, 33.48it/s][A[A

 40%|████      | 40/100 [00:01<00:01, 33.50it/s][A[A

 44%|████▍     | 44/100 [00:01<00:01, 33.41it/s][A[A

 48%|████▊     | 48/100 [00:01<00:01, 33.80it/s][A[A

 52%|█████▏    | 52/100 [00:01<00:01, 34.40it/s][A[A

 56%|█████▌    | 56/100 [00:01<00:01, 34.91it/s][A[A

 60%|██████    | 60/100 [00:01<00:01, 35.09it/s][A[A

 64%|██████▍   | 64/100 [00:01<00:01, 35.34it/s][A[A

 68%|██████▊   | 68/100 [00:01<00:00, 34.57it/s][A[A

 72%|███████▏  | 72/100 [00:02<00:00, 34.33it/s][A[A

 76%|███████▌  | 76/100 [00:02<00:00, 34.41it/s

------------------  End of epoch performance report  -------------------
train_Accuracy: 0.9535
val_Accuracy: 0.9576
Getting train set predictions from store
Getting validation set predictions from store
-----------------  End of training performance report  -----------------
train_end_train_Accuracy: 0.9535
train_end_val_Accuracy: 0.9576


Net(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)

## Experiment tracking training

In [28]:
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.5)
criterion = F.nll_loss

In [29]:
callbacks = [ModelPerformanceEvaluation(ClassificationResultPackage(), {'batch_size': 100, 'lr': 0.001},
                                        on_train_data=True, on_val_data=True),
             ModelPerformancePrintReport(['train_Accuracy', 'val_Accuracy'], strict_metric_reporting=True)]

In [31]:
TrainLoopModelCheckpointEndSave(ModelWrap(model, MNISTModelFeedDefinition()),
                                train_loader, test_loader, test_loader,
                                optimizer, criterion,
                                project_name='localRunCNNTest',
                                experiment_name='CNN_MNIST_test',
                                local_model_result_folder_path='model_results',
                                hyperparams={'batch_size': 100, 'lr': 0.001},
                                test_result_package=ClassificationResultPackage())(num_epoch=3, callbacks=callbacks)



  0%|          | 0/600 [00:00<?, ?it/s][A[A

  0%|          | 2/600 [00:00<00:41, 14.25it/s][A[A



Epoch: 0




  1%|          | 4/600 [00:00<00:40, 14.63it/s][A[A

  1%|          | 6/600 [00:00<00:39, 15.15it/s][A[A

  1%|▏         | 8/600 [00:00<00:38, 15.52it/s][A[A

  2%|▏         | 10/600 [00:00<00:37, 15.73it/s][A[A

  2%|▏         | 12/600 [00:00<00:36, 15.92it/s][A[A

  2%|▏         | 14/600 [00:00<00:36, 15.88it/s][A[A

  3%|▎         | 16/600 [00:01<00:36, 15.93it/s][A[A

  3%|▎         | 18/600 [00:01<00:36, 15.79it/s][A[A

  3%|▎         | 20/600 [00:01<00:37, 15.62it/s][A[A

  4%|▎         | 22/600 [00:01<00:36, 15.75it/s][A[A

  4%|▍         | 24/600 [00:01<00:36, 15.70it/s][A[A

  4%|▍         | 26/600 [00:01<00:38, 14.86it/s][A[A

  5%|▍         | 28/600 [00:01<00:38, 15.03it/s][A[A

  5%|▌         | 30/600 [00:01<00:37, 15.19it/s][A[A

  5%|▌         | 32/600 [00:02<00:38, 14.92it/s][A[A

  6%|▌         | 34/600 [00:02<00:37, 14.93it/s][A[A

  6%|▌         | 36/600 [00:02<00:37, 14.98it/s][A[A

  6%|▋         | 38/600 [00:02<00:37, 15.16it/s]

 49%|████▊     | 292/600 [00:18<00:17, 17.23it/s][A[A

 49%|████▉     | 294/600 [00:18<00:18, 16.66it/s][A[A

 49%|████▉     | 296/600 [00:18<00:18, 16.67it/s][A[A

 50%|████▉     | 298/600 [00:18<00:18, 16.73it/s][A[A

 50%|█████     | 300/600 [00:18<00:17, 16.94it/s][A[A

 50%|█████     | 302/600 [00:18<00:17, 16.98it/s][A[A

 51%|█████     | 304/600 [00:18<00:17, 17.08it/s][A[A

 51%|█████     | 306/600 [00:19<00:17, 17.22it/s][A[A

 51%|█████▏    | 308/600 [00:19<00:17, 17.00it/s][A[A

 52%|█████▏    | 310/600 [00:19<00:16, 17.10it/s][A[A

 52%|█████▏    | 312/600 [00:19<00:16, 17.18it/s][A[A

 52%|█████▏    | 314/600 [00:19<00:16, 17.25it/s][A[A

 53%|█████▎    | 316/600 [00:19<00:16, 17.35it/s][A[A

 53%|█████▎    | 318/600 [00:19<00:16, 17.36it/s][A[A

 53%|█████▎    | 320/600 [00:19<00:16, 17.27it/s][A[A

 54%|█████▎    | 322/600 [00:20<00:16, 17.28it/s][A[A

 54%|█████▍    | 324/600 [00:20<00:15, 17.30it/s][A[A

 54%|█████▍    | 326/600 [00:20

 96%|█████████▋| 578/600 [00:35<00:01, 16.81it/s][A[A

 97%|█████████▋| 580/600 [00:36<00:01, 16.92it/s][A[A

 97%|█████████▋| 582/600 [00:36<00:01, 16.95it/s][A[A

 97%|█████████▋| 584/600 [00:36<00:00, 16.99it/s][A[A

 98%|█████████▊| 586/600 [00:36<00:00, 16.95it/s][A[A

 98%|█████████▊| 588/600 [00:36<00:00, 16.85it/s][A[A

 98%|█████████▊| 590/600 [00:36<00:00, 16.92it/s][A[A

 99%|█████████▊| 592/600 [00:36<00:00, 17.06it/s][A[A

 99%|█████████▉| 594/600 [00:36<00:00, 16.87it/s][A[A

 99%|█████████▉| 596/600 [00:37<00:00, 16.24it/s][A[A

100%|█████████▉| 598/600 [00:37<00:00, 16.08it/s][A[A

100%|██████████| 600/600 [00:37<00:00, 16.02it/s][A[A

  0%|          | 0/600 [00:00<?, ?it/s][A[A

  1%|          | 5/600 [00:00<00:14, 41.52it/s][A[A

AVG BATCH ACCUMULATED TRAIN LOSS: 0.5714592970907688




  2%|▏         | 10/600 [00:00<00:14, 41.24it/s][A[A

  2%|▏         | 13/600 [00:00<00:16, 35.98it/s][A[A

  3%|▎         | 16/600 [00:00<00:17, 33.77it/s][A[A

  3%|▎         | 20/600 [00:00<00:17, 34.12it/s][A[A

  4%|▍         | 24/600 [00:00<00:16, 34.11it/s][A[A

  5%|▍         | 28/600 [00:00<00:16, 34.37it/s][A[A

  5%|▌         | 32/600 [00:00<00:16, 33.54it/s][A[A

  6%|▌         | 36/600 [00:01<00:16, 34.99it/s][A[A

  7%|▋         | 40/600 [00:01<00:15, 35.95it/s][A[A

  7%|▋         | 44/600 [00:01<00:15, 36.50it/s][A[A

  8%|▊         | 48/600 [00:01<00:14, 36.90it/s][A[A

  9%|▊         | 52/600 [00:01<00:14, 37.00it/s][A[A

  9%|▉         | 56/600 [00:01<00:14, 37.34it/s][A[A

 10%|█         | 60/600 [00:01<00:14, 37.65it/s][A[A

 11%|█         | 64/600 [00:01<00:14, 37.33it/s][A[A

 11%|█▏        | 68/600 [00:01<00:14, 36.85it/s][A[A

 12%|█▏        | 72/600 [00:01<00:14, 36.61it/s][A[A

 13%|█▎        | 76/600 [00:02<00:14, 36.47it/

TRAIN LOSS: 0.39679430122176806




  8%|▊         | 8/100 [00:00<00:02, 33.47it/s][A[A

 12%|█▏        | 12/100 [00:00<00:02, 34.57it/s][A[A

 16%|█▌        | 16/100 [00:00<00:02, 36.00it/s][A[A

 20%|██        | 20/100 [00:00<00:02, 36.30it/s][A[A

 24%|██▍       | 24/100 [00:00<00:02, 36.85it/s][A[A

 28%|██▊       | 28/100 [00:00<00:01, 37.67it/s][A[A

 33%|███▎      | 33/100 [00:00<00:01, 38.45it/s][A[A

 37%|███▋      | 37/100 [00:00<00:01, 38.39it/s][A[A

 41%|████      | 41/100 [00:01<00:01, 37.59it/s][A[A

 45%|████▌     | 45/100 [00:01<00:01, 37.10it/s][A[A

 49%|████▉     | 49/100 [00:01<00:01, 37.11it/s][A[A

 53%|█████▎    | 53/100 [00:01<00:01, 37.52it/s][A[A

 57%|█████▋    | 57/100 [00:01<00:01, 38.09it/s][A[A

 61%|██████    | 61/100 [00:01<00:01, 37.75it/s][A[A

 65%|██████▌   | 65/100 [00:01<00:00, 37.46it/s][A[A

 69%|██████▉   | 69/100 [00:01<00:00, 37.81it/s][A[A

 74%|███████▍  | 74/100 [00:01<00:00, 38.64it/s][A[A

 78%|███████▊  | 78/100 [00:02<00:00, 38.69it/s

VAL LOSS: 0.37770835891366006


FileNotFoundError: [Errno 2] No such file or directory: '<ipython-input-31-6e3f288e0d40>'