In [132]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset

In [133]:
def compute_true_info(hessian, hessian_unp, num_removed):
    information_true = [torch.empty_like(p,dtype=torch.float32) for p in hessian]
    A_list = [torch.empty_like(p) for p in hessian]
    B_list = [torch.empty_like(p) for p in hessian]
    for i,_ in enumerate(A_list):
        ratio = torch.where(torch.logical_and(hessian[i]==0, hessian_unp[i]==0), 1, hessian_unp[i]/hessian[i])
        #A=1/2 * torch.log(ratio)/num_removed
        #B=1/2 * (1-torch.pow(ratio,2))/num_removed
        A=1/2 * torch.log(ratio)/num_removed
        B=1/2 * (1-ratio)/num_removed
        A[torch.isinf(A)]=0
        B[torch.isinf(B)]=0
        A_list[i]=torch.nan_to_num(A)
        B_list[i]=torch.nan_to_num(B)


    A_flat = torch.cat([A.flatten() for A in A_list])
    B_flat = torch.cat([B.flatten() for B in B_list])

    torch.cuda.empty_cache()
    device = torch.device("cpu")
    A_flat = A_flat.to(device)
    B_flat = B_flat.to(device)
    information_true = [torch.empty_like(p,dtype=torch.float32).to(device) for p in hessian]
    hessian = [p.to(device) for p in hessian]
    hessian_unp = [p.to(device) for p in hessian_unp]

    information_true_view = [p.view(-1) for p in information_true]
    A_list_view = [p.view(-1) for p in A_list]
    B_list_view = [p.view(-1) for p in B_list]

    for i,_ in enumerate(information_true_view):
        for j,_ in enumerate(information_true_view[i]):

            temp=A_list_view[i][j]*A_flat + B_list_view[i][j]*B_flat + A_list_view[i][j]*B_flat + B_list_view[i][j]*A_flat
            information_true_view[i][j] = torch.sum(temp)+2*torch.pow(B_list_view[i][j],2)

    return information_true


In [134]:
def compute_true_info_new(hessian, hessian_unp, num_removed):
    A_list = [torch.empty_like(p) for p in hessian]
    B_list = [torch.empty_like(p) for p in hessian]
    for i,_ in enumerate(hessian): 
        ratio = hessian_unp[i]/hessian[i]
        A=1/2 * torch.log(ratio)/num_removed
        B=1/2 * (1-ratio)/num_removed

        A[torch.isinf(A)]=0
        B[torch.isinf(B)]=0
        A[torch.isnan(A)]=0
        B[torch.isnan(B)]=0
        A_list[i]=A
        B_list[i]=B

    C = sum([torch.sum(A_list[i]+B_list[i]).item() for i in range(len(A_list))])

    information_true=[(A_list[i]+B_list[i])*C + 2*torch.pow(B_list[i],2) for i in range(len(hessian))]

    return information_true

In [135]:
tesnor1=torch.tensor([2,1,3,4,5,6,7,8,9,10])
tesnor2=torch.tensor([[11,12,13,14,15,16],[17,18,19,20,21,22]])
tesnor3=torch.tensor([6,5,7,8,9,10,11,12,13,14])
tesnor4=torch.tensor([[15,16,17,18,19,20],[21,22,23,24,25,26]])
hessian=[tesnor1,tesnor2]
hessian_half=[tesnor3,tesnor4]
information_true=compute_true_info(hessian, hessian_half,1)
information_true_new=compute_true_info_new(hessian, hessian_half,1)
for i in range(len(information_true)):
    print(information_true[i])
    print(information_true_new[i])
    print(torch.isclose(information_true[i],information_true_new[i]))


tensor([ 3.1607, 11.0783,  1.5148,  0.8951,  0.5933,  0.4229,  0.3171,  0.2467,
         0.1976,  0.1618])
tensor([ 3.1607, 11.0783,  1.5148,  0.8951,  0.5933,  0.4229,  0.3171,  0.2467,
         0.1976,  0.1618])
tensor([True, True, True, True, True, True, True, True, True, True])
tensor([[0.1350, 0.1143, 0.0981, 0.0851, 0.0745, 0.0658],
        [0.0586, 0.0524, 0.0472, 0.0428, 0.0389, 0.0355]])
tensor([[0.1350, 0.1143, 0.0981, 0.0851, 0.0745, 0.0658],
        [0.0586, 0.0524, 0.0472, 0.0428, 0.0389, 0.0355]])
tensor([[True, True, True, True, True, True],
        [True, True, True, True, True, True]])
