In [None]:
import torch
from torchvision import datasets, transforms
from pyhessian import hessian # Hessian computation
from pytorchcv.model_provider import get_model as ptcv_get_model # model
import numpy as np
from scipy.sparse import lil_matrix
from scipy.sparse.linalg import expm, expm_multiply
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

In [None]:

def getData(name='cifar10', train_bs=128, test_bs=1000):
    """
    Get the dataloader
    """
    if name == 'cifar10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        trainset = datasets.CIFAR10(root='../data',
                                    train=True,
                                    download=True,
                                    transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=True)

        testset = datasets.CIFAR10(root='../data',
                                   train=False,
                                   download=False,
                                   transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False)
    if name == 'cifar10_without_dataaugmentation':
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        trainset = datasets.CIFAR10(root='../data',
                                    train=True,
                                    download=True,
                                    transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=True)

        testset = datasets.CIFAR10(root='../data',
                                   train=False,
                                   download=False,
                                   transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False)

    return train_loader, test_loader

def hessian_vector_product(gradsH, params, v):
    """
    compute the hessian vector product of Hv, where
    gradsH is the gradient at the current point,
    params is the corresponding variables,
    v is the vector.
    """
    hv = torch.autograd.grad(gradsH,
                             params,
                             grad_outputs=v,
                             only_inputs=True,
                             retain_graph=True)
    return hv
def group_product(xs, ys):
    """
    the inner product of two lists of variables xs,ys
    :param xs:
    :param ys:
    :return:
    """
    return sum([torch.sum(x * y) for (x, y) in zip(xs, ys)])

In [None]:

def RademacherVector(n):
    """
        Input:
            n: num of components
        Output:
            vec: random vector from Rademacher distribution
    """

    v = [
        torch.randint_like(p, high=2)
        for p in n
    ]
    for v_i in v:
        v_i[v_i == 0] = -1

    return v

def Hutchinson(oracle, n, l: int):
    """
        Input:
            oracle: oracle for implicit matrix-vector multiplication with A
            n: size of the matrix A
            l: number of iteration to approximate the trace of the matrix A
        Output:
            approximation: approximation to the trace of A
    """

    assert l >= 0

    approximation = 0

    for iter in range(l):
        g = RademacherVector(n)
        approximation += group_product(oracle(g), g).cpu().item()

    return approximation / l

def SquaredFrobenius(oracle, n: int, l: int):
    """
        Input:
            oracle: oracle for implicit matrix-vector multiplication with A
            n: size of the matrix A
            l: number of iteration to approximate the frobenius norm of the matrix A
        Output:
            approximation: approximation to the frobenius norm of A
    """

    assert l >= 0

    approximation = 0
    
    for iter in range(l):
        g = oracle(RademacherVector(n))
        approximation += group_product(g, g).cpu().item()
    
    return approximation[0, 0] / l

def write_array(arr, file_name):
    f = open(file_name, 'w')
    for i in arr:
        f.write(str(i) + " ")
    f.write("\n")
    f.close()

In [None]:

def DeltaShiftRestart(prev, cur, iter, last_app, n: int, l0: int, l: int, q: int):
    """
        DeltaShift algorithm, but restart every q iterations
        Input:
            oracles: oracles for implicit matrix-vector multiplication with A1, ..., Am
            n: matrix size
            l0: number of iteration to approximate the trace of the matrix A1
            l: number of iteration to approximate the trace of the other matrixes
            q: number of iterations to restart
        Output:
            approximation: list of approximations to the trace of A1, ..., Am
    """

    if iter % q == 0:
        return Hutchinson(cur, n, l0)
    else:
        t = last_app
        for iter in range(l):
            g = RademacherVector(n)
            t += (group_product(cur(g), g).cpu().item() - group_product(prev(g), g).cpu().item()) / l
        return t


def DeltaShift(prev, cur, var, iter, last_app, n: int, l0: int, l: int):
    if iter == 0:
        approximation = 0
        N = 0

        for _ in range(l0):
            g = RademacherVector(n)
            z = cur(g)
            approximation += group_product(z, g).cpu().item()
            N += group_product(z, z).cpu().item()

        approximation /= l0
        N /= l0
        variance = 2 / l0 * N
        return approximation, variance
    else:
        approximation = 0
        z = []
        w = []
        g = []
        for j in range(l):
            current = RademacherVector(n)
            z.append(prev(current))
            w.append(cur(current))
            g.append(current)

        N = 0
        M = 0
        C = 0
        for j in range(l):
            N += group_product(z[j], z[j]).cpu().item() / l
            M += group_product(w[j], w[j]).cpu().item() / l
            C += group_product(w[j], z[j]).cpu().item() / l

        gamma = 1 - (2 * C) / (l * var + 2 * N)

        t = (1 - gamma) * last_app
        for j in range(l):
            t += (group_product(g[j], w[j]).cpu().item() - (1 - gamma) * group_product(g[j], z[j]).cpu().item()) / l

        variance = (1 - gamma)**2 * var + 2 / l * (M + (1 - gamma)**2 * N - 2 * (1 - gamma) * C)
        return t, variance

