# Figure 3 - infer interactions then plot
Katharine Z. Coyte January 2020

Infer interactions between microbes by fitting data to a generalized lotka volterra model using bayesian spike-and-slab variable selection

In [None]:
import numpy as np
import pandas as pd

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch, Circle
import matplotlib.cm as cm
from matplotlib.gridspec import GridSpec
import seaborn as sns; sns.set(color_codes=True)

import networkx as nx

#import scipy
#from scipy import stats
import statsmodels.api as sm
from sklearn.preprocessing import StandardScaler
import pymc3 as pm

import jan_miseq_functions as jnf
import infer_interactions_clean as iic
import drug_info as di

%matplotlib inline
%load_ext autoreload
sns.set_style('whitegrid')

## Load data 

In [None]:
tl='genus' 
prevalence_threshold = 30
relative_abundance_threshold = 0.005 # Note we filter out any genera with low prevalence and abundance

antibacterials, antifungals, vaccines = di.load_drug_types()
all_meds = pd.read_excel('allMeds_jan2020.xlsx')

data_bac, otu_table_bac = jnf.load_microbiome_data(file_name = '20190207_NICU_rDNA_zOTU-table_mod2.xlsx',
                         sheet_name = 'bac16S')

data_fun, otu_table_fun = jnf.load_microbiome_data(file_name = '20190207_NICU_rDNA_zOTU-table_mod2.xlsx',
                         sheet_name = 'ITS1')


## Various additional data processing - calculating geometric means, cleaning column names etc

In [None]:
### Clean and reorganize data ###
data_bacteria = jnf.process_NICU_data_for_plotting(data_bac, otu_table_bac, 'Bacteria', tl)
data_fungi = jnf.process_NICU_data_for_plotting(data_fun, otu_table_fun, 'Fungi', tl)
data = pd.concat([data_bacteria, data_fungi.drop(['babyid', 'day'],1)],1, sort=False).fillna(0)
otu_table = pd.concat([otu_table_bac, otu_table_fun])

relative_bacteria = iic.get_relative_abundances(data_bacteria)
relative_fungi = iic.get_relative_abundances(data_fungi)

my_bacteria = iic.filter_for_prevalence(relative_bacteria, prevalence_threshold, relative_abundance_threshold)
my_fungi = iic.filter_for_prevalence(relative_fungi, prevalence_threshold, relative_abundance_threshold)

current_dataset = pd.concat([data.loc[:, ['day','babyid']],
                             data[my_bacteria],
                             data[my_fungi]], 1, sort=False)
current_dataset.day = current_dataset.day.astype(float)


### Calculate geometric means and dxdt ###
all_geo_mean_df, all_dlog_dt_df =  iic.calculate_dxdt(current_dataset,
                                                      all_meds,
                                                      0,
                                                     antibacterials,
                                                     antifungals)
all_Y = all_dlog_dt_df.dropna().drop(['babyid','day'],1)
all_X = all_geo_mean_df
all_X = all_X.loc[all_Y.index,:]
all_X = iic.clean_column_names(all_X)
all_Y = iic.clean_column_names(all_Y)


### Remove brackets, drop any antibiotics that occur fewer than 5 times ###
all_X = all_X.rename(columns={'Zosyn_(Piperacillin/tazobactam)':'Zosyn'})
copy_all_X = all_X.copy()
for col in all_X.columns:
    if len(all_X.loc[all_X[col]>0,:]) < 5:
        copy_all_X = copy_all_X.drop(col,1) 
all_X = copy_all_X


### Rename for fitting ###
my_bacteria = ['Escherichia_Shigella' if x=='Escherichia-Shigella' else x for x in my_bacteria]
my_bacteria = ['Clostridium_sensu_stricto_1' if x=='Clostridium sensu stricto 1' else x for x in my_bacteria]
my_bacteria = ['Clostridiaceae_1' if x=='Clostridiaceae 1' else x for x in my_bacteria]

my_list = ['Escherichia_Shigella',
           'Klebsiella',
           'Staphylococcus',
           'Enterococcus',
           'Candida']

## Infer microbe microbe interactions

Use spike-and-slab variable selection to identify interactions most consistently apparent in our data


