# L4: VI and mixed-membership mixture models

Here we explore how to use Variational Inference to learn parameters in mixed-membership models.  
We compare with MLE + EM.


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import networkx as nx
import scipy.special as sp

In [None]:
import sys
sys.path.append('../../../src/')
import tools as tl
import plot as viz



In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

colormap = plt.cm.tab10
colors = {i: colormap(i) for i in range(20)}

In [None]:
from probinet.input.loader import build_adjacency_from_file
from probinet.input.stats import print_graph_stats
from probinet.models.mtcov import MTCOV
from probinet.visualization.plot import plot_hard_membership, plot_soft_membership
from probinet.visualization.plot import extract_bridge_properties

In [None]:
import cv_tools as cvtl

In [None]:
import country_converter as coconv

In [None]:
outdir_fig = '../figures/'
lecture_id = 4

In [None]:
seed = 10
prng = np.random.RandomState(seed)

In [None]:
cc = coconv.CountryConverter()

# 1. Setup algorithms

## 1.1 MLE + EM

In [None]:
class PMF_EM(object):
    
    def __init__(self, A, K=3, is_directed = False):
        self.A = A                 # data
        self.K = K                 # number of communities
        self.N = self.A.shape[0]   # number of nodes
        self.is_directed = is_directed

    def _init(self, prng,u0=None):
        if u0 is None:
            self.u = prng.random_sample((self.N, self.K))
        if self.is_directed == False:
            self.v = self.u
        else:
            self.v = prng.random_sample((self.N, self.K))
        # self.C = prng.random_sample((self.K, self.K))
        self.C = np.zeros((self.K,self.K))
        np.fill_diagonal(self.C, 1)
        
    def fit(self, prng, N_real=15, max_iter=500, tol=0.1, decision=2, verbose=True, u0=None):
        maxL = - 1e12  # initialization of the maximum likelihood

        for r in range(N_real):
            # random initialization
            self._init(prng,u0=u0)
            
            # convergence local variables
            coincide, it = 0, 0
            convergence = False

            loglik_values = []  # keep track of the values of the loglik to plot
            loglik = - 1e12  # initialization of the loglik

            while not convergence and it < max_iter:
                self._em()
                it, loglik, coincide, convergence = self.check_for_convergence(it, loglik, coincide, convergence, tolerance=tol, decision=decision)
                loglik_values.append(loglik)
            if verbose == True: print(f'Nreal = {r} - Loglikelihood = {tl.fl(loglik)} - Best Loglikelihood = {tl.fl(maxL)} - iterations = {it} - ')
    
            if maxL < loglik:
                u_f,v_f,C_f = self.update_optimal_parameters()
                maxL = loglik
                final_it = it
                best_loglik_values = list(loglik_values)
        
        return u_f, v_f, C_f, best_loglik_values

    def _em(self):
        # E-step
        q = self.update_q()
        # M-step
        # self.C = self.update_C(q)
        # q = self.update_q()
        self.u = self.update_u(q)
        q = self.update_q()
        if self.is_directed == False:
            self.v = self.u
        else:
            self.v = self.update_v(q)
            q = self.update_q()

    def update_q(self):
        lambda_ij = np.einsum('ik,jq,kq-> ijkq', self.u, self.v, self.C)
        lambda_ij_den = np.einsum('ijkq -> ij', lambda_ij)
        return lambda_ij/lambda_ij_den[:,:,np.newaxis,np.newaxis]
    
    def update_u(self, q):
        numerator = 0.1 + np.einsum('ij,ijkq->ik', self.A, q)
        denominator = 0.1 + np.einsum('q,kq->k', self.v.sum(axis=0), self.C)[np.newaxis,:]
        u_temp = numerator / denominator
        return u_temp

    def update_v(self, q):
        numerator = np.einsum('ij,ijkq->jq', self.A, q)
        denominator = np.einsum('k,kq->q', self.u.sum(axis=0), self.C)[np.newaxis,:]
        v_temp = numerator / denominator
        return v_temp

    def update_C(self, q):
        numerator = 0.1 + np.einsum('ij,ijkq->kq', self.A, q)
        denominator = 0.1 + np.einsum('k,q->kq', self.u.sum(axis=0), self.v.sum(axis=0))
        C_temp = numerator / denominator
        return C_temp
    
    def check_for_convergence(self, it, loglik, coincide, convergence, tolerance=0.1, decision=2):
        if it % 10 == 0:
            old_L = loglik
            loglik = self.Likelihood(EPS = 1e-12)
            if abs(loglik - old_L) < tolerance:
                coincide += 1
            else:
                coincide = 0
        if coincide > decision:
            convergence = True
        it += 1
        return it, loglik, coincide, convergence

    def Likelihood(self, EPS = 1e-12):
        lambda_ij = np.einsum('ik,jq,kq-> ij', self.u, self.v, self.C)
        return (self.A * np.log(lambda_ij + EPS)).sum() - lambda_ij.sum() 

    def update_optimal_parameters(self):
        u_f = np.copy(self.u)
        v_f = np.copy(self.v)
        C_f = np.copy(self.C)
        return u_f,v_f,C_f