In [None]:
def GetTraceAndFrob(oracle, n, l: int):
    """
        Input:
            oracle: oracle for implicit matrix-vector multiplication with A
            n: size of the matrix A
            l: number of iteration to approximate the trace of the matrix A
        Output:
            approximation: approximation to the trace and to the squared frob-norm of A
    """

    assert l >= 0

    approximation_tr = 0
    approximation_fr = 0

    for iter in range(l):
        g = RademacherVector(n)
        t = oracle(g)
        approximation_tr += group_product(t, g).cpu().item()
        approximation_fr += group_product(t, t).cpu().item()

    return approximation_tr / l, approximation_fr / l

def DeltaShiftFrob(prev, cur, iter, last_app_e, last_app_h, n: int, l0: int, l: int):
    if iter == 0:
        approximation = 0

        for _ in range(l0):
            g = cur(RademacherVector(n))
            approximation += group_product(g, g).cpu().item()
        return approximation / l0
    else:
        easy_way = last_app_e
        hard_way = last_app_h
        for _ in range(l // 2):
            g1 = prev(RademacherVector(n))
            g2 = cur(RademacherVector(n))
            easy_way += 2 * group_product(g1, g2 - g1).cpu().item() / (l // 2)
            hard_way += 2 * group_product(g1, g2 - g1).cpu().item() / (l // 2)
            hard_way += group_product(g2 - g1, g2 - g1).cpu().item() / (l // 2)
        
        return easy_way, hard_way

In [None]:

def CheckAlgorithms(oracles, n, correct_ans, title = None):
    """
        Generate graphic with relative error
        Input:
            oracles: list of oracle to compute matrix-vector multiplication
            n: matrix size
            correct_ans: list of the trace of the given matrixes
            title: title of the graphic
        Output:
            -
    """

    l0 = 100
    l = 50

    correct_ans = np.array(correct_ans)

    simpl_hutchinson = np.array(SimpleHutchinson(oracles, n, l))
    write_array(abs(correct_ans - simpl_hutchinson) / max(correct_ans), 'Hutchinson.txt')

    print("1")

    delta_shift = np.array(ParameterFreeDeltaShift(oracles, n, l0, l - 50 // (len(oracles) - 1)))
    write_array(abs(correct_ans - delta_shift) / max(correct_ans), 'Deltshift.txt')

    print("2")

    delta_shift_r = np.array(DeltaShiftRestart(oracles, n, l0, l - 50 // 19, 20))
    write_array(abs(correct_ans - delta_shift_r) / max(correct_ans), 'Deltshiftrest.txt')


In [None]:

# get the model
import datetime


model = ptcv_get_model("resnet20_cifar10", pretrained=True)
# change the model to eval mode to disable running stats upate
model.eval()

# create loss function
criterion = torch.nn.CrossEntropyLoss()

# get dataset
train_loader, test_loader = getData('cifar10_without_dataaugmentation')

correct_ans = []
prev = None
varience = 0
last_app = 0
last_appr = 0

last_ser = 0
last_me = 0

model = model.cuda()

for ab, i in zip(train_loader, range(len(train_loader))):
    if i > 70:
        break
    inputs, targets = ab
    hessian_comp = hessian(model, criterion, data=(inputs, targets), cuda=True)

    now = datetime.datetime.now()

    true_trace, true_frob = GetTraceAndFrob((lambda x: hessian_vector_product(hessian_comp.gradsH, hessian_comp.params, x)), hessian_comp.params, 1000)
    hutch_trace, hutch_frob = GetTraceAndFrob((lambda x: hessian_vector_product(hessian_comp.gradsH, hessian_comp.params, x)), hessian_comp.params, 50)
    delt_shift, var = DeltaShift((lambda x: hessian_vector_product(prev.gradsH, prev.params, x)),
                            (lambda x: hessian_vector_product(hessian_comp.gradsH, hessian_comp.params, x)),
                            varience, i, last_app, hessian_comp.params, 100, 50)
    delt_shift_r = DeltaShiftRestart((lambda x: hessian_vector_product(prev.gradsH, prev.params, x)),
                            (lambda x: hessian_vector_product(hessian_comp.gradsH, hessian_comp.params, x)),
                          i, last_appr, hessian_comp.params, 100, 50 - (50 // 19), 20)
    sereja, me = DeltaShiftFrob((lambda x: hessian_vector_product(prev.gradsH, prev.params, x)),
                            (lambda x: hessian_vector_product(hessian_comp.gradsH, hessian_comp.params, x)),
                          i, last_ser, last_me, hessian_comp.params, 100, 50)
    prev = hessian_comp
    last_app = delt_shift
    last_appr = delt_shift_r
    varience = var
    last_ser = sereja
    last_me = me

    correct_ans.append((true_trace, hutch_trace, delt_shift, delt_shift_r, true_frob, hutch_frob, sereja, me))
    print(true_trace, hutch_trace, delt_shift, delt_shift_r, true_frob, hutch_frob, sereja, me, datetime.datetime.now() - now)

print(correct_ans)
write_array(correct_ans, "correct.txt")


Files already downloaded and verified
889.7223922424316 846.8184045410156 961.093782043457 927.6364559936524 0:00:38.750567


KeyboardInterrupt: 