# Validate Implicit Fisher Computation

The natural gradient algorithm used in the paper relies on implicitly inverting the Fisher and preconditioning gradients with it. To validate that these computations are correct, the results are compared to explictly computing natural gradients in small networks.

The first cell below let's you decide which algorithm's correctness you want to verify. 
- For 'optimizer' you can choose 'natural' or 'natural_bd' (the latter being the block diagonal version).
- For 'mc_fisher' you can choose 0 or 1. If 1, the fisher is approximated by sampling one label per input. If 0, the fisher is computed fully.
- For network type, you can either choose 'fc' for a standard fully connected net, or 'conv' for a mix of convolutional and fully connected layers.

In [1]:
optimizer = 'natural' # 'natural' or 'natural_bd'
mc_fisher = 1 # 1 or 0
network_type = 'fc' # 'fc' or 'conv'

In [2]:
import torch
import architectures
import optimization_modules as om

seed = 0
torch.manual_seed(seed)

device = 'cpu'
print('Torch version', torch.__version__)

print('\n\nUsing optimizer:', optimizer)
print('Using MC Fisher:', mc_fisher, '\n\n')

def brute_fisher_MC():
    """
    This function explicitly computes the Fisher Information. 
    It uses global variables, notably a mini-batch 'X' and sampeld labels
    'y_sampled' as well as a network 'net'.
    
    As in paper, the Fisher can be written as GGˆT. This function's output is
    the triple GGˆT, GˆTG, G
    """
    n_param = [0]
    for par in net.parameters():
        n_param.append(n_param[-1]+par.numel())
    n_grads = 1 * n_data
    
    Fisher = torch.zeros(size=(n_param[-1],n_param[-1]))
    G = torch.zeros(size=(n_param[-1],n_grads)).to(device)
    curr_grad_id = 0
    criterion = torch.nn.CrossEntropyLoss()
    for i in range(X.size()[0]):
        output = net(X[[i],:])
        loss = criterion(output, y_sampled[i])
        loss.backward()
        G[:, curr_grad_id] = stacked_grad()
        curr_grad_id += 1
        net.zero_grad()
    G /= torch.sqrt(torch.tensor(n_data))
    Fisher = torch.mm(G, G.t()) 
    Gram = torch.mm(G.t(), G) 
    return Fisher, Gram, G

def brute_fisher_full():
    """
    Analogous to brute_fisher_mc, but computes full fisher 
    rather than an MC sample.
    """
    n_param = [0]
    for par in net.parameters():
        n_param.append(n_param[-1]+par.numel())
    n_grads = output_dim * n_data
    
    Fisher = torch.zeros(size=(n_param[-1],n_param[-1]))
    G = torch.zeros(size=(n_param[-1],n_grads)).to(device)
    curr_grad_id = 0
    criterion = torch.nn.CrossEntropyLoss()
    for i in range(X.size()[0]):
        for label_id in range(len(label_list)):
            output = net(X[[i],:])
            label = label_list[[label_id]]
            probs = torch.nn.functional.softmax(output,dim=1).detach()
            loss = criterion(output, label) * torch.sqrt(probs[0][label_id])
            loss.backward()
            G[:, curr_grad_id] = stacked_grad()
            curr_grad_id += 1
            net.zero_grad()
    G /= torch.sqrt(torch.tensor(n_data))
    Fisher = torch.mm(G, G.t()) 
    Gram = torch.mm(G.t(), G) 
    return Fisher, Gram, G

def block_diagonal(Fisher):
    """
    Thus function replaces the entries of the input matrix, which 
    are not on the block-diagonal by 0.
    """
    n_param = [0]
    for par in net.parameters():
        #only non-bias parameters
        if len(par.size()) > 0:
            n_param.append(n_param[-1]+par.numel())
    mask = torch.zeros(size=(n_param[-1],n_param[-1]))
    for n0, n1 in zip(n_param[:-1], n_param[1:]):
        mask[n0:n1, :][:, n0:n1] = 1.
    return Fisher*mask

def stacked_grad():
    """
    Returns the gradient of the network as a 1-dimensional vector.
    """
    stacked_grad = torch.tensor([])#torch.zeros(size=(n_param[-1],1))
    for layer in net.my_modules():
        stacked_grad = torch.cat((stacked_grad, layer.weight.grad.reshape(layer.weight.numel())))
    return stacked_grad

def stacked_update():
    """
    Returns the update_direction of the network as a 1-dimensional vector.
    [This relies on using the FOOFSequential class defined in optimization_modules]
    """
    stacked_grad = torch.tensor([])#torch.zeros(size=(n_param[-1],1))
    for layer in net.my_modules():
        #stacked_grad = torch.cat((stacked_grad, layer.up_dir_weight.reshape(layer.weight.numel())))
        stacked_grad = torch.cat((stacked_grad, layer.update_direction.reshape(layer.weight.numel())))
    return stacked_grad

# Specify (synthetic) data
input_dim = 4
output_dim = 3
n_data = 7
n_data = int(n_data)
if network_type == 'fc':
    X = torch.normal(0,1,size=(n_data, input_dim)).to(device)
if network_type == 'conv':
    X = torch.normal(0,1,size=(n_data, 1, input_dim, input_dim)).to(device)
