### Import headers

In [1]:
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
import sys
sys.path.append("..")

import argparse
import os
import time
from dataclasses import dataclass

import numpy as np
import torch
import torch.backends.cudnn as cudnn
from torch import nn

from dataloader import cifar10, mnist
from models import FullyConnectedNet, LeNet
from src import hessians, lanczos, regularization

device = "cuda" if torch.cuda.is_available() else "cpu"
root = ".."

### Define a wrapper function

This makes my life more easier

In [2]:
def compute_eigvals(data, target, path:int, network: str, loss:str, alpha: float, num_eigval: int=800):
    # Network configuration
    if network == "LeNet":
        net = LeNet().to(device)
    else:
        net = FullyConnectedNet(28 * 28, 8, 10, 3, 0.1).to(device)
        flatten = True
    net_name = net.__class__.__name__

    if device == "cuda":
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    # Load checkpoint.
    assert os.path.isfile(
        f"{root}/checkpoints/{path}/{net_name}/{loss}/ckpt_{alpha}.pth"
    ), "Error: no checkpoint file found!: "+ f"checkpoints/{path}/{net_name}/{loss}/ckpt_{alpha}.pth"
    checkpoint = torch.load(f"{root}/checkpoints/{path}/{net_name}/{loss}/ckpt_{alpha}.pth")
    net.load_state_dict(checkpoint["net"])
    assert (
        checkpoint["alpha"] == alpha
    ), "Error: alpha is not equal to checkpoint value!"
    assert (
        checkpoint["criterion"] == loss
    ), "Error: loss is not equal to checkpoint value!"

    # Loss configuration
    if loss == "cross_entropy":
        criterion = regularization.RegularizedLoss(
            net, nn.CrossEntropyLoss(), alpha
        )
        one_hot = False
    else:
        criterion = regularization.RegularizedLoss(net, nn.MSELoss(), alpha)
        one_hot = True
        
    loss = criterion(net(data.to(device)), target.to(device))

    start = time.time()
    if device == "cuda":
        eigvals_lanczos, _ = lanczos.lanczos(
            loss,
            net,
            num_eigenthings=num_eigval,
            tol=0,
            use_gpu=True,
        )
    else:
        eigvals_lanczos, _ = lanczos.lanczos(
            loss,
            net,
            num_eigenthings=num_eigval,
            tol=0,
        )

    return np.sum(eigvals_lanczos < -1e-8)

### Call dataloaders

In [3]:
#LeNet, CIFAR10
#FCN, MNIST
#loss: cross_entropy, mse
#Models: 1-10
#alpha 0.00 - 0.25, 0.01

num_eigval = 400

# Data
batch_size = 512
num_workers = 2

data_loader_mnist = mnist.MNISTDataLoader(batch_size, num_workers, one_hot=False, flatten=True)
_, _, test_loader_mnist_ce = data_loader_mnist.get_data_loaders()

data_loader_mnist = mnist.MNISTDataLoader(batch_size, num_workers, one_hot=True, flatten=True)
_, _, test_loader_mnist_mse = data_loader_mnist.get_data_loaders()

data_loader_cifar10 = cifar10.CIFAR10DataLoader(batch_size, num_workers, one_hot=False)
_, _, test_loader_cifar10_ce = data_loader_cifar10.get_data_loaders()

data_loader_cifar10 = cifar10.CIFAR10DataLoader(batch_size, num_workers, one_hot=True)
_, _, test_loader_cifar10_mse = data_loader_cifar10.get_data_loaders()

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [8]:
def recursive_computation(data_loader, network, loss):
    ratio_list = []
    for alpha in range(0, 26):
        alpha /= 100.
        num_negative = []
        for path  in range(1,6): #26
            for batch in range(1):
                data, target = next(iter(data_loader))
                result = compute_eigvals(data, target, path, network, loss, alpha, num_eigval)
                num_negative.append(result / num_eigval)
                print(f"{path = }, {batch = }, {result = }")
        average = sum(num_negative)/len(num_negative)*100
        print(f"{average = :.2f} %")
        ratio_list.append(average)
    return ratio_list

In [5]:
ratio_list_cifar10_ce = recursive_computation(test_loader_cifar10_ce, "LeNet", "cross_entropy")

path = 1, batch = 0, result = 77
path = 2, batch = 0, result = 98
path = 3, batch = 0, result = 86
path = 4, batch = 0, result = 91
path = 5, batch = 0, result = 76
path = 6, batch = 0, result = 84
path = 7, batch = 0, result = 80
path = 8, batch = 0, result = 98
path = 9, batch = 0, result = 98
path = 10, batch = 0, result = 72
path = 11, batch = 0, result = 73
path = 12, batch = 0, result = 76
path = 13, batch = 0, result = 81
path = 14, batch = 0, result = 76
path = 15, batch = 0, result = 73
path = 16, batch = 0, result = 79
path = 17, batch = 0, result = 84
path = 18, batch = 0, result = 74
path = 19, batch = 0, result = 81
path = 20, batch = 0, result = 102
path = 21, batch = 0, result = 89
path = 22, batch = 0, result = 88
path = 23, batch = 0, result = 69
path = 24, batch = 0, result = 87
path = 25, batch = 0, result = 93
average = 20.85 %
path = 1, batch = 0, result = 87
path = 2, batch = 0, result = 90
path = 3, batch = 0, result = 84
path = 4, batch = 0, result = 98
path = 5

