In [1]:
import os
import copy
from typing import Tuple
from pathlib import Path
import random
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
import h5py
import torch
from torchvision.datasets import MNIST
from torchvision.datasets import SVHN
from torch.utils.data.dataset import TensorDataset
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from data_fashion import FashionMNIST, FASHION_LABELS
from data_svhn import SVHN_LABELS
from data_generic import load_dataset
from utilities import iterate_dataset, get_loss_and_acc, get_gd_optimizer, compute_losses, get_hessian_eigenvalues
from archs import load_architecture

from data_coloured_mnist import AugmentTensorDataset, cifar_transform, flatten, make_labels, standardize, unflatten
from data_cifar_c import load_cifar_corrupted, CIFAR10C

In [2]:
main_dir = str(Path().resolve().parent)
if not "RESULTS" in os.environ:
    os.environ["RESULTS"] = os.path.join(main_dir, "results")
    results_dir = os.environ["RESULTS"]
if not "DATASETS" in os.environ:
    os.environ["DATASETS"] = os.path.join(main_dir, "data")
    data_dir = os.environ["DATASETS"]

In [3]:
d = CIFAR10(data_dir)
dt = CIFAR10(data_dir, train = False)

In [4]:
svhn_train = SVHN(root=data_dir, download=True, split="train")

dataset_name = "svhn"
loss = "mse"
svhn_train_dataset, svhn_test_dataset = load_dataset(dataset_name, loss)
physical_batch_size = 1000

Using downloaded and verified file: /home/mateuszpyla/stan/sharpness/data/train_32x32.mat
Using downloaded and verified file: /home/mateuszpyla/stan/sharpness/data/train_32x32.mat
Using downloaded and verified file: /home/mateuszpyla/stan/sharpness/data/extra_32x32.mat
Using downloaded and verified file: /home/mateuszpyla/stan/sharpness/data/test_32x32.mat


In [5]:
svhn_train = SVHN(root=data_dir, download=True, split="train")

dataset_name = "cifar10"
loss = "mse"
cifar10_train_dataset, cifar10_test_dataset = load_dataset(dataset_name, loss)
physical_batch_size = 1000
cutted_dataset_name = "cifar10-5k"
cifar10_train_cutted_dataset = load_dataset(cutted_dataset_name, loss)

Using downloaded and verified file: /home/mateuszpyla/stan/sharpness/data/train_32x32.mat
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [6]:
arch_id = "resnet9"
dynamic = False
network = load_architecture(arch_id, dataset_name, dynamic).cuda()

loss_fn, acc_fn = get_loss_and_acc(loss)
loss_fn.__setattr__("individual", False)
acc_fn.__setattr__("individual", False)

In [7]:
params = filter(lambda p: p.requires_grad, network.parameters())
opt = "sgd"
lr = 0.01
epochs = 20
momentum = 0.0

optimizer = get_gd_optimizer(params, opt, lr, momentum)

In [13]:
train_loss_s, train_acc_s, test_loss_s, test_acc_s = [], [], [], []
eigs = []
networks = []

for e in range(epochs):
    for i, batch in enumerate(iterate_dataset(cifar10_train_dataset, physical_batch_size)):
        print(f"epoch {e} iter {i}")
        (X, y) = batch
        X, y = X.cuda(), y.cuda()
        
        loss = loss_fn(network(X), y) / len(X)
        loss.backward()

        train_loss, train_acc = compute_losses(network, [loss_fn, acc_fn], cifar10_train_dataset, physical_batch_size)
        test_loss, test_acc = compute_losses(network, [loss_fn, acc_fn], cifar10_test_dataset, physical_batch_size)
        current_eigs = get_hessian_eigenvalues(network, loss_fn, cifar10_train_cutted_dataset, neigs=2, physical_batch_size=100)

        train_loss_s.append(train_loss)
        train_acc_s.append(train_acc)
        test_loss_s.append(test_loss)
        test_acc_s.append(test_acc)
        eigs.append(current_eigs)
        networks.append(copy.deepcopy(network))

In [None]:
plt.plot(train_loss_s)

[]

In [None]:
checkpoint_id = 10
network_checkpoint = networks[checkpoint_id]