## 1.2 VI

In [None]:
class PMF_VI(object):
    
    def __init__(self, A, K=3):
        self.A = A                 # data
        self.K = K                 # number of communities
        self.N = self.A.shape[0]   # number of nodes

    def _init(self, prng):
        # priors
        self.a = 1
        self.b = 1
        self.c = 1
        self.d = 1
        
        # random initialization
        self.alpha_shp = prng.random_sample(size=(self.N,self.K)) + self.a
        self.alpha_rte = prng.random_sample(size=(self.N,self.K)) + self.b
        self.beta_shp = prng.random_sample(size=(self.N,self.K)) + self.c
        self.beta_rte = prng.random_sample(size=(self.N,self.K)) + self.d

    def fit(self, prng, N_real=15, max_iter=500, tol=0.1, decision=2, verbose=True):
        maxElbo = - 1e12  # initialization of the maximum elbo

        for r in range(N_real):
            # random initialization
            self._init(prng)

            # convergence local variables
            coincide, it = 0, 0
            convergence = False

            elbo_values = []  # keep track of the values of the elbo to plot
            elbo = - 1e12  # initialization of the loglik

            while not convergence and it < max_iter:
                self._cavi()
                
                Eu, Elogu = compute_expectations(self.alpha_shp, self.alpha_rte)
                Ev, Elogv = compute_expectations(self.beta_shp, self.beta_rte)

                it, elbo, coincide, convergence = self.check_for_convergence_cavi(Eu, Elogu, Ev, Elogv, it, elbo, coincide,   
                                                                          convergence, tolerance=tol, decision=decision)
                elbo_values.append(elbo)
            if verbose == True: print(f'Nreal = {r} - ELBO = {tl.fl(elbo)} - Best ELBO = {tl.fl(maxElbo)} - iterations = {it} - ')

            if maxElbo < elbo:
                alpha_shp_f,alpha_rte_f,beta_shp_f,beta_rte_f = self.update_optimal_parameters()
                maxElbo = elbo
                final_it = it
                best_elbo_values = list(elbo_values)
        
        return alpha_shp_f, alpha_rte_f, beta_shp_f, beta_rte_f, best_elbo_values

    def _cavi(self):
        phi_ij = self.update_phi()

        self.update_alphas(phi_ij)
        phi_ij = self.update_phi()
        self.update_betas(phi_ij)

    def update_phi(self):
        phi_ijk = np.einsum('ik,jk->ijk',np.exp(sp.psi(self.alpha_shp) - np.log(self.alpha_rte)), np.exp(sp.psi(self.beta_shp) - np.log(self.beta_rte)))
        sumPhi = phi_ijk.sum(axis=-1)[:,:,np.newaxis]
        sumPhi[sumPhi == 0] = 1
        return phi_ijk / sumPhi
    
    def update_alphas(self, phi_ij):
        self.alpha_shp = self.a + np.einsum('ij,ijk->ik', self.A,phi_ij)
        self.alpha_rte = self.b + (self.beta_shp / self.beta_rte).sum(axis=0)[np.newaxis,:]
        
    def update_betas(self, phi_ij):
        self.beta_shp = self.c + np.einsum('ij,ijk->jk', self.A,phi_ij)
        self.beta_rte = self.d + (self.alpha_shp / self.alpha_rte).sum(axis=0)[np.newaxis,:]
   
    def check_for_convergence_cavi(self, Eu, Elogu, Ev, Elogv, it, elbo, coincide, convergence, tolerance=0.1,decision=2):
        if it % 10 == 0:
            old_elbo = elbo
            elbo = self.Elbo(Eu, Elogu, Ev, Elogv)
            if abs(elbo - old_elbo) < tolerance:
                coincide += 1
            else:
                coincide = 0
        if coincide > decision:
            convergence = True
        it += 1
        return it, elbo, coincide, convergence

    def Elbo(self, Eu, Elogu, Ev, Elogv):
        bound = (self.A * np.log(np.einsum('ik,jk->ij',np.exp(Elogu),np.exp(Elogv)))).sum() - Eu.dot(Ev.T).sum()
        bound += gamma_elbo_term(pa=self.a, pb=self.b, qa=self.alpha_shp, qb=self.alpha_rte).sum()
        bound += gamma_elbo_term(pa=self.c, pb=self.d, qa=self.beta_shp, qb=self.beta_rte).sum()
        return bound

    def update_optimal_parameters(self):
        alpha_shp = np.copy(self.alpha_shp)
        alpha_rte = np.copy(self.alpha_rte)
        beta_shp = np.copy(self.beta_shp)
        beta_rte = np.copy(self.beta_rte)
        return alpha_shp,alpha_rte,beta_shp,beta_rte
    
