In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader

import numpy as np
from pyhessian import hessian
from tqdm import tqdm
import pandas as pd
import os
import matplotlib.pyplot as plt

from train_mlp import muMLPTab9

device = "cuda"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
def get_cifar(batch_size=128, num_classes=10, MSE=False, on_gpu=False, device=None):
    assert np.unique(targets[indices]).shape[0] >= num_classes, f"Number of classes {np.unique(targets[indices]).shape[0]} != {num_classes}"
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    
    train_ds = datasets.CIFAR10(root='/tmp', train=True, download=False, transform=transform)
    targets = np.array(train_ds.targets)
    mask = np.isin(targets, np.arange(num_classes))
    indices = np.where(mask)[0]


    X, y = [], []
    for i in tqdm(indices):
        x, y_ = train_ds[i]
        X.append(x)
        y.append(y_)
    X = torch.stack(X)
    y = torch.tensor(y)

    if MSE:
        y = F.one_hot(y, num_classes=num_classes).float()

    if on_gpu:
        assert device is not None, "Please provide a device="
        X = X.to(device)
        y = y.to(device)

    tensor_ds = TensorDataset(X, y)
    train_dl = DataLoader(tensor_ds, batch_size=batch_size, shuffle=True, pin_memory=not on_gpu)

    if on_gpu:
        print(f"Estimated size of the dataset in MB: {(X.numel() * X.element_size() + y.numel() * y.element_size()) / 1024 / 1024:.2f}")

    return train_dl, tensor_ds


In [42]:
seed = 1
epochs = 5
classes = 2

# Tensors loaded on GPU per batch

In [43]:
dl, ds = get_cifar(batch_size=128, num_classes=classes, MSE=False, on_gpu=False)
print(len(dl))

torch.manual_seed(seed)
np.random.seed(seed)
print(next(iter(dl))[1].shape)
model = muMLPTab9(128, classes).to(device)
criterion = nn.CrossEntropyLoss()

model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
for epoch in range(epochs):
    epoch_loss = 0
    for i, (X, y) in enumerate(dl):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(X)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * X.size(0)
    print(epoch_loss / len(dl.dataset))

100%|██████████| 10000/10000 [00:02<00:00, 4600.45it/s]


79
torch.Size([128])
0.6877436098098755
0.6252798287391662
0.5944725264549255
0.5714316897392273
0.553438679933548


In [44]:
torch.manual_seed(seed)
np.random.seed(seed)
print(next(iter(dl))[1].shape)
model = muMLPTab9(128, classes).to(device)
criterion = nn.CrossEntropyLoss()

model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
for epoch in range(epochs):
    epoch_loss = 0
    for i, (X, y) in enumerate(dl):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(X)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * X.size(0)
    print(epoch_loss / len(dl.dataset))

torch.Size([128])
0.6877436098098755
0.6252798287391662
0.5944725264549255
0.5714316897392273
0.553438679933548


# Tensors on GPU

In [46]:
dl, ds = get_cifar(batch_size=128, num_classes=classes, MSE=False, on_gpu=True, device=device)
print(len(dl))

torch.manual_seed(seed)
np.random.seed(seed)
print(next(iter(dl))[1].shape)
model = muMLPTab9(128, classes).to(device)
criterion = nn.CrossEntropyLoss()

model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
for epoch in range(epochs):
    epoch_loss = 0
    for i, (X, y) in enumerate(dl):
        optimizer.zero_grad()
        out = model(X)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * X.size(0)
    print(epoch_loss / len(dl.dataset))

100%|██████████| 10000/10000 [00:02<00:00, 4246.00it/s]


Estimated size of the dataset in MB: 117.26
79
torch.Size([128])
0.6877436098098755
0.6252798287391662
0.5944725264549255
0.5714316897392273
0.553438679933548


In [47]:
torch.manual_seed(seed)
np.random.seed(seed)
print(next(iter(dl))[1].shape)
model = muMLPTab9(128, classes).to(device)
criterion = nn.CrossEntropyLoss()

model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
for epoch in range(epochs):
    epoch_loss = 0
    for i, (X, y) in enumerate(dl):
        optimizer.zero_grad()
        out = model(X)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * X.size(0)
    print(epoch_loss / len(dl.dataset))

torch.Size([128])
0.6877436098098755
0.6252798287391662
0.5944725264549255
0.5714316897392273
0.553438679933548


# MSE + on GPU

In [48]:
dl, ds = get_cifar(batch_size=128, num_classes=classes, MSE=True, on_gpu=True, device=device)
print(len(dl))

torch.manual_seed(seed)
np.random.seed(seed)
print(next(iter(dl))[1].shape)
model = muMLPTab9(128, classes).to(device)
criterion = nn.MSELoss()

model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
for epoch in range(epochs):
    epoch_loss = 0
    for i, (X, y) in enumerate(dl):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(X)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * X.size(0)
    print(epoch_loss / len(dl.dataset))

100%|██████████| 10000/10000 [00:02<00:00, 4303.08it/s]


Estimated size of the dataset in MB: 117.26
79
torch.Size([128, 2])
0.6868212818145752
0.4809396454811096
0.4108572193145752
0.36817205924987795
0.33871090376377105


In [49]:
torch.manual_seed(seed)
np.random.seed(seed)
print(next(iter(dl))[1].shape)
model = muMLPTab9(128, classes).to(device)
criterion = nn.MSELoss()

model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
for epoch in range(epochs):
    epoch_loss = 0
    for i, (X, y) in enumerate(dl):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(X)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * X.size(0)
    print(epoch_loss / len(dl.dataset))

torch.Size([128, 2])
0.6868212818145752
0.4809396454811096
0.4108572193145752
0.36817205924987795
0.33871090376377105
