# Finetuning on MNIST

Practical example of exponentially tilted finetuning for MNIST: train on a very different MNIST distribution, and use exponential tilting to fine-tune on uniform test data.


In [1]:
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.core.pylabtools import figsize

import seaborn as sns
import plotly.express as px

import numpy as np
import pandas as pd

import polars as pl

import statsmodels.formula.api as smf

import torch

In [2]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose, Normalize
from functools import partial

from torch.utils.data import TensorDataset, Subset

mnist_train_data = MNIST('~/Datasets/', download=True, train=True,
                        transform=Compose(
                          [ToTensor(), Normalize(0., 1.),
                           partial(torch.reshape, shape=(-1,))
                           ])
                       )

mnist_test_data = MNIST('~/Datasets/', download=True, train=False,
                        transform=Compose(
                          [ToTensor(), Normalize(0., 1.),
                           partial(torch.reshape, shape=(-1,))
                           ])
                       )

## First try

In [3]:
num_per_bin = torch.bincount(mnist_train_data.targets)
baseline = num_per_bin.max()

In [4]:
# Resample unbalanced data distribution

num_per_bin = torch.bincount(mnist_train_data.targets)
num_to_sample = num_per_bin.max() - num_per_bin

unbalace = torch.tensor( 5*[2*baseline] + 5*[0])

num_to_sample += unbalace

extra_inds = []

for i in range(10):
    snum = num_to_sample[i]
    all_inds = torch.nonzero(mnist_train_data.targets == i).squeeze()
    conditional_inds = np.random.choice(num_per_bin[i], size=(num_to_sample[i],))
    extra_inds.append(all_inds[conditional_inds])

extra_inds.append(torch.arange(len(mnist_train_data)))

all_inds = torch.concat(extra_inds)

print(torch.bincount(mnist_train_data.targets[all_inds]))

tensor([20226, 20226, 20226, 20226, 20226,  6742,  6742,  6742,  6742,  6742])


In [5]:
unbalanced_mnist = Subset(mnist_train_data, all_inds)

Does preserving the exact balance in train and validation data matter in this instance?

In [6]:
targets = mnist_train_data.targets[all_inds]
per_bin = torch.bincount(mnist_train_data.targets[all_inds])


percent_train = 0.95

train_inds = []
val_inds = []

for i in range(10):
    num_total = per_bin[i]
    cutoff = round(num_total.item() * percent_train)

    data_inds = torch.nonzero((targets == i)).squeeze()
    perm = torch.randperm(num_total)
    
    train_inds.append(data_inds[perm[0:cutoff]])
    val_inds.append(data_inds[perm[cutoff:]])


train_inds = torch.concat(train_inds)
val_inds = torch.concat(val_inds)

train_set = Subset(unbalanced_mnist, train_inds)
val_set = Subset(unbalanced_mnist, val_inds)

In [7]:
dl = torch.utils.data.DataLoader(train_set, batch_size=len(train_set))
i, t = next(iter(dl))
print(torch.bincount(t))

tensor([19215, 19215, 19215, 19215, 19215,  6405,  6405,  6405,  6405,  6405])


Ok nice, exact sampling again.

In [8]:
from models.simple_examples import Basic_MNIST
%run trainers.py

In [9]:
mnist_model = Basic_MNIST()

mnist_model.data_train = train_set
mnist_model.data_val = val_set