def compute_expectations(alpha, beta):
    '''
    Given x ~ Gam(alpha, beta), compute E[x] and E[log x]
    '''    
    return (alpha / beta , sp.psi(alpha) - np.log(beta))

def gamma_elbo_term(pa, pb, qa, qb):
        return sp.gammaln(qa) - pa * np.log(qb) + (pa - qa) * sp.psi(qa) + qa * (1 - pb / qb)

# 2. Import data

In [None]:
indir = '../../../data/outdir/wto/'
filename = 'wto_aob.csv'
infile = f"{indir}{filename}"
df = pd.read_csv(infile)

In [None]:
source = 'reporter_name'
target = 'partner_name'
weight = 'weight'

In [None]:
undirected = True
force_dense = True
binary = True
data = build_adjacency_from_file(
    f"{indir}{filename}",
    ego=source,
    alter=target,
    sep=",",
    undirected=undirected,
    force_dense=force_dense,
    binary=binary,
    header=0,
)
# Print the names of the coordinates in the namedtuple gdata
print(data._fields)

nodeLabel2Id = {k:i for i,k in enumerate(data.nodes)}
nodeId2Label = {i:k for i,k in enumerate(data.nodes)}

Y = data.adjacency_tensor

plt.figure(figsize=(2,2))

nmax = 500
node_order = np.argsort(-Y[0].sum(axis=1))
viz.plot_matrix(Y,node_order=node_order[:nmax],title=f"Y")

plt.tight_layout()

In [None]:
ms = 10
node_size = [np.log(data.graph_list[0].degree[i]) * ms + 20 for i in data.nodes]
position = tl.get_custom_node_positions(data.graph_list[0])

In [None]:
A = data.adjacency_tensor[0]
A.shape

Let's add some attribute based on country

In [None]:
macro_area = cc.continent['continent'].unique()
nameShort2region = dict(zip(cc.continent['name_short'],cc.continent['continent']))
nameShort2region['European Union'] = 'Europe'
names_short = coconv.convert(names=data.nodes, to='name_short',not_found=None)
nameRaw2Short = {data.nodes[i]: names_short[i] for i in range(len(names_short))}

C = len(macro_area) + 1 # if 2: binary
X_reg = np.zeros((len(data.nodes),C)).astype(int)

for i,n in enumerate(data.nodes):
    if nameRaw2Short[n] in nameShort2region:
        r = nameShort2region[nameRaw2Short[n]]
        idx = np.where(macro_area ==r)[0]
        X_reg[i,idx] = 1
    else:
        print(n)
        X_reg[i,-1] = 1
        
assert np.all(np.sum(X_reg,axis=1) == 1), np.where(np.sum(X_reg,axis=1) != 1)
X_reg.shape

# 3. Run inference

In [None]:
K = 6
u = {}

## 3.1 EM + MLE