In [None]:
save_all_X = all_X.copy()
save_all_Y = all_Y.copy()

store_all_interactions = pd.DataFrame()
store_intercepts = pd.DataFrame()

tmp = save_all_X.copy()
tmp[tmp>0]=1

for focal_species in all_Y.columns:
    
    my_all_X = save_all_X.copy()
    my_all_Y = save_all_Y.copy()

    # Drop samples without focal species        
    my_all_X = my_all_X.drop(my_all_X.loc[my_all_X[focal_species] ==0,:].index)
    my_all_Y = my_all_Y.loc[my_all_X.index,:].copy()
        
    if len(my_all_Y) > 10:

        # Standardize all X variables
        tmp = StandardScaler().fit_transform(my_all_X)
        my_all_X = pd.DataFrame(tmp, index = my_all_X.index, columns = my_all_X.columns)
        
        
        model = pm.Model()
        X_taxa = my_all_X.copy()[my_all_Y.columns]
        X_drugs = my_all_X.copy().drop(X_taxa.columns, 1)
               
        X_taxa = X_taxa.drop(X_taxa.loc[:,X_taxa.sum()==0].columns,1)
        X_drugs = X_drugs.drop(X_drugs.loc[:,X_drugs.sum()==0].columns,1)

        y = my_all_Y[focal_species].copy()

        
        # Set up priors for model
        
        Sigma_taxa = .5 * np.matmul(X_taxa.T.values, X_taxa.values)
        Sigma_taxa += np.diag(np.diag(Sigma_taxa))
        Sigma_taxa = np.linalg.inv(Sigma_taxa)
        #Sigma_taxa = np.identity(len(Sigma_taxa)) # alternatively use identity matrix (results same)
        
        Sigma_drugs = .5 * np.matmul(X_drugs.T.values, X_drugs.values)
        Sigma_drugs += np.diag(np.diag(Sigma_drugs))
        Sigma_drugs = np.linalg.inv(Sigma_drugs)
        #Sigma_drugs = np.identity(len(Sigma_drugs)) # alternatively use identity matrix (results same)
        
        # For calculating growth rate initialisation
        X_growth = my_all_X.copy()
        X_growth = sm.add_constant(X_growth)
        Y = my_all_Y[focal_species].copy()
        r_model = sm.OLS(Y,X_growth)
    
        results = r_model.fit()
        init_r = results.params['const']
        if init_r < 0:
            init_r = 0
        init_r_std= np.std(results.params) # inflate initial prior for intercept
    
        
        with pm.Model() as model:
            
            xi_taxa = pm.Bernoulli('xi_taxa', .5, shape=X_taxa.shape[1])
            tau_taxa = pm.HalfCauchy('tau_taxa', 1)
            beta_taxa = pm.MvNormal('beta_taxa', 0, tau_taxa * Sigma_taxa, shape=X_taxa.shape[1])
            mean_taxa = pm.math.dot(X_taxa, xi_taxa * beta_taxa)
            
            xi_drugs = pm.Bernoulli('xi_drugs', .5, shape=X_drugs.shape[1])
            tau_drugs = pm.HalfCauchy('tau_drugs', 1)
            beta_drugs = pm.MvNormal('beta_drugs', 0, tau_drugs * Sigma_drugs, shape=X_drugs.shape[1])
            mean_drugs = pm.math.dot(X_drugs, xi_drugs * beta_drugs)
            
            my_sigma = pm.HalfNormal('my_sigma', 10) 

            
            intercp = pm.Bound(pm.Normal, lower=0.0)('intercp', mu=1.0, tau=(init_r_std**2)*1e2)
            
            my_var = pm.Normal('my_var', mean_drugs + mean_taxa + intercp, my_sigma, observed=y)

            lasso_normal_trace_s = pm.sample(draws =10000,tune=2500, init='adapt_diag', cores=-1)



        my_summary = pm.summary(lasso_normal_trace_s)
        
        store_intercepts.loc[focal_species, 'ic'] = my_summary.loc['intercp','mean']
        
        new_summary = pd.DataFrame()
        for ix, val in enumerate(X_taxa.columns):
            cur_index = 'beta_taxa__' + str(ix)
            cur_mean = my_summary.loc[cur_index, 'mean']
            cur_sd = my_summary.loc[cur_index, 'sd']
            try:
                rename_val = val + '_' + otu_table.loc[val, 'family']
            except:
                rename_val = val
            new_summary.loc[rename_val, 'cur_mean'] = cur_mean
            new_summary.loc[rename_val, 'cur_sd'] = cur_sd
            store_all_interactions.loc[focal_species, val] = cur_mean



        for ix, val in enumerate(X_drugs.columns):
            cur_index = 'beta_drugs__' + str(ix)
            cur_mean = my_summary.loc[cur_index, 'mean']
            cur_sd = my_summary.loc[cur_index, 'sd']
            try:
                rename_val = val + '_' + otu_table.loc[val, 'family']
            except:
                rename_val = val
            new_summary.loc[rename_val, 'cur_mean'] = cur_mean
            new_summary.loc[rename_val, 'cur_sd'] = cur_sd
            store_all_interactions.loc[focal_species, val] = cur_mean


        try:
            print(focal_species, otu_table.loc[focal_species, 'family'])
        except:
            print(focal_species)
        sns.set_style('white')
        plt.figure(figsize=(10,10))

        plt.scatter(x = new_summary.cur_mean, y = new_summary.index, c='orange')
        plt.barh(y = new_summary.index,
                width = new_summary.cur_mean,
                height=0.0,
                xerr= new_summary.cur_sd)
        plt.plot([0,0],[-0.5,len(new_summary)+0.0],'gray')
        plt.show()


