In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import time
from IPython.display import display, Latex
from torch.autograd.functional import jacobian

from torch.distributions.multivariate_normal import MultivariateNormal
from datetime import datetime

In [2]:
torch.set_default_dtype(torch.float64)

In [3]:
# Create model of linear NN with L hidden layers
# input dim = d, hidden dim = m, output dim = k
class Linear_NN(nn.Module):
    def __init__(self,d,m,k,L):
        """
            d: input dimension
            m: hidden layer dimension 
            k: output dimension
            L: number of hidden layers
        """
        super().__init__()
        
        self.L = L
        self.lin_out = nn.Linear(m, k, bias=False)
        self.lin_in = nn.Linear(d, m, bias=False)
        
        self.lin_hidden = nn.ModuleList([nn.Linear(m, m, bias=False) for i in range(self.L)])
        
        
    def forward(self, xb):
        xb = self.lin_in(xb)
        
        for i in range(self.L):
            xb = self.lin_hidden[i](xb)
            
        xb = self.lin_out(xb)
        
        return xb
    
    def init_weights(self, init_type):
        if init_type == 'kaiming_normal':
            torch.nn.init.kaiming_normal_(self.lin_in.weight, nonlinearity='linear')
            torch.nn.init.kaiming_normal_(self.lin_out.weight, nonlinearity='linear')
            for i in range(self.L):
                torch.nn.init.kaiming_normal_(self.lin_hidden[i].weight, nonlinearity='linear')
        elif init_type == 'kaiming_uniform':
            torch.nn.init.kaiming_uniform_(self.lin_in.weight, nonlinearity='linear')
            torch.nn.init.kaiming_uniform_(self.lin_out.weight, nonlinearity='linear')
            for i in range(self.L):
                torch.nn.init.kaiming_uniform_(self.lin_hidden[i].weight, nonlinearity='linear')
        elif init_type == 'xavier_normal':
            torch.nn.init.xavier_normal_(self.lin_in.weight)
            torch.nn.init.xavier_normal_(self.lin_out.weight)
            for i in range(self.L):
                torch.nn.init.xavier_normal_(self.lin_hidden[i].weight)
        elif init_type == 'xavier_uniform':
            torch.nn.init.xavier_uniform_(self.lin_in.weight, nonlinearity='linear')
            torch.nn.init.xavier_uniform_(self.lin_out.weight, nonlinearity='linear')
            for i in range(self.L):
                torch.nn.init.xavier_uniform_(self.lin_hidden[i].weight, nonlinearity='linear')
        else:
            print('Unknown initialization. Using Kaiming normal initialization')
            torch.nn.init.kaiming_normal_(self.lin1.weight, nonlinearity='linear')
            torch.nn.init.kaiming_normal_(self.lin2.weight, nonlinearity='linear')
            for i in range(self.L):
                torch.nn.init.kaiming_normal_(self.lin_hidden[i].weight, nonlinearity='linear')

In [4]:
# Create model of linear NN with L hidden layers
# input dim = d, hidden dim = m, output dim = k
class ReLU_NN(nn.Module):
    def __init__(self,d,m,k,L):
        """
            d: input dimension
            m: hidden layer dimension 
            k: output dimension
            L: number of hidden layers
        """
        super().__init__()
        
        self.L = L
        self.relu = nn.ReLU()
        self.lin_out = nn.Linear(m, k, bias=False)
        self.lin_in = nn.Linear(d, m, bias=False)
        
        self.lin_hidden = nn.ModuleList([nn.Linear(m, m, bias=False) for i in range(self.L)])
        
        self.sequential = nn.Sequential(self.lin_in)
        
        for i in range(self.L):
            self.sequential.append(self.relu)
            self.sequential.append(self.lin_hidden[i])
        
        self.sequential.append(self.relu)
        self.sequential.append(self.lin_out)                
        
        
    def forward(self, xb):
#         xb = self.lin_in(xb)
                
#         for i in range(self.L):
#             xb = self.relu(xb)
            
#             xb = self.lin_hidden[i](xb)
        
#         xb = self.relu(xb)
                        