In [None]:
pmf_em = PMF_EM(A, K=K)
u_em, v_em, C_em, best_loglik_values = pmf_em.fit(prng)

u['em'] = u_em
u['norm_em']= tl.normalize_nonzero_membership(u_em)


In [None]:
plot_L(best_loglik_values, int_ticks=True)

## 3.2 VI

In [None]:
pmf_vi = PMF_VI(A, K=K)
alpha_shp_vi, alpha_rte_vi, beta_shp_vi, beta_rte_vi, best_elbo_values = pmf_vi.fit(prng)

In [None]:
viz.plot_L(best_elbo_values, int_ticks=True, ylab='ELBO')

## 3.3 MLE +EM + node attributes (MTCOV)

In [None]:
config_dict = {
    "assortative": True,
    "end_file": "_mtcov",
    "out_folder": '../../../data/outdir/wto/',
    "out_inference": True,
    "undirected": True,
    "rseed": 10
}

plot_loglik = False
num_realizations = 20
max_iter = 500
decision = 1
convergence_tol = 1e-3
data = data._replace(design_matrix=X_reg)

gamma = 0.7
model = MTCOV(num_realizations=num_realizations, plot_loglik=plot_loglik,max_iter=max_iter,decision=decision,convergence_tol=convergence_tol)
params_mtcov = model.fit(data, K=K, gamma=gamma, rng=np.random.default_rng(config_dict["rseed"]), **config_dict)
    
u['mtcov'] = params_mtcov[0]

# 4. Analyze results

In [None]:
attrib_label = 'continent'
figsize= (16,10)

fig, axs = plt.subplots(2,3, figsize=(16,6))

algo = 'em'
n_row = 0
viz.plot_network(data,X_reg,ax=axs[n_row,0], title=f'Attribute {attrib_label}')
viz.plot_network(data,u[algo],ax=axs[n_row,1], title = algo, plot_labels = False, threshold=0.1)
q = tl.from_mixed_to_hard(u[algo])
viz.plot_network(data,q,ax=axs[n_row,2], title = f'{algo} (hard)')

algo = 'mtcov'
n_row = 1
viz.plot_network(data,X_reg,ax=axs[n_row,0], title=f'Attribute {attrib_label}')
viz.plot_network(data,u[algo],ax=axs[n_row,1], title = algo, plot_labels = False, threshold=0.1)
q = tl.from_mixed_to_hard(u[algo])
viz.plot_network(data,q,ax=axs[n_row,2], title = f'{algo} (hard)')


filename = tl.get_filename(f'wto_attribute_{attrib_label}_EM', lecture_id=lecture_id)
filename = None
tl.savefig(plt, outfile=filename, outdir=outdir_fig)

#### 4.1 How about VI ?  

Recall that VI does not output automatically point estimates!  
We need to extrapolate them from posterior distributions!  

For instance, we can get them from taking **expectations** over the posteriors.


In [None]:
Eu_vi, Elogu_vi = compute_expectations(alpha_shp_vi,alpha_rte_vi)
Ev_vi, Elogv_vi = compute_expectations(beta_shp_vi,beta_rte_vi)

u['vi'] = Eu_vi
u['norm_vi'] = tl.normalize_nonzero_membership(Eu_vi)
assert np.all(np.allclose(np.sum(u['norm_vi'],axis=1),1))

q_vi = np.argmax(u['norm_vi'], axis=1)  # extract hard communities

In [None]:
selected_nodes = ['Norway','United Kingdom','European Union','Albania','Other Countries, n.e.s.']
node_labels = {n: n for n in selected_nodes}

In [None]:
attrib_label = 'continent'
figsize= (16,10)

fig, axs = plt.subplots(1,3, figsize=(16,6))

viz.plot_network(data,X_reg,position=position,ax=axs[0], title=f'Attribute {attrib_label}')
viz.plot_network(data,u['vi'],position=position,ax=axs[1], title = r'VI', plot_labels = True, threshold=0.1,node_labels=node_labels)
q = tl.from_mixed_to_hard(u['vi'])
viz.plot_network(data,q,position=position,ax=axs[2], title = 'VI (hard)',node_labels=node_labels)


filename = tl.get_filename(f'wto_attribute_{attrib_label}_VI', lecture_id=lecture_id)
filename = None
tl.savefig(plt, outfile=filename, outdir=outdir_fig)

