<a href="https://colab.research.google.com/github/fallnlove/dynamic-trace-estimation/blob/main/nn_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pytorchcv
!pip install pyhessian
!pip install torchvision

Collecting pytorchcv
  Downloading pytorchcv-0.0.67-py2.py3-none-any.whl (532 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m532.4/532.4 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pytorchcv
Successfully installed pytorchcv-0.0.67
Collecting pyhessian
  Downloading PyHessian-0.1-py3-none-any.whl (6.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->pyhessian)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->pyhessian)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->pyhessian)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch->pyhessian)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cubl

In [None]:
import torch
from torchvision import datasets, transforms
from pyhessian import hessian # Hessian computation
from pytorchcv.model_provider import get_model as ptcv_get_model # model
import numpy as np
from scipy.sparse import lil_matrix
from scipy.sparse.linalg import expm, expm_multiply
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

In [None]:

def getData(name='cifar10', train_bs=128, test_bs=1000):
    """
    Get the dataloader
    """
    if name == 'cifar10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        trainset = datasets.CIFAR10(root='../data',
                                    train=True,
                                    download=True,
                                    transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=True)

        testset = datasets.CIFAR10(root='../data',
                                   train=False,
                                   download=False,
                                   transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False)
    if name == 'cifar10_without_dataaugmentation':
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        trainset = datasets.CIFAR10(root='../data',
                                    train=True,
                                    download=True,
                                    transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=True)

        testset = datasets.CIFAR10(root='../data',
                                   train=False,
                                   download=False,
                                   transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False)

    return train_loader, test_loader

def hessian_vector_product(gradsH, params, v):
    """
    compute the hessian vector product of Hv, where
    gradsH is the gradient at the current point,
    params is the corresponding variables,
    v is the vector.
    """
    hv = torch.autograd.grad(gradsH,
                             params,
                             grad_outputs=v,
                             only_inputs=True,
                             retain_graph=True)
    return hv
def group_product(xs, ys):
    """
    the inner product of two lists of variables xs,ys
    :param xs:
    :param ys:
    :return:
    """
    return sum([torch.sum(x * y) for (x, y) in zip(xs, ys)])

In [None]:

def RademacherVector(n):
    """
        Input:
            n: num of components
        Output:
            vec: random vector from Rademacher distribution
    """

    v = [
        torch.randint_like(p, high=2)
        for p in n
    ]
    for v_i in v:
        v_i[v_i == 0] = -1

    return v

def Hutchinson(oracle, n, l: int):
    """
        Input:
            oracle: oracle for implicit matrix-vector multiplication with A
            n: size of the matrix A
            l: number of iteration to approximate the trace of the matrix A
        Output:
            approximation: approximation to the trace of A
    """

    assert l >= 0

    approximation = 0

    for iter in range(l):
        g = RademacherVector(n)
        approximation += group_product(oracle(g), g).cpu().item()

    return approximation / l

def SimpleHutchinson(oracles, n: int, l: int):
    """
        Input:
            oracles: oracles for implicit matrix-vector multiplication with A1, ..., Am
            n: matrix size
            l: number of iteration to approximate the trace of the matrixes A1, ..., Am
        Output:
            approximation: list of approximations to the trace of A1, ..., Am
    """
    assert len(oracles) > 0

    approximation = []

    for i in range(len(oracles)):
        approximation.append(Hutchinson(oracles[i], n[i], l))

    return approximation

def DeltaShiftRestart(prev, cur, iter, last_app, n: int, l0: int, l: int, q: int):
    """
        DeltaShift algorithm, but restart every q iterations
        Input:
            oracles: oracles for implicit matrix-vector multiplication with A1, ..., Am
            n: matrix size
            l0: number of iteration to approximate the trace of the matrix A1
            l: number of iteration to approximate the trace of the other matrixes
            q: number of iterations to restart
        Output:
            approximation: list of approximations to the trace of A1, ..., Am
    """
    assert len(oracles) > 0

    if iter % q == 0:
        return Hutchinson(cur, n, l0)
    else:
        t = last_app
        for iter in range(l):
            g = RademacherVector(n)
            t += (group_product(cur(g), g).cpu().item() - group_product(prev(g), g).cpu().item()) / l
        return t


def DeltShift(prev, cur, var, iter, last_app, n: int, l0: int, l: int):
    if iter == 0:
        approximation = 0
        N = 0

        for _ in range(l0):
            g = RademacherVector(n)
            z = cur(g)
            approximation += group_product(z, g).cpu().item()
            N += group_product(z, z).cpu().item()

        approximation /= l0
        N /= l0
        variance = 2 / l0 * N
        return approximation, variance
    else:
        approximation = 0
        z = []
        w = []
        g = []
        for j in range(l):
            current = RademacherVector(n)
            z.append(prev(current))
            w.append(cur(current))
            g.append(current)

        N = 0
        M = 0
        C = 0
        for j in range(l):
            N += group_product(z[j], z[j]).cpu().item() / l
            M += group_product(w[j], w[j]).cpu().item() / l
            C += group_product(w[j], z[j]).cpu().item() / l

        gamma = 1 - (2 * C) / (l * var + 2 * N)

        t = (1 - gamma) * last_app
        for j in range(l):
            t += (group_product(g[j], w[j]).cpu().item() - (1 - gamma) * group_product(g[j], z[j]).cpu().item()) / l

        variance = (1 - gamma)**2 * var + 2 / l * (M + (1 - gamma)**2 * N - 2 * (1 - gamma) * C)
        return t, variance