y_target = torch.randint(low=0, high=output_dim, size=(n_data,1)).long()
label_list = torch.tensor(range(output_dim)).long()

# Specify damping term for Fisher
lam = 0.3

# Specify the network
if network_type == 'fc': 
    width = 3
    depth = 2
    net = architectures.SimpleFCNet(width=width,
                         depth=depth,
                         input_dim=input_dim,
                         output_dim=output_dim,
                         lr=0.0,
                         damp=lam, 
                         optimizer=None,  #optimizer will be set later
                         mc_fisher=mc_fisher)
    
if network_type == 'conv':
    # Create a simple conv net from scratch
    module_list = []
    module_list += [om.FOOFConv2d(in_channels=1,out_channels=2,
                                  kernel_size=3,padding=1,bias=False)]
    module_list += [torch.nn.ReLU()]
    module_list += [om.FOOFConv2d(in_channels=2,out_channels=2,
                                  kernel_size=3,padding=1,bias=False)]
    module_list += [torch.nn.ReLU()]
    module_list += [torch.nn.AvgPool2d(kernel_size=2)]
    module_list += [torch.nn.Flatten()]
    module_list += [om.FOOFLinear(n_in_features=2*input_dim**2//4, 
                                  n_out_features=output_dim, bias=False)]
    net = om.FOOFSequential(module_list, 
                            lr=0.0,
                            damp=lam, 
                            optimizer=None,  #optimizer will be set later
                            mc_fisher=mc_fisher,
                            output_dim=output_dim)

print('Network:', net)
net.to(device)
criterion = torch.nn.CrossEntropyLoss()

# Sample a set of labels. Note that for the purpose of this validation, 
# the labels do not need to be sampled from the model distribution
# This variable will only be used if mc_fisher==0.
y_sampled = torch.randint(low=0, high=output_dim, size=(n_data,1)).long()




# Compute Fisher Explicitly
if mc_fisher:
    Fisher, Gram, G = brute_fisher_MC()
else:
    Fisher, Gram, G = brute_fisher_full()
# And make block diagonal, if needed. 
if optimizer == 'natural_bd':
    Fisher = block_diagonal(Fisher)

# Comptue Gradients Explciitly
output = net(X)
loss = criterion(output, y_target[:,0])
loss.backward()
grad = stacked_grad()
net.zero_grad()

# Compute the natural gradient update explicitly
natural_update_brute_1 = torch.mv(torch.inverse(lam*torch.eye(Fisher.shape[0]) + Fisher), 
                                  grad)

# Compute the natural gradient update explicitly, in a differnt way (relying on the matrix inversion lemma)
A = 1/lam * torch.eye(Fisher.shape[0]) \
    - (1/lam**2)* G @ torch.inverse(torch.eye(Gram.shape[0])+1/lam*Gram) @ G.t()
natural_update_brute_2 = torch.mv(A, grad)

print('\n\nExplcitly computed natural gradient, first version:\n', 
      natural_update_brute_1)
if optimizer == 'natural':
    print('\nExplicitly computed natural gradient, second version\n\
(should be the same as first version):\n', 
     natural_update_brute_2)

# Now compute the natural update implicitly, using method described in paper.
net.set_optimizer(optimizer)
# First carry out implicit Fisher computations. 
# Note that we need to use the same sampled labels as before. 
# Note that y_sampled is ignored, if mc_fisher==0.
net.update_fisher(X, y=y_sampled[:,0])
# Now compute the parameter update
net.parameter_update(X, y_target[:,0])
update_implicit = stacked_update()

print('\nImplicity computed natural gradient. If this agrees with the explicit\ncomputation(s)\
 above, the algorithm is implemented correctly.\n', 
 update_implicit)

print('\nSquared norm of difference between implictly and explicitly computed updates\n', 
      (update_implicit-natural_update_brute_1).pow(2).sum().item())
print('Squared norms of implictly and explicitly computed update\n', 
      (update_implicit).pow(2).sum().item(), (natural_update_brute_1).pow(2).sum().item())


Torch version 1.8.1


Using optimizer: natural
Using MC Fisher: 1 


Network: FOOFSequential(
  (0): FOOFLinear(in_features=4, out_features=3, bias=False)
  (1): ReLU()
  (2): FOOFLinear(in_features=3, out_features=3, bias=False)
  (3): ReLU()
  (4): FOOFLinear(in_features=3, out_features=3, bias=False)
)


Explcitly computed natural gradient, first version:
 tensor([-0.2828, -0.0660, -0.4187,  0.0682,  0.0874, -0.3425,  0.0949,  0.0835,
        -0.4336, -0.3140, -0.2616,  0.1682,  0.0000,  0.0388,  0.0000,  0.6727,
         0.0000,  0.0916,  0.0521,  0.1551,  0.0755,  0.0497,  0.3377,  0.0411,
        -0.1117, -0.0448, -0.0900,  0.0619, -0.2930,  0.0489])

Explicitly computed natural gradient, second version
(should be the same as first version):
 tensor([-0.2828, -0.0660, -0.4187,  0.0682,  0.0874, -0.3425,  0.0949,  0.0835,
        -0.4336, -0.3140, -0.2616,  0.1682,  0.0000,  0.0388,  0.0000,  0.6727,
         0.0000,  0.0916,  0.0521,  0.1551,  0.0755,  0.0497,  0.3377,  0.0411,
 