In [None]:
import torch
import scipy.io
import numpy as np
from tqdm.notebook import tqdm, trange
import pickle
import pandas as pd

from iGEM import iGEM
from iGEM_fixed_w import iGEM_fixed_w

import seaborn as sns; sns.set(color_codes=True)

from scipy.stats.stats import pearsonr

In [None]:
def gene_topic(beta, gene_name_path, save_name, top_gene_num):
    phi = beta
    phi = pd.DataFrame(phi)
    try:
        if gene_name_path == 'mRNA':
            gene_name = pd.DataFrame(dataset['mRNA_ref'])
        elif gene_name_path == 'methylation':
            gene_name = pd.DataFrame(dataset['methylation_ref']).astype(str)
        elif gene_name_path == 'geneset':
            gene_name = pd.DataFrame(dataset['geneset_ref'])
        else:
            try:
                gene_name = pd.read_csv(gene_name_path, sep = ',', dtype='str', header = None)
            except:
                gene_name = gene_name_path
    except:
        gene_name = gene_name_path
    gene_name.columns = ['Name']
    phi.columns = [f'{i}' for i in range(len(phi.columns))]
    phi = phi.astype(float)
    phi = pd.concat([gene_name['Name'], phi], axis = 1)
    dataset_lst = [phi]
    final_phi_lst = []
    for data in dataset_lst:
        final_phi = data.copy()
        final_phi = final_phi.sort_values(by=['0'], ascending = False)
        final_phi = final_phi.iloc[0:top_gene_num, :]
        for i in trange(len(data.columns)-2):
            tmp_phi = data.sort_values(by=[f'{i+1}'], ascending = False)
            final_phi = pd.concat([final_phi, tmp_phi.iloc[0:top_gene_num, :]])
        final_phi_lst.append(final_phi)
    for i in range(1):
        data = final_phi_lst[i]
        data = data.set_index('Name')
        #for names that are too long
        data.index = [name[:40]+"..." if len(name)>40 else name for name in list(data.index)]
        data.columns = [int(i)+1 for i in list(data.columns)]
        sns.set(font_scale=1.4)
        #dendrogram_ratio for controlling space around
        g = sns.clustermap(data, figsize=(20, 3*params_dict['k']*top_gene_num/10), row_cluster = False, col_cluster = False, cmap = "RdBu_r", 
                           xticklabels=True, yticklabels=True, center=0, cbar_pos=(0, .2, .03, .4),
                          dendrogram_ratio=(0.075,0.001))
        g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_ymajorticklabels(), fontsize = 18)
        g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xmajorticklabels(), fontsize = 18)
        g.ax_heatmap.hlines([(i+1)*top_gene_num for i in range(params_dict['k'])], *g.ax_heatmap.get_xlim())
        #g.savefig(save_name + '_feature_topic.pdf')
        
def normalize_sum(x):
    x_normed = x / x.sum(0, keepdim=True)[0]
    x_normed[x_normed != x_normed] = 0
    return x_normed

In [None]:
data_path = f'processed_dataset.pkl'
with open(data_path, 'rb') as handle:
    dataset = pickle.load(handle)

#mouse rnaseq single section
#params_dict = set_model_param()
params_dict = {}
params_dict['rho1'] = dataset['rho1'].to(torch.float)
#params_dict['rho2'] = mouse_omic['rho2'].to(torch.float)
true_rho = params_dict['rho1']
X1 = dataset['mRNA'].to(torch.float)
X2 = dataset['methylation'].to(torch.float)

mRNA = dataset['mRNA'].detach().numpy()
met = dataset['methylation'].detach().numpy()

#define mode flags
params_dict['omic_num'] = 2
params_dict['use_alpha'] = [True, False, False]
params_dict['use_poisson'] = [False, False, False]

params_dict['k'] = 10
params_dict['max_iter'] = 3000
params_dict['tol'] = 1e-8
params_dict['pi_aa'] = 0
params_dict['pi_ab'] = 0.01
params_dict['pi_bb'] = 0.01
params_dict['pi_xr'] = 2e-6
params_dict['pi_xc'] = 0.2
params_dict['xc_h2_coef'] = 0.01
params_dict['xc_alpha1_coef'] = 0.05
params_dict['prev_model'] = None

#mode week8 NRES tf-gene 0.4/0.4
aa_path = 'SNMNMF/rnaseq_interaction_sparse_matrix_deseq1639.pkl'
aa_path = None
#ab_path = 'SNMNMF/bader_met_embedding_omic_data_0610.pkl'
ab_path = None
bb_path = 'SNMNMF/rnaseq_interaction_sparse_matrix_deseq5000.pkl'
bb_path = None
if aa_path == None:
    aa = torch.zeros(X1.shape[1], X1.shape[1])
