In [5]:
import jax

import jax.numpy as np
import jax.scipy as sp

import numpy as onp
import os
import paragami

import matplotlib.pyplot as plt
%matplotlib inline

from bnpmodeling_runjingdev import log_phi_lib, cluster_quantities_lib
import bnpmodeling_runjingdev.functional_sensitivity_lib as func_sens_lib
import bnpmodeling_runjingdev.exponential_families as ef

from structure_vb_lib import structure_model_lib, plotting_utils
from structure_vb_lib import structure_optimization_lib as s_optim_lib
from structure_vb_lib.data_utils import cluster_admix_get_indx

import re


In [2]:
import numpy as onp 
onp.random.seed(53453)

# File paths

In [7]:
# scratch_dir = '/scratch/users/genomic_times_series_bnp/structure/' # scratch directory
# out_folder = scratch_dir + 'hgdp_fits/' # folder where fits are found

# out_filename = 'huang2011_fit' # name of fit

data_file='../data/phased_HGDP+India+Africa_2810SNPs-regions1to36.bed'
out_filename='huang2011_fit'
out_folder='../fits/hgdp_fits/'

In [9]:
# initial fit and linear response derivatives
alpha0 = 6.0

init_fit_file = out_folder + out_filename + '_alpha' + str(alpha0) + '.npz'
lr_file = out_folder + out_filename + '_alpha' + str(alpha0) + '_lrderivatives.npz'

print(init_fit_file)
print(lr_file)

# check paths exist
assert os.path.isfile(init_fit_file)
assert os.path.isfile(lr_file)

../fits/hgdp_fits/huang2011_fit_alpha6.0.npz
../fits/hgdp_fits/huang2011_fit_alpha6.0_lrderivatives.npz


AssertionError: 

# Load initial fit and meta data

In [None]:
print('initial fit file: ', init_fit_file)

vb_opt_dict, vb_params_paragami, \
        prior_params_dict, prior_params_paragami, \
            gh_loc, gh_weights, init_fit_meta_data = \
                structure_model_lib.load_structure_fit(init_fit_file)

alpha0 = prior_params_dict['dp_prior_alpha']

# the initial free parameters
vb_opt = vb_params_paragami.flatten(vb_opt_dict, free = True)

print('Init optim time: {:.3f}secs'.format(init_fit_meta_data['optim_time']))

`vb_opt_dict` contains the parameters at the initial ($\epsilon = 0)$ fit. 

The parameter dictionary contains two keys: 
- `pop_freq_beta_params`: a (`n_loci` x `k_approx` x 2) array of Beta parameters that describe the approximate posterior on population allele frequencies. (these are the "cluster centroids").
- `ind_admix_params`: contain the logit-normal parameters on the stick-breaking proportions. This is yet another dictionary with two keys, `stick_means` and `stick_infos`. Each are (`n_obs` x `k_approx - 1`). 

In [None]:
print(vb_params_paragami)

The `ind_admix_params` are of particular importance, as they define inferred admixture proportions. 

The expected individual mixture proportions is an array of shape `n_obs x k_approx`, where the $(n,k)$-th entry is the posterior probability of individual $n$ belonging to cluster $k$. 

In [None]:
def get_e_ind_admix(vb_free_params): 
    
    # this returns the expected individual mixture proportions. 
        
    vb_params_dict = vb_params_paragami.fold(vb_free_params, free = True)
    
    stick_means = vb_params_dict['ind_admix_params']['stick_means']
    stick_infos = vb_params_dict['ind_admix_params']['stick_infos']
    
    e_ind_admix = cluster_quantities_lib.get_e_cluster_probabilities(stick_means, 
                                                                     stick_infos, 
                                                                     gh_loc, 
                                                                     gh_weights)
    
    return e_ind_admix
    


In [None]:
e_ind_admix0 = get_e_ind_admix(vb_opt)

print(e_ind_admix0.shape)

In [None]:
# plot expected mixture proportions

fig, ax = plt.subplots(1, 1, figsize = (6, 3))

# top 6 clusters are colored ... remaining clusters are grey
# height of bar represents mixture proportions
plotting_utils.plot_top_clusters(e_ind_admix0, ax, n_top_clusters = 6); 

# Define perturbation

This class contains a collection of perturbations. 

In [None]:
f_obj_all = log_phi_lib.LogPhiPerturbations(vb_params_paragami, 
                                            prior_params_dict['dp_prior_alpha'],
                                            gh_loc, 
                                            gh_weights,
                                            stick_key = 'ind_admix_params')

