In [1]:
import pandas as pd
import os
import pickle as pkl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import sys
sys.path.append(os.path.abspath(".."))
import seaborn as sns

try:
    import pubchempy as pcp
    from rdkit import Chem
    from rdkit.Chem import Draw
    import ete3
    import ete4
except:
    pass
import scipy
from scipy.spatial.distance import cityblock, squareform, pdist
from matplotlib.colors import LogNorm, Normalize
from matplotlib.backends.backend_pdf import PdfPages
from scipy.special import expit
from plot_ete_tree import *
from plot_results import *
import torch
from model_helper import *
import re
import torch.nn.functional as F

In [2]:
# Get fold change of rules
# 1. For each dataset, find seed with lowest loss
# 2. For given seed, find rule (s) (seq + metab rule) with highest odds
# 3. Find fold change between cases and control for given rule

In [3]:
def convert_to_log_normal(mu,var):
    var2 = np.log(var/(mu**2) + 1)
    mu2 = np.log(mu) - var2/2
    return mu2, var2

In [2]:
path = '/Users/jendawk/logs/mditre-real-june6/'
rd ={'fold change':{},'var 0':{},'var 1':{}}
rd={}
for root, dirs, files in os.walk(path):
    if 'results_last.csv' in files:
        # For each dataset, find seed with lowest loss

        df = pd.read_csv(os.path.join(root, 'results_last.csv'), index_col=0).drop(['Mean','StDev','Median','25% Quantile','75% Quantile'])
        ls=df.index.values[df["Total Loss"]==df["Total Loss"].min()]
        # if len(ls)==1:
        #     sst=ls[0]
        # else:
        #     sst = np.random.choice(ls, 1)[0]
        for sst in ls:
            sst = re.findall(r'\d+', sst)[0]
    
            # for sst in [1,2,3,4,5,6,7,8,9,10]:
            seed_path = os.path.join(root, 'seed_'+sst+'/EVAL/')
            # Get rule(s) with highest odds
            try:
                rules = pd.read_csv(os.path.join(seed_path, 'rules.csv'), index_col=[0,1])
            except:
                continue
            highest_odds = rules.index.values[rules['Rule Log Odds'].abs()==rules['Rule Log Odds'].abs().max()]
            # highest_odds = rules.index.values
            with open(seed_path + 'plotting_data/detector_params.pkl','rb') as f:
                det_params = pkl.load(f)
    
            data = pd.read_pickle(seed_path.replace('EVAL/','dataset_used.pkl'))
            for h in highest_odds:
                if 'pubchem' in seed_path:
                    dtype='metabs'
                elif '_ra/' in seed_path or '_cts/' in seed_path:
                    dtype='taxa'
                # elif 'pubchem' in seed_path:
                #     dtype='metabs'
                else:
                    continue
                if (dtype=='metabs' and 'seqs' in root):
                    continue
                
                # for k in rd.keys():
                # if dtype not in rd.keys():
                #     rd[dtype]={}
                # if root not in rd[dtype].keys():
                #     rd[dtype][root]={'fold change':[],'var 0':[],'var 1':[]}
                rule = int(re.findall(r'\d+', h[0])[0])
                det = int(re.findall(r'\d+', h[1])[0])
    
                detd = det_params[rule][det]
                if dtype=='taxa':
                    feats = detd['features'].sum(1)
                else:
                    feats = detd['features'].mean(1)
                    # feats = data['metabs']['X'][fnames].median(1)
                if 'cdi' in root or 'eraw' in root:
                    fold_change = abs((np.mean(feats.loc[detd['y']==1])/np.mean(feats.loc[detd['y']==0])))
                    var0=np.var(feats.loc[detd['y']==0])
                    var1=np.var(feats.loc[detd['y']==1])
                else:
                    fold_change = abs((np.mean(feats.loc[detd['y']==0])/np.mean(feats.loc[detd['y']==1])))
                    var0=np.var(feats.loc[detd['y']==1])
                    var1=np.var(feats.loc[detd['y']==0])
                if fold_change < 1:
                    fold_change = 1/fold_change
                if dtype=='metabs':
                    fold_change = np.exp(fold_change)
                # if abs(fold_change)<1e-5 or abs(fold_change)>1e5 or var0>1e5 or var1>1e5:
                #     print(var0)
                #     print(var1)
                #     print(fold_change)
                #     print('')
                #     continue
                if dtype not in rd.keys():
                    rd[dtype]={}
                if seed_path not in rd[dtype].keys():
                    rd[dtype][seed_path]={'fold change':[],'var 0':[],'var 1':[]}
                rd[dtype][seed_path]['fold change'].append(fold_change)
                rd[dtype][seed_path]['var 0'].append(var0)
                rd[dtype][seed_path]['var 1'].append(var1)

            # {'fold change':fold_change,'var 0':var0, 'var 1':var1})
         