## Plot interaction network 

In [None]:
f, ax = plt.subplots(1,figsize=(8,6.5))

### Split interactions into positive and negative for plotting ###
cur_matrix = store_all_interactions.copy()
cur_matrix[abs(cur_matrix) <0.01]=0 # filter out very weak interactions
cur_matrix = cur_matrix.loc[my_list, my_list].T
cm_pos = cur_matrix.copy()
cm_pos[cm_pos<0]=0
cm_neg = cur_matrix.copy()
cm_neg[cm_neg>0]=0


### Plot positive edges ###
G=nx.from_numpy_matrix(np.array(cm_pos),
                       create_using=nx.MultiDiGraph())
pos = nx.layout.circular_layout(G)
pos[4] = (pos[4][0] + 0.08, pos[4][1]-0.7)

M = G.number_of_edges()
edge_colors = np.zeros(M)
for ix, e in enumerate(G.edges(data=True)):
    edge_colors[ix] = abs(e[2]['weight'])

for cix, cn in enumerate(G.edges()):
    e = FancyArrowPatch(pos[cn[0]],
                        pos[cn[1]],
                            arrowstyle='-|>',
                            connectionstyle='arc3,rad=0.25',
                            mutation_scale=35.0,
                            lw=7.5,
                            color=cm.Blues(edge_colors[cix]*5),
                            shrinkA=25,
                            shrinkB=25)
    ax.add_patch(e)

    
### Plot negative edges ###
G=nx.from_numpy_matrix(np.array(cm_neg),
                       create_using=nx.MultiDiGraph())

M = G.number_of_edges()
edge_colors = np.zeros(M)
for ix, e in enumerate(G.edges(data=True)):
    edge_colors[ix] = abs(e[2]['weight'])

for cix, cn in enumerate(G.edges()):
    e = FancyArrowPatch(pos[cn[0]],
                        pos[cn[1]],
                            arrowstyle='-[',
                            connectionstyle='arc3,rad=0.25',
                            mutation_scale=12.0,
                            lw=7.5,
                            color=cm.Reds(edge_colors[cix]*8),
                            shrinkA=25,
                            shrinkB=25)
    ax.add_patch(e)


### Plot nodes ###
node_colors = ['red', 'dodgerblue', 'gold', 'green', 'brown']
nx.draw_networkx_nodes(G,pos,nodelist=range(4),node_shape='o', node_size=250, node_color= node_colors[0:4])
nx.draw_networkx_nodes(G,pos,nodelist=[4],node_shape='^', node_size=250, node_color=['brown'])
    
    
### Fix axis ###
ax.set_xlim([-1.1,1.1])
ax.set_ylim([-1.1,1.1])
plt.axis('off')
plot_margin = 0.5
x0, x1, y0, y1 = plt.axis()
plt.axis((x0 - plot_margin,
          x1 + plot_margin,
          y0 -plot_margin-0.3,
          y1 + 0.09))