#         xb = self.lin_out(xb)
        xb = self.sequential(xb)
        
        return xb
    
    def init_weights(self, init_type):
        if init_type == 'kaiming_normal':
            torch.nn.init.kaiming_normal_(self.lin_in.weight, nonlinearity='linear')
            torch.nn.init.kaiming_normal_(self.lin_out.weight, nonlinearity='linear')
            for i in range(self.L):
                torch.nn.init.kaiming_normal_(self.lin_hidden[i].weight, nonlinearity='linear')
        elif init_type == 'kaiming_uniform':
            torch.nn.init.kaiming_uniform_(self.lin_in.weight, nonlinearity='linear')
            torch.nn.init.kaiming_uniform_(self.lin_out.weight, nonlinearity='linear')
            for i in range(self.L):
                torch.nn.init.kaiming_uniform_(self.lin_hidden[i].weight, nonlinearity='linear')
        elif init_type == 'xavier_normal':
            torch.nn.init.xavier_normal_(self.lin_in.weight)
            torch.nn.init.xavier_normal_(self.lin_out.weight)
            for i in range(self.L):
                torch.nn.init.xavier_normal_(self.lin_hidden[i].weight)
        elif init_type == 'xavier_uniform':
            torch.nn.init.xavier_uniform_(self.lin_in.weight, nonlinearity='linear')
            torch.nn.init.xavier_uniform_(self.lin_out.weight, nonlinearity='linear')
            for i in range(self.L):
                torch.nn.init.xavier_uniform_(self.lin_hidden[i].weight, nonlinearity='linear')
        else:
            print('Unknown initialization. Using Kaiming normal initialization')
            torch.nn.init.kaiming_normal_(self.lin1.weight, nonlinearity='linear')
            torch.nn.init.kaiming_normal_(self.lin2.weight, nonlinearity='linear')
            for i in range(self.L):
                torch.nn.init.kaiming_normal_(self.lin_hidden[i].weight, nonlinearity='linear')

In [120]:
# Author: ludwigwinkler
# Source: https://discuss.pytorch.org/t/get-gradient-and-jacobian-wrt-the-parameters/98240/6

import future, sys, os, datetime, argparse
from typing import List, Tuple
import numpy as np
import matplotlib

matplotlib.rcParams["figure.figsize"] = [10, 10]

import torch, torch.nn
from torch import nn
from torch.nn import Sequential, Module, Parameter
from torch.nn import Linear, Tanh, ReLU
import torch.nn.functional as F

Tensor = torch.Tensor
FloatTensor = torch.FloatTensor

torch.set_printoptions(precision=4, sci_mode=False)
np.set_printoptions(precision=4, suppress=True)

sys.path.append("../../..")  # Up to -> KFAC -> Optimization -> PHD

import copy

cwd = os.path.abspath(os.getcwd())
os.chdir(cwd)

# from Optimization.BayesianGradients.src.DeterministicLayers import GradBatch_Linear as Linear


def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
    """
    Deletes the attribute specified by the given list of names.
    For example, to delete the attribute obj.conv.weight,
    use _del_nested_attr(obj, ['conv', 'weight'])
    """
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        _del_nested_attr(getattr(obj, names[0]), names[1:])

def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
    """
    This function removes all the Parameters from the model and
    return them as a tuple as well as their original attribute names.
    The weights must be re-loaded with `load_weights` before the model
    can be used again.
    Note that this function modifies the model in place and after this
    call, mod.parameters() will be empty.
    """
    orig_params = tuple(mod.parameters())
    # Remove all the parameters in the model
    names = []
    for name, p in list(mod.named_parameters()):
        _del_nested_attr(mod, name.split("."))
        names.append(name)

    '''
        Make params regular Tensors instead of nn.Parameter
    '''
    params = tuple(p.detach().requires_grad_() for p in orig_params)
    return params, names

def _set_nested_attr(obj: Module, names: List[str], value: Tensor) -> None:
    """
    Set the attribute specified by the given list of names to value.
    For example, to set the attribute obj.conv.weight,
    use _del_nested_attr(obj, ['conv', 'weight'], value)
    """
    if len(names) == 1:
        setattr(obj, names[0], value)
    else:
        _set_nested_attr(getattr(obj, names[0]), names[1:], value)

def load_weights(mod: Module, names: List[str], params: Tuple[Tensor, ...]) -> None:
    """
    Reload a set of weights so that `mod` can be used again to perform a forward pass.
    Note that the `params` are regular Tensors (that can have history) and so are left
    as Tensors. This means that mod.parameters() will still be empty after this call.
    """
    for name, p in zip(names, params):
        _set_nested_attr(mod, name.split("."), p)