path = 12, batch = 0, result = 108
path = 13, batch = 0, result = 155
path = 14, batch = 0, result = 133
path = 15, batch = 0, result = 137
path = 16, batch = 0, result = 135
path = 17, batch = 0, result = 140
path = 18, batch = 0, result = 139
path = 19, batch = 0, result = 159
path = 20, batch = 0, result = 144
path = 21, batch = 0, result = 125
path = 22, batch = 0, result = 157
path = 23, batch = 0, result = 145
path = 24, batch = 0, result = 135
path = 25, batch = 0, result = 125
average = 35.29 %
path = 1, batch = 0, result = 161
path = 2, batch = 0, result = 169
path = 3, batch = 0, result = 140
path = 4, batch = 0, result = 152
path = 5, batch = 0, result = 160
path = 6, batch = 0, result = 121
path = 7, batch = 0, result = 126
path = 8, batch = 0, result = 145
path = 9, batch = 0, result = 146
path = 10, batch = 0, result = 154
path = 11, batch = 0, result = 129
path = 12, batch = 0, result = 154
path = 13, batch = 0, result = 143
path = 14, batch = 0, result = 125
path = 15, 

path = 19, batch = 0, result = 174
path = 20, batch = 0, result = 182
path = 21, batch = 0, result = 169
path = 22, batch = 0, result = 176
path = 23, batch = 0, result = 149
path = 24, batch = 0, result = 163
path = 25, batch = 0, result = 164
average = 39.79 %
path = 1, batch = 0, result = 163
path = 2, batch = 0, result = 180
path = 3, batch = 0, result = 180
path = 4, batch = 0, result = 185
path = 5, batch = 0, result = 161
path = 6, batch = 0, result = 144
path = 7, batch = 0, result = 175
path = 8, batch = 0, result = 151
path = 9, batch = 0, result = 177
path = 10, batch = 0, result = 170
path = 11, batch = 0, result = 130
path = 12, batch = 0, result = 161
path = 13, batch = 0, result = 158
path = 14, batch = 0, result = 189
path = 15, batch = 0, result = 188
path = 16, batch = 0, result = 182
path = 17, batch = 0, result = 150
path = 18, batch = 0, result = 97
path = 19, batch = 0, result = 0
path = 20, batch = 0, result = 127
path = 21, batch = 0, result = 183
path = 22, bat

In [6]:
ratio_list_cifar10_mse = recursive_computation(test_loader_cifar10_mse, "LeNet", "mse")

path = 1, batch = 0, result = 100
path = 2, batch = 0, result = 88
path = 3, batch = 0, result = 90
path = 4, batch = 0, result = 92
path = 5, batch = 0, result = 88
path = 6, batch = 0, result = 95
path = 7, batch = 0, result = 94
path = 8, batch = 0, result = 68
path = 9, batch = 0, result = 86
path = 10, batch = 0, result = 94
path = 11, batch = 0, result = 71
path = 12, batch = 0, result = 96
path = 13, batch = 0, result = 95
path = 14, batch = 0, result = 96
path = 15, batch = 0, result = 94
path = 16, batch = 0, result = 94
path = 17, batch = 0, result = 95
path = 18, batch = 0, result = 92
path = 19, batch = 0, result = 88
path = 20, batch = 0, result = 101
path = 21, batch = 0, result = 92


KeyboardInterrupt: 

In [None]:
ratio_list_mnist_ce = recursive_computation(test_loader_mnist_mse, "FCN", "mse")

In [None]:
ratio_list_mnist_mse = recursive_computation(test_loader_mnist_ce, "FCN", "cross_entropy")

path = 1, batch = 0, result = 79
path = 2, batch = 0, result = 80
path = 3, batch = 0, result = 80
path = 4, batch = 0, result = 81
path = 5, batch = 0, result = 84
average = 20.20 %
path = 1, batch = 0, result = 79
path = 2, batch = 0, result = 79
path = 3, batch = 0, result = 75
path = 4, batch = 0, result = 80
path = 5, batch = 0, result = 80
average = 19.65 %
path = 1, batch = 0, result = 77
path = 2, batch = 0, result = 73
path = 3, batch = 0, result = 75
path = 4, batch = 0, result = 78
path = 5, batch = 0, result = 78
average = 19.05 %
path = 1, batch = 0, result = 75
path = 2, batch = 0, result = 74
path = 3, batch = 0, result = 74
path = 4, batch = 0, result = 74
path = 5, batch = 0, result = 75
average = 18.60 %
path = 1, batch = 0, result = 73
path = 2, batch = 0, result = 72
path = 3, batch = 0, result = 72
path = 4, batch = 0, result = 72
path = 5, batch = 0, result = 73
average = 18.10 %
path = 1, batch = 0, result = 71
path = 2, batch = 0, result = 70
path = 3, batch = 0

In [None]:
import matplotlib as plt

W_reg = np.arange(0, 0.26, 0.01).tolist()

plt.subplot(2,2,1)
plt.plot(W_reg, ratio_list_cifar10_ce, label="CIFAR10/CE")
plt.xlabel('Regulaization Weight')
plt.ylabel('Negative Eigen Value Rate')
plt.legend()
plt.subplot(2,2,2)
plt.plot(W_reg, ratio_list_cifar10_mse, label="CIFAR10/MSE")
plt.xlabel('Regulaization Weight')
plt.ylabel('Negative Eigen Value Rate')
plt.legend()
plt.subplot(2,2,3)
plt.plot(W_reg, ratio_list_mnist_ce, label="MNIST/CE")
plt.xlabel('Regulaization Weight')
plt.ylabel('Negative Eigen Value Rate')
plt.legend()
plt.subplot(2,2,4)
plt.plot(W_reg, ratio_list_mnist_mse, label="MNIST/MSE")
plt.xlabel('Regulaization Weight')
plt.ylabel('Negative Eigen Value Rate')
plt.legend()

plt.show()