In [None]:
import seaborn as sns
import pickle
import torch
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

from sklearn.decomposition import PCA

import statsmodels.api as sm
import statsmodels.formula.api as smf

plt.rcParams['font.family'] = 'DeJavu Serif'
plt.rcParams['font.serif'] = ['Times New Roman']

In [None]:
def compute_entropy_mi(emb):
    '''
    emb: [num_learner, num_sample, time_step, dim]
    '''
    num_learner, num_sample, time, dim = emb.shape
    const = dim/2 * (1+np.log(2*np.pi))
    
    # H(s)
    cov_s = torch.cov(emb.reshape(-1, dim).transpose(0,1))
    hs = torch.log(torch.linalg.det(cov_s))
    hs = hs/2 + const

    # H(s|l)
    hs_l_list = []
    cov_sl_list = []
    for i in range(num_learner):
        cov_sl = torch.cov(emb[i].reshape(-1, dim).transpose(0,1))
        determinant = torch.linalg.det(cov_sl)
        hs_l_list.append(torch.log(determinant))
        cov_sl_list.append(determinant)
        
    i = 0
    mean = 0
    for j in range(len(hs_l_list)):
        if not (hs_l_list[j].isnan() or hs_l_list[j].isinf()):
            mean += hs_l_list[j]
        else:
            i += 1
    hs_l = mean / (len(hs_l_list)-i)
    hs_l = hs_l/2 + const
    
    print('hs_full: ', hs)
    print('hs_l_full: ', hs_l)
    print('number of invalid covariance matrix: ', i)

    hs_diag = torch.var(emb.reshape(-1, dim), 0)
    hs_diag = torch.log(hs_diag).sum()
    hs_diag = hs_diag/2 + const

    hs_l_diag = torch.var(emb.reshape(num_learner, -1, dim), 1)
    hs_l_diag = torch.log(hs_l_diag).sum(1).mean()
    hs_l_diag = hs_l_diag/2 + const

    print('hs_diag: ', hs_diag)
    print('hs_l_diag: ', hs_l_diag)

    print('I_sl_full_full', hs - hs_l)
    print('I_sl_full_diag', hs - hs_l_diag)
    print('I_sl_diag_diag', hs_diag - hs_l_diag)