Choose a perturbation. Available perturbations are: 

- 'worst-case', the worst-case L-inf perturbation
- 'sigmoidal', the perturbation where log-phi is a sigmoid function
- 'sigmoidal_neg', the perturbation where log-phi is the negative sigmoid function
- 'alpha_pert_pos', the multiplicative functional perturbation to move the DP alpha parameter by +5
- 'alpha_pert_neg', the multiplicative functional perturbation to move the DP alpha parameter by -5
- 'alpha_pert_pos_xflip', same as alpha_pert_pos but reflected about x = 0
- 'alpha_pert_neg_xflip', same as alpha_pert_neg but reflected about x = 0
- 'gauss_pert1_pos', a positive gaussian bump at logit_v = 0
- 'gauss_pert2_pos', a positive gaussian bump at logit_v = -3
- 'gauss_pert1_neg', a negative gaussian bump at logit_v = 0
- 'gauss_pert2_neg', a negative gaussian bump at logit_v = -3

In [None]:
# choose a perturbation
perturbation = 'alpha_pert_pos'
f_obj = getattr(f_obj_all, 'f_obj_' + perturbation)

`f_obj` contains a few important methods for fitting and computing the linear response. For this notebook, the important one (really just for plotting) is the `log_phi` method, which returns log multiplicative perturbation. 

(other methods include expecations of log-phi). 

In [None]:
log_phi = f_obj.log_phi

In [None]:
# The log-phi
# (we will plot the effect on the prior later)
logit_v_grid = np.linspace(-8, 8, 100)
plt.plot(logit_v_grid, log_phi(logit_v_grid))

# Define linear sensitivity

In [None]:
print('loading derivatives from: ', lr_file)
lr_data = np.load(lr_file)

# check some stuff: make sure it matches the intial fit
assert lr_data['alpha0'] == alpha0 # check alphas match
assert np.abs(lr_data['vb_opt'] - vb_opt).max() < 1e-12 # check optimum match
assert np.abs(lr_data['kl'] - init_fit_meta_data['final_kl']) < 1e-8 # check KL's match

In [None]:
dinput_dfun = lr_data['dinput_dfun_' + perturbation]

In [None]:
print('Derivative time: {:.02f} secs'.format(lr_data['lr_time_' + perturbation]))
print('CG tolerance: ', lr_data['cg_tol'])

# Load refits

In [None]:
match_crit = out_filename + '_' + perturbation + '\d+.npz'
refit_files = [f for f in os.listdir(out_folder) if re.match(match_crit, f)]
    
assert len(refit_files) > 0, 'no refit files found'

In [None]:
def load_refit_files(refit_files): 
    
    # load files
    optim_time_vec = [init_fit_meta_data['optim_time']]
    epsilon_vec = [0.]
    vb_refit_list = [vb_opt]
    delta = []
    
    for i in range(len(refit_files)): 

        print('loading fit from: ', refit_files[i])
        vb_params_dict, vb_params_paragami, meta_data = \
            paragami.load_folded(out_folder + refit_files[i])
        
        vb_free = vb_params_paragami.flatten(vb_params_dict, free = True)
                
        # save the metadata we need
        optim_time_vec.append(meta_data['optim_time'])
        epsilon_vec.append(meta_data['epsilon'])
        vb_refit_list.append(vb_free)
        
        # for the sigmoidal perturbation, we scaled the original function 
        # by a factor delta.
        # Save this delta
        delta.append(meta_data['delta'])
        
    optim_time_vec = np.array(optim_time_vec)
    epsilon_vec = np.array(epsilon_vec)
    vb_refit_list = np.array(vb_refit_list)
    
    # all the delta's should be the same ... 
    delta = np.unique(np.array(delta))
    print('delta = ', delta)
    assert len(delta) == 1
    
    # sort the epsilons 
    _indx = np.argsort(epsilon_vec)
    optim_time_vec = optim_time_vec[_indx]
    epsilon_vec = epsilon_vec[_indx]
    vb_refit_list = vb_refit_list[_indx]
    
    return vb_refit_list, epsilon_vec, optim_time_vec, delta

In [None]:
vb_refit_list, epsilon_vec, optim_time_vec, delta = \
    load_refit_files(refit_files)

`vb_refit_list` contains the free parameters from each refit. 

`vb_refit_list[i]` corresponds to the refit with `epsilon = epsilon_vec[i]`. 

# Get linear approximation to the refits