else:
    with open(aa_path, 'rb') as handle:
        aa = pickle.load(handle)
        aa = aa.to_dense()
if ab_path == None:
    ab = torch.zeros(X1.shape[1], X2.shape[1])
else:
    with open(ab_path, 'rb') as handle:
        ab = pickle.load(handle)
        ab = ab.to_dense()
if bb_path == None:
    bb = torch.zeros(X2.shape[1], X2.shape[1])
else:
    with open(bb_path, 'rb') as handle:
        bb = pickle.load(handle)
        bb = bb.to_dense()
model = iGEM(X1,X2,aa,ab,bb, params_dict)
model.train()
gene_topic(normalize_sum(model.H1.t().detach()).numpy(), 'mRNA', 'test', 10)

In [None]:
gene_topic(normalize_sum(model.alpha1.t().detach()).numpy(), 'geneset', 'test', 10)

In [None]:
gene_topic(normalize_sum(model.H1.t().detach()).numpy(), 'mRNA', 'test', 10)

In [None]:
gene_topic(normalize_sum(model.alpha1.t().detach()).numpy(), 'geneset', 'test', 5)

In [None]:
#test section
params_dict = {}
source_w = torch.rand(111, 10)
source_rho = np.mod(np.random.permutation(100*100).reshape(100,100),2)
source_rho = scipy.sparse.random(100,100, density=0.1, random_state=2020, data_rvs=np.ones)
source_rho = source_rho.A
source_rho = torch.from_numpy(source_rho).float()
params_dict['rho1'] = source_rho
source_alpha = scipy.sparse.random(source_w.shape[1], source_rho.shape[0], density=0.1, random_state=2020)
source_alpha = source_alpha.A
source_alpha = torch.from_numpy(source_alpha).float()
source_X = torch.mm(torch.mm(source_w, source_alpha), source_rho)
X1 = source_X
X2 = source_X

fixed_model_w = source_w

#define mode flags
params_dict['omic_num'] = 1
params_dict['use_alpha'] = [True, False, True]
params_dict['use_poisson'] = [False, False, False]

params_dict['k'] = 10
params_dict['max_iter'] = 500
params_dict['tol'] = 1e-8
params_dict['pi_aa'] = 0
params_dict['pi_ab'] = 0.01
params_dict['pi_bb'] = 0.01
params_dict['pi_xr'] = 2e-6
params_dict['pi_xc'] = 0.2
params_dict['xc_h2_coef'] = 0.01
params_dict['xc_alpha1_coef'] = 0.05

params_dict['prev_model'] = None

aa_path = None
ab_path = None
bb_path = None
if aa_path == None:
    aa = torch.zeros(X1.shape[1], X1.shape[1])
else:
    with open(aa_path, 'rb') as handle:
        aa = pickle.load(handle)
        aa = aa.to_dense()
if ab_path == None:
    ab = torch.zeros(X1.shape[1], X2.shape[1])
else:
    with open(ab_path, 'rb') as handle:
        ab = pickle.load(handle)
        ab = ab.to_dense()
if bb_path == None:
    bb = torch.zeros(X2.shape[1], X2.shape[1])
else:
    with open(bb_path, 'rb') as handle:
        bb = pickle.load(handle)
        bb = bb.to_dense()

model = iGEM_fixed_w(X1,X2,aa,ab,bb, params_dict, fixed_model_w)
model.train()

In [None]:
alpha1_source = source_alpha.t().detach().numpy()
alpha1_true = model.alpha1.t().detach().numpy()
g_source = sns.clustermap(alpha1_source, figsize=(6, 8), row_cluster = False, col_cluster = False, cmap = "RdBu_r", xticklabels=True, yticklabels=False, center=0)
g_true = sns.clustermap(alpha1_true, figsize=(6, 8), row_cluster = False, col_cluster = False, cmap = "RdBu_r", xticklabels=True, yticklabels=False, center=0)

In [None]:
#derived score correlates well with source score
source_matrix = alpha1_source
true_matrix = alpha1_true
score = []
for i in trange(source_matrix.shape[1]):
    for j in range(true_matrix.shape[1]):
        stat, p = pearsonr(source_matrix[:,i], true_matrix[:,j])
        score.append([i, j, stat, p])
score = pd.DataFrame(score, columns = ['topic1', 'topic2', 'coef', 'pval'])
score = score.sort_values(by=['coef'], ascending=False).reset_index(drop=True)
score.iloc[:15]