def ParameterFreeDeltaShift(oracles, n: int, l0: int, l: int):
    approximation = [0]
    N = 0

    for _ in range(l0):
        g = RademacherVector(n[0])
        z = oracles[0](g)
        approximation[0] += group_product(z, g).cpu().item()
        N += group_product(z, z).cpu().item()

    approximation[0] /= l0
    N /= l0
    variance = 2 / l0 * N

    l //= 2

    for i in range(1, len(oracles)):
        z = []
        w = []
        g = []
        for j in range(l):
            cur = RademacherVector(n[i])
            z.append(oracles[i - 1](cur))
            w.append(oracles[i](cur))
            g.append(cur)

        N = 0
        M = 0
        C = 0
        for j in range(l):
            N += group_product(z[j], z[j]).cpu().item() / l
            M += group_product(w[j], w[j]).cpu().item() / l
            C += group_product(w[j], z[j]).cpu().item() / l

        gamma = 1 - (2 * C) / (l * variance + 2 * N)

        t = (1 - gamma) * approximation[i - 1]
        for j in range(l):
            t += (group_product(g[j], w[j]).cpu().item() - (1 - gamma) * group_product(g[j], z[j]).cpu().item()) / l

        approximation.append(t)
        variance = (1 - gamma)**2 * variance + 2 / l * (M + (1 - gamma)**2 * N - 2 * (1 - gamma) * C)

    return approximation

def write_array(arr, file_name):
    f = open(file_name, 'w')
    for i in arr:
        f.write(str(i) + " ")
    f.write("\n")
    f.close()

In [None]:

def CheckAlgorithms(oracles, n, correct_ans, title = None):
    """
        Generate graphic with relative error
        Input:
            oracles: list of oracle to compute matrix-vector multiplication
            n: matrix size
            correct_ans: list of the trace of the given matrixes
            title: title of the graphic
        Output:
            -
    """

    l0 = 100
    l = 50

    correct_ans = np.array(correct_ans)

    simpl_hutchinson = np.array(SimpleHutchinson(oracles, n, l))
    write_array(abs(correct_ans - simpl_hutchinson) / max(correct_ans), 'Hutchinson.txt')

    print("1")

    delta_shift = np.array(ParameterFreeDeltaShift(oracles, n, l0, l - 50 // (len(oracles) - 1)))
    write_array(abs(correct_ans - delta_shift) / max(correct_ans), 'Deltshift.txt')

    print("2")

    delta_shift_r = np.array(DeltaShiftRestart(oracles, n, l0, l - 50 // 19, 20))
    write_array(abs(correct_ans - delta_shift_r) / max(correct_ans), 'Deltshiftrest.txt')


In [None]:

# get the model
import datetime


model = ptcv_get_model("resnet20_cifar10", pretrained=True)
# change the model to eval mode to disable running stats upate
model.eval()

# create loss function
criterion = torch.nn.CrossEntropyLoss()

# get dataset
train_loader, test_loader = getData()

computable = []

oracles = []
correct_ans = []
n = []
prev = None
varience = 0
last_app = 0
last_appr = 0

model = model.cuda()

for ab, i in zip(train_loader, range(len(train_loader))):
    if i > 100:
        break
    inputs, targets = ab
    hessian_comp = hessian(model, criterion, data=(inputs, targets), cuda=True)

    now = datetime.datetime.now()

    a = Hutchinson((lambda x: hessian_vector_product(hessian_comp.gradsH, hessian_comp.params, x)), hessian_comp.params, 1000)
    b = Hutchinson((lambda x: hessian_vector_product(hessian_comp.gradsH, hessian_comp.params, x)), hessian_comp.params, 50)
    # DeltShift(prev, cur, var, iter, last_app, n: int, l0: int, l: int):
    c, var = DeltShift((lambda x: hessian_vector_product(prev.gradsH, prev.params, x)),
                            (lambda x: hessian_vector_product(hessian_comp.gradsH, hessian_comp.params, x)),
                            varience, i, last_app, hessian_comp.params, 100, 50)
    # DeltaShiftRestart(prev, cur, iter, last_app, n: int, l0: int, l: int, q: int)
    d = DeltaShiftRestart((lambda x: hessian_vector_product(prev.gradsH, prev.params, x)),
                            (lambda x: hessian_vector_product(hessian_comp.gradsH, hessian_comp.params, x)),
                          i, last_appr, hessian_comp.params, 100, 50 - (50 // 19), 20)
    prev = hessian_comp
    last_app = c
    last_appr = d
    varience = var

    correct_ans.append((a, b, c, d))
    print(a, b, c, d, datetime.datetime.now() - now)

print(correct_ans)
write_array(correct_ans, "correct.txt")

from google.colab import files
files.download('/content/correct.txt')


Files already downloaded and verified
1016.102216430664 1004.3760083007812 1037.2214050292969 0:01:37.344413
1657.5451815185547 1561.9753857421874 1732.5197295675084 0:03:08.085710
2316.594673828125 2210.121203613281 2366.6816897662798 0:04:42.468303
3652.5810687255857 3725.522126464844 3648.8112411813004 0:06:17.149740
4603.965220947266 4681.194135742187 4551.139233795997 0:07:51.899252
5190.488402832031 5184.853696289062 5129.135526561055 0:09:26.159801
6018.601088623047 5916.0640234375 5945.104070452603 0:11:00.347810
6955.381074707031 6503.560244140625 6842.339725218982 0:12:34.340973
7556.683923339844 7247.705903320312 7507.508256970812 0:14:08.517283
8243.149080566407 8041.317470703125 8237.571609283194 0:15:44.177495
9120.014256347657 8869.059091796875 8892.109327512677 0:17:17.173275


KeyboardInterrupt: 