In [10]:
ckpt = train_model(mnist_model, 'trainedParameters/MNIST_unbalanced/classifier')

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mlrast[0m. Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.

  | Name   | Type             | Params | Mode 
----------------------------------------------------
0 | model  | Sequential       | 22.3 K | train
1 | lossFn | CrossEntropyLoss | 0      | train
----------------------------------------------------
22.3 K    Trainable params
0         Non-trainable params
22.3 K    Total params
0.089     Total estimated model params size (MB)
5         Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

0,1
Train Loss,█▄▁▁▃▂▁▁▂▁▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁
Val Loss,█▇▄▄▃▂▂▁▁▂▁▁▂▁▁▂▂▂▂▂▂▂▂▂▂▃▃▄▃▃▃▃▃▃▅▄▄▄▄▄
Val acc,▁▄▄▅▆▇▇▇▇▇▇▇▇▇█▇▇████████▇████████▇█████
epoch,▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
trainer/global_step,▁▁▁▁▂▂▂▂▃▃▃▄▄▄▄▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇█████

0,1
Train Loss,0.02391
Val Loss,0.0847
Val acc,0.99021
epoch,53.0
trainer/global_step,216215.0


In [11]:
optimal_model = Basic_MNIST.load_from_checkpoint(ckpt)

In [14]:
dl = torch.utils.data.DataLoader(mnist_test_data, batch_size=len(mnist_test_data))
test_inputs, test_targets = next(iter(dl))

In [21]:
decoded = torch.argmax(optimal_model.forward(test_inputs.to(optimal_model.device)), dim=1).cpu()

In [26]:
(decoded == test_targets).sum() / test_targets.shape[0]

tensor(0.9653)

Unfortunately, this still gives very accurate outputs. In some ways, this makes sense: the model is still trained on all of the samples.

## Second try: strictly less data

In [3]:
num_per_bin = torch.bincount(mnist_train_data.targets)

In [4]:
num_per_bin

tensor([5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949])

In [5]:
data_ind = []

for i in range(10):
    all_inds = torch.nonzero(mnist_train_data.targets == i).squeeze()
    
    if i >= 5:
        to_keep = all_inds[0:all_inds.shape[0]//100]
        data_ind.append(to_keep)
    else:
        data_ind.append(all_inds)

data_ind = torch.concat(data_ind)

In [6]:
from torch.utils.data import random_split

unbalanced_mnist = Subset(mnist_train_data, data_ind)
train_set, val_set = random_split(unbalanced_mnist, (0.95, 0.05))

In [7]:
dl = torch.utils.data.DataLoader(unbalanced_mnist, batch_size=len(unbalanced_mnist))
torch.bincount(next(iter(dl))[1])

tensor([5923, 6742, 5958, 6131, 5842,   54,   59,   62,   58,   59])

In [8]:
from models.simple_examples import Basic_MNIST
%run trainers.py

In [9]:
mnist_model = Basic_MNIST()

mnist_model.data_train = train_set
mnist_model.data_val = val_set

### Half the amount of data:

In [9]:
ckpt = train_model(mnist_model, 'trainedParameters/MNIST_unbalanced/classifier2')

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mlrast[0m. Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.

  | Name   | Type             | Params | Mode 
----------------------------------------------------
0 | model  | Sequential       | 22.3 K | train
1 | lossFn | CrossEntropyLoss | 0      | train
----------------------------------------------------
22.3 K    Trainable params
0         Non-trainable params
22.3 K    Total params
0.089     Total estimated model params size (MB)
5         Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

0,1
Train Loss,▅▃█▃▂▂▁▁▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁
Val Loss,█▄▃▂▂▁▁▁▁▁▁▁▁▂▂▂▂▃▂▂▂▃▃▃▃▃▃▄▄▄▄▅▄▄▄▅▅▅▆▆
Val acc,▁▂▄▅▆▆▆▇▇▇██▇▇▇▇▇▇▇▇█▇▇▇▇██▇██▇█▇█▇█████
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇████
trainer/global_step,▁▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇█

0,1
Train Loss,7e-05
Val Loss,0.21848
Val acc,0.96466
epoch,56.0
trainer/global_step,76664.0


In [10]:
optimal_model = Basic_MNIST.load_from_checkpoint(ckpt)
dl = torch.utils.data.DataLoader(mnist_test_data, batch_size=len(mnist_test_data))
test_inputs, test_targets = next(iter(dl))

In [11]:
decoded = torch.argmax(optimal_model.forward(test_inputs.to(optimal_model.device)), dim=1).cpu()
(decoded == test_targets).sum() / test_targets.shape[0]

tensor(0.9606)

Still no. No loss of accuracy.

### One tenth the amount of data

In [10]:
ckpt = train_model(mnist_model, 'trainedParameters/MNIST_unbalanced/classifier3')

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mlrast[0m. Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.

  | Name   | Type             | Params | Mode 
----------------------------------------------------
0 | model  | Sequential       | 22.3 K | train
1 | lossFn | CrossEntropyLoss | 0      | train
----------------------------------------------------
22.3 K    Trainable params
0         Non-trainable params
22.3 K    Total params
0.089     Total estimated model params size (MB)
5         Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

0,1
Train Loss,▄█▃▁▁▅▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Val Loss,▅▄▃▃▂▂▂▂▁▁▁▂▁▂▁▂▂▂▂▂▃▃▃▃▃▃▃▅▄▄▄▄▄▄▅▅▅█▅▅
Val acc,▁▃▄▄▅▅▆▆▆▇▆▆▇▆▇▆▆▇▇▇▇▇█▇▇▇▇█▇▆█▇▇▇▇▇█▅█▇
epoch,▁▁▁▁▂▂▂▃▃▃▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇██
trainer/global_step,▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇█████

0,1
Train Loss,0.00038
Val Loss,0.2016
Val acc,0.97255
epoch,52.0
trainer/global_step,52787.0


In [11]:
optimal_model = Basic_MNIST.load_from_checkpoint(ckpt)
dl = torch.utils.data.DataLoader(mnist_test_data, batch_size=len(mnist_test_data))
test_inputs, test_targets = next(iter(dl))

In [12]:
decoded = torch.argmax(optimal_model.forward(test_inputs.to(optimal_model.device)), dim=1).cpu()
(decoded == test_targets).sum() / test_targets.shape[0]

tensor(0.9235)

Wow. Still pretty good.

### One hundredth.

In [10]:
ckpt = train_model(mnist_model, 'trainedParameters/MNIST_unbalanced/classifier3')

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mlrast[0m. Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /Users/luke/Documents/activeProjects/learnedExpFam/trainedParameters/MNIST_unbalanced/classifier3 exists and is not empty.

  | Name   | Type             | Params | Mode 
----------------------------------------------------
0 | model  | Sequential       | 22.3 K | train
1 | lossFn | CrossEntropyLoss | 0      | train
----------------------------------------------------
22.3 K    Trainable params
0         Non-trainable params
22.3 K    Total params


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

0,1
Train Loss,█▄▄▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Val Loss,█▆▄▃▂▂▂▂▂▁▂▁▂▁▂▂▃▃▃▃▄▃▄▄▃▄▃▄▄▄▅▅▅▅▆▅▅▆▆▆
Val acc,▁▂▅▆▆▆▇▆▆█▇▇█▇██▆▇▇▇▆▇█▇█▇███▆▇█▇█████▇█
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
trainer/global_step,▁▁▁▁▁▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇██████

0,1
Train Loss,7e-05
Val Loss,0.10918
Val acc,0.98575
epoch,50.0
trainer/global_step,46766.0


In [11]:
optimal_model = Basic_MNIST.load_from_checkpoint(ckpt)
dl = torch.utils.data.DataLoader(mnist_test_data, batch_size=len(mnist_test_data))
test_inputs, test_targets = next(iter(dl))

In [12]:
decoded = torch.argmax(optimal_model.forward(test_inputs.to(optimal_model.device)), dim=1).cpu()
(decoded == test_targets).sum() / test_targets.shape[0]

tensor(0.7775)

Aha. Finally, a reasonable decrease in performance. I suppose it tracks that we need such an extreme example for this very simple task.