In [None]:
attrib_label = 'continent'
figsize= (16,10)

fig, axs = plt.subplots(3,3, figsize=(24,18))

viz.plot_network(data,X_reg,position=position,ax=axs[0,0], title=f'Attribute {attrib_label}', plot_labels = False)
viz.plot_network(data,X_reg,position=position,ax=axs[1,0], title=f'Attribute {attrib_label}', plot_labels = False)
viz.plot_network(data,X_reg,position=position,ax=axs[2,0], title=f'Attribute {attrib_label}', plot_labels = False)

algo = 'vi'
viz.plot_network(data,u[algo],position=position,ax=axs[0,1], title = algo, plot_labels = False, threshold=0.1)
q = tl.from_mixed_to_hard(u[algo])
viz.plot_network(data,q,position=position,ax=axs[0,2], title = f'{algo} (hard)')

algo = 'em'
viz.plot_network(data,u[algo],position=position,ax=axs[1,1], title = algo, plot_labels = False, threshold=0.1)
q = tl.from_mixed_to_hard(u[algo])
viz.plot_network(data,q,position=position,ax=axs[1,2], title = f'{algo} (hard)')

algo = 'mtcov'
viz.plot_network(data,u[algo],position=position,ax=axs[2,1], title = algo, plot_labels = False, threshold=0.1)
q = tl.from_mixed_to_hard(u[algo])
viz.plot_network(data,q,position=position,ax=axs[2,2], title = f'{algo} (hard)', plot_labels = False)


filename = tl.get_filename(f"WTO_MM",lecture_id=lecture_id)
filename = None
tl.savefig(plt,outfile = filename,outdir = outdir_fig)

### 4.2 Posterior distribution
We can use posterior estimates to assess uncertainty

In [None]:
import seaborn as sns

In [None]:
idx = 17 # example node id
SAMPLE = 1000

L = alpha_shp_vi.shape[1]
fig, ax = plt.subplots(1,L,figsize=(16, 4), sharey=True,sharex=True)
for i in range(L):
    mean = alpha_shp_vi[idx,i] /alpha_rte_vi[0,i]
    sns.histplot(prng.gamma(alpha_shp_vi[idx,i],  1. /alpha_rte_vi[0,i], size=SAMPLE),color=colors[i], kde=True,line_kws={'ls':'--'},alpha=0.1, ax=ax[i])
    ax[i].axvline(x=mean,color=colors[i],ls='-',lw=1,alpha=0.8)
    ylim = ax[i].get_ylim()
    # ax[i].set_ylabel('P(k)')
    ax[i].set_ylabel(r'$P(u_{ik})$')
    ax[i].set_xlabel(r'$u_{ik}$')
    ax[i].text(0.5,ylim[1]*0.85,f'k = {i}')

title = nodeId2Label[idx]
plt.title(f"node = {title}")

In [None]:
nodeId2Label

In [None]:
idx = 1 # example node id
SAMPLE = 1000

L = alpha_shp_vi.shape[1]


for with_text in [False,True]:
    prng_tmp = np.random.RandomState(seed=seed)
    fig, ax = plt.subplots(1,1,figsize=(8, 4))
    for i in range(L):
        mean = alpha_shp_vi[idx,i] /alpha_rte_vi[0,i]
        std = alpha_shp_vi[idx,i] /(alpha_rte_vi[0,i] * alpha_rte_vi[0,i])
        sns.histplot(prng_tmp.gamma(alpha_shp_vi[idx,i],  1. /alpha_rte_vi[0,i], size=SAMPLE),color=colors[i], kde=True,line_kws={'ls':'--'},alpha=0.3, ax=ax, label=f"k={i}")
        ax.axvline(x=mean,color=colors[i],ls='-',lw=1,alpha=0.8)
        ylim = ax.get_ylim()
        xlim = ax.get_xlim()
        # ax[i].set_ylabel('P(k)')
        ax.set_ylabel(r'$P(u_{ik})$')
        ax.set_xlabel(r'$u_{ik}$')

    if with_text == True:
        k = np.argmax(alpha_shp_vi[idx,:] /alpha_rte_vi[0,:])
        mean = alpha_shp_vi[idx,k] /alpha_rte_vi[0,k]
        std = np.sqrt(alpha_shp_vi[idx,k] /(alpha_rte_vi[0,k]*alpha_rte_vi[0,k]))
        msg = f"mean u_ik = {mean:.2f}"
        msg = f"{msg}\nstd u_ik = {std:.2f}"
        msg = f"{msg}\nVMR u_ik = {std*std/mean:.2f}"
        ax.text(xlim[1]* 0.5, ylim[1]*0.7, msg)

    ax.set_xlim(0.01,xlim[1])
    # ax.set_ylim(0.0,500)
    
    title = nodeId2Label[idx]
    plt.title(f"n = {title}")
    plt.legend(loc='best')

    filename = tl.get_filename(f"WTO_{title}_VI_{with_text}",lecture_id=lecture_id)
    filename = None
    tl.savefig(plt,outfile = filename,outdir = outdir_fig)

