# Implementation of IHVP w.r.t. Freezed Network

Gradient, Hessian을 구할 때 weight의 범위를 한정지어서 inversed HVP를 구함.

앞단 network를 freeze 시키는 데에는 두 가지 이유가 있음.

1. Convexity
    - Network가 깊어지면 깊어질 수록 convexity가 망가질 가능성이 있음.
    - 때문에 CG, NCG, SE 방법론을 사용할 때 발산하는 등 여러 문제가 발생함.
    - 이를 해결하기 위한 극단적인 예시로는 feature weight를 전부 고정하고, 최종 layer만을 사용하여 convexity가 보장된 logistic regression문제로 만듬.
    - 보장된 건 아니지만, 비슷한 맥락으로 네트워크의 weight를 고정시키면 좀 더 안정적으로 IHVP를 얻을 가능성이 있음.
2. Computational Complexity
    - Weight가 많을 수록 계산이 복잡해지고, precision loss가 발생할 가능성이 늘어남.

In [1]:
import cntk as C
from cntk.device import try_set_default_device, gpu
try_set_default_device(gpu(0))

import numpy as np
import time

from torch.utils.data import DataLoader

In [20]:
# Hessian Vector Product

def grad_inner_product(grad1, grad2):
    # inner product for dictionary-format gradients (output scalar value)
    
    val = 0
    
    for ks in grad1.keys():
        val += np.sum(np.multiply(grad1[ks],grad2[ks]))
        
    return val

def weight_update(w, v, r):
    # w: weights of neural network (tuple)
    # v: value for delta w (dictionary, e.g., gradient value)
    # r: hyperparameter for a gradient (scalar)

    for p in w:
        p.value += r * v[p]

def HVP(y, x, v):
    # Calculate Hessian vector product 
    # y: scalar function to be differentiated (function, e.g. cross entropy loss)
    # x: feed_dict value for the network (dictionary, e.g. {model.X: image_batch, model.y: label_batch})
    # v: vector to be producted (by Hessian) (numeric dictionary, e.g., g(z_test))
    ## w: variables to differentiate (numeric, e.g. neural network weight)
    
    # hyperparameter r
    r = 1e-2
    
    assert type(x)==dict, "Input of HVP is wrong. this should be dictionary"
     
    #w = y.parameters
    w = v.keys()
    
    # gradient for plus
    weight_update(w, v, +r)
    g_plus = y.grad(x, wrt=w)
  
    # gradient for minus
    weight_update(w, v, -2*r)
    g_minus = y.grad(x, wrt=w)
    
    # weight reconstruction
    weight_update(w, v, +r)
    
    hvp = {ks: (g_plus[ks] - g_minus[ks])/(2*r) for ks in g_plus.keys()}
       
    return hvp

# Conjugate Gradient

from scipy.optimize import fmin_cg

def dic2vec(dic):
    # convert a dictionary with matrix values to a 1D vector
    # e.g. gradient of network -> 1D vector
    vec = np.concatenate([val.reshape(-1) for val in dic.values()])
    
    return vec

def vec2dic(vec, fmt):
    # convert a 1D vector to a dictionary of format fmt
    # fmt = {key: val.shape for (key,val) in dict}
    fmt_idx = [np.prod(val) for val in fmt.values()]
    #lambda ls, idx: [ls[sum(idx[:i]):sum(idx[:i+1])] for i in range(len(idx))]
    vec_split = [vec[sum(fmt_idx[:i]):sum(fmt_idx[:i+1])] for i in range(len(fmt_idx))]
    dic = {key: vec_split[i].reshape(shape) for (i,(key,shape)) in enumerate(fmt.items())}

    return dic

def get_inverse_hvp_cg(model, y, v, data_set, **kwargs):
    # return x, which is the solution of QP, whose value is H^-1 v
    # kwargs: hyperparameters for conjugate gradient
    batch_size = kwargs.pop('batch_size', 50)
    damping = kwargs.pop('damping', 0.0)
    maxiter = kwargs.pop('maxiter', 5e1)
    
    dataloader = DataLoader(data_set, batch_size, shuffle=True, num_workers=6)

    def HVP_minibatch_val(y, v):
        # Calculate Hessian vector product w.r.t whole dataset
        # y: scalar function output of the neural network (e.g. model.loss)
        # v: vector to be producted by inverse hessian (i.e.H^-1 v) (numeric dictionary, e.g. v_test)
        
        ## model: neural network model (e.g. model)
        ## dataloader: training set dataloader
        ## damping: damp term to make hessian convex

        hvp_batch = {ks: [] for ks in v.keys()}

        for img, lb in dataloader:
            img = img.numpy(); lb = lb.numpy()
            x_feed = {model.X: img, model.y:lb}
            hvp = HVP(y,x_feed,v)
            # add hvp value
            [hvp_batch[ks].append(hvp[ks]/img.shape[0]) for ks in hvp.keys()]

        hvp_mean = {ks: np.mean(hvp_batch[ks], axis=0) + damping*v[ks] for ks in hvp_batch.keys()}

        return hvp_mean

    def get_fmin_loss(x):
        x_dic = vec2dic(x, {key: val.shape for (key, val) in v.items()})
        hvp_val = HVP_minibatch_val(y, x_dic)

        return 0.5 * grad_inner_product(hvp_val, x_dic) - grad_inner_product(v, x_dic)

    def get_fmin_grad(x):
        # x: 1D vector
        x_dic = vec2dic(x, {key: val.shape for (key, val) in v.items()})
        hvp_val = HVP_minibatch_val(y, x_dic)
        hvp_flat = dic2vec(hvp_val)
        v_flat = dic2vec(v)

        return hvp_flat - v_flat

    fmin_loss_fn = get_fmin_loss
    fmin_grad_fn = get_fmin_grad
    
    fmin_results = fmin_cg(f=get_fmin_loss, x0=dic2vec(v), fprime=fmin_grad_fn, maxiter=maxiter)
    
    return vec2dic(fmin_results, {key: val.shape for (key, val) in v.items()})