In [None]:
def predict_opt_par_from_hyper_par(epsilon): 
    return vb_opt + dinput_dfun * epsilon * delta

lr_list = []
for epsilon in epsilon_vec: 
    # get linear response
    lr_list.append(predict_opt_par_from_hyper_par(epsilon))


`lr_list` contains the free parameters from the linear approximation. 

`lr_list[i]` corresponds to the predicted free parameters for `epsilon = epsilon_vec[i]`. 

# Plot priors 

In [None]:
prior_perturbation = func_sens_lib.PriorPerturbation(
                                    alpha0 = alpha0,
                                    log_phi = lambda x : f_obj.log_phi(x) * delta, 
                                    logit_v_ub = 10, 
                                    logit_v_lb = -10)

prior_perturbation.plot_perturbation();

# Plot vb free parameters

In [None]:
def print_diff_plot(x, y, x0, ax, color = 'red', alpha = 1.): 
    ax.plot(x - x0, y - x0, '+', color = color, alpha = alpha)
    ax.plot(x - x0, x - x0, '-', color = 'blue', alpha = alpha)

In [None]:
def plot_params_diff_plot(refit_free_params, lr_free_params, init_free_params): 
    
    # store the free parameters into dictionaries
    # free = False means that we are plotting the free parameters (we don't constrain them)
    init_free_param_dict = vb_params_paragami.fold(init_free_params, free = False)
    lr_free_param_dict = vb_params_paragami.fold(lr_free_params, free = False)
    refit_free_param_dict = vb_params_paragami.fold(refit_free_params, free = False)
    
    fig, ax = plt.subplots(1, 3, figsize = (12, 3))

    ax[0].set_title('pop freq. parameters')
    print_diff_plot(lr_free_param_dict['pop_freq_beta_params'].flatten(),
                    refit_free_param_dict['pop_freq_beta_params'].flatten(), 
                    init_free_param_dict['pop_freq_beta_params'].flatten(), 
                    ax[0])


    ax[1].set_title('stick means')
    print_diff_plot(lr_free_param_dict['ind_admix_params']['stick_means'].flatten(),
                    refit_free_param_dict['ind_admix_params']['stick_means'].flatten(), 
                    init_free_param_dict['ind_admix_params']['stick_means'].flatten(), 
                    ax[1])

    ax[2].set_title('stick infos')
    print_diff_plot(lr_free_param_dict['ind_admix_params']['stick_infos'].flatten(),
                    refit_free_param_dict['ind_admix_params']['stick_infos'].flatten(), 
                    init_free_param_dict['ind_admix_params']['stick_infos'].flatten(), 
                    ax[2])
    
    for i in range(3): 
        ax[i].set_xlabel('lr - init')
        ax[i].set_ylabel('refit - init')
        
    fig.tight_layout()

### At smallest epsilon

In [None]:
# the smallest epsilon
epsilon_vec[1]

In [None]:
plot_params_diff_plot(vb_refit_list[1], lr_list[1], vb_refit_list[0])

### at epsilon = 1

In [None]:
# if all the refits completed, 
# last epsilon shoudl be epsilon = 1
epsilon_vec[-1]

In [None]:
plot_params_diff_plot(vb_refit_list[-1], lr_list[-1], vb_refit_list[0])

# Expected number of clusters

In [None]:
def get_e_num_clusters(vb_free_params): 

    # get expected mixture proportions
    e_ind_admix = get_e_ind_admix(vb_free_params)
    
    # get number of clusters from expected mixture proportions
    e_ind_admix = e_ind_admix
    return cluster_quantities_lib.get_e_num_clusters_from_ez(e_ind_admix)

In [None]:
def get_all_stats(param_list, get_stat): 
    
    stat_vec = []
    
    for i in range(len(param_list)): 
        stat_vec.append(get_stat(param_list[i]))
        
    return np.array(stat_vec)

In [None]:
refit_color = '#d95f02'
lr_color = '#1b9e77'

In [None]:
refit_e_num_clusters_vec = get_all_stats(vb_refit_list, get_e_num_clusters)
lr_e_num_clusters_vec = get_all_stats(lr_list, get_e_num_clusters)
    
# plot
plt.plot(epsilon_vec, 
         refit_e_num_clusters_vec, 
         'o-', 
         color = refit_color, 
         label = 'refit')

plt.plot(epsilon_vec,
         lr_e_num_clusters_vec, 
         'o-',
         color = lr_color, 
         label = 'linear approx.')

plt.legend()