## 4.3 Model selection
Which model performs the best?

We need to apply model selection criteria to decide. Here we use cross-validation.

In [None]:
seed = 10
cv_mask = cvtl.extract_mask(data.adjacency_tensor.shape, seed = seed )
cv_mask.keys(), cv_mask[0].shape

In [None]:
params_cv =  { f: {} for f in cv_mask.keys()}
K = 6
for fold, mask in cv_mask.items():

    data_cv = cvtl.get_df_train_test(df,data,cv_mask,fold=fold)
    
    algo = 'em'
    pmf_em = PMF_EM(data_cv.adjacency_tensor[0], K=K)
    params_cv[fold][algo] = pmf_em.fit(prng,verbose=False)

    algo = 'vi'
    pmf_vi = PMF_VI(data_cv.adjacency_tensor[0], K=K)
    alpha_shp_f, alpha_rte_f, beta_shp_f, beta_rte_f, best_elbo_values = pmf_vi.fit(prng,verbose=False)
    Eu_vi, Elogu_vi = compute_expectations(alpha_shp_vi,alpha_rte_vi)
    Ev_vi, Elogv_vi = compute_expectations(beta_shp_vi,beta_rte_vi)
    C_vi = np.ones((1,Eu_vi.shape[0],Ev_vi.shape[0]))
    # params_cv[fold][algo] = (Eu_vi,Ev_vi,C_vi)
    params_cv[fold][algo] = (alpha_shp_f, alpha_rte_f, beta_shp_f, beta_rte_f, best_elbo_values)

    algo = 'mtcov'
    gamma = 0.7
    data_cv = data_cv._replace(design_matrix=X_reg)
    model = MTCOV(num_realizations=num_realizations, plot_loglik=plot_loglik,max_iter=max_iter,decision=decision,convergence_tol=convergence_tol)
    params_cv[fold][algo] = model.fit(data_cv, K=K, gamma=gamma, rng=np.random.default_rng(config_dict["rseed"]), **config_dict)
   


In [None]:
def compute_mean_vi(alpha_shp_u,alpha_rte_u,alpha_shp_v,alpha_rte_v, method='geometric'):

    Eu_vi, Elogu_vi = compute_expectations(alpha_shp_vi,alpha_rte_vi)
    Ev_vi, Elogv_vi = compute_expectations(beta_shp_vi,beta_rte_vi)
    
    if method == 'geometric':
        return np.einsum('ik,jk->ij',np.exp(Elogu_vi),np.exp(Elogv_vi))
    else:
        return np.einsum('ik,jk->ij',Eu_vi,Ev_vi)

def compute_mean_lambda0_em(u,v,w):
    if w.ndim == 3:
        return np.einsum('ik,jq,akq->aij',u,v,w)
    else:
        if w.shape[0] == w.shape[1]:
            Y = np.zeros((1,u.shape[0],v.shape[0]))
            Y[0,:] = np.einsum('ik,jq,kq->ij',u,v,w)
            return Y
        else:
            return np.einsum('ik,jk,ak->aij',u,v,w)


In [None]:
params_cv.keys()

In [None]:

Y_pred = {fold: {algo: compute_mean_lambda0_em(params_cv[fold][algo][0],params_cv[fold][algo][1],params_cv[fold][algo][2]) for algo in ['em','mtcov']} for fold in params_cv.keys()}