plt.show()

## Generate histogram of antibiotic and antifungal interactions, split by kingdom / drug type


In [None]:
# Get antimicrobial interactions
abx_matrix = store_all_interactions.copy()
abx_matrix = abx_matrix.drop(my_bacteria,1)
abx_matrix = abx_matrix.drop(my_fungi,1).fillna(0)
abx_matrix[abs(abx_matrix) <0.01]=0

## Split up antimicrobial matrix
# bacteria / antibiotics
bac_abx = abx_matrix.loc[my_bacteria,antibacterials].copy().dropna(1)
store_abx_tmp = pd.DataFrame()
for col in bac_abx.columns:
    tmp = bac_abx.loc[:,col]
    store_abx_tmp = pd.concat([store_abx_tmp, pd.DataFrame(tmp[tmp!=0].get_values())])
store_abx_tmp = store_abx_tmp.reset_index().drop(['index'],1)

# bacteria / antifungals
bac_anf = abx_matrix.loc[my_bacteria,antifungals].copy().dropna(1)
store_anf_tmp = pd.DataFrame()
for col in bac_anf.columns:
    tmp = bac_anf.loc[:,col]
    store_anf_tmp = pd.concat([store_anf_tmp, pd.DataFrame(tmp[tmp!=0].get_values())])
store_anf_tmp = store_anf_tmp.reset_index().drop(['index'],1)

# fungi / antibiotics
fun_abx = abx_matrix.loc[my_fungi,antibacterials].copy().drop(['Aspergillus'])
fun_abx =fun_abx.dropna(1)
store_abx_tmp_fun = pd.DataFrame()
for col in fun_abx.columns:
    tmp = fun_abx.loc[:,col]
    store_abx_tmp_fun = pd.concat([store_abx_tmp_fun, pd.DataFrame(tmp[tmp!=0].get_values())])
store_abx_tmp_fun = store_abx_tmp_fun.reset_index().drop(['index'],1)

# fungi / antifungals
fun_anf = abx_matrix.loc[my_fungi,antifungals].copy().drop(['Aspergillus'])
fun_anf =fun_anf.dropna(1)
store_anf_tmp_fun = pd.DataFrame()
for col in fun_anf.columns:
    tmp = fun_anf.loc[:,col]
    store_anf_tmp_fun = pd.concat([store_anf_tmp_fun, pd.DataFrame(tmp[tmp!=0].get_values())])
store_anf_tmp_fun = store_anf_tmp_fun.reset_index().drop(['index'],1)


# Plot separate histograms
fig, ax = plt.subplots(2,1, figsize=(5,5))

thresh_val=0.05
store_abx_tmp[store_abx_tmp<-thresh_val]=-thresh_val
store_abx_tmp[store_abx_tmp>thresh_val]=thresh_val

store_anf_tmp[store_anf_tmp<-thresh_val]=-thresh_val
store_anf_tmp[store_anf_tmp>thresh_val]=thresh_val

store_abx_tmp_fun[store_abx_tmp_fun<-thresh_val]=-thresh_val
store_abx_tmp_fun[store_abx_tmp_fun>thresh_val]=thresh_val

store_anf_tmp_fun[store_anf_tmp_fun<-thresh_val]=-thresh_val
store_anf_tmp_fun[store_anf_tmp_fun>thresh_val]=thresh_val


my_bins=np.linspace(-3,3,51)*0.1
ax[0].hist(store_abx_tmp.values,
         density=False,
         bins=my_bins,
        histtype='stepfilled',
          color=[179/255,129/255,176/255])

ax0_2 = ax[0].twinx() 
ax[0].hist(store_anf_tmp.values,
         density=False,
         bins=my_bins,
        alpha=0.75,
        histtype='stepfilled',
          color=[182/255,212/255,146/255])
ax[0].plot([0,0],[0,6],'k--')
ax[0].set_xlim([-0.075,0.075])