plt.xlabel('Epsilon')
plt.ylabel('Expected n. clusters')

# Plot inferred admixtures

In [None]:
n_obs = e_ind_admix0.shape[0]
k_approx = e_ind_admix0.shape[1] 

In [None]:
# plot expected mixture proportions

fig, ax = plt.subplots(3, 1, figsize = (6, 9))

# if full data, subsample 
if n_obs > 25: 
    n_indx = onp.random.choice(n_obs, 25)
    n_indx = np.sort(n_indx)
else: 
    n_indx = np.arange(25)

# top 6 clusters are colored ... remaining clusters are grey
# height of bar represents mixture proportions
ax[0].set_title('initial fit')
plotting_utils.plot_top_clusters(e_ind_admix0[n_indx], ax[0], n_top_clusters = 6);

ax[1].set_title('refit at epsilon = 1')
e_ind_admix_refit = get_e_ind_admix(vb_refit_list[-1])
plotting_utils.plot_top_clusters(e_ind_admix_refit[n_indx], ax[1], n_top_clusters = 6);

ax[2].set_title('LR at epsilon = 1')
e_ind_admix_lr = get_e_ind_admix(lr_list[-1])
plotting_utils.plot_top_clusters(e_ind_admix_lr[n_indx], ax[2], n_top_clusters = 6);

fig.tight_layout()


### Trace plots for some subsampled individuals

In [None]:
def make_admix_trace_plots(stick_fun, n_indx, k_indx, ylab = ''): 
    # stick_fun is a function that takes in vb_free parameters 
    # and returns a n_obs x k_approx array of summary statistics 
    # (e.g. stick means or mixture proportions)
    
    lr_array = get_all_stats(lr_list, stick_fun)
    refit_array = get_all_stats(vb_refit_list, stick_fun)

    for n in n_indx: 
        fig, ax = plt.subplots(1, len(k_indx), figsize = (18, 3))

        for k in k_indx: 
            ax[k].plot(epsilon_vec, lr_array[:, n, k], '-x', color = lr_color, label = 'lr')
            ax[k].plot(epsilon_vec, refit_array[:, n, k], '-x', color = refit_color, label = 'refit')
            
            ax[k].set_title('n = {}; k = {}'.format(n,k))

            if k == 0: 
                ax[k].legend()
                ax[k].set_ylabel(ylab)
                
            if n == (n_indx[-1]): 
                ax[k].set_xlabel('epsilon')
            else: 
                ax[k].set_xticklabels([])
        
        fig.tight_layout()

In [None]:
# subsample a few individuals
n_obs = e_ind_admix0.shape[0]
n_indx = onp.random.choice(n_obs, 5)
k_indx = np.arange(6)

make_admix_trace_plots(get_e_ind_admix, n_indx, k_indx)

In [None]:
# refit vs. lr difference in individual admixtures
fig, ax = plt.subplots(1, 2, figsize = (8, 3))
fig.tight_layout(); 

# at smallest epsilon
e_ind_admix_refit = get_e_ind_admix(vb_refit_list[1])
e_ind_admix_lr = get_e_ind_admix(lr_list[1])

print_diff_plot(e_ind_admix_lr, 
                e_ind_admix_refit, 
                e_ind_admix0, 
                ax[0])

ax[0].set_title('Epsilon = {:.05f}'.format(epsilon_vec[1]))


# at epsilon = 1
e_ind_admix_refit = get_e_ind_admix(vb_refit_list[-1])
e_ind_admix_lr = get_e_ind_admix(lr_list[-1])

print_diff_plot(e_ind_admix_lr, 
                e_ind_admix_refit, 
                e_ind_admix0, 
                ax[1])

ax[1].set_title('Epsilon = {:.05f}'.format(epsilon_vec[-1]))

# Plot cluster weights

In [None]:
# this is len(epsilon_vec) x n_obs x k_approx
e_ind_admix_lr_array = get_all_stats(lr_list, get_e_ind_admix)
e_ind_admix_refit_array = get_all_stats(vb_refit_list, get_e_ind_admix)

cluster_weights_lr = e_ind_admix_lr_array.sum(1)
cluster_weights_refit = e_ind_admix_refit_array.sum(1)

In [None]:
fig, ax = plt.subplots(5, 4,
                       figsize = (18, 15),
                       sharex = True)

k_approx = e_ind_admix0.shape[1]

