In [1]:
import numpy as np
from matplotlib import pyplot as plt
from scipy.sparse import lil_matrix
from scipy.sparse.linalg import expm, expm_multiply
from pylanczos import PyLanczos
%matplotlib inline

Реализуем все алгоритмы из статьи [Dynamic trace estimation](https://arxiv.org/abs/2110.13752). Сначала все же начнем с вспомогательных функций

In [2]:
def RademacherVector(n: int):
    """
        Input:
            n: num of components
        Output:
            vec: random vector from Rademacher distribution
    """
    assert n > 0

    vec = np.random.binomial(1, 1/2, (n, 1))
    vec[vec == 0] = -1

    return vec

def RademacherMatrix(shape):
    """
        Input:
            shape: shape of the matrix
        Output:
            mat: random matrix from Rademacher distribution
    """
    mat = np.random.binomial(1, 1/2, shape)
    mat[mat == 0] = -1

    return mat

def Hutchinson(oracle, n: int, l: int):
    """
        Input:
            oracle: oracle for implicit matrix-vector multiplication with A
            n: size of the matrix A
            l: number of iteration to approximate the trace of the matrix A
        Output:
            approximation: approximation to the trace of A
    """

    assert l >= 0

    approximation = 0
    
    for iter in range(l):
        g = RademacherVector(n)
        approximation += g.T @ oracle(g)
    
    return approximation[0, 0] / l

def HutchinsonATA(oracle1, oracle2, n: int, l: int):
    """
        Input:
            oracle1: oracle for implicit matrix-vector multiplication with A
            oracle2: oracle for implicit matrix-vector multiplication with B
            n: size of the matrix A
            l: number of iteration to approximate the trace of the matrix A.T @ B
        Output:
            approximation: approximation to the trace of the matrix A.T @ B
    """

    assert l >= 0

    approximation = 0
    
    for iter in range(l // 2):
        g = RademacherVector(n)
        approximation += (oracle1(g)).T @ oracle2(g)
    
    return approximation[0, 0] / l

def SquaredFrobenius(oracle, n: int, l: int):
    """
        Input:
            oracle: oracle for implicit matrix-vector multiplication with A
            n: size of the matrix A
            l: number of iteration to approximate the frobenius norm of the matrix A
        Output:
            approximation: approximation to the frobenius norm of A
    """

    assert l >= 0

    approximation = 0
    
    for iter in range(l):
        g = oracle(RademacherVector(n))
        approximation += g.T @ g
    
    return approximation[0, 0] / l

def HutchinsonPlusPlus(oracle, n: int, l: int):
    """
        Input:
            oracle: oracle for implicit matrix-vector multiplication with A
            n: size of the matrix A
            l: number of iteration to approximate the trace of the matrix A
        Output:
            approximation: approximation to the trace of A
    """

    S = RademacherMatrix((n, l // 3))

    Q, _ = np.linalg.qr(oracle(S))

    return np.trace(Q.T @ oracle(Q)) + HutchinsonATA((lambda x: x - Q @ (Q.T @ x)), (lambda x: oracle(x - Q @ (Q.T @ x))), n, l - l // 3)

def HutchinsonPlusPlusWithQ(oracle, n: int, l: int):
    """
        Input:
            oracle: oracle for implicit matrix-vector multiplication with A
            n: size of the matrix A
            l: number of iteration to approximate the trace of the matrix A
        Output:
            approximation: approximation to the trace of A
            Q: matrix appended in the algorithm
    """
    
    S = RademacherMatrix((n, l // 3))

    Q, _ = np.linalg.qr(oracle(S))

    return np.trace(Q.T @ oracle(Q)) + HutchinsonATA((lambda x: x - Q @ (Q.T @ x)), (lambda x: oracle(x - Q @ (Q.T @ x))), n, l - l // 3), Q

Тут напишем для сравнения алгоритмы, не использующие информацию о близости соседних матриц.

In [3]:
def SimpleHutchinson(oracles, n: int, l: int):
    """
        Input:
            oracles: oracles for implicit matrix-vector multiplication with A1, ..., Am
            n: matrix size
            l: number of iteration to approximate the trace of the matrixes A1, ..., Am
        Output:
            approximation: list of approximations to the trace of A1, ..., Am
    """
    assert len(oracles) > 0

    approximation = []
    
    for i in range(len(oracles)):
        approximation.append(Hutchinson(oracles[i], n, l))
    
    return approximation

def SimpleHutchinsonPlusPlus(oracles, n: int, l: int):
    """
        Input:
            oracles: oracles for implicit matrix-vector multiplication with A1, ..., Am
            n: matrix size
            l: number of iteration to approximate the trace of the matrixes A1, ..., Am
        Output:
            approximation: list of approximations to the trace of A1, ..., Am
    """
    assert len(oracles) > 0

    approximation = []
    
    for i in range(len(oracles)):
        approximation.append(HutchinsonPlusPlus(oracles[i], n, l))
    
    return approximation

Напишем все возможные алгоритмы DeltaShift на основе обычного Хатчинсона.

In [4]:
def DeltaShift(oracles, n: int, l0: int, l: int, gamma: float):
    """
        Input:
            oracles: oracles for implicit matrix-vector multiplication with A1, ..., Am
            n: matrix size
            l0: number of iteration to approximate the trace of the matrix A1
            l: number of iteration to approximate the trace of the other matrixes
            gamma: damping factor
        Output:
            approximation: list of approximations to the trace of A1, ..., Am
    """
    assert len(oracles) > 0

    approximation = []
    approximation.append(Hutchinson(oracles[0], n, l0))
    
    for i in range(1, len(oracles)):
        approximation.append((1 - gamma) * approximation[i - 1]
                             +
                             Hutchinson((lambda x: oracles[i](x) - (1 - gamma) * oracles[i - 1](x)), n, l))
    
    return approximation

In [5]:
def ParameterFreeDeltaShift(oracles, n: int, l0: int, l: int):
    """
        Parameter free version of DeltaShift algorithm, gamma estimates inplace
        
        Input:
            oracles: oracles for implicit matrix-vector multiplication with A1, ..., Am
            n: matrix size
            l0: number of iteration to approximate the trace of the matrix A1
            l: number of iteration to approximate the trace of the other matrixes
        Output:
            approximation: list of approximations to the trace of A1, ..., Am
    """
    assert len(oracles) > 0

    approximation = [0]
    N = 0

    for _ in range(l0):
        g = RademacherVector(n)
        z = oracles[0](g)
        approximation[0] += (g.T @ z)[0, 0]
        N += (z.T @ z)[0, 0]

    approximation[0] /= l0
    N /= l0
    variance = 2 / l0 * N

    l //= 2
    
    for i in range(1, len(oracles)):
        z = []
        w = []
        g = []
        for j in range(l):
            cur = RademacherVector(n)
            z.append(oracles[i - 1](cur))
            w.append(oracles[i](cur))
            g.append(cur)

        N = 0
        M = 0
        C = 0
        for j in range(l):
            N += (z[j].T @ z[j])[0, 0] / l
            M += (w[j].T @ w[j])[0, 0] / l
            C += (w[j].T @ z[j])[0, 0] / l

        gamma = 1 - (2 * C) / (l * variance + 2 * N)

        t = (1 - gamma) * approximation[i - 1]
        for j in range(l):
            t += (g[j].T @ (w[j] - (1 - gamma) * z[j]))[0, 0] / l

        approximation.append(t)
        variance = (1 - gamma)**2 * variance + 2 / l * (M + (1 - gamma)**2 * N - 2 * (1 - gamma) * C)
    
    return approximation

In [6]:
def DeltaShiftRestart(oracles, n: int, l0: int, l: int, q: int):
    """
        DeltaShift algorithm, but restart every q iterations
        Input:
            oracles: oracles for implicit matrix-vector multiplication with A1, ..., Am
            n: matrix size
            l0: number of iteration to approximate the trace of the matrix A1
            l: number of iteration to approximate the trace of the other matrixes
            q: number of iterations to restart
        Output:
            approximation: list of approximations to the trace of A1, ..., Am
    """
    assert len(oracles) > 0

    approximation = []
    
    for i in range(len(oracles)):
        if i % q == 0:
            approximation.append(Hutchinson(oracles[i], n, l0))
        else:
            approximation.append(approximation[i - 1] +
                                 Hutchinson((lambda x: oracles[i](x) - oracles[i - 1](x)), n, l))
    
    return approximation

Теперь напишем все возможные алгоритмы DeltaShift++ на основе обычного Хатчинсона++.

In [7]:
def DeltaShiftPlusPlus(oracles, n: int, l0: int, l: int, gamma: float):
    """
        Use Hutch++ instead of simple Hutchinson
        Input:
            oracles: oracles for implicit matrix-vector multiplication with A1, ..., Am
            n: matrix size
            l0: number of iteration to approximate the trace of the matrix A1
            l: number of iteration to approximate the trace of the other matrixes
            gamma: damping factor
        Output:
            approximation: list of approximations to the trace of A1, ..., Am
    """
    assert len(oracles) > 0

    approximation = []
    approximation.append(HutchinsonPlusPlus(oracles[0], n, l0))
    
    for i in range(1, len(oracles)):
        approximation.append(gamma * HutchinsonPlusPlus(oracles[i], n, l // 2)
                             +
                             (1 - gamma) *
                             (approximation[i - 1] + HutchinsonPlusPlus((lambda x: oracles[i](x) - oracles[i - 1](x)), n, l // 2)))
    
    return approximation

In [8]:
def DeltaShiftPlusPlusRestart(oracles, n: int, l0: int, l: int, q: int):
    """
        DeltaShift++ algorithm, but restart every q iterations
        Input:
            oracles: oracles for implicit matrix-vector multiplication with A1, ..., Am
            n: matrix size
            l0: number of iteration to approximate the trace of the matrix A1
            l: number of iteration to approximate the trace of the other matrixes
            q: number of iterations to restart
        Output:
            approximation: list of approximations to the trace of A1, ..., Am
    """
    assert len(oracles) > 0

    approximation = []
    
    for i in range(len(oracles)):
        if i % q == 0:
            approximation.append(HutchinsonPlusPlus(oracles[i], n, l0))
        else:
            approximation.append(approximation[i - 1] + HutchinsonPlusPlus((lambda x: oracles[i](x) - oracles[i - 1](x)), n, l // 2))
    
    return approximation

In [9]:
def ParameterFreeDeltaShiftPlusPlus(oracles, n: int, l0: int, l: int):
    """
        Parameter free version of DeltaShift algorithm, gamma estimates inplace
        
        Input:
            oracles: oracles for implicit matrix-vector multiplication with A1, ..., Am
            n: matrix size
            l0: number of iteration to approximate the trace of the matrix A1
            l: number of iteration to approximate the trace of the other matrixes
        Output:
            approximation: list of approximations to the trace of A1, ..., Am
    """
    assert len(oracles) > 0

    approximation = []
    appr, Q = HutchinsonPlusPlusWithQ(oracles[0], n, l0)
    approximation.append(appr)

    variance = 2 / l0 * SquaredFrobenius((lambda x: oracles[0](x) - Q @ (Q.T @ oracles[0](x))), n, l0)

    
    for i in range(1, len(oracles)):
        appr1, Q1 = HutchinsonPlusPlusWithQ(oracles[i], n, l // 2)
        appr2, Q2 = HutchinsonPlusPlusWithQ((lambda x: oracles[i](x) - oracles[i - 1](x)), n, l // 2)
        K_A= SquaredFrobenius((lambda x: oracles[i](x) - Q1 @ (Q1.T @ oracles[i](x))), n, l // 2)
        K_delta = SquaredFrobenius((lambda x:
                                (oracles[i](x) - oracles[i - 1](x))
                                - Q2 @ (Q2.T @ (oracles[i](x) - oracles[i - 1](x)))
                                ), n, l // 2)

        gamma = (8 * K_delta + l // 2 * variance) / (8 * K_A + l // 2 * variance + 8 * K_delta)

        approximation.append(gamma * appr1
                             +
                             (1 - gamma) *
                             (approximation[i - 1] + appr2))
        variance = gamma**2 * 16 * K_A / l + (1  - gamma)**2 * (variance + 16 * K_delta / l)
    
    return approximation

Перейдем к тестированию алгоритмов. Возьмем граф города Иннополис и посмотрим как он менялся со временем.

In [10]:
def CheckAlgorithms(oracles, n, correct_ans, title = None):
    """
        Generate graphic with relative error
        Input:
            oracles: list of oracle to compute matrix-vector multiplication
            n: matrix size
            correct_ans: list of the trace of the given matrixes
            title: title of the graphic
        Output:
            -
    """

    l0 = 100
    l = 50

    
    x = np.arange(len(oracles))
    correct_ans = np.array(correct_ans)

    simpl_hutchinson = np.array(SimpleHutchinson(oracles, n, l))
    plt.plot(x, abs(correct_ans - simpl_hutchinson) / max(correct_ans), label="Hutchinson")
    simpl_hutchinsonpp = np.array(SimpleHutchinsonPlusPlus(oracles, n, l))
    plt.plot(x,  abs(correct_ans - simpl_hutchinsonpp) / max(correct_ans), label="Hutchinson++")

    print("1")

    delta_shift = np.array(ParameterFreeDeltaShift(oracles, n, l0, l - 50 // 19))
    plt.plot(x, abs(correct_ans - delta_shift) / max(correct_ans), label="DeltaShift")
    delta_shift_r = np.array(DeltaShiftRestart(oracles, n, l0, l - 50 // 19, 20))
    plt.plot(x,  abs(correct_ans - delta_shift_r) / max(correct_ans), label="DeltaShift Restart")

    print("2")

    delta_shiftpp = np.array(ParameterFreeDeltaShiftPlusPlus(oracles, n, l0, l - 50 // 19))
    plt.plot(x, abs(correct_ans - delta_shiftpp) / max(correct_ans), label="DeltaShift++")
    delta_shiftpp_r = np.array(DeltaShiftPlusPlusRestart(oracles, n, l0, l - 50 // 19, 20))
    plt.plot(x,  abs(correct_ans - delta_shiftpp_r) / max(correct_ans), label="DeltaShift++ Restart")

    plt.legend()
    plt.xlabel("Time step(i)")
    plt.ylabel("$\\frac{|tr(A_i^3) - t_i|}{\max_i tr(A_i^3)}$")

    if title != None:
        plt.title(title)

    plt.show()

In [11]:
oracles = []
correct_ans = []
n = 100

mat = np.random.random((2000, 2000))

for i in range(n):
    oracles.append((lambda x: mat @ x))
    correct_ans.append(np.trace(mat))
    mat += 5 * np.exp(-5) * RademacherVector(2000) @ np.random.random((1, 2000))

print(correct_ans)

# CheckAlgorithms(oracles, 2000, correct_ans, "Very small perturbation")

[987.1539785096313, 987.298315448448, 987.5692829348611, 987.5580724142576, 988.1985167713103, 986.8388245256483, 986.5726798713958, 986.2245726845852, 986.6973692335946, 985.6355406721506, 984.6405891204512, 983.5608013400863, 985.40444892322, 984.7850205301846, 984.1285644351038, 983.9632105760058, 984.3381662758758, 985.8186953112939, 985.6866961344067, 983.9527670513942, 984.3605190237101, 983.5996428764108, 983.1447904589936, 984.952076413195, 984.6234101270534, 985.215294386016, 985.6801091670517, 986.2901755640581, 986.4830723684111, 986.554583283631, 986.7757618940561, 987.698387504001, 987.853010106202, 987.6783797527366, 987.1463097468295, 986.2218350864389, 985.7487351944203, 986.2082296038964, 985.4078759266165, 984.1705154012529, 983.3225526717958, 983.9639058094808, 982.8374598304224, 981.982727250434, 983.0646238023911, 983.3864218703688, 983.3790084828588, 983.3489252853067, 983.7385591005634, 984.5281759820216, 983.5626228945044, 983.0483136134052, 982.8726625389191, 9

In [12]:
import torch 
from torchvision import datasets, transforms
from data import * # get the dataset
from pyhessian import hessian # Hessian computation
from pytorchcv.model_provider import get_model as ptcv_get_model # model

In [13]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [14]:
#*
# @file Different utility functions
# Copyright (c) Zhewei Yao, Amir Gholami
# All rights reserved.
# This file is part of PyHessian library.
#
# PyHessian is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# PyHessian is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with PyHessian.  If not, see <http://www.gnu.org/licenses/>.
#*

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable


def getData(name='cifar10', train_bs=128, test_bs=1000):
    """
    Get the dataloader
    """
    if name == 'cifar10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        trainset = datasets.CIFAR10(root='../data',
                                    train=True,
                                    download=True,
                                    transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=True)

        testset = datasets.CIFAR10(root='../data',
                                   train=False,
                                   download=False,
                                   transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False)
    if name == 'cifar10_without_dataaugmentation':
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        trainset = datasets.CIFAR10(root='../data',
                                    train=True,
                                    download=True,
                                    transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=True)

        testset = datasets.CIFAR10(root='../data',
                                   train=False,
                                   download=False,
                                   transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False)

    return train_loader, test_loader


def test(model, test_loader, cuda=True):
    """
    Get the test performance
    """
    model.eval()
    correct = 0
    total_num = 0
    for data, target in test_loader:
        if cuda:
            data, target = data.cuda(), target.cuda()
        output = model(data)
        pred = output.data.max(
            1, keepdim=True)[1]  # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
        total_num += len(data)
    print('testing_correct: ', correct / total_num, '\n')
    return correct / total_num


In [15]:
def hessian_vector_product(gradsH, params, v):
    """
    compute the hessian vector product of Hv, where
    gradsH is the gradient at the current point,
    params is the corresponding variables,
    v is the vector.
    """
    hv = torch.autograd.grad(gradsH,
                             params,
                             grad_outputs=v,
                             only_inputs=True,
                             retain_graph=True)
    return hv
def group_product(xs, ys):
    """
    the inner product of two lists of variables xs,ys
    :param xs:
    :param ys:
    :return:
    """
    return sum([torch.sum(x * y) for (x, y) in zip(xs, ys)])

In [34]:
# get the model 
model = ptcv_get_model("preresnet20_cifar10", pretrained=True)
# change the model to eval mode to disable running stats upate
model.eval()

# create loss function
criterion = torch.nn.CrossEntropyLoss()

# get dataset 
train_loader, test_loader = getData()

# for illustrate, we only use one batch to do the tutorial
for inputs, targets in train_loader:
    hessian_comp = hessian(model, criterion, data=(inputs, targets), cuda=False)
    # a, b = hessian_comp.dataloader_hv_product(np.zeros(100))
    v = [
        torch.randint_like(p, high=2)
        for p in hessian_comp.params
    ]
    # b, _ = hessian_comp.eigenvalues()
    # print(b)
    Hv = hessian_vector_product(hessian_comp.gradsH, hessian_comp.params, v)
    print(group_product(Hv, Hv).cpu().item())


Files already downloaded and verified
546863.0625


KeyboardInterrupt: 