rd_new = {}
for d in rd.keys():
    if d not in rd_new.keys():
        rd_new[d]={}
    for ro in rd[d].keys():
        r = ro.split(path)[-1].split('/')[0]
        if r not in rd_new[d].keys():
            rd_new[d][r]={}
        rd_new[d][r]['fold change']=np.max(rd[d][ro]['fold change'])
        rd_new[d][r]['var 0']=rd[d][ro]['var 0'][np.argmax(rd[d][ro]['fold change'])]
        rd_new[d][r]['var 1']=rd[d][ro]['var 1'][np.argmax(rd[d][ro]['fold change'])]

In [80]:
seqs_wang = pd.read_pickle('/Users/jendawk/Dropbox (MIT)/microbes-metabolites/datasets/WANG/processed/wang_ra/seqs.pkl')
mets_wang = pd.read_pickle('/Users/jendawk/Dropbox (MIT)/microbes-metabolites/datasets/WANG/processed/wang_pubchem/mets.pkl')

In [81]:
mean_taxa = seqs_wang['X'].loc[seqs_wang['y']==0].mean(0).median()
mean_mets = mets_wang['X'].loc[seqs_wang['y']==0].mean(0).median()

In [82]:
dfm = pd.DataFrame(rd_new['metabs'])

In [83]:
dfm.sort_values(by='fold change', axis=1)

Unnamed: 0,erawijantari_pubchem,he_pubchem,cdi_pubchem,ibmdb_pubchem,wang_pubchem,franzosa_pubchem
fold change,3.617251,5.078419,10.560735,20.873106,26.67114,35.121855
var 0,0.826792,0.671964,0.952256,0.194703,1.430507e-13,0.486856
var 1,0.691695,0.078243,0.159149,0.221652,3.094257e-13,1.143771


In [84]:
dft = pd.DataFrame(rd_new['taxa'])
dft.sort_values(by='fold change', axis=1)

Unnamed: 0,franzosa_ra,ibmdb_ra,he_cts,wang_ra,cdi_cts
fold change,1.191989,1.289438,1.314907,1.660129,22.23713
var 0,0.027414,0.003368,0.033415,0.005394,0.00344
var 1,0.041582,0.013027,0.012109,0.00833,1e-06


In [37]:
# Easiest Case

In [38]:
p_taxa = pd.DataFrame(rd_new['taxa']).sort_values(by='fold change', axis=1).iloc[:,-1]
p_mets = pd.DataFrame(rd_new['metabs']).sort_values(by='fold change', axis=1).iloc[:,-1]

print('taxa')
# Control mean and var, taxa
print('controls: ', (mean_taxa, p_taxa['var 0']))

# Case mean and var, taxa 
print('cases: ', (mean_taxa*p_taxa['fold change'], p_taxa['var 1']))

print('metabs')
# Control mean and var, metabs
print('controls: ', (mean_mets, p_mets['var 0']))

# Case mean and var, metabs
print('cases: ', (mean_mets*p_mets['fold change'], p_mets['var 1']))

taxa
controls:  (0.00015121409090909, 0.0034398890991365312)
cases:  (0.0033625674508433504, 1.3933597708743422e-06)
metabs
controls:  (0.03082741400725913, 0.4868558803402221)
cases:  (1.0827159556444839, 1.143770731556085)