for k in range(k_approx): 
    x0 = k // 4
    x1 = k % 4
    ax[x0, x1].plot(epsilon_vec,
                    cluster_weights_lr[:, k],
                    '-x', 
                    color = refit_color, 
                    label = 'refit')
    
    ax[x0, x1].plot(epsilon_vec,
                    cluster_weights_refit[:, k],
                    '-x', 
                    color = lr_color, 
                    label = 'lr')
    ax[x0, x1].set_title('k = {}'.format(k))
    if x0 == 4: 
        ax[x0, x1].set_xlabel('epsilon')
    if x1 == 0: 
        ax[x0, x1].set_ylabel('Expected n.ind')
        
    if k == 0: 
        ax[x0, x1].legend()

fig.tight_layout()

# Check out co-clustering

In [None]:
# sort based on initial fit
indx = cluster_admix_get_indx(e_ind_admix0)

def get_co_clustering_mat(ind_admix): 
    ind_admix_sorted = ind_admix[indx]
    
    return np.dot(ind_admix_sorted, ind_admix_sorted.transpose())


def plot_co_clustering_mat(coclust_mat, ax, fig): 
    
    im = ax.matshow(coclust_mat, vmax = 1, vmin = 0, cmap = plt.get_cmap('Blues'))
    fig.colorbar(im, ax = ax)
    ax.set_xticks([])
    ax.set_yticks([])

### Plot co-clustering

In [None]:
fig, ax = plt.subplots(1, 3, figsize = (12, 3))
fig.tight_layout()

ax[0].set_title('init')
coclust_mat0 = get_co_clustering_mat(e_ind_admix0)
plot_co_clustering_mat(coclust_mat0, ax[0], fig)

ax[1].set_title('refit')
e_ind_admix_refit = get_e_ind_admix(vb_refit_list[-1])
coclust_mat_refit = get_co_clustering_mat(e_ind_admix_refit)
plot_co_clustering_mat(coclust_mat_refit, ax[1], fig)

ax[2].set_title('lr')
e_ind_admix_lr = get_e_ind_admix(lr_list[-1])
coclust_mat_lr = get_co_clustering_mat(e_ind_admix_lr)
plot_co_clustering_mat(coclust_mat_lr, ax[2], fig)

### Print differences in co-clustering

In [None]:
fig, ax = plt.subplots(1, 2, figsize = (8, 3))
fig.tight_layout()

vmax = 0.1

ax[0].set_title('refit - init')
im0 = ax[0].matshow(coclust_mat_refit - coclust_mat0, 
              cmap = plt.get_cmap('bwr'), 
              vmax = vmax, vmin = -vmax)
fig.colorbar(im0, ax = ax[0])


ax[1].set_title('lr - init')
im0 = ax[1].matshow(coclust_mat_lr - coclust_mat0, 
              cmap = plt.get_cmap('bwr'), 
              vmax = vmax, vmin = -vmax)
fig.colorbar(im0, ax = ax[1])

In [None]:
fig, ax = plt.subplots(1, 1, figsize = (4, 3))
fig.tight_layout(); 

print_diff_plot(coclust_mat_lr.flatten(), 
                coclust_mat_refit.flatten(), 
                coclust_mat0.flatten(), 
                ax = ax)

# Study the population frequencies

In [None]:
def get_e_pop_freq(vb_free_params): 
    vb_params_dict = vb_params_paragami.fold(vb_free_params, free = True)
    
    beta_params = vb_params_dict['pop_freq_beta_params']
    
    return beta_params[:, :, 0] / (beta_params[:, :, 0] + beta_params[:, :, 1])

In [None]:
# inferred population frequency at initial fit
e_pop_freq0 = get_e_pop_freq(vb_refit_list[0])

In [None]:
# refit vs. lr difference in individual admixtures
fig, ax = plt.subplots(1, 2, figsize = (8, 3))
fig.tight_layout(); 

# at smallest epsilon
e_pop_freq_refit = get_e_pop_freq(vb_refit_list[1])
e_pop_freq_lr = get_e_pop_freq(lr_list[1])

print_diff_plot(e_pop_freq_lr, 
                e_pop_freq_refit, 
                e_pop_freq0, 
                ax[0])

ax[0].set_title('Epsilon = {:.05f}'.format(epsilon_vec[1]))


# at epsilon = 1
e_pop_freq_refit = get_e_pop_freq(vb_refit_list[-1])
e_pop_freq_lr = get_e_pop_freq(lr_list[-1])

print_diff_plot(e_pop_freq_lr, 
                e_pop_freq_refit, 
                e_pop_freq0, 
                ax[1])

