### Import headers

In [7]:
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
import sys
sys.path.append("..")

import argparse
import os
import time
from dataclasses import dataclass

import numpy as np
import torch
import torch.backends.cudnn as cudnn
from torch import nn

from dataloader import cifar10, mnist
from models import FullyConnectedNet, LeNet
from src import hessians, lanczos, regularization

device = "cuda" if torch.cuda.is_available() else "cpu"
root = ".."

### Define a wrapper function

This makes my life more easier

In [14]:
def compute_eigvals(data, target, path:int, network: str, loss:str, alpha: float, num_eigval: int=500):
    # Network configuration
    if network == "LeNet":
        net = LeNet().to(device)
    else:
        net = FullyConnectedNet(28 * 28, 8, 10, 3, 0.1).to(device)
        flatten = True
    net_name = net.__class__.__name__

    if device == "cuda":
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    # Load checkpoint.
    assert os.path.isfile(
        f"{root}/checkpoints/Figure_2/{path}/{net_name}/{loss}/ckpt_{alpha}.pth"
    ), "Error: no checkpoint file found!: "+ f"checkpoints/Figure_2/{path}/{net_name}/{loss}/ckpt_{alpha}.pth"
    checkpoint = torch.load(f"{root}/checkpoints/Figure_2/{path}/{net_name}/{loss}/ckpt_{alpha}.pth")
    net.load_state_dict(checkpoint["net"])
    assert (
        checkpoint["alpha"] == alpha
    ), "Error: alpha is not equal to checkpoint value!"
    assert (
        checkpoint["criterion"] == loss
    ), "Error: loss is not equal to checkpoint value!"

    # Loss configuration
    if loss == "cross_entropy":
        criterion = regularization.RegularizedLoss(
            net, nn.CrossEntropyLoss(), alpha
        )
        one_hot = False
    else:
        criterion = regularization.RegularizedLoss(net, nn.MSELoss(), alpha)
        one_hot = True
        
    loss = criterion(net(data.to(device)), target.to(device))

    start = time.time()
    if device == "cuda":
        eigvals_lanczos = lanczos.lanczos(
            loss,
            net,
            num_eigenthings=num_eigval,
            tol=0,
            use_gpu=True,
        )
    else:
        eigvals_lanczos = lanczos.lanczos(
            loss,
            net,
            num_eigenthings=num_eigval,
            tol=0,
        )

    return np.sum(eigvals_lanczos < -1e-8)

### Call dataloaders

In [15]:
#LeNet, CIFAR10
#FCN, MNIST
#loss: cross_entropy, mse
#Models: 1-10
#alpha 0.00 - 0.25, 0.01

num_eigval = 400

# Data
batch_size = 512
num_workers = 2

data_loader_mnist = mnist.MNISTDataLoader(batch_size, num_workers, one_hot=False, flatten=True)
_, _, test_loader_mnist_ce = data_loader_mnist.get_data_loaders()

data_loader_mnist = mnist.MNISTDataLoader(batch_size, num_workers, one_hot=True, flatten=True)
_, _, test_loader_mnist_mse = data_loader_mnist.get_data_loaders()

data_loader_cifar10 = cifar10.CIFAR10DataLoader(batch_size, num_workers, one_hot=False)
_, _, test_loader_cifar10_ce = data_loader_cifar10.get_data_loaders()

data_loader_cifar10 = cifar10.CIFAR10DataLoader(batch_size, num_workers, one_hot=True)
_, _, test_loader_cifar10_mse = data_loader_cifar10.get_data_loaders()

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [16]:
def recursive_computation(data_loader, network, loss):
    ratio_list = []
    for alpha in range(8, 26):
        alpha /= 100.
        num_negative = []
        for path  in range(1,101): #26
            for batch in range(1):
                data, target = next(iter(data_loader))
                result = compute_eigvals(data, target, path, network, loss, alpha, num_eigval)
                num_negative.append(result / num_eigval)
                print(f"{path = }, {batch = }, {result = }")
        average = sum(num_negative)/len(num_negative)*100
        print(alpha, f"{average = :.2f} %")
        ratio_list.append(average)
    return ratio_list

In [17]:
ratio_list_mnist_ce = recursive_computation(test_loader_mnist_mse, "FCN", "mse")

path = 1, batch = 0, result = 0
path = 2, batch = 0, result = 0
path = 3, batch = 0, result = 0


ArpackError: ARPACK error 3: No shifts could be applied during a cycle of the Implicitly restarted Arnoldi iteration. One possibility is to increase the size of NCV relative to NEV. 

In [8]:
ratio_list_mnist_mse = recursive_computation(test_loader_mnist_ce, "FCN", "cross_entropy")

path = 1, batch = 0, result = 79


KeyboardInterrupt: 

In [None]:
import matplotlib as plt

W_reg = np.arange(0, 0.26, 0.01).tolist()

plt.subplot(1,2,1)
plt.plot(W_reg, ratio_list_mnist_ce, label="MNIST/CE")
plt.xlabel('Regulaization Weight')
plt.ylabel('Negative Eigen Value Rate')
plt.legend()
plt.subplot(1,2,2)
plt.plot(W_reg, ratio_list_mnist_mse, label="MNIST/MSE")
plt.xlabel('Regulaization Weight')
plt.ylabel('Negative Eigen Value Rate')
plt.legend()

plt.show()