In [39]:
print('taxa')
# Control mean and var, taxa
print(convert_to_log_normal(mean_taxa, p_taxa['var 0']))

# Case mean and var, taxa 
print(convert_to_log_normal(mean_taxa*p_taxa['fold change'], p_taxa['var 1']))

print('metabs')
# Control mean and var, metabs
print(convert_to_log_normal(mean_mets, p_mets['var 0']))

# Case mean and var, metabs
print(convert_to_log_normal(mean_mets*p_mets['fold change'], p_mets['var 1']))

taxa
(-14.757473109665252, 11.921318409894646)
(-5.75315531366812, 0.11620967828622322)
metabs
(-6.599783307872394, 6.240864776334834)
(-0.26098508616570754, 0.6809154885889437)


In [49]:
# smallest fold change

In [50]:
p_taxa = pd.DataFrame(rd_new['taxa']).sort_values(by='fold change', axis=1).iloc[:,0]
p_mets = pd.DataFrame(rd_new['metabs']).sort_values(by='fold change', axis=1).iloc[:,0]

In [51]:
p_taxa

fold change    1.191989
var 0          0.027414
var 1          0.041582
Name: franzosa_ra, dtype: float64

In [52]:
p_mets

fold change    3.617251
var 0          0.826792
var 1          0.691695
Name: erawijantari_pubchem, dtype: float64

In [53]:
print('taxa')
# Control mean and var, taxa
print(mean_taxa, p_taxa['var 0'])

# Case mean and var, taxa 
print(mean_taxa* p_taxa['fold change'], p_taxa['var 1'])

print('metabs')
# Control mean and var, metabs
print(mean_mets, p_mets['var 0'])

# Case mean and var, metabs
print(mean_mets*p_mets['fold change'], p_mets['var 1'])

taxa
0.00015121409090909 0.027414180189987602
0.00018024558402570857 0.04158196365833808
metabs
0.03082741400725913 0.8267920298727679
0.11151048751249222 0.691695089124096


In [54]:
print('taxa')
# Control mean and var, taxa
print(convert_to_log_normal(mean_taxa, p_taxa['var 0']))

# Case mean and var, taxa 
print(convert_to_log_normal(mean_taxa*p_taxa['fold change'], p_taxa['var 1']))

print('metabs')
# Control mean and var, metabs
print(convert_to_log_normal(mean_mets, p_mets['var 0']))

# Case mean and var, metabs
print(convert_to_log_normal(mean_mets*p_mets['fold change'], p_mets['var 1']))

taxa
(-15.795280789339152, 13.996933769242444)
(-15.652336567363516, 14.062292572200601)
metabs
(-6.864175172784861, 6.7696485061597675)
(-4.2118768991171, 4.036480530039848)


In [8]:
res_dict = {'taxa':{'ctrl mea

dict_keys(['metabs'])

In [5]:
pd.DataFrame(rd_new['metabs']).quantile(0.5, axis=1)

fold change    3.777695
var 0          0.234616
var 1          0.216853
Name: 0.5, dtype: float64

In [6]:
pd.DataFrame(rd_new['taxa']).quantile(0.5, axis=1)

fold change    0.538469
var 0          0.003358
var 1          0.010197
Name: 0.5, dtype: float64

In [51]:
convert_to_log_normal(1e-2,0.009757)

(-6.900553663826765, 4.590766955677347)

In [52]:
convert_to_log_normal((1e-2)+0.540406,0.014164)

(-0.6199461073942916, 0.04569403226932391)

In [46]:
pd.DataFrame(rd_new['metabs']).quantile(0.75, axis=1)

fold change    9.832707
var 0          0.402572
var 1          0.276570
Name: 0.75, dtype: float64

In [47]:
convert_to_log_normal(1,0.402572)

(-0.16915384699770827, 0.33830769399541655)

In [48]:
convert_to_log_normal(1+9.832707,0.276570)

(2.381392947464934, 0.0023540722933171874)

In [93]:
highest_odds

array([('Rule 4', 'Detector 13'), ('Rule 4', 'Detector 19')], dtype=object)