<a href="https://colab.research.google.com/github/nmonson1/mnist_exploration/blob/main/basic_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn, optim
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import os
import pandas as pd
from torchvision.io import read_image
import time, tqdm
import numpy as np

In [None]:
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=transforms.Compose(
        [transforms.Resize([32,32]),
         transforms.RandomRotation(5),
         transforms.RandomCrop(28),
         transforms.ToTensor()]
    )
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=transforms.Compose(
        [transforms.Resize([32,32]),
         transforms.CenterCrop(28),
         transforms.ToTensor()]
    )
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 199064955.32it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 112058921.21it/s]

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz



100%|██████████| 1648877/1648877 [00:00<00:00, 106624701.62it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 21872019.25it/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



In [None]:
n_epochs = 20
batch_size_train = 64
batch_size_test = 64
learning_rate = 0.0001
log_interval = 100000
device='cuda:0'

In [None]:
train_loader = torch.utils.data.DataLoader(
    training_data,
    batch_size=batch_size_train,
    num_workers=2,
    shuffle=True)

test_loader = torch.utils.data.DataLoader(
    test_data,
    batch_size=batch_size_test,
    num_workers=2,
    shuffle=False)

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [None]:
network = NeuralNetwork().to(device)

optimizer = optim.Adam(
    network.parameters(),
    lr = learning_rate
)

params = optimizer.param_groups[0]['params']

In [None]:
def train(epoch):
    start=time.time()
    network.train()
    train_loss=[]
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    for batch_idx, (data, target) in enumerate(tqdm.tqdm(train_loader)):
        data = data.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        output = network(data)
        #loss = F.binary_cross_entropy_with_logits(torch.squeeze(output), target)
        #print("output = " + str(output))
        #print("target = " +  str(target))
        loss = nn.CrossEntropyLoss()(output, target)
        train_loss.append(loss.item())
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == -1:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            #train_losses.append(loss.item())
            #train_counter.append(
            #    (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))

    tot_time = time.time()-start

    print('\nTrain set: Avg. loss: {:.4f}, Time: {:.1f} secs \n'.format(np.mean(train_loss), tot_time ))

    return np.mean(train_loss)

In [None]:
#correct!
def test():
    CM=np.zeros((10,10))
    start=time.time()
    network.eval()
    test_loss = []
    correct = 0
    device = 'cuda:0' #if torch.cuda.is_available() else 'cpu'
    with torch.no_grad():
        for data, target in tqdm.tqdm(test_loader):
            data = data.to(device)
            target = target.to(device)
            output = network(data)
            #test_loss += F.binary_cross_entropy_with_logits(torch.squeeze(output), target).item()
            test_loss.append(nn.CrossEntropyLoss()(output, target).item())
            pred = output.data.max(1, keepdim=True)[1]
            for t,p in zip(target.cpu().numpy(),pred.cpu().numpy().flatten()):
                CM[t,p]+=1
    correct=np.sum(np.diag(CM))
    tot_time = time.time()-start
    torch.save(network.state_dict(), './model_wo_augmentation.pth')
    torch.save(optimizer.state_dict(), './optimizer_wo_augmentation.pth')
    acc=100. * correct / len(test_loader.dataset)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} - {:.2f}%, Time: {:.1f} secs\n'.format(np.mean(test_loss), int(correct), len(test_loader.dataset), acc,  tot_time))
    return CM, test_loss, acc

In [None]:
device='cuda:0'
network.to(device)
train_losses = []
test_losses = []
test_acc = []
CMs = []

for epoch in range(n_epochs):
    print('Epoch ' + str(epoch + 1))
    torch.save(network.state_dict(), './model.pth')
    torch.save(optimizer.state_dict(), './optimizer.pth')

    CM, test_loss, acc = test()
    test_losses.append(test_loss)
    CMs.append(CM)
    test_acc.append(acc)
    print(CM)
    t_l = train(epoch+1)
    train_losses.append(t_l)