ax[1].set_title('Epsilon = {:.05f}'.format(epsilon_vec[-1]))



### For plotting centroids,  run PCA on population frequencies

In [None]:
import sklearn
pca_model = sklearn.decomposition.PCA(n_components = 2)
pca_model.fit(e_pop_freq0.transpose())

In [None]:
# Plot initial centroids
e_pop_freq0_transformed = pca_model.transform(e_pop_freq0.transpose())
plt.scatter(e_pop_freq0_transformed[:, 0], 
            e_pop_freq0_transformed[:, 1])

In [None]:
e_pop_freq0_transformed.shape

### Plot centroids across epsilon

In [None]:
def get_pca_centroids_from_free(vb_free):
    
    # get expected population frequencies 
    e_pop_freq = get_e_pop_freq(vb_free)
    
    # return pca centroids
    return pca_model.transform(e_pop_freq.transpose())

In [None]:
# this is len(epsilon_vec) x k_approx x n_pca_components
centroids_refit_array = get_all_stats(vb_refit_list, get_pca_centroids_from_free)
centroids_lr_array = get_all_stats(lr_list, get_pca_centroids_from_free)

In [None]:
fig, ax = plt.subplots(5, 4,
                       figsize = (18, 15))

k_approx = e_ind_admix0.shape[1]

for k in range(k_approx): 
    x0 = k // 4
    x1 = k % 4
    
    ax[x0, x1].plot(centroids_refit_array[:, k, 0], 
                    centroids_refit_array[:, k, 1], 
                    '-x', 
                    color = refit_color, 
                    label = 'refit')
    
    ax[x0, x1].plot(centroids_lr_array[:, k, 0], 
                    centroids_lr_array[:, k, 1], 
                    '-x', 
                    color = lr_color, 
                    label = 'lr')
    
    ax[x0, x1].set_title('k = {}'.format(k))
    ax[x0, x1].set_ylabel('PC2')
    ax[x0, x1].set_xlabel('PC1')

ax[0, 0].legend()
fig.tight_layout(); 

In [None]:
def get_cosine_sim_from_vb_free(vb_free): 
    
    # get expected population frequencies 
    e_pop_freq = get_e_pop_freq(vb_free)
    
    # compute cosine similarity
    norm_const = np.sqrt(np.sum(e_pop_freq**2, axis = 0, keepdims = True))
    
    e_pop_freq = (e_pop_freq / norm_const)
    
    return np.dot(e_pop_freq.transpose(), e_pop_freq)

In [None]:
fig, ax = plt.subplots(1, 3, figsize = (12, 3))
fig.tight_layout()

ax[0].set_title('init')
plot_co_clustering_mat(get_cosine_sim_from_vb_free(vb_refit_list[0]), ax[0], fig)

ax[1].set_title('refit')
plot_co_clustering_mat(get_cosine_sim_from_vb_free(vb_refit_list[-1]), ax[1], fig)


ax[2].set_title('lr')
e_pop_freq_lr = get_e_pop_freq(lr_list[-1])
plot_co_clustering_mat(get_cosine_sim_from_vb_free(lr_list[-1]), ax[2], fig)


### Trace plot of how first row of this similarity matrix changes across epsilon

In [None]:
# first dimension is of length epsilon_vec
cos_sim_refit_array = get_all_stats(vb_refit_list, get_cosine_sim_from_vb_free)
cos_sim_lr_array = get_all_stats(lr_list, get_cosine_sim_from_vb_free)

In [None]:
coclust_row = 0

In [None]:
fig, ax = plt.subplots(5, 4, figsize = (18, 15))

k_approx = e_ind_admix0.shape[1]

for k in range(k_approx): 
    x0 = k // 4
    x1 = k % 4
    
    ax[x0, x1].plot(epsilon_vec, 
                    cos_sim_refit_array[:, k, coclust_row], 
                    '-x', 
                    color = refit_color, 
                    label = 'refit')
    
    ax[x0, x1].plot(epsilon_vec, 
                    cos_sim_lr_array[:, k, coclust_row], 
                    '-x', 
                    color = lr_color, 
                    label = 'lr')
    ax[x0, x1].set_title('k = {}'.format(k))
    if x0 == 4: 
        ax[x0, x1].set_xlabel('epsilon')
    if x1 == 0: 
        ax[x0, x1].set_ylabel('Expected n.ind')
        
    if k == 0: 
        ax[x0, x1].legend()

ax[0, 0].legend()
fig.tight_layout(); 