def compute_jacobian(model, x):
    '''

    @param model: model with vector output (not scalar output!) the parameters of which we want to compute the Jacobian for
    @param x: input since any gradients requires some input
    @return: either store jac directly in parameters or store them differently

    we'll be working on a copy of the model because we don't want to interfere with the optimizers and other functionality
    '''

    jac_model = copy.deepcopy(model) # because we're messing around with parameters (deleting, reinstating etc)
    all_params, all_names = extract_weights(jac_model) # "deparameterize weights"
    load_weights(jac_model, all_names, all_params) # reinstate all weights as plain tensors

    def param_as_input_func(model, x, param):
        load_weights(model, [name], [param]) # name is from the outer scope
        out = model(x)
        return out

    jacobian=np.zeros(1)    
    for i, (name, param) in enumerate(zip(all_names, all_params)):
        jac = torch.autograd.functional.jacobian(lambda param: param_as_input_func(jac_model, x, param), param,
                             strict=True if i==0 else False, vectorize=False if i==0 else True)
        print(jac.shape)
        j = torch.reshape(jac,(n,k,jac.shape[-1]*jac.shape[-2]))
        print(j.shape)
        if i==0:
            jacobian = j
        else:
            jacobian = torch.cat([jacobian,j],dim=2)
#     print(jacobian.shape)        

    del jac_model # cleaning up
    
    return jacobian

In [137]:
d = 2 # input dimension
m1 = [5] # hidden layer dimension
k = 10 # output dimension
L = [0] # number of hidden layers of dim "m"
    
m_L_config = [] # keep track of the network configuration
num_param = [] # count the number of parameters in the model
Linear_Networks = [] # list of NN with different configurations
ReLU_Networks = [] # list of ReLU NN with different configurations


# initiate linear networks of given depth L[l] with m1 hidden units each
for m in m1:
    for l in L:
        m_L_config.append((m,l))
        Linear_Networks.append(Linear_NN(d,m,k,l))
        ReLU_Networks.append(ReLU_NN(d,m,k,l))
        num_param.append(sum(p.numel() for p in Linear_NN(d,m,k,l).parameters()))
        
print('num parameters: ', num_param)

num parameters:  [60]


In [138]:
n = 100
torch.manual_seed(314159)
x = torch.randn(n,d).requires_grad_()
cov_xx = x.detach().T @ x.detach() / n

In [139]:
x.shape

torch.Size([100, 2])

In [140]:
torch.linalg.eigvalsh(cov_xx)

tensor([0.8384, 1.1571])

In [141]:
Linear_Networks[0].init_weights('kaiming_normal')

jacob = compute_jacobian(Linear_Networks[0],x)
print('Jacobian shape', jacob.shape)

V_kaiming = Linear_Networks[0].lin_in.weight.detach()
W_kaiming = Linear_Networks[0].lin_out.weight.detach()

# calculate the outer Hessian product according to the expression derived by Sidak
H_o_tilde_lin = torch.kron( W_kaiming @ W_kaiming.T, cov_xx ) + \
                torch.kron( torch.eye(k), V_kaiming.T @ V_kaiming @ cov_xx )

torch.Size([100, 10, 10, 5])
torch.Size([100, 10, 50])
torch.Size([100, 10, 5, 2])
torch.Size([100, 10, 10])
Jacobian shape torch.Size([100, 10, 60])


In [142]:
# W_kaiming @ V_kaiming @ x.T.detach()

In [143]:
# Linear_Networks[0].forward(x.detach()).T

In [144]:
H_o_tilde_lin.shape

torch.Size([20, 20])

In [145]:
torch.linalg.matrix_rank(H_o_tilde_lin)

tensor(20)

In [146]:
(torch.linalg.eigvalsh(H_o_tilde_lin))

tensor([1.3715, 1.3715, 1.3715, 1.3715, 1.3715, 1.6049, 1.8692, 2.1158, 3.5034,
        4.6688, 5.2732, 5.2732, 5.2732, 5.2732, 5.2732, 5.5131, 5.7864, 6.0432,
        7.5188, 8.7981])

In [147]:
arr = [jacob[i,:,:] @ jacob[i,:,:].T for i in range(n)]

jac_jac_T = sum(arr)/n
torch.linalg.matrix_rank(jac_jac_T)

tensor(10)

In [148]:
jac_jac_T.shape

torch.Size([10, 10])

In [149]:
torch.linalg.eigvalsh(jac_jac_T)

tensor([ 6.6447,  6.6447,  6.6447,  6.6447,  6.6447,  7.1180,  7.6556,  8.1590,
        11.0222, 13.4668])