100%|██████████| 938/938 [00:18<00:00, 50.68it/s]



Train set: Avg. loss: 0.1214, Time: 18.5 secs 

Epoch 8


100%|██████████| 157/157 [00:02<00:00, 56.14it/s]



Test set: Avg. loss: 0.0726, Accuracy: 9768/10000 - 97.68%, Time: 2.8 secs

[[9.710e+02 0.000e+00 1.000e+00 0.000e+00 0.000e+00 1.000e+00 3.000e+00
  1.000e+00 2.000e+00 1.000e+00]
 [0.000e+00 1.124e+03 4.000e+00 0.000e+00 0.000e+00 0.000e+00 1.000e+00
  1.000e+00 5.000e+00 0.000e+00]
 [5.000e+00 0.000e+00 1.010e+03 3.000e+00 1.000e+00 0.000e+00 1.000e+00
  4.000e+00 8.000e+00 0.000e+00]
 [0.000e+00 0.000e+00 2.000e+00 9.790e+02 0.000e+00 8.000e+00 0.000e+00
  9.000e+00 6.000e+00 6.000e+00]
 [0.000e+00 0.000e+00 0.000e+00 0.000e+00 9.630e+02 0.000e+00 3.000e+00
  2.000e+00 1.000e+00 1.300e+01]
 [3.000e+00 0.000e+00 0.000e+00 9.000e+00 1.000e+00 8.700e+02 3.000e+00
  1.000e+00 0.000e+00 5.000e+00]
 [3.000e+00 2.000e+00 1.000e+00 1.000e+00 6.000e+00 4.000e+00 9.400e+02
  0.000e+00 1.000e+00 0.000e+00]
 [0.000e+00 4.000e+00 9.000e+00 0.000e+00 1.000e+00 0.000e+00 0.000e+00
  1.000e+03 3.000e+00 1.100e+01]
 [4.000e+00 3.000e+00 1.000e+00 2.000e+00 5.000e+00 9.000e+00 4.000e+00
  2.000e+00

100%|██████████| 938/938 [00:18<00:00, 51.06it/s]



Train set: Avg. loss: 0.1108, Time: 18.4 secs 

Epoch 9


100%|██████████| 157/157 [00:02<00:00, 78.10it/s]



Test set: Avg. loss: 0.0689, Accuracy: 9790/10000 - 97.90%, Time: 2.0 secs

[[9.660e+02 0.000e+00 1.000e+00 0.000e+00 4.000e+00 1.000e+00 3.000e+00
  1.000e+00 3.000e+00 1.000e+00]
 [0.000e+00 1.126e+03 3.000e+00 0.000e+00 0.000e+00 0.000e+00 1.000e+00
  1.000e+00 4.000e+00 0.000e+00]
 [3.000e+00 1.000e+00 1.012e+03 2.000e+00 1.000e+00 0.000e+00 0.000e+00
  4.000e+00 9.000e+00 0.000e+00]
 [0.000e+00 0.000e+00 4.000e+00 9.930e+02 0.000e+00 0.000e+00 0.000e+00
  5.000e+00 6.000e+00 2.000e+00]
 [0.000e+00 0.000e+00 0.000e+00 0.000e+00 9.710e+02 0.000e+00 2.000e+00
  2.000e+00 1.000e+00 6.000e+00]
 [2.000e+00 0.000e+00 0.000e+00 1.800e+01 2.000e+00 8.580e+02 4.000e+00
  1.000e+00 5.000e+00 2.000e+00]
 [1.000e+00 4.000e+00 1.000e+00 0.000e+00 7.000e+00 2.000e+00 9.410e+02
  0.000e+00 2.000e+00 0.000e+00]
 [0.000e+00 2.000e+00 9.000e+00 2.000e+00 1.000e+00 0.000e+00 0.000e+00
  1.001e+03 3.000e+00 1.000e+01]
 [2.000e+00 1.000e+00 3.000e+00 1.000e+00 7.000e+00 4.000e+00 0.000e+00
  2.000e+00

100%|██████████| 938/938 [00:22<00:00, 41.51it/s]



Train set: Avg. loss: 0.1033, Time: 22.6 secs 

Epoch 10


100%|██████████| 157/157 [00:02<00:00, 61.14it/s]



Test set: Avg. loss: 0.0641, Accuracy: 9793/10000 - 97.93%, Time: 2.6 secs

[[9.720e+02 0.000e+00 0.000e+00 0.000e+00 0.000e+00 1.000e+00 3.000e+00
  1.000e+00 3.000e+00 0.000e+00]
 [0.000e+00 1.124e+03 3.000e+00 0.000e+00 0.000e+00 0.000e+00 1.000e+00
  1.000e+00 6.000e+00 0.000e+00]
 [5.000e+00 0.000e+00 1.010e+03 4.000e+00 0.000e+00 0.000e+00 2.000e+00
  3.000e+00 8.000e+00 0.000e+00]
 [0.000e+00 0.000e+00 0.000e+00 1.000e+03 0.000e+00 3.000e+00 0.000e+00
  2.000e+00 4.000e+00 1.000e+00]
 [0.000e+00 0.000e+00 0.000e+00 0.000e+00 9.500e+02 0.000e+00 4.000e+00
  3.000e+00 2.000e+00 2.300e+01]
 [2.000e+00 0.000e+00 0.000e+00 1.100e+01 0.000e+00 8.710e+02 2.000e+00
  1.000e+00 2.000e+00 3.000e+00]
 [2.000e+00 1.000e+00 0.000e+00 0.000e+00 2.000e+00 2.000e+00 9.500e+02
  0.000e+00 1.000e+00 0.000e+00]
 [1.000e+00 3.000e+00 9.000e+00 7.000e+00 1.000e+00 0.000e+00 0.000e+00
  9.910e+02 3.000e+00 1.300e+01]
 [4.000e+00 1.000e+00 2.000e+00 4.000e+00 1.000e+00 5.000e+00 3.000e+00
  2.000e+00

100%|██████████| 938/938 [00:20<00:00, 46.83it/s]



Train set: Avg. loss: 0.0978, Time: 20.0 secs 

Epoch 11


100%|██████████| 157/157 [00:02<00:00, 71.57it/s]



Test set: Avg. loss: 0.0619, Accuracy: 9800/10000 - 98.00%, Time: 2.2 secs

[[9.730e+02 0.000e+00 0.000e+00 0.000e+00 0.000e+00 1.000e+00 3.000e+00
  1.000e+00 2.000e+00 0.000e+00]
 [1.000e+00 1.121e+03 3.000e+00 1.000e+00 0.000e+00 0.000e+00 1.000e+00
  1.000e+00 7.000e+00 0.000e+00]
 [5.000e+00 2.000e+00 9.990e+02 6.000e+00 0.000e+00 0.000e+00 1.000e+00
  1.000e+01 9.000e+00 0.000e+00]
 [0.000e+00 0.000e+00 0.000e+00 1.000e+03 0.000e+00 1.000e+00 0.000e+00
  4.000e+00 3.000e+00 2.000e+00]
 [0.000e+00 0.000e+00 0.000e+00 0.000e+00 9.580e+02 0.000e+00 4.000e+00
  3.000e+00 1.000e+00 1.600e+01]
 [2.000e+00 0.000e+00 0.000e+00 1.100e+01 0.000e+00 8.680e+02 2.000e+00
  1.000e+00 2.000e+00 6.000e+00]
 [2.000e+00 1.000e+00 0.000e+00 1.000e+00 4.000e+00 2.000e+00 9.470e+02
  0.000e+00 1.000e+00 0.000e+00]
 [0.000e+00 2.000e+00 6.000e+00 3.000e+00 1.000e+00 0.000e+00 0.000e+00
  1.008e+03 3.000e+00 5.000e+00]
 [4.000e+00 1.000e+00 1.000e+00 4.000e+00 1.000e+00 6.000e+00 4.000e+00
  3.000e+00

100%|██████████| 938/938 [00:18<00:00, 50.99it/s]



Train set: Avg. loss: 0.0912, Time: 18.4 secs 

Epoch 12


100%|██████████| 157/157 [00:03<00:00, 48.09it/s]



Test set: Avg. loss: 0.0618, Accuracy: 9806/10000 - 98.06%, Time: 3.3 secs

[[9.720e+02 0.000e+00 1.000e+00 0.000e+00 0.000e+00 1.000e+00 3.000e+00
  1.000e+00 2.000e+00 0.000e+00]
 [1.000e+00 1.129e+03 2.000e+00 0.000e+00 0.000e+00 0.000e+00 1.000e+00
  0.000e+00 2.000e+00 0.000e+00]
 [5.000e+00 3.000e+00 1.009e+03 1.000e+00 2.000e+00 0.000e+00 1.000e+00
  5.000e+00 6.000e+00 0.000e+00]
 [0.000e+00 0.000e+00 0.000e+00 1.000e+03 0.000e+00 0.000e+00 0.000e+00
  5.000e+00 3.000e+00 2.000e+00]
 [0.000e+00 0.000e+00 0.000e+00 0.000e+00 9.700e+02 0.000e+00 2.000e+00
  3.000e+00 1.000e+00 6.000e+00]
 [4.000e+00 2.000e+00 0.000e+00 1.800e+01 1.000e+00 8.570e+02 4.000e+00
  2.000e+00 3.000e+00 1.000e+00]
 [1.000e+00 2.000e+00 0.000e+00 1.000e+00 6.000e+00 1.000e+00 9.470e+02
  0.000e+00 0.000e+00 0.000e+00]
 [0.000e+00 5.000e+00 7.000e+00 1.000e+00 3.000e+00 0.000e+00 0.000e+00
  1.005e+03 3.000e+00 4.000e+00]
 [7.000e+00 1.000e+00 2.000e+00 4.000e+00 5.000e+00 3.000e+00 3.000e+00
  3.000e+00

100%|██████████| 938/938 [00:18<00:00, 51.18it/s]



Train set: Avg. loss: 0.0861, Time: 18.3 secs 

Epoch 13


100%|██████████| 157/157 [00:02<00:00, 74.88it/s]



Test set: Avg. loss: 0.0564, Accuracy: 9823/10000 - 98.23%, Time: 2.1 secs

[[9.710e+02 0.000e+00 0.000e+00 0.000e+00 0.000e+00 1.000e+00 5.000e+00
  0.000e+00 3.000e+00 0.000e+00]
 [0.000e+00 1.127e+03 2.000e+00 0.000e+00 0.000e+00 0.000e+00 1.000e+00
  1.000e+00 4.000e+00 0.000e+00]
 [3.000e+00 3.000e+00 1.018e+03 1.000e+00 0.000e+00 0.000e+00 0.000e+00
  3.000e+00 4.000e+00 0.000e+00]
 [0.000e+00 0.000e+00 1.000e+00 9.970e+02 0.000e+00 3.000e+00 0.000e+00
  3.000e+00 2.000e+00 4.000e+00]
 [0.000e+00 1.000e+00 1.000e+00 0.000e+00 9.540e+02 0.000e+00 5.000e+00
  4.000e+00 2.000e+00 1.500e+01]
 [2.000e+00 0.000e+00 0.000e+00 1.000e+01 0.000e+00 8.740e+02 2.000e+00
  1.000e+00 0.000e+00 3.000e+00]
 [1.000e+00 2.000e+00 0.000e+00 0.000e+00 1.000e+00 5.000e+00 9.480e+02
  0.000e+00 1.000e+00 0.000e+00]
 [0.000e+00 5.000e+00 9.000e+00 5.000e+00 0.000e+00 0.000e+00 0.000e+00
  1.002e+03 3.000e+00 4.000e+00]
 [3.000e+00 1.000e+00 1.000e+00 5.000e+00 0.000e+00 4.000e+00 0.000e+00
  2.000e+00

100%|██████████| 938/938 [00:19<00:00, 48.55it/s]



Train set: Avg. loss: 0.0810, Time: 19.3 secs 

Epoch 14


100%|██████████| 157/157 [00:02<00:00, 75.12it/s]



Test set: Avg. loss: 0.0533, Accuracy: 9829/10000 - 98.29%, Time: 2.1 secs

[[9.710e+02 0.000e+00 0.000e+00 0.000e+00 0.000e+00 1.000e+00 4.000e+00
  1.000e+00 3.000e+00 0.000e+00]
 [1.000e+00 1.128e+03 3.000e+00 0.000e+00 0.000e+00 0.000e+00 1.000e+00
  1.000e+00 1.000e+00 0.000e+00]
 [1.000e+00 0.000e+00 1.020e+03 1.000e+00 0.000e+00 0.000e+00 0.000e+00
  4.000e+00 6.000e+00 0.000e+00]
 [0.000e+00 0.000e+00 1.000e+00 9.990e+02 0.000e+00 2.000e+00 0.000e+00
  4.000e+00 2.000e+00 2.000e+00]
 [0.000e+00 0.000e+00 1.000e+00 0.000e+00 9.690e+02 0.000e+00 3.000e+00
  2.000e+00 1.000e+00 6.000e+00]
 [2.000e+00 0.000e+00 0.000e+00 1.500e+01 0.000e+00 8.670e+02 2.000e+00
  1.000e+00 1.000e+00 4.000e+00]
 [3.000e+00 2.000e+00 0.000e+00 1.000e+00 3.000e+00 3.000e+00 9.460e+02
  0.000e+00 0.000e+00 0.000e+00]
 [0.000e+00 2.000e+00 1.000e+01 2.000e+00 1.000e+00 0.000e+00 0.000e+00
  1.006e+03 3.000e+00 4.000e+00]
 [5.000e+00 1.000e+00 4.000e+00 4.000e+00 2.000e+00 3.000e+00 1.000e+00
  3.000e+00

100%|██████████| 938/938 [00:18<00:00, 50.78it/s]



Train set: Avg. loss: 0.0782, Time: 18.5 secs 

Epoch 15


100%|██████████| 157/157 [00:02<00:00, 55.39it/s]



Test set: Avg. loss: 0.0521, Accuracy: 9830/10000 - 98.30%, Time: 2.8 secs

[[9.700e+02 0.000e+00 0.000e+00 0.000e+00 1.000e+00 1.000e+00 3.000e+00
  1.000e+00 3.000e+00 1.000e+00]
 [0.000e+00 1.128e+03 2.000e+00 0.000e+00 0.000e+00 0.000e+00 1.000e+00
  1.000e+00 3.000e+00 0.000e+00]
 [3.000e+00 2.000e+00 1.001e+03 2.000e+00 6.000e+00 0.000e+00 0.000e+00
  1.000e+01 8.000e+00 0.000e+00]
 [0.000e+00 0.000e+00 0.000e+00 9.960e+02 0.000e+00 2.000e+00 0.000e+00
  3.000e+00 2.000e+00 7.000e+00]
 [0.000e+00 0.000e+00 0.000e+00 0.000e+00 9.690e+02 0.000e+00 2.000e+00
  3.000e+00 1.000e+00 7.000e+00]
 [2.000e+00 0.000e+00 0.000e+00 1.300e+01 0.000e+00 8.700e+02 1.000e+00
  1.000e+00 2.000e+00 3.000e+00]
 [1.000e+00 2.000e+00 0.000e+00 1.000e+00 7.000e+00 5.000e+00 9.420e+02
  0.000e+00 0.000e+00 0.000e+00]
 [0.000e+00 2.000e+00 3.000e+00 2.000e+00 0.000e+00 0.000e+00 0.000e+00
  1.014e+03 3.000e+00 4.000e+00]
 [3.000e+00 1.000e+00 1.000e+00 1.000e+00 3.000e+00 3.000e+00 2.000e+00
  3.000e+00

100%|██████████| 938/938 [00:18<00:00, 50.87it/s]



Train set: Avg. loss: 0.0732, Time: 18.4 secs 

Epoch 16


100%|██████████| 157/157 [00:02<00:00, 72.93it/s]



Test set: Avg. loss: 0.0501, Accuracy: 9845/10000 - 98.45%, Time: 2.2 secs

[[9.750e+02 0.000e+00 0.000e+00 0.000e+00 0.000e+00 1.000e+00 2.000e+00
  0.000e+00 2.000e+00 0.000e+00]
 [1.000e+00 1.131e+03 2.000e+00 0.000e+00 0.000e+00 0.000e+00 1.000e+00
  0.000e+00 0.000e+00 0.000e+00]
 [2.000e+00 0.000e+00 1.019e+03 1.000e+00 1.000e+00 0.000e+00 0.000e+00
  4.000e+00 5.000e+00 0.000e+00]
 [0.000e+00 0.000e+00 1.000e+00 1.000e+03 0.000e+00 1.000e+00 0.000e+00
  2.000e+00 3.000e+00 3.000e+00]
 [0.000e+00 0.000e+00 1.000e+00 0.000e+00 9.600e+02 0.000e+00 3.000e+00
  3.000e+00 0.000e+00 1.500e+01]
 [2.000e+00 1.000e+00 0.000e+00 1.700e+01 0.000e+00 8.650e+02 2.000e+00
  1.000e+00 3.000e+00 1.000e+00]
 [1.000e+00 2.000e+00 0.000e+00 1.000e+00 3.000e+00 2.000e+00 9.470e+02
  0.000e+00 2.000e+00 0.000e+00]
 [0.000e+00 2.000e+00 9.000e+00 2.000e+00 0.000e+00 0.000e+00 0.000e+00
  1.005e+03 3.000e+00 7.000e+00]
 [3.000e+00 1.000e+00 5.000e+00 2.000e+00 0.000e+00 2.000e+00 2.000e+00
  2.000e+00

100%|██████████| 938/938 [00:19<00:00, 48.10it/s]



Train set: Avg. loss: 0.0694, Time: 19.5 secs 

Epoch 17


100%|██████████| 157/157 [00:02<00:00, 75.15it/s]



Test set: Avg. loss: 0.0527, Accuracy: 9823/10000 - 98.23%, Time: 2.1 secs

[[9.750e+02 0.000e+00 0.000e+00 0.000e+00 0.000e+00 1.000e+00 2.000e+00
  1.000e+00 1.000e+00 0.000e+00]
 [1.000e+00 1.127e+03 2.000e+00 0.000e+00 0.000e+00 0.000e+00 2.000e+00
  1.000e+00 2.000e+00 0.000e+00]
 [5.000e+00 0.000e+00 1.008e+03 2.000e+00 2.000e+00 0.000e+00 1.000e+00
  9.000e+00 5.000e+00 0.000e+00]
 [0.000e+00 0.000e+00 1.000e+00 9.920e+02 0.000e+00 3.000e+00 0.000e+00
  5.000e+00 0.000e+00 9.000e+00]
 [0.000e+00 0.000e+00 1.000e+00 0.000e+00 9.650e+02 0.000e+00 3.000e+00
  2.000e+00 0.000e+00 1.100e+01]
 [2.000e+00 0.000e+00 0.000e+00 1.200e+01 0.000e+00 8.710e+02 2.000e+00
  1.000e+00 0.000e+00 4.000e+00]
 [2.000e+00 2.000e+00 0.000e+00 1.000e+00 5.000e+00 2.000e+00 9.460e+02
  0.000e+00 0.000e+00 0.000e+00]
 [0.000e+00 2.000e+00 7.000e+00 1.000e+00 2.000e+00 0.000e+00 0.000e+00
  1.013e+03 1.000e+00 2.000e+00]
 [8.000e+00 1.000e+00 1.000e+00 2.000e+00 3.000e+00 3.000e+00 4.000e+00
  5.000e+00

100%|██████████| 938/938 [00:18<00:00, 49.84it/s]



Train set: Avg. loss: 0.0684, Time: 18.8 secs 

Epoch 18


100%|██████████| 157/157 [00:02<00:00, 58.91it/s]



Test set: Avg. loss: 0.0471, Accuracy: 9845/10000 - 98.45%, Time: 2.7 secs

[[9.690e+02 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 5.000e+00
  0.000e+00 6.000e+00 0.000e+00]
 [0.000e+00 1.125e+03 2.000e+00 0.000e+00 0.000e+00 0.000e+00 2.000e+00
  1.000e+00 5.000e+00 0.000e+00]
 [0.000e+00 0.000e+00 1.016e+03 1.000e+00 2.000e+00 0.000e+00 0.000e+00
  4.000e+00 9.000e+00 0.000e+00]
 [0.000e+00 0.000e+00 0.000e+00 9.950e+02 0.000e+00 5.000e+00 0.000e+00
  3.000e+00 4.000e+00 3.000e+00]
 [0.000e+00 0.000e+00 1.000e+00 1.000e+00 9.690e+02 0.000e+00 3.000e+00
  3.000e+00 0.000e+00 5.000e+00]
 [2.000e+00 0.000e+00 0.000e+00 9.000e+00 0.000e+00 8.760e+02 2.000e+00
  1.000e+00 1.000e+00 1.000e+00]
 [1.000e+00 2.000e+00 0.000e+00 1.000e+00 3.000e+00 4.000e+00 9.470e+02
  0.000e+00 0.000e+00 0.000e+00]
 [0.000e+00 2.000e+00 6.000e+00 1.000e+00 1.000e+00 0.000e+00 0.000e+00
  1.014e+03 3.000e+00 1.000e+00]
 [1.000e+00 0.000e+00 1.000e+00 2.000e+00 0.000e+00 2.000e+00 2.000e+00
  2.000e+00

100%|██████████| 938/938 [00:18<00:00, 50.08it/s]



Train set: Avg. loss: 0.0638, Time: 18.7 secs 

Epoch 19


100%|██████████| 157/157 [00:02<00:00, 76.10it/s]



Test set: Avg. loss: 0.0529, Accuracy: 9822/10000 - 98.22%, Time: 2.1 secs

[[9.760e+02 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 2.000e+00
  0.000e+00 2.000e+00 0.000e+00]
 [0.000e+00 1.133e+03 1.000e+00 0.000e+00 0.000e+00 0.000e+00 1.000e+00
  0.000e+00 0.000e+00 0.000e+00]
 [2.000e+00 3.000e+00 1.017e+03 1.000e+00 1.000e+00 0.000e+00 0.000e+00
  5.000e+00 3.000e+00 0.000e+00]
 [0.000e+00 0.000e+00 2.000e+00 9.940e+02 0.000e+00 2.000e+00 0.000e+00
  4.000e+00 0.000e+00 8.000e+00]
 [0.000e+00 0.000e+00 0.000e+00 0.000e+00 9.560e+02 0.000e+00 2.000e+00
  3.000e+00 0.000e+00 2.100e+01]
 [2.000e+00 1.000e+00 0.000e+00 1.400e+01 0.000e+00 8.690e+02 2.000e+00
  1.000e+00 0.000e+00 3.000e+00]
 [4.000e+00 2.000e+00 1.000e+00 1.000e+00 5.000e+00 3.000e+00 9.410e+02
  0.000e+00 0.000e+00 1.000e+00]
 [0.000e+00 2.000e+00 5.000e+00 2.000e+00 0.000e+00 0.000e+00 0.000e+00
  1.015e+03 1.000e+00 3.000e+00]
 [6.000e+00 1.000e+00 4.000e+00 9.000e+00 1.000e+00 5.000e+00 2.000e+00
  4.000e+00

100%|██████████| 938/938 [00:19<00:00, 48.07it/s]



Train set: Avg. loss: 0.0615, Time: 19.5 secs 

Epoch 20


100%|██████████| 157/157 [00:02<00:00, 73.52it/s]



Test set: Avg. loss: 0.0504, Accuracy: 9835/10000 - 98.35%, Time: 2.1 secs

[[9.750e+02 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 2.000e+00
  1.000e+00 2.000e+00 0.000e+00]
 [0.000e+00 1.128e+03 3.000e+00 0.000e+00 0.000e+00 0.000e+00 1.000e+00
  1.000e+00 2.000e+00 0.000e+00]
 [0.000e+00 0.000e+00 1.023e+03 0.000e+00 1.000e+00 0.000e+00 1.000e+00
  4.000e+00 3.000e+00 0.000e+00]
 [0.000e+00 0.000e+00 3.000e+00 1.000e+03 0.000e+00 1.000e+00 0.000e+00
  3.000e+00 1.000e+00 2.000e+00]
 [0.000e+00 0.000e+00 1.000e+00 1.000e+00 9.610e+02 0.000e+00 3.000e+00
  5.000e+00 0.000e+00 1.100e+01]
 [2.000e+00 0.000e+00 0.000e+00 2.400e+01 0.000e+00 8.620e+02 1.000e+00
  2.000e+00 1.000e+00 0.000e+00]
 [2.000e+00 2.000e+00 1.000e+00 1.000e+00 4.000e+00 3.000e+00 9.430e+02
  0.000e+00 2.000e+00 0.000e+00]
 [0.000e+00 2.000e+00 8.000e+00 1.000e+00 0.000e+00 0.000e+00 0.000e+00
  1.012e+03 2.000e+00 3.000e+00]
 [3.000e+00 1.000e+00 4.000e+00 6.000e+00 0.000e+00 2.000e+00 1.000e+00
  4.000e+00

100%|██████████| 938/938 [00:19<00:00, 47.97it/s]


Train set: Avg. loss: 0.0602, Time: 19.6 secs 






In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
total_loss = torch.Tensor(1)
total_loss = total_loss.to(device)
for batch_idx, (data, target) in enumerate(tqdm.tqdm(train_loader)):
    data = data.to(device)
    #print(data)
    #break
    target = target.to(device)

    output = network(data)
    loss = nn.CrossEntropyLoss()(output, target)
    #loss.to("cpu")
    #print("tl", total_loss.device)
    #print("l", loss.device)
    total_loss += loss
    loss.backward()
    optimizer.step()
    if batch_idx % log_interval == -1:
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(data), len(train_loader.dataset),
            100. * batch_idx / len(train_loader), loss.item()))


100%|██████████| 938/938 [00:20<00:00, 44.95it/s]


In [None]:
loss.grad

  loss.grad


In [None]:
"""okay, our params are at or near a local min. What is the hessian at that min?
take output, apply .backward
inspect model.param, apply .grad
 """
env_grads = torch.autograd.grad(total_loss, params, retain_graph=True, create_graph=True)

print(env_grads[0])
hess_params = torch.zeros_like(env_grads[0])
for i in range(env_grads[0].size(0)):
    for j in range(env_grads[0].size(1)):
        hess_params[i, j] = torch.autograd.grad(env_grads[0][i][j], params, retain_graph=True)[0][i, j] #  <--- error here
print( hess_params )

RuntimeError: ignored