method = 'geometric'
for f in Y_pred.keys():
    Y_pred[f]['vi'] = np.zeros_like(Y_pred[f]['mtcov'])
    Y_pred[f]['vi'][0,:] = compute_mean_vi(params_cv[fold]['vi'][0],params_cv[fold]['vi'][1],params_cv[fold]['vi'][2],params_cv[fold]['vi'][3],method=method)

In [None]:
f, axarr = plt.subplots(2, len(params_cv[0].keys()),figsize=(8,6))

fold = 2
for i,algo in enumerate(params_cv[fold].keys()):
    node_order = tl.extract_node_order(params_cv[fold]['em'][0])
    viz.plot_matrix(data.adjacency_tensor,node_order=node_order,ax=axarr[0,i],title=f"True: {algo}",vmax = 1e-3,vmin=0)
    viz.plot_matrix(Y_pred[fold][algo],node_order=node_order,ax=axarr[1,i],title=f"Pred: {algo}",vmin=0)

plt.tight_layout()

In [None]:
df_pred = pd.concat([cvtl.get_prediction_results(data, params_cv[fold], cv_mask,fold=fold,Y_pred=Y_pred[fold]) for fold in cv_mask.keys()])
df_pred.head(n=10)

In [None]:
df_pred_mean = df_pred.groupby(by='algo').agg('mean').drop(columns=['fold']).reset_index()
df_pred_std = df_pred.groupby(by='algo').agg('std').drop(columns=['fold']).reset_index()

metrics = ['auc_test', 	'logL_test', 	'bce_test']
df_pred_mean.style.background_gradient(subset=metrics,cmap=plt.cm.RdYlGn)

In [None]:
c = viz.default_colors[0]
L = len(metrics)

algos = list(df_pred_mean['algo'].unique())
xticks = np.arange(len(algos))

fig, axs = plt.subplots(1,L,figsize=(12,4),sharex=True)
for i in range(L):
    m = metrics[i]
    axs[i].scatter(xticks,df_pred_mean[m],s=200,c=c, edgecolor='black')
    axs[i].errorbar(xticks,df_pred_mean[m],yerr=df_pred_std[m], linewidth=1, capsize=4, capthick=1, color=c)
    axs[i].set_xlabel('Model')
    axs[i].set_ylabel(m)
    axs[i].set_xticks(xticks,algos)
plt.tight_layout()


filename = tl.get_filename(f'wto_cv_example', lecture_id=lecture_id)
filename = None
tl.savefig(plt, outfile=filename, outdir=outdir_fig)


In [None]:
c = viz.default_colors[0]

m = 'auc_test'
method1 = 'vi'
methods = list(set(df_pred.algo.unique()).difference(set([method1])))
L = len(methods)

xlim = (df_pred[m].min() * 0.9,df_pred[m].max() * 1.05)
mask1 = df_pred.algo == method1
y_ref = df_pred[mask1].reset_index()

fig, axs = plt.subplots(1,L,figsize=(8,3),sharex=True)
for i in range(L):
    mask2 = df_pred.algo == methods[i]
    y_comp = df_pred[mask2].reset_index()

    # mask_tot = mask1 & mask2
    mask_c = y_ref[m] >= y_comp[m]
    if np.sum(mask_c) > 0:
        axs[i].scatter(y_ref[m][mask_c],y_comp[m][mask_c],s=100,c='b', edgecolor='black')
        axs[i].scatter(y_ref[m][mask_c==False],y_comp[m][mask_c==False],s=100,c='r', edgecolor='black')
    else:
        axs[i].scatter(y_ref[m],y_comp[m],s=100,c='r', edgecolor='black')
    axs[i].set_xlabel(f"{m} {method1}")
    axs[i].set_ylabel(f"{m} {methods[i]}")

    axs[i].set_xlim(xlim)
    axs[i].set_ylim(xlim)

    xs = np.linspace(xlim[0],xlim[1])
    axs[i].plot(xs,xs,ls='--',alpha=0.8, color='darkgrey',lw=1)

plt.tight_layout()

# filename = tl.get_filename(f'wto_cv_example_fold_by_fold', lecture_id=lecture_id)
filename = None
tl.savefig(plt, outfile=filename, outdir=outdir_fig)