ax[1].hist(store_abx_tmp_fun.values,
         density=False,
         bins=my_bins,
        histtype='stepfilled',
          color=[179/255,129/255,176/255])
ax1_2 = ax[1].twinx() 

ax[1].hist(store_anf_tmp_fun.values,
         density=False,
         bins=my_bins,
        alpha=0.75,
        histtype='stepfilled',
        color=[182/255,212/255,146/255])
ax[1].plot([0,0],[0,6],'k--')
ax[1].set_xlim([-0.075,0.075])

plt.show()

## Calculate proportion of different interaction types

In [None]:
# Get just taxa interactions
cur_matrix = store_all_interactions.copy()
cur_matrix = cur_matrix.drop(abx_matrix.columns, 1)
cur_matrix = cur_matrix.loc[:, cur_matrix.index]
cur_matrix[abs(cur_matrix) <0.01]=0


# Go through and count interactions
count_ints = pd.DataFrame(0, index = ['interaction'], columns = ['-/-', '-/o', '+/-', '+/o', '+/+'])
for ix_sp in cur_matrix.index:
    for col_sp in cur_matrix.columns:
        
        if col_sp != ix_sp:
        
            ci_focal = cur_matrix.loc[ix_sp, col_sp]
            ci_pair = cur_matrix.loc[col_sp, ix_sp]

            if ci_focal > 0:
                if ci_pair > 0:
                    count_ints.loc['interaction', '+/+'] = count_ints.loc['interaction', '+/+'] + 1 # pp
                if ci_pair < 0:
                    count_ints.loc['interaction', '+/-'] = count_ints.loc['interaction', '+/-'] + 1 # pm
                if ci_pair == 0:
                    count_ints.loc['interaction', '+/o'] = count_ints.loc['interaction', '+/o'] + 1 # po

            if ci_focal < 0:
                if ci_pair > 0:
                    count_ints.loc['interaction', '+/-'] = count_ints.loc['interaction', '+/-'] + 1# mp
                if ci_pair < 0:
                    count_ints.loc['interaction', '-/-'] = count_ints.loc['interaction', '-/-'] + 1# mm
                if ci_pair == 0:
                    count_ints.loc['interaction', '-/o'] = count_ints.loc['interaction', '-/o'] + 1# mo

            if ci_focal == 0:
                if ci_pair > 0:
                    count_ints.loc['interaction', '+/o'] = count_ints.loc['interaction', '+/o'] + 1# op
                if ci_pair < 0:
                    count_ints.loc['interaction', '-/o'] = count_ints.loc['interaction', '-/o'] + 1# om 
                
        
# Plot result
fig, ax = plt.subplots(1,1, figsize=(2,5))
colors = ['darkred', 'tomato', 'khaki', 'royalblue', 'darkblue']
count_ints=count_ints/count_ints.sum(1).values[0]
count_ints.plot(kind='bar',
                stacked=True,
                colors = colors,
                ax=ax)  
ax.set_ylim(0,1)
ax.legend(bbox_to_anchor=(01.6, 1.00),
          ncol=1, fancybox=True)
plt.show()

## Plot matrices of all interaction types

In [None]:
fig = plt.figure(constrained_layout=True, figsize=(11.5,7.5))

gs = GridSpec(1, 8, figure=fig)
ax1 = fig.add_subplot(gs[0, 0:5])
ax2 = fig.add_subplot(gs[0, 5:8])#, sharey=ax1)

clusters =  sns.heatmap(cur_matrix.fillna(0),
               cmap='RdBu',
               center=0,
               vmin=-0.05,
               vmax=0.05,
              square=False, ax=ax1,cbar=True,yticklabels=True,linewidths=0.05,
                  linecolor = 'darkslategray',)

clusters =  sns.heatmap(abx_matrix.fillna(0),
               cmap='RdBu',
               center=0,
               vmin=-0.05,
               vmax=0.05,
              square=False,
                ax=ax2,
              cbar=True,
                yticklabels=False,
                       linewidths=0.05,
                  linecolor = 'darkslategray',)
ax1.get_shared_y_axes().join(ax2)
plt.show()