In [None]:
def IF_val(net, ihvp, data_set):
    # Calculate influence function w.r.t ihvp and data_set
    # This should be done in sample-wise, since the gradient operation will sum up over whole feed-dicted data
    
    # ihvp: inverse hessian vector product (dictionary)
    # data_set: data_set to be feed to the gradient operation (dataset)
    IF_list = []
    
    #params = net.logits.parameters
    params = ihvp.keys()
    
    dataloader = DataLoader(data_set, 1, shuffle=False, num_workers=6)
    
    for img, lb in dataloader:
        img = img.numpy(); lb = lb.numpy()
        gd = net.loss.grad({net.X:img, net.y:lb}, wrt=params)
        IF = grad_inner_product(ihvp, gd) / len(dataloader)
        IF_list.append(IF)
        
    return IF_list

In [24]:
# toy example for inverse HVP (CG and SE)

class SimpleNet(object):
    def __init__(self):
        self.X = C.input_variable(shape=(1,))
        self.h = C.layers.Dense(1, activation=None, init=C.uniform(1), bias=False)(self.X)
        self.h2 = C.layers.Dense(1, activation=None, init=C.uniform(1), bias=False)(self.h)
        self.pred = C.layers.Dense(1, activation=None, init=C.uniform(1), bias=False)(self.h2)
        self.y = C.input_variable(shape=(1,))
        self.loss = C.squared_error(self.pred, self.y)
        
class SimpleDataset(object):
    def __init__(self, images, labels):
        self._images, self._labels = images, labels
    
    def __getitem__(self, index):
        X = self._images[index]
        y = self._labels[index]
        
        return X, y
    
    def __len__(self):
        return len(self._images)


net = SimpleNet()

params = net.pred.parameters

x_feed = {net.X:np.array([[2.]],dtype=np.float32), net.y:np.array([[1.]],dtype=np.float32)}

print(params)
[print(pr, pr.value) for pr in params]
print('loss = \n', net.loss.eval(x_feed))
# params[0].value = np.asarray([[1.]])
# params[1].value = np.asarray([[1./3.]])
# print('w1 = \n', params[0].value, '\nw2 = \n', params[1].value, '\nloss = \n', net.loss.eval(x_feed))

images = np.asarray([[2.]], dtype=np.float32)
labels = np.asarray([[1.]], dtype=np.float32)

train_set = SimpleDataset(images,labels)

p1 = net.h.parameters # w1
p2 = net.h2.parameters # w1, w2
print(p2)
p3 = tuple(set(params) - set(p1)) # w2, w3

v_p = {p: np.ones_like(p.value) for p in params}
print('hvp_p(w1, w2, w3)\n', HVP(net.loss, x_feed, v_p))
# v_p1 = {p: np.ones_like(p.value) for p in p1}
# print('hvp_p1(w1)\n', HVP(net.loss, x_feed, v_p1))
# v_p2 = {p: np.ones_like(p.value) for p in p2}
# print('hvp_p2(w1, w2)\n', HVP(net.loss, x_feed, v_p2))
v_p3 = {p: np.ones_like(p.value) for p in p3}
print('hvp_p3(w2, w3)\n', HVP(net.pred, x_feed, v_p3))


(Parameter('W', [], [1 x 1]), Parameter('W', [], [1 x 1]), Parameter('W', [], [1 x 1]))
Parameter('W', [], [1 x 1]) [[-0.52869767]]
Parameter('W', [], [1 x 1]) [[-0.55389237]]
Parameter('W', [], [1 x 1]) [[ 0.2377166]]
loss = 
 [ 0.7409308]
(Parameter('W', [], [1 x 1]), Parameter('W', [], [1 x 1]))
hvp_p(w1, w2, w3)
 {Parameter('W', [], [1 x 1]): array([[ 1.05137527]], dtype=float32), Parameter('W', [], [1 x 1]): array([[ 0.96631497]], dtype=float32), Parameter('W', [], [1 x 1]): array([[ 3.81159782]], dtype=float32)}
hvp_p3(w2, w3)
 {Parameter('W', [], [1 x 1]): array([[ 0.47543347]], dtype=float32), Parameter('W', [], [1 x 1]): array([[ 0.47543347]], dtype=float32)}


In [9]:
p1 = net.h.parameters
p2 = tuple(set(params)-set(p1))

print(p1, p2)

v_p = net.loss.grad(x_feed, wrt=params)
v_p1 = net.loss.grad(x_feed, wrt=p1)
v_p2 = net.loss.grad(x_feed, wrt=p2)
print(v_p, v_p1, v_p2)

(Parameter('W', [], [1 x 2]),) (Parameter('W', [], [2 x 1]),)
{Parameter('W', [], [2 x 1]): array([[ 3.90574503],
       [-5.65007734]], dtype=float32), Parameter('W', [], [1 x 2]): array([[ 4.20328236,  4.17423201]], dtype=float32)} [[ 4.20328236  4.17423201]] [[ 3.90574503]
 [-5.65007734]]
