## Figure 2
Katharine Z. Coyte January 2020

Code to analyse and plot data for figure 2: 
Load data, calculate total bacterial loads, fit linear mixed effects model, perform PCoA, plot all results

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns; sns.set(color_codes=True)

import matplotlib.pyplot as plt
from matplotlib import gridspec
from matplotlib.lines import Line2D
from matplotlib.ticker import NullFormatter

from scipy import stats
from scipy.stats import sem
import statsmodels.api as sm
import statsmodels.formula.api as smf
import sklearn
from sklearn.preprocessing import StandardScaler
from sklearn import manifold, datasets
from sklearn.cluster import DBSCAN
from sklearn.datasets import make_blobs
from sklearn import decomposition
from sklearn.metrics import euclidean_distances


import microbiome_data_processing_functions as mdpf
import baby_color_tables as bct
import august_nexseq_functions as anf
import jan_miseq_functions as jnf
import drug_info as di

%matplotlib inline
%load_ext autoreload
sns.set_style('white')

# Load and process microbiome data - first load next seq OTU

In [None]:
%autoreload
data_bac, otu_table_bac = jnf.load_microbiome_data(file_name = './20190307_CR_NICU/201903_NICU_combined_bac16SV4_L150_otu-table.xlsx',
                         sheet_name = 'bac16S_L150_mod3', is_nextseq=1)

data_fun, otu_table_fun = jnf.load_microbiome_data(file_name = './20190307_CR_NICU/201903_NICU_combined_ITS1_L100_otu-table.xlsx',
                         sheet_name = 'ITS1_L100_mod3', is_nextseq=1)

tl = 'otu'

data_bacteria = jnf.process_NICU_data_for_plotting(data_bac, otu_table_bac, 'Bacteria(1.0000)', tl)
data_fungi = jnf.process_NICU_data_for_plotting(data_fun, otu_table_fun, 'Fungi(1.0000)', tl)

data = pd.concat([data_bacteria, data_fungi.drop(['babyid', 'day'],1)],1, sort=False).fillna(0)
otu_table = pd.concat([otu_table_bac, otu_table_fun])
data = data.iloc[:-2,:]


data_bacteria = data_bacteria.drop(data_bacteria.loc[data_bacteria.day.str.contains('re', regex=False),:].index)
data_fungi = data_fungi.drop(data_fungi.loc[data_fungi.day.str.contains('re', regex=False),:].index)
data = data.drop(data.loc[data.day.str.contains('re', regex=False),:].index)

# Strip confidence values
for col in otu_table.columns:
    otu_table[col] = otu_table[col].map(lambda x: x.rstrip('(1.023456789)'))
    



# Load metadata, drop any topical antibiotics / antifungals

In [None]:
# Read in clinical metadata

antibacterials, antifungals, vaccines = di.load_drug_types()
all_meds = pd.read_excel('allMeds.xlsx')
all_weights = pd.read_excel('daily_weights.xlsx')
all_baby_info = pd.read_csv('all_baby_info_oct2.csv')

subset_baby_info = pd.read_csv('baby_info.csv')
delivery_df = subset_baby_info[['delivery', 'baby_id']].set_index('baby_id').rename(columns={"delivery": "val"})

# Drop topical antibiotics (eye and skin medication)
antibacterials, antifungals, vaccines = di.load_drug_types()
tmp = all_meds.loc[all_meds.Med=='Erythromycin',:]
all_meds = all_meds.drop(tmp.loc[tmp.pn_day==0,:].index)
tmp = all_meds.loc[all_meds.Med=='Erythromycin',:]
all_meds = all_meds.drop(tmp.loc[tmp.pn_day==1,:].index)
tmp = all_meds.loc[all_meds.Med=='Erythromycin',:]
all_meds = all_meds.drop(tmp.loc[tmp.pn_day==2,:].index)
tmp = all_meds.loc[all_meds.Med=='Erythromycin ophthalmic',:]
all_meds = all_meds.drop(tmp.index)
antibacterials.remove('Bacitracin')
antibacterials.remove('Gentamicin ophthalmic')
antibacterials.remove('Erythromycin ophthalmic')
antifungals.remove('Miconazole Powder')
antifungals.remove('Nystatin Ointment')
antifungals.remove('Nystatin')
antibacterials.remove('Mupirocin (Bactroban)')


# Calculate sample diversity and total bacteria / fungal loads

In [None]:
# Calculate inverse simpson for diversity
inv_simpson_df=pd.DataFrame()
for ix in data_bacteria.index:
    inv_simpson_df.loc[ix, 'inv_s'] = jnf.inverse_simpson_di(dict(data_bacteria.loc[ix,data_bacteria.columns[2]:]))

f_inv_simpson_df=pd.DataFrame()
for ix in data_fungi.index:
    f_inv_simpson_df.loc[ix, 'inv_s_f'] = jnf.inverse_simpson_di(dict(data_fungi.loc[ix,data_fungi.columns[2]:]))
    
inv_simpson_df = pd.concat([data_bacteria.iloc[:,:2], inv_simpson_df, f_inv_simpson_df],1)



# calculate total fungal load, replacing zero values with limit of detection
use_logged = 1
if use_logged:
    tb = data_bacteria.iloc[:,2:].sum(1)
    tb[tb<10**3] = 10**3 
    
    tf = data_fungi.iloc[:,2:].sum(1)
    tf[tf<1] = 1
    
    total_bacteria = pd.DataFrame(np.log10(tb), columns=['total_bacteria'])
    total_fungi = pd.DataFrame(np.log10(tf), columns=['total_fungi'])
    inv_simpson_df.loc[tf[tf==0].index,'inv_s_f']=0.0
else:
    total_bacteria = pd.DataFrame(data_bacteria.iloc[:,2:].sum(1), columns=['total_bacteria'])
    total_fungi = pd.DataFrame(data_fungi.iloc[:,2:].sum(1), columns=['total_fungi'])
data_for_mixed_model = pd.concat([inv_simpson_df, total_bacteria, total_fungi],1, sort=False).dropna()

# Strip repeated measurements to prevent counting same sample twice
data_for_mixed_model = data_for_mixed_model.drop(data_for_mixed_model.loc[data_for_mixed_model.day.str.contains('re', regex=False),:].index)

## Calculate antibiotics and antifungals per day 

Here we are counting the number of unique antimicrobial agents administered each day

In [None]:
store_abx=pd.DataFrame(columns=antibacterials)
store_fung=pd.DataFrame(columns=antifungals)

for baby in np.unique(data_for_mixed_model.babyid):
    
    # Get all drugs
    individual_weight, individual_drugs = mdpf.get_baby_weights_and_drugs(int(baby), all_meds, all_weights)
    try:
        individual_abx = individual_drugs.loc[pd.concat([pd.DataFrame(antibacterials)])[0],:]
        individual_fung = individual_drugs.loc[pd.concat([pd.DataFrame(antifungals)])[0],:]
    except:
        individual_abx = pd.DataFrame(0, index=antibacterials, columns=range(0,50))
        individual_fung = pd.DataFrame(0, index=antifungals, columns=range(0,50))

##        
    
    
    individual_abx = individual_abx.iloc[:,:int(50-0.5)]
    individual_abx = individual_abx.drop(individual_abx.loc[individual_abx.sum(1)==0,:].index)
    
    individual_fung = individual_fung.iloc[:,:int(50-0.5)]
    individual_fung = individual_fung.drop(individual_fung.loc[individual_fung.sum(1)==0,:].index)

    
##    
    
    tmp_ind_abx = pd.DataFrame(0, index=antibacterials, columns=range(0,50))
    tmp_ind_fung = pd.DataFrame(0, index=antifungals, columns=range(0,50))

    tmp_ind_fung.loc[individual_fung.index, individual_fung.columns]=individual_fung
    individual_fung = tmp_ind_fung
    
    
##    
    
    
    tmp_ind_abx.loc[individual_abx.index, individual_abx.columns]=individual_abx
    individual_abx = tmp_ind_abx
    
    tmp_ind_fung.loc[individual_fung.index, individual_fung.columns]=individual_fung
    individual_fung = tmp_ind_fung

    
##
    
    cur_days = data_for_mixed_model.loc[data_for_mixed_model.babyid==baby,:]
    
    
    for ix in cur_days.index:
        day = cur_days.loc[ix,'day']
        cur_abx = individual_abx[int(day)] 
        cur_fung = individual_fung[int(day)] 
        data_for_mixed_model.loc[ix,'total_antibiotics'] = cur_abx.sum().sum()
        data_for_mixed_model.loc[ix,'total_antifungals'] = cur_fung.sum().sum()


store_abx = store_abx.loc[:,store_abx.sum()!=0].fillna(0)
data_for_mixed_model = pd.concat([data_for_mixed_model, store_abx],1)
inv_simpson_df = pd.concat([inv_simpson_df, store_abx],1)

data_for_mixed_model = data_for_mixed_model.rename(columns={'Zosyn (Piperacillin/tazobactam)':'Zosyn'})
data_for_mixed_model.day = data_for_mixed_model.day.astype(int)


## Fit linear mixed effects model

In [None]:
tmp = StandardScaler().fit_transform(data_for_mixed_model)
my_all_X = pd.DataFrame(tmp, index = data_for_mixed_model.index, columns = data_for_mixed_model.columns)

md_fun = smf.mixedlm("total_fungi ~ total_antibiotics + total_antifungals  + total_bacteria + day",
                    my_all_X, 
                    groups=my_all_X["babyid"])

md_bac = smf.mixedlm("total_bacteria ~ total_antibiotics + total_antifungals  + total_fungi + day",
                    my_all_X, 
                    groups=my_all_X["babyid"])

md_bac_fit = md_bac.fit()
md_fun_fit = md_fun.fit()

print(md_bac_fit.summary())
print(md_fun_fit.summary())
    


## Plot results of linear mixed effects model

In [None]:
tmp = pd.concat([md_bac_fit.params[1:-1], md_bac_fit.params[1:-1] - md_bac_fit.conf_int()[0][1:-1]],1).rename(columns= {0:'my_mean', 1:'sd'})
tmp = tmp.loc[['total_fungi', 'total_antifungals', 'total_antibiotics','day'],:]
plt.figure(figsize=(7,5))
plt.barh(y = tmp.index,
        width = tmp.my_mean,
        height=0.0,
        lw=0.25,
        xerr= tmp.sd,
        align = 'center',
        )
plt.scatter(y = tmp.index, x=tmp.my_mean, s=100)
plt.plot([0,0],[-0.5,3.5],'gray',ls='--', lw=0.5)
plt.tight_layout()
plt.title('impact on log10(total bacterial density)')
#plt.savefig('combined_miseq_nextseq_mixed_effects_bacteria.pdf')
plt.show()


tmp = pd.concat([md_fun_fit.params[1:-1], md_fun_fit.params[1:-1] - md_fun_fit.conf_int()[0][1:-1]],1).rename(columns= {0:'my_mean', 1:'sd'})
tmp = tmp.loc[['total_bacteria', 'total_antifungals', 'total_antibiotics','day'],:]
plt.figure(figsize=(7,5))
plt.barh(y = tmp.index,
        width = tmp.my_mean,
        height=0.0,
        lw=0.25,
        xerr= tmp.sd,
        align = 'center',
        )
plt.scatter(y = tmp.index, x=tmp.my_mean, s=100)
plt.plot([0,0],[-0.5,3.5],'gray',ls='--', lw=0.5)
plt.tight_layout()
plt.title('impact on log10(total bacterial density)')
#plt.savefig('combined_miseq_nextseq_mixed_effects_bacteria.pdf')
plt.show()

# Now load at genus level for NMDS plot

In [None]:
%autoreload
data_bac, otu_table_bac = jnf.load_microbiome_data(file_name = './20190307_CR_NICU/201903_NICU_combined_bac16SV4_L150_otu-table.xlsx',
                         sheet_name = 'bac16S_L150_mod3', is_nextseq=1)

data_fun, otu_table_fun = jnf.load_microbiome_data(file_name = './20190307_CR_NICU/201903_NICU_combined_ITS1_L100_otu-table.xlsx',
                         sheet_name = 'ITS1_L100_mod3', is_nextseq=1)

tl = 'genus'

data_bacteria = jnf.process_NICU_data_for_plotting(data_bac, otu_table_bac, 'Bacteria(1.0000)', tl)
data_fungi = jnf.process_NICU_data_for_plotting(data_fun, otu_table_fun, 'Fungi(1.0000)', tl)

data = pd.concat([data_bacteria, data_fungi.drop(['babyid', 'day'],1)],1, sort=False).fillna(0)
otu_table = pd.concat([otu_table_bac, otu_table_fun])
data = data.iloc[:-2,:]


data_bacteria = data_bacteria.drop(data_bacteria.loc[data_bacteria.day.str.contains('re', regex=False),:].index)
data_fungi = data_fungi.drop(data_fungi.loc[data_fungi.day.str.contains('re', regex=False),:].index)
data = data.drop(data.loc[data.day.str.contains('re', regex=False),:].index)

# Strip confidence values
for col in otu_table.columns:
    otu_table[col] = otu_table[col].map(lambda x: x.rstrip('(1.023456789)'))
    
    
save_day_baby = data_bacteria[['day','babyid']].copy()
for col in data_bacteria.columns:
    new_col = col.rstrip('(1.023456789)')
    data_bacteria = data_bacteria.rename(columns={col:new_col}) 
data_bacteria = data_bacteria.groupby(by=data_bacteria.columns, axis=1).sum()
data_bacteria = data_bacteria.drop(['day','babyid'],1)
data_bacteria = pd.concat([save_day_baby, data_bacteria],1)

save_day_baby_f = data_fungi[['day','babyid']].copy()
for col in data_fungi.columns:
    new_col = col.rstrip('(1.023456789)')
    data_fungi = data_fungi.rename(columns={col:new_col})
data_fungi = data_fungi.groupby(by=data_fungi.columns, axis=1).sum()
data_fungi = data_fungi.drop(['day','babyid'],1)
data_fungi = pd.concat([save_day_baby_f, data_fungi],1)


In [None]:
# INV SIMPSON

%autoreload
inv_simpson_df_bac=pd.DataFrame()
for ix in data_bacteria.index:
    inv_simpson_df_bac.loc[ix, 'inv_s'] = jnf.inverse_simpson_di(dict(data_bacteria.loc[ix,data_bacteria.columns[2]:]))
inv_simpson_df_bac = pd.concat([data_bacteria.iloc[:,:2], inv_simpson_df_bac],1)

inv_simpson_df_fun=pd.DataFrame()
for ix in data_fungi.index:
    inv_simpson_df_fun.loc[ix, 'inv_s'] = jnf.inverse_simpson_di(dict(data_fungi.loc[ix,data_fungi.columns[2]:]))
inv_simpson_df_fun = pd.concat([data_fungi.iloc[:,:2], inv_simpson_df_fun],1)



In [None]:
age_df_bac = data_bacteria.loc[:,'day'].astype(int)      
age_df_bac = data_bacteria.loc[:,'day'].astype(int)/43
age_df_fun = data_bacteria.loc[:,'day'].astype(int)/43



# Create colortable for given taxonomic level - BACTERIA

In [None]:
%autoreload
tl2 = 'genus'
color_table_bac = bct.create_data_colortable(data_bacteria, otu_table, 'bacteria', tl2)

slim_color_table = pd.DataFrame(columns = ['c'])
copy_slim_color_table = pd.DataFrame(0, index = color_table_bac.index, columns = ['count'])

taxa_df_bac= pd.DataFrame(index = data_bacteria.index, columns = ['c'])
copy_taxa_df_bac= pd.DataFrame(index = data_bacteria.index, columns = ['c'])

for ix in data_bacteria.index:
    valval = data_bacteria.loc[ix,'Acidovorax':].values.argmax()
    cur_family = data_bacteria.columns[2+valval]
    taxa_df_bac.loc[ix, :] = color_table_bac.loc[cur_family,:].values
    copy_taxa_df_bac.loc[ix, :] = cur_family
    slim_color_table.loc[cur_family, 'c'] = color_table_bac.loc[cur_family,:].values[0]
    copy_slim_color_table.loc[cur_family, 'count'] = copy_slim_color_table.loc[cur_family, 'count']+1
    
    
# Color as other if only occurs fewer than 5 times
copy_slim_color_table = copy_slim_color_table.loc[slim_color_table.index,:]
for ix in copy_slim_color_table.index:
    cur_count = copy_slim_color_table.loc[ix, 'count']
    #print(cur_count, ix)
    if cur_count < 5:
        slim_color_table = slim_color_table.drop([ix])
        cur_indexes = copy_taxa_df_bac[copy_taxa_df_bac.loc[:, 'c'] == ix].index
        for new_ix in cur_indexes:
            taxa_df_bac.loc[new_ix, 'c'] = [0.6,0.6,0.6]
slim_color_table.loc['Other', 'c'] = [0.6,0.6,0.6]      
slim_color_table_bac = slim_color_table.copy()

# Create colortable for given taxonomic level - FUNGI


In [None]:
%autoreload
tl2 = 'genus'
color_table_fun = bct.create_data_colortable(data_fungi, otu_table, 'fungi', tl2)

slim_color_table = pd.DataFrame(columns = ['c'])
copy_slim_color_table = pd.DataFrame(0, index = color_table_fun.index, columns = ['count'])


taxa_df_fun= pd.DataFrame(index = data_fungi.index, columns = ['c'])
copy_taxa_df_fun= pd.DataFrame(index = data_fungi.index, columns = ['c'])

for ix in data_fungi.index:
    valval = data_fungi.loc[ix,'Acremonium':].values.argmax()
    cur_family = data_fungi.columns[2+valval]
    taxa_df_fun.loc[ix, :] = color_table_fun.loc[cur_family,:].values
    copy_taxa_df_fun.loc[ix, :] = cur_family
    slim_color_table.loc[cur_family, 'c'] = color_table_fun.loc[cur_family,:].values[0]
    copy_slim_color_table.loc[cur_family, 'count'] = copy_slim_color_table.loc[cur_family, 'count']+1
    
    
# Color as other if only occurs fewer than 5 times
copy_slim_color_table = copy_slim_color_table.loc[slim_color_table.index,:]
for ix in copy_slim_color_table.index:
    cur_count = copy_slim_color_table.loc[ix, 'count']
    #print(cur_count, ix)
    if cur_count < 5:
        slim_color_table = slim_color_table.drop([ix])
        cur_indexes = copy_taxa_df_fun[copy_taxa_df_fun.loc[:, 'c'] == ix].index
        for new_ix in cur_indexes:
                taxa_df_fun.loc[new_ix, 'c'] = [0.6,0.6,0.6]
            
slim_color_table.loc['Other', 'c'] = [0.6,0.6,0.6]      
slim_color_table_fun = slim_color_table.copy()


# Creating legend for figure
# for ii, ix in enumerate(slim_color_table_fun.index):
#     plt.scatter(1,ii, color=slim_color_table_fun.loc[ix, 'c'])
#     plt.text(1.0025,ii,ix)
# plt.savefig('fun_legend.svg')

# MDS For Bacteria

In [None]:
### Calculate ###

X = data_bacteria.iloc[:,2:].copy()
X = sklearn.metrics.pairwise.pairwise_distances(X, metric='braycurtis')
mds = manifold.MDS(n_components=2, eps=1e-12, dissimilarity="precomputed", max_iter=5000, random_state=3)
Y = mds.fit_transform(X)
Y_df = pd.DataFrame(Y, index = data_bacteria.index)
Y_df = pd.concat([data_bacteria.iloc[:,0:2],Y_df],1)

df = pd.DataFrame(Y[:,0:2], columns=['x','y'])


dbscan = DBSCAN(eps = 0.05, min_samples = 3)
cls = dbscan.fit_predict(Y)
df['label'] = cls +1

for ix in np.unique(cls):
    if len(cls[cls==ix]) < 25:
        cls[cls==ix] = -1

############
### Plot ###
############

fig3 = plt.figure(constrained_layout=True, figsize=(10,5))
gs = fig3.add_gridspec(2, 6)
tsne_ft = fig3.add_subplot(gs[:, :3])
tsne_ft.set_title('Bacteria - colored by cluster')
tsne_fa = fig3.add_subplot(gs[:1, 3:6])
tsne_fa.set_title('Cluster size and composition')
tsne_fa2 = fig3.add_subplot(gs[1:2, 3:6])
tsne_fa2.set_title('Average age in cluster')


colors = plt.cm.PiYG(np.linspace(0, 1, len(df.label.unique())))

for color, label in zip(colors, df.label.unique()):
    if label == 0:
        color = [0.5,0.5,0.5]
    tempdf = df[df.label == label]
    if len(tempdf)< 25:
        color = [0.5,0.5,0.5]
    tsne_ft.scatter(tempdf.x, tempdf.y, color=color, s=150)
    
############
## Determine size and composition of clusters
############


save_cls = pd.DataFrame(cls, index = copy_taxa_df_bac.index)

size_per_cluster = pd.DataFrame(index=np.unique(save_cls[0]))
for ix in size_per_cluster.index:
    size_per_cluster.loc[ix,'my_count'] = len(save_cls.loc[save_cls[0]==ix])
    
    members_of_current = copy_taxa_df_bac.loc[save_cls.loc[save_cls[0]==ix].index,:]
    for t_ix in np.unique(members_of_current.c):
        size_per_cluster.loc[ix,t_ix] = len(members_of_current[members_of_current.c==t_ix])

size_per_cluster = size_per_cluster.fillna(0)

size_per_cluster = size_per_cluster.loc[size_per_cluster.my_count>25,:]
size_per_cluster = size_per_cluster.sort_values(by='my_count', ascending=False)
short_size_per_cluster = size_per_cluster.iloc[:,1:]


###########
# Get and plot average age
###########

save_cls = pd.DataFrame(cls, index = copy_taxa_df_bac.index)
get_ages = pd.DataFrame(index=np.unique(save_cls[0]))
for ix in size_per_cluster.index:
    cur_ages = age_df_bac.loc[save_cls.loc[save_cls[0]==ix].index]
    get_ages.loc[ix,'mean_age'] = np.mean(cur_ages)
    get_ages.loc[ix,'std_age'] = np.std(cur_ages)
    

slim_ages = get_ages.loc[size_per_cluster.index,:]
slim_ages = slim_ages.sort_values(by='mean_age')
slim_ages = slim_ages.drop([-1])


short_size_per_cluster = short_size_per_cluster.loc[slim_ages.index,:]


short_size_per_cluster.plot(kind='bar', 
                            stacked=True, 
                            ax=tsne_fa,
                            color=np.array(color_table_bac.loc[short_size_per_cluster.columns,:]))
tsne_fa.get_legend().set_visible(False)
tsne_fa.set_xticklabels('')


tsne_fa2.errorbar(range(len(slim_ages)), slim_ages.mean_age, yerr=slim_ages.std_age, ls='', marker='o')
tsne_fa2.plot([-0.1,len(short_size_per_cluster)+0.1-1], [np.mean(age_df_bac), np.mean(age_df_bac)],
             ls='--',
             c = 'gray')

tsne_fa2.set_xlabel('Cluster ID')

#plt.savefig('new_clustering_analysis_bacteria.svg')


# Prep data for R, perform kruskal-wallis test using kruskal function from argicolae

save_cls = pd.DataFrame(cls, index = copy_taxa_df_bac.index)
save_all_ages = pd.DataFrame()
get_ages = pd.DataFrame(index=np.unique(save_cls[0]))
for ix in size_per_cluster.index:
    cur_ages = age_df_bac.loc[save_cls.loc[save_cls[0]==ix].index]
    get_ages.loc[ix,'mean_age'] = np.mean(cur_ages)
    get_ages.loc[ix,'std_age'] = np.std(cur_ages)
    
    cur_ages= pd.DataFrame(cur_ages)
    cur_ages['c_group'] = ix
    save_all_ages = pd.concat([save_all_ages, cur_ages],0)
    
save_all_ages = save_all_ages.drop(save_all_ages.loc[save_all_ages.c_group==-1,:].index)
   
#save_all_ages.to_csv('all_ages_bac.csv')


# MDS For Fungi

In [None]:
data_fungi2 = data_fungi.copy().drop(data_fungi[data_fungi.loc[:,'Acremonium':].sum(1)==0].index)

X = data_fungi2.iloc[:,2:].copy()
X = sklearn.metrics.pairwise.pairwise_distances(X, metric='braycurtis')
X = np.nan_to_num(X)

mds = manifold.MDS(n_components=2, eps=1e-12, dissimilarity="precomputed", max_iter=15000, random_state=0)
Y_fun = mds.fit_transform(X)


df = pd.DataFrame(Y_fun, columns=['x','y'])


dbscan = DBSCAN(eps = 0.06, min_samples = 4)
cls_fun = dbscan.fit_predict(Y_fun)
for ix in np.unique(cls_fun):
    if len(cls_fun[cls_fun==ix]) < 24:
        cls_fun[cls_fun==ix] = -1

df['label'] = cls_fun +1

# Plot stuff#
fig3 = plt.figure(constrained_layout=True, figsize=(10,5))
gs = fig3.add_gridspec(2, 6)
tsne_ft = fig3.add_subplot(gs[:, :3])
tsne_ft.set_title('Fungi - colored by cluster')

tsne_fa = fig3.add_subplot(gs[:1, 3:6])
tsne_fa.set_title('Cluster size and composition')

tsne_fa2 = fig3.add_subplot(gs[1:2, 3:6])
tsne_fa2.set_title('Average age in cluster')



# Plot tsne of clusters #
colors = plt.cm.PiYG(np.linspace(0, 1, len(df.label.unique())))

for color, label in zip(colors, df.label.unique()):
    
    if label == 0:
        color = [0.5,0.5,0.5]
    
    tempdf = df[df.label == label]
    tsne_ft.scatter(tempdf.x, tempdf.y, color=color, s=150)
    


## Determine size and composition of clusters

save_cls = pd.DataFrame(cls_fun, index = copy_taxa_df_fun.loc[data_fungi2.index,:].index)
number_of_unclassified_fun = len(save_cls.loc[save_cls[0]==-1])

size_per_cluster = pd.DataFrame(index=np.unique(save_cls[0]))
for ix in size_per_cluster.index:
    size_per_cluster.loc[ix,'my_count'] = len(save_cls.loc[save_cls[0]==ix])
    
    members_of_current = copy_taxa_df_fun.loc[save_cls.loc[save_cls[0]==ix].index,:]
    for t_ix in np.unique(members_of_current.c):
        size_per_cluster.loc[ix,t_ix] = len(members_of_current[members_of_current.c==t_ix])

size_per_cluster = size_per_cluster.fillna(0)
size_per_cluster = size_per_cluster.loc[size_per_cluster.my_count>24,:]
size_per_cluster = size_per_cluster.sort_values(by='my_count', ascending=False)
short_size_per_cluster = size_per_cluster.iloc[:,1:]


get_ages = pd.DataFrame(index=np.unique(save_cls[0]))
for ix in size_per_cluster.index:
    cur_ages = age_df_fun.loc[save_cls.loc[save_cls[0]==ix].index]
    get_ages.loc[ix,'mean_age'] = np.mean(cur_ages)
    get_ages.loc[ix,'std_age'] = np.std(cur_ages)

slim_ages = get_ages.loc[size_per_cluster.index,:]
slim_ages = slim_ages.sort_values(by='mean_age')
slim_ages = slim_ages.drop([-1])


short_size_per_cluster = short_size_per_cluster.loc[slim_ages.index,:]

short_size_per_cluster.plot(kind='bar', 
                            stacked=True, 
                            ax=tsne_fa,
                            color=np.array(color_table_fun.loc[short_size_per_cluster.columns,:]))
tsne_fa.get_legend().set_visible(False)
tsne_fa.set_xticklabels('')
tsne_fa2.set_xticklabels(short_size_per_cluster.index)


tsne_fa2.errorbar(range(len(slim_ages)), slim_ages.mean_age, yerr=slim_ages.std_age, ls='', marker='o')
tsne_fa2.plot([-0.1,len(short_size_per_cluster)+0.1-1], [np.mean(age_df_bac), np.mean(age_df_bac)],
             ls='--',
             c = 'gray')
tsne_fa2.set_xlabel('Cluster ID')
tsne_fa2.set_xticklabels(short_size_per_cluster.index)

plt.show()


save_cls = pd.DataFrame(cls_fun, index = copy_taxa_df_fun.loc[data_fungi2.index,:].index)
number_of_unclassified_fun = len(save_cls.loc[save_cls[0]==-1])
save_all_ages = pd.DataFrame()
get_ages = pd.DataFrame(index=np.unique(save_cls[0]))
for ix in size_per_cluster.index:
    cur_ages = age_df_bac.loc[save_cls.loc[save_cls[0]==ix].index]
    get_ages.loc[ix,'mean_age'] = np.mean(cur_ages)
    get_ages.loc[ix,'std_age'] = np.std(cur_ages)
    
    cur_ages= pd.DataFrame(cur_ages)
    cur_ages['c_group'] = ix
    save_all_ages = pd.concat([save_all_ages, cur_ages],0)
    
save_all_ages = save_all_ages.drop(save_all_ages.loc[save_all_ages.c_group==-1,:].index)
#save_all_ages.to_csv('all_ages_fun.csv')

# Load for individual baby data

In [None]:
%autoreload
data_bac_single, otu_table_bac_single = jnf.load_microbiome_data(file_name = '20190207_NICU_rDNA_zOTU-table_mod2.xlsx',
                         sheet_name = 'bac16S')

data_fun_single, otu_table_fun_single = jnf.load_microbiome_data(file_name = '20190207_NICU_rDNA_zOTU-table_mod2.xlsx',
                         sheet_name = 'ITS1')

tl = 'genus'

data_bacteria_single = jnf.process_NICU_data_for_plotting(data_bac_single, otu_table_bac_single, 'Bacteria', tl)
data_fungi_single = jnf.process_NICU_data_for_plotting(data_fun_single, otu_table_fun_single, 'Fungi', tl)


# Read in old colormaps
my_cmap = bct.load_color_variables(kingdom = 'bacteria')
for ix in my_cmap.keys():
    foo = my_cmap[ix]
    for ixx in [0,1,2]:
        foo[ixx] = foo[ixx]/255
        
my_cmap2 = bct.load_color_variables(kingdom = 'fungi')
for ix in my_cmap2.keys():
    foo = my_cmap2[ix]
    for ixx in [0,1,2]:
        foo[ixx] = foo[ixx]/255

In [None]:
%autoreload
tl2 = 'genus'
cur_baby='260'

# Create colortable for given taxonomic level - BACTERIA
cur_data = data_bacteria_single.copy()
cur_otu = otu_table_bac_single.copy()
relative = 1
if relative:
    my_sum = cur_data.iloc[:,2:].sum(1)
    for cix, ix in enumerate(cur_data.index):
        cur_data.loc[ix,cur_data.columns[2]:] = cur_data.loc[ix,cur_data.columns[2]:].astype(float)/my_sum[ix]
        
    

color_table_bac_single = bct.create_data_colortable(cur_data, otu_table_bac_single, 'bacteria', tl2)
baby_260_data = cur_data.loc[cur_data.babyid==cur_baby,'Staphylococcus':].copy()
color_table_genera = baby_260_data.columns[baby_260_data.loc[:,'Staphylococcus':].max()>0.1].values
bacteria_legend = color_table_bac_single.loc[color_table_genera,:]


# Create colortable for given taxonomic level - FUNGI
cur_data_fun = data_fungi_single.copy()
cur_otu_fun = otu_table_fun_single.copy()
relative = 1
if relative:
    my_sum = cur_data_fun.iloc[:,2:].sum(1)
    for cix, ix in enumerate(cur_data_fun.index):
        cur_data_fun.loc[ix,cur_data_fun.columns[2]:] = cur_data_fun.loc[ix,cur_data_fun.columns[2]:].astype(float)/my_sum[ix]


color_table_fun_single = bct.create_data_colortable(cur_data_fun, otu_table_fun_single, 'fungi', tl2)
baby_260_data_fun = cur_data_fun.loc[cur_data_fun.babyid==cur_baby,'Candida':].copy()
color_table_genera_fun = baby_260_data_fun.columns[baby_260_data_fun.loc[:,'Candida':].max()>0.3].values
fungi_legend = color_table_fun_single.loc[color_table_genera_fun,:]


# Get specific data for baby of interest

In [None]:
%autoreload
cur_baby = '260'

baby_260 = data_bacteria_single.loc[data_bacteria_single.babyid==cur_baby,:].copy()
baby_260[baby_260==0]=np.nan
baby_260 = baby_260.fillna(1)
baby_260['day'] = baby_260['day'].astype(float)
baby_260 = baby_260.set_index('day')

baby_260_rel = data_bacteria_single.loc[data_bacteria_single.babyid==cur_baby,:].copy()
baby_260_rel['day'] = baby_260_rel['day'].astype(float)
baby_260_rel = baby_260_rel.set_index('day')
for ix in baby_260_rel.index:
    row_tmp = baby_260_rel.loc[ix, 'Staphylococcus':]
    baby_260_rel.loc[ix, 'Staphylococcus':] = baby_260_rel.loc[ix, 'Staphylococcus':] / sum(row_tmp)
    
baby_260_fun = data_fungi_single.loc[data_fungi_single.babyid==cur_baby,:].copy()
baby_260_fun[baby_260_fun==0]=np.nan
baby_260_fun = baby_260_fun.fillna(1)
baby_260_fun['day'] = baby_260_fun['day'].astype(float)
baby_260_fun = baby_260_fun.set_index('day')

baby_260_rel_fun = data_fungi_single.loc[data_fungi_single.babyid==cur_baby,:].copy()
baby_260_rel_fun['day'] = baby_260_rel_fun['day'].astype(float)
baby_260_rel_fun = baby_260_rel_fun.set_index('day')
for ix in baby_260_rel_fun.index:
    row_tmp = baby_260_rel_fun.loc[ix, 'Candida':]
    baby_260_rel_fun.loc[ix, 'Candida':] = baby_260_rel_fun.loc[ix, 'Candida':] / sum(row_tmp)
    
max_day = int(max(data_bacteria_single.day.astype(int)))+0.5  
    

In [None]:
# Discretize days for visualization

age_df_bac = data_bacteria.loc[:,'day'].astype(int)
for ix in age_df_bac.index:
    cur_day = age_df_bac.loc[ix]
    if cur_day < 8:
        age_df_bac.loc[ix] = 0.02
    elif cur_day < 15:
        age_df_bac.loc[ix] = 0.2
    elif cur_day < 22:
        age_df_bac.loc[ix] = 0.4
    elif cur_day < 29:
        age_df_bac.loc[ix] = 0.6
    elif cur_day < 36:
        age_df_bac.loc[ix] = 0.8
    else:
        age_df_bac.loc[ix] = 1
    

age_df_fun = data_fungi.loc[:,'day'].astype(int)
for ix in age_df_fun.index:
    cur_day = age_df_fun.loc[ix]
    if cur_day < 8:
        age_df_fun.loc[ix] = 0.02
    elif cur_day < 15:
        age_df_fun.loc[ix] = 0.2
    elif cur_day < 22:
        age_df_fun.loc[ix] = 0.4
    elif cur_day < 29:
        age_df_fun.loc[ix] = 0.6
    elif cur_day < 36:
        age_df_fun.loc[ix] = 0.8
    else:
        age_df_fun.loc[ix] = 1
    

# Fig 2 - alternative arrangement

In [None]:
fig3 = plt.figure(constrained_layout=True, figsize=(17,17))
gs = fig3.add_gridspec(9, 9)

# All bacteria
dba_bt = fig3.add_subplot(gs[0:3, :3])
dba_bt.set_title('Bacteria - taxa')

dba_ba = fig3.add_subplot(gs[0:3, 3:6])
dba_ba.set_title('Bacteria - age')

dba_l_relbac = fig3.add_subplot(gs[0:1, 6:])
dba_l_relbac.set_title('Dynamics for representative infant')

dba_l_staph = fig3.add_subplot(gs[1:2, 6:])

dba_l_klebs = fig3.add_subplot(gs[2:3, 6:])

# All fungi

dba_ft = fig3.add_subplot(gs[3:6, :3])
dba_ft.set_title('Fungi - taxa')

dba_fa = fig3.add_subplot(gs[3:6, 3:6])
dba_fa.set_title('Fungi - age')

dba_l_relfun = fig3.add_subplot(gs[3:4, 6:])

dba_l_cand = fig3.add_subplot(gs[4:5, 6:])

dba_l_alt = fig3.add_subplot(gs[5:6, 6:])

my_lims = 0.95

# GLMM

glmm_bac = fig3.add_subplot(gs[6:9, 6:])
glmm_bac.set_title('Effect on total bacteria')


# Legends

lgd_axis = fig3.add_subplot(gs[6:9, :4])
color_axis = fig3.add_subplot(gs[6:9, 4:5])



# Bacteria - by taxa

df = pd.DataFrame(Y[:,0:2], columns=['x','y'])
taxa_df = taxa_df_bac
save_staph = pd.DataFrame()

for ii, ix in enumerate(df.index):
    if inv_simpson_df_bac.iloc[ii,2]>4:
        sc = dba_bt.scatter(df.loc[ii, 'x'], df.loc[ii, 'y'], s=150, color= 'w', edgecolors='k', linewidth=0.2) #25.02 -
    elif copy_taxa_df_bac.iloc[ii,0] == 'Staphylococcus':
        save_staph.loc[ii, 'x'] = df.loc[ii, 'x']
        save_staph.loc[ii, 'y'] = df.loc[ii, 'y']
    else:
        sc = dba_bt.scatter(df.loc[ii, 'x'], df.loc[ii, 'y'], s=150, color= [taxa_df.iloc[ii,0]], edgecolors='k', linewidth=0.2) #25.02 -  

sc = dba_bt.scatter(save_staph.x, save_staph.y, s=150, color= 'gold', edgecolors='k', linewidth=0.2) #25.02 -          
dba_bt.set_xticks([], [])
dba_bt.set_xlim([-my_lims,my_lims])
dba_bt.set_ylim([-my_lims,my_lims])

# Bacteria - by age
cur_age_df =age_df_bac.loc[taxa_df_bac.index]
tmp_age_2 = pd.DataFrame(cur_age_df).reset_index()
early_samples = tmp_age_2[tmp_age_2.day<=0.1]
late_samples = tmp_age_2[tmp_age_2.day>0.1]
my_hot = plt.get_cmap('summer_r')
my_hot = plt.get_cmap('gray_r')

sc = dba_ba.scatter(df.loc[late_samples.index,'x'],
                df.loc[late_samples.index,'y'],
                s=150,
                color= my_hot(late_samples.day),
                edgecolors='k',
                #cmap = 'hot_r',
                linewidth=0.3) #25.02 -  

sc = dba_ba.scatter(df.loc[early_samples.index,'x'],
                df.loc[early_samples.index,'y'],
                s=150,
                color= my_hot(early_samples.day),
                edgecolors='k',
                #cmap = 'hot_r',
                linewidth=0.3) #25.02 -  

dba_ba.set_xticks([], [])
dba_ba.set_yticks([], [])
dba_ba.set_xlim([-my_lims,my_lims])
dba_ba.set_ylim([-my_lims,my_lims])
    
# Fungi - by taxa

df = pd.DataFrame(Y_fun, columns=['x','y'])
taxa_df = taxa_df_fun
taxa_df = taxa_df_fun.loc[data_fungi2.index,:]


for ii, ix in enumerate(df.index):
    if inv_simpson_df_fun.iloc[ii,2]>4:
        sc = dba_ft.scatter(df.loc[ii, 'x'], df.loc[ii, 'y'], s=150, color= 'w', edgecolors='k', linewidth=0.2) #25.02 -

    else:
        sc = dba_ft.scatter(df.loc[ii, 'x'], df.loc[ii, 'y'], s=150, color= [taxa_df.iloc[ii,0]], edgecolors='k', linewidth=0.2) #25.02 -  

dba_ft.set_xlim([-my_lims,my_lims])
dba_ft.set_ylim([-my_lims,my_lims])  
#dba_ft.set_xticks([], [])
#dba_ft.set_yticks([], [])
    
# Fungi - by age
cur_age_df =age_df_fun.loc[taxa_df.index]
tmp_age_2 = pd.DataFrame(cur_age_df).reset_index()
early_samples = tmp_age_2[tmp_age_2.day<=0.17]
late_samples = tmp_age_2[tmp_age_2.day>0.17]

sc = dba_fa.scatter(df.loc[late_samples.index,'x'],
                df.loc[late_samples.index,'y'],
                s=150,
                color= my_hot(late_samples.day),
                edgecolors='k',
                cmap = 'hot_r',
                linewidth=0.3) #25.02 -  

sc = dba_fa.scatter(df.loc[early_samples.index,'x'],
                df.loc[early_samples.index,'y'],
                s=150,
                color= my_hot(early_samples.day),
                edgecolors='k',
                cmap = 'hot_r',
                linewidth=0.3) #25.02 -  

#dba_fa.set_xticks([], [])
dba_fa.set_yticks([], [])
dba_fa.set_xlim([-my_lims,my_lims])
dba_fa.set_ylim([-my_lims,my_lims])



#####################
# Individual baby   #
#####################

anf.plot_baby_timeseries(cur_data,
                             cur_otu,
                             cur_baby,
                             tl,
                             0.000001,
                             'bacteria',
                             dba_l_relbac)
dba_l_relbac.legend('')
dba_l_relbac.set_ylabel('Bacteria \n Relative Abundance')
dba_l_relbac.set_ylim([0,1])
dba_l_relbac.set_xticks([], [])


# # /NEW

anf.plot_baby_timeseries(cur_data_fun,
                             cur_otu_fun,
                             cur_baby,
                             tl,
                             0.000001,
                             'fungi',
                             dba_l_relfun)
dba_l_relfun.legend('')
dba_l_relfun.set_ylabel('Fungi \n Relative Abundance')
dba_l_relfun.set_ylim([0,1])
dba_l_relfun.set_xticks([], [])


dba_l_staph.plot(np.log10(baby_260[['Escherichia-Shigella']]), c='gray', marker='o', linewidth=1.0)
dba_l_staph.set_ylabel('Staphylococcus \n log$_{10}$ Absolute \n Abundance', color='gray')  # we already handled the x-label with ax1
dba_l_staph.tick_params(axis='y', labelcolor='gray')
dba_l_staph.plot([3.5,42.5],[3,3],'k--')
dba_l_staph.set_ylim([0,11])
dba_l_staph.set_xlim([3.5,42.5])
dba_l_staph.set_xticks([], [])


ax32 = dba_l_staph.twinx() 
ax32.plot(baby_260_rel[['Escherichia-Shigella']], c=my_cmap['Escherichia-Shigella'], marker='o', linewidth=1.0)
ax32.set_ylabel('Relative', color=my_cmap['Escherichia'])  # we already handled the x-label with ax1
ax32.tick_params(axis='y', labelcolor=my_cmap['Escherichia'])
ax32.set_ylim([0,1.05])
ax32.set_xticks([], [])

dba_l_klebs.plot(np.log10(baby_260[['Enterococcus']]), c='gray', marker='o', linewidth=1.0)
dba_l_klebs.set_ylabel('Klebsiella \n log$_{10}$ Absolute \n Abundance', color='gray')  # we already handled the x-label with ax1
dba_l_klebs.tick_params(axis='y', labelcolor='gray')
dba_l_klebs.plot([3.5,42.5],[3,3],'k--')
dba_l_klebs.set_ylim([0,11])
dba_l_klebs.set_xlim([3.5,42.5])
dba_l_klebs.set_xticks([], [])


ax52 = dba_l_klebs.twinx() 
ax52.plot(baby_260_rel[['Enterococcus']], c=my_cmap['Enterococcus'], marker='o', linewidth=1.0)
ax52.set_ylabel('Relative', color=my_cmap['Enterococcus'])  # we already handled the x-label with ax1
ax52.tick_params(axis='y', labelcolor=my_cmap['Enterococcus'])
ax52.set_ylim([0,1.05])
ax52.set_xticks([], [])

dba_l_cand.plot(np.log10(baby_260_fun[['Candida']]), c='gray', marker='o', linewidth=1.0)
dba_l_cand.set_ylabel('Candida \n log$_{10}$ Absolute \n Abundance', color='gray')  # we already handled the x-label with ax1
dba_l_cand.tick_params(axis='y', labelcolor='gray')
dba_l_cand.set_ylim([0,9])
dba_l_cand.set_xlim([3.5,42.5])
dba_l_cand.set_xticks([], [])

ax72 = dba_l_cand.twinx() 
ax72.plot(baby_260_rel_fun[['Candida']], c=my_cmap2['Candida'], marker='o', linewidth=1.0)
ax72.set_ylabel('Relative', color=my_cmap2['Candida'])  # we already handled the x-label with ax1
ax72.tick_params(axis='y', labelcolor=my_cmap2['Candida'])
ax72.set_ylim([0,1.05])
ax72.set_xticks([], [])

#dba_l_cand.set_xlabel('Day of life')


dba_l_alt.plot(np.log10(baby_260_fun[['Malassezia']]), c='gray', marker='o', linewidth=1.0)
dba_l_alt.set_ylabel('Malassezia \n log$_{10}$ Absolute \n Abundance', color='gray')  # we already handled the x-label with ax1
dba_l_alt.tick_params(axis='y', labelcolor='gray')
dba_l_alt.set_ylim([0,9])
dba_l_alt.set_xlim([3.5,42.5])

ax82 = dba_l_alt.twinx() 
ax82.plot(baby_260_rel_fun[['Malassezia']], c=my_cmap2['Malassezia'], marker='o', linewidth=1.0)
ax82.set_ylabel('Relative', color=my_cmap2['Malassezia'])  # we already handled the x-label with ax1
ax82.tick_params(axis='y', labelcolor=my_cmap2['Malassezia'])
ax82.set_ylim([0,1.05])
dba_l_alt.set_xlabel('Day of life')

#fig.subplots_adjust(hspace=0.01)


## Add bacteria legend

for genus in bacteria_legend.index:
    lgd_axis.scatter([], [], s=75,edgecolors='k',linewidth=0.2, color=bacteria_legend.loc[genus,'c'],
                label=genus)
lgd_axis.legend(scatterpoints=1,
                frameon=True,
                labelspacing=1,
               # title='Bacterial genus',
                loc = 'upper left',
                fontsize = 15,
                ncol = 1)

lgd_axis.xaxis.set_major_formatter(NullFormatter())
lgd_axis.yaxis.set_major_formatter(NullFormatter())
lgd_axis.spines['top'].set_visible(False)
lgd_axis.spines['right'].set_visible(False)
lgd_axis.spines['left'].set_visible(False)
lgd_axis.spines['bottom'].set_visible(False)


tmp_legn_fun = lgd_axis.twinx() 

# Add fungi legend

for genus in fungi_legend.index:
    tmp_legn_fun.scatter([], [], s=75,edgecolors='k',linewidth=0.2, color=fungi_legend.loc[genus,'c'],
                label=genus)
tmp_legn_fun.legend(scatterpoints=1,
                    frameon=True,
                    labelspacing=1,
                  #  title='Fungi genus',
                    loc = 'upper right',
                    fontsize = 15,
                    ncol = 1)
tmp_legn_fun.spines['top'].set_visible(False)
tmp_legn_fun.spines['right'].set_visible(False)
tmp_legn_fun.spines['left'].set_visible(False)
tmp_legn_fun.spines['bottom'].set_visible(False)
tmp_legn_fun.xaxis.set_major_formatter(NullFormatter())
tmp_legn_fun.yaxis.set_major_formatter(NullFormatter())



tmp = pd.concat([md_bac_fit.params[1:-1], md_bac_fit.params[1:-1] - md_bac_fit.conf_int()[0][1:-1]],1).rename(columns= {0:'my_mean', 1:'sd'})
tmp = tmp.loc[['day','total_antibiotics', 'total_antifungals','total_fungi'],:]

glmm_bac.bar(x = tmp.index,
        width = 0,
        height=tmp.my_mean,
        lw=0.25,
        yerr= tmp.sd,
        align = 'center',
        )
glmm_bac.scatter(x = tmp.index, y=tmp.my_mean, s=150)
glmm_bac.plot([-0.5,3.5],[0,0],'gray',ls='--', lw=0.5)
glmm_bac.set_ylim(-0.2,0.6)
plt.setp(glmm_bac.get_xticklabels(), rotation=45)

import matplotlib as mpl
col_map = plt.get_cmap('gray_r')
cbar = mpl.colorbar.ColorbarBase(color_axis, cmap=col_map, orientation = 'vertical')

#plt.colorbar(sc, cax=color_axis)
cbar.set_ticks([0.02, 1])
cbar.set_ticklabels(['$\leq$ 7 days', 42])
color_axis.set_ylabel('Infant Age (days)', fontsize=15)

#plt.savefig('figure_2_jan8.svg')
plt.show()

# SI Fig Clustering

In [None]:
feeding_df= pd.DataFrame(index=data_bacteria.index, columns=['f_color'])
copy_data_bacteria=data_bacteria.copy()

for ix in copy_data_bacteria.index:
    
    baby = copy_data_bacteria.loc[ix, 'babyid']
    day = copy_data_bacteria.loc[ix, 'day']
    
    s_day = day.split('re')
    if len(s_day)==2:
        day=s_day[0]
        
    
    tmp = all_baby_info.loc[all_baby_info.id==int(baby[:3]),:].copy()
    tmp = tmp[['day','diet']]
    tmp = tmp.loc[tmp.diet != '0',:]
    closest_days = tmp.iloc[(tmp['day']-int(day)).abs().argsort()[:2]]
    closest_days=closest_days.iloc[0,1]
    
    if closest_days=='F':
        feeding_df.loc[ix,'f_color'] = 'blue'
    elif closest_days=='BM':
        feeding_df.loc[ix,'f_color'] = 'pink'
    elif closest_days=='NPO':
        feeding_df.loc[ix,'f_color'] = 'yellow'
    elif closest_days=='MIX':
        feeding_df.loc[ix,'f_color'] = 'orange'
    else:
        feeding_df.loc[ix,'f_color'] = 'k'
        print(ix, closest_days)
        

In [None]:
fig3 = plt.figure(constrained_layout=True, figsize=(15,15))
gs = fig3.add_gridspec(6, 6)

dot_size = 100
tsne_bt = fig3.add_subplot(gs[0:2, :2])
tsne_bt.set_title('Taxa')

tsne_ba = fig3.add_subplot(gs[0:2, 2:4])
tsne_ba.set_title('Age')

tsne_ft = fig3.add_subplot(gs[2:4, :2])
tsne_ft.set_title('Cluster')

dba_deliv = fig3.add_subplot(gs[4:6, 2:4])
dba_deliv.set_title('Delivery mode')

dba_sex = fig3.add_subplot(gs[4:6, 4:6])
dba_sex.set_title('Gender')

dba_food = fig3.add_subplot(gs[0:2, 4:6])
dba_food.set_title('Diet')

tsne_fa = fig3.add_subplot(gs[4:5, 0:2])
tsne_fa.set_title('Cluster size and composition')

tsne_fa2 = fig3.add_subplot(gs[5:6, 0:2])
tsne_fa2.set_title('Average age in cluster')

df = pd.DataFrame(Y[:,0:2], columns=['x','y'])

dba_deliv_lgd = fig3.add_subplot(gs[3, 2:4])
dba_sex_lgd = fig3.add_subplot(gs[3, 4:6])
dba_food_lgd = fig3.add_subplot(gs[2, 4:6])
color_axis = fig3.add_subplot(gs[2, 2:4])


taxa_df = taxa_df_bac


### TAXA ###

for ii, ix in enumerate(df.index):
    if inv_simpson_df_bac.iloc[ii,2]>4:
        sc = tsne_bt.scatter(df.loc[ii, 'x'], df.loc[ii, 'y'], s=dot_size, color= 'w', edgecolors='k', linewidth=0.2) #25.02 -

    else:
        sc = tsne_bt.scatter(df.loc[ii, 'x'], df.loc[ii, 'y'], s=dot_size, color= [taxa_df.iloc[ii,0]], edgecolors='k', linewidth=0.2) #25.02 -  


### AGE ###

cur_age_df =age_df_bac.loc[taxa_df_bac.index]

sc = tsne_ba.scatter(df.x,
                df.y,
                s=dot_size,
                c= cur_age_df.values,
                edgecolors='k',
                cmap = 'gray_r',
                linewidth=0.3) #25.02 -  
    
    
### CLUSTER ###

dbscan = DBSCAN(eps = 0.05, min_samples = 3)

cls = dbscan.fit_predict(Y)

df = pd.DataFrame(Y[:,0:2], columns=['x','y'])
df['label'] = cls +1

colors = plt.cm.PiYG(np.linspace(0, 1, len(df.label.unique())))

    
for color, label in zip(colors, df.label.unique()):
    if label == 0:
        color = [0.5,0.5,0.5]
    tempdf = df[df.label == label]
    if len(tempdf)< 25:
        color = [0.5,0.5,0.5]
    tsne_ft.scatter(tempdf.x, tempdf.y, color=color, s=150)
 

### DELIVERY ###

for ii, ix in enumerate(df.index):
    
    baby = int(inv_simpson_df_bac.iloc[ii, 1])
    del_type = subset_baby_info.loc[subset_baby_info.baby_id==baby,'delivery'].values[0]
    
    if del_type==1:
        face_col = 'green'
        edge_col = 'green'
        sc = dba_deliv.scatter(df.loc[ii, 'x'], df.loc[ii, 'y'], s=dot_size, c= face_col, edgecolors=edge_col, alpha=0.75)

    else:
        face_col = 'orange'
        edge_col = 'orange'
        sc = dba_deliv.scatter(df.loc[ii, 'x'], df.loc[ii, 'y'], s=dot_size, c= face_col, edgecolors=edge_col,alpha=0.75) 

dba_deliv.set_title("Delivery mode")


### FOOD ###

for ii, ix in enumerate(df.index):

    face_col = feeding_df.iloc[ix, 0]#'w'
    edge_col = feeding_df.iloc[ix, 0]#'k'
    sc = dba_food.scatter(df.loc[ii, 'x'], df.loc[ii, 'y'], s=dot_size, c= face_col, edgecolors=edge_col)

### GENDER ###

for ii, ix in enumerate(df.index):
    
    baby = int(inv_simpson_df_bac.iloc[ii, 1])
    del_type = subset_baby_info.loc[subset_baby_info.baby_id==baby,'sex'].values[0]
    
    if del_type==1:
        face_col = 'darkblue'
        edge_col = 'darkblue'
        sc = dba_sex.scatter(df.loc[ii, 'x'], df.loc[ii, 'y'], s=dot_size, c= face_col, edgecolors=edge_col, alpha=0.75)

    else:
        face_col = 'goldenrod'
        edge_col = 'goldenrod'
        sc = dba_sex.scatter(df.loc[ii, 'x'], df.loc[ii, 'y'], s=dot_size, c= face_col, edgecolors=edge_col,alpha=0.75) 

dba_sex.set_title("Gender")



## Determine size and composition of clusters

save_cls = pd.DataFrame(cls, index = copy_taxa_df_bac.index)

size_per_cluster = pd.DataFrame(index=np.unique(save_cls[0]))
for ix in size_per_cluster.index:
    size_per_cluster.loc[ix,'my_count'] = len(save_cls.loc[save_cls[0]==ix])
    
    members_of_current = copy_taxa_df_bac.loc[save_cls.loc[save_cls[0]==ix].index,:]
    for t_ix in np.unique(members_of_current.c):
        size_per_cluster.loc[ix,t_ix] = len(members_of_current[members_of_current.c==t_ix])

size_per_cluster = size_per_cluster.fillna(0)


size_per_cluster = size_per_cluster.loc[size_per_cluster.my_count>44,:]
size_per_cluster = size_per_cluster.sort_values(by='my_count', ascending=False)
short_size_per_cluster = size_per_cluster.iloc[:,1:]



get_ages = pd.DataFrame(index=np.unique(save_cls[0]))
for ix in size_per_cluster.index:
    cur_ages = age_df_bac.loc[save_cls.loc[save_cls[0]==ix].index]
    get_ages.loc[ix,'mean_age'] = np.mean(cur_ages)
    get_ages.loc[ix,'std_age'] = np.std(cur_ages)
    

slim_ages = get_ages.loc[size_per_cluster.index,:]
slim_ages = slim_ages.sort_values(by='mean_age')
slim_ages = slim_ages.drop([-1])


short_size_per_cluster = short_size_per_cluster.loc[slim_ages.index,:]


short_size_per_cluster.plot(kind='bar', 
                            stacked=True, 
                            ax=tsne_fa,
                            color=np.array(color_table_bac.loc[short_size_per_cluster.columns,:]))
tsne_fa.get_legend().set_visible(False)
tsne_fa.set_xticklabels('')


tsne_fa2.errorbar(range(len(slim_ages)), slim_ages.mean_age, yerr=slim_ages.std_age, ls='', marker='o')
tsne_fa2.plot([-0.1,len(short_size_per_cluster)+0.1-1], [np.mean(age_df_bac), np.mean(age_df_bac)],
             ls='--',
             c = 'gray')

tsne_fa2.set_xlabel('Cluster ID')
#plt.savefig('new_clustering_analysis_bacteria.svg')




####### LEGENDS


dba_deliv_lgd.scatter([], [], s=dot_size,edgecolors='k',linewidth=0.2, color='orange', label='C-section')
dba_deliv_lgd.scatter([], [], s=dot_size,edgecolors='k',linewidth=0.2, color='green', label='Vaginal')
dba_deliv_lgd.legend(scatterpoints=1,
                frameon=True,
                labelspacing=1,
                title='Delivery Mode',
                loc = 'lower center',
                fontsize = 15,
                ncol = 2)
dba_deliv_lgd.xaxis.set_major_formatter(NullFormatter())
dba_deliv_lgd.yaxis.set_major_formatter(NullFormatter())
dba_deliv_lgd.spines['top'].set_visible(False)
dba_deliv_lgd.spines['right'].set_visible(False)
dba_deliv_lgd.spines['left'].set_visible(False)
dba_deliv_lgd.spines['bottom'].set_visible(False)


dba_sex_lgd.scatter([], [], s=dot_size,edgecolors='k',linewidth=0.2, color='goldenrod', label='Male')
dba_sex_lgd.scatter([], [], s=dot_size,edgecolors='k',linewidth=0.2, color='darkblue', label='Female')

dba_sex_lgd.legend(scatterpoints=1,
                frameon=True,
                labelspacing=1,
                title='Delivery Mode',
                loc = 'lower center',
                fontsize = 15,
                ncol = 2)

dba_sex_lgd.xaxis.set_major_formatter(NullFormatter())
dba_sex_lgd.yaxis.set_major_formatter(NullFormatter())
dba_sex_lgd.spines['top'].set_visible(False)
dba_sex_lgd.spines['right'].set_visible(False)
dba_sex_lgd.spines['left'].set_visible(False)
dba_sex_lgd.spines['bottom'].set_visible(False)



dba_food_lgd.scatter([], [], s=dot_size,edgecolors='k',linewidth=0.2, color='blue', label='Formula')
dba_food_lgd.scatter([], [], s=dot_size,edgecolors='k',linewidth=0.2, color='pink', label='Breast-milk')
dba_food_lgd.scatter([], [], s=dot_size,edgecolors='k',linewidth=0.2, color='yellow', label='NPO')
dba_food_lgd.scatter([], [], s=dot_size,edgecolors='k',linewidth=0.2, color='orange', label='Mix')
dba_food_lgd.scatter([], [], s=dot_size,edgecolors='k',linewidth=0.2, color='black', label='Unknown')

dba_food_lgd.legend(scatterpoints=1,
                frameon=True,
                labelspacing=1,
                title='Delivery Mode',
                loc = 'upper center',
                fontsize = 15,
                ncol = 2)

dba_food_lgd.xaxis.set_major_formatter(NullFormatter())
dba_food_lgd.yaxis.set_major_formatter(NullFormatter())
dba_food_lgd.spines['top'].set_visible(False)
dba_food_lgd.spines['right'].set_visible(False)
dba_food_lgd.spines['left'].set_visible(False)
dba_food_lgd.spines['bottom'].set_visible(False)


col_map = plt.get_cmap('gray_r')
cbar = mpl.colorbar.ColorbarBase(color_axis, cmap=col_map, orientation = 'horizontal')
color_axis.set_xlabel('Infant age')

#plt.savefig('SI_fig_bacterial_composition.svg')
plt.show()


# FUNGI SI FIG

In [None]:
fig3 = plt.figure(constrained_layout=True, figsize=(15,15))
gs = fig3.add_gridspec(6, 6)

dot_size = 100

tsne_bt = fig3.add_subplot(gs[0:2, :2])
tsne_bt.set_title('Taxa')

tsne_ba = fig3.add_subplot(gs[0:2, 2:4])
tsne_ba.set_title('Age')

tsne_ft = fig3.add_subplot(gs[2:4, :2])
tsne_ft.set_title('Cluster')

dba_deliv = fig3.add_subplot(gs[4:6, 2:4])
dba_deliv.set_title('Delivery mode')

dba_sex = fig3.add_subplot(gs[4:6, 4:6])
dba_sex.set_title('Gender')

dba_food = fig3.add_subplot(gs[0:2, 4:6])
dba_food.set_title('Diet')

tsne_fa = fig3.add_subplot(gs[4:5, 0:2])
tsne_fa.set_title('Cluster size and composition')

tsne_fa2 = fig3.add_subplot(gs[5:6, 0:2])
tsne_fa2.set_title('Average age in cluster')

df = pd.DataFrame(Y[:,0:2], columns=['x','y'])
#df['label'] = km.labels_

dba_deliv_lgd = fig3.add_subplot(gs[3, 2:4])
dba_sex_lgd = fig3.add_subplot(gs[3, 4:6])
dba_food_lgd = fig3.add_subplot(gs[2, 4:6])
color_axis = fig3.add_subplot(gs[2, 2:4])

df = pd.DataFrame(Y_fun, columns=['x','y'])
taxa_df = taxa_df_fun.loc[data_fungi2.index,:]



### TAXA ###

for ii, ix in enumerate(df.index):
    if inv_simpson_df_fun.iloc[ii,2]>4:
        sc = tsne_bt.scatter(df.loc[ii, 'x'], df.loc[ii, 'y'], s=dot_size, color= 'w', edgecolors='k', linewidth=0.2) #25.02 -

    else:
        sc = tsne_bt.scatter(df.loc[ii, 'x'], df.loc[ii, 'y'], s=dot_size, color= [taxa_df.iloc[ii,0]], edgecolors='k', linewidth=0.2) #25.02 -  


### AGE ###

cur_age_df =age_df_fun.loc[taxa_df.index]

sc = tsne_ba.scatter(df.x,
                df.y,
                s=dot_size,
                c= cur_age_df.values,
                edgecolors='k',
                cmap = 'gray_r',
                linewidth=0.3) #25.02 -  
    
    
### CLUSTER ###

#dbscan = DBSCAN(eps = 0.05, min_samples = 3)
#dbscan = DBSCAN(eps = 0.06, min_samples = 6)

dbscan = DBSCAN(eps = 0.06, min_samples = 4)

cls = dbscan.fit_predict(Y_fun)

df = pd.DataFrame(Y_fun[:,0:2], columns=['x','y'])
df['label'] = cls +1

colors = plt.cm.PiYG(np.linspace(0, 1, len(df.label.unique())))


for color, label in zip(colors, df.label.unique()):
    if label == 0:
        color = [0.5,0.5,0.5]
    tempdf = df[df.label == label]
    if len(tempdf)< 25:
        color = [0.5,0.5,0.5]
    tsne_ft.scatter(tempdf.x, tempdf.y, color=color, s=150)
     


### DELIVERY ###

for ii, ix in enumerate(df.index):
    
    baby = int(inv_simpson_df_fun.iloc[ii, 1])
    del_type = subset_baby_info.loc[subset_baby_info.baby_id==baby,'delivery'].values[0]
    
    if del_type==1:
        face_col = 'green'
        edge_col = 'green'
        sc = dba_deliv.scatter(df.loc[ii, 'x'], df.loc[ii, 'y'], s=dot_size, c= face_col, edgecolors=edge_col, alpha=0.75)

    else:
        face_col = 'orange'
        edge_col = 'orange'
        sc = dba_deliv.scatter(df.loc[ii, 'x'], df.loc[ii, 'y'], s=dot_size, c= face_col, edgecolors=edge_col,alpha=0.75) 

dba_deliv.set_title("Delivery mode")


### FOOD ###

for ii, ix in enumerate(df.index):

    face_col = feeding_df.iloc[ix, 0]#'w'
    edge_col = feeding_df.iloc[ix, 0]#'k'
    sc = dba_food.scatter(df.loc[ii, 'x'], df.loc[ii, 'y'], s=100, c= face_col, edgecolors=edge_col)

### GENDER ###

for ii, ix in enumerate(df.index):
    
    baby = int(inv_simpson_df_fun.iloc[ii, 1])
    del_type = subset_baby_info.loc[subset_baby_info.baby_id==baby,'sex'].values[0]
    
    if del_type==1:
        face_col = 'darkblue'
        edge_col = 'darkblue'
        sc = dba_sex.scatter(df.loc[ii, 'x'], df.loc[ii, 'y'], s=dot_size, c= face_col, edgecolors=edge_col, alpha=0.75)

    else:
        face_col = 'goldenrod'
        edge_col = 'goldenrod'
        sc = dba_sex.scatter(df.loc[ii, 'x'], df.loc[ii, 'y'], s=dot_size, c= face_col, edgecolors=edge_col,alpha=0.75) 

dba_sex.set_title("Gender")



## Determine size and composition of clusters


save_cls = pd.DataFrame(cls, index = copy_taxa_df_fun.loc[data_fungi2.index,:].index)
#save_cls = pd.DataFrame(cls, index = copy_taxa_df_fun.index)
number_of_unclassified_bac = len(save_cls.loc[save_cls[0]==-1])

size_per_cluster = pd.DataFrame(index=np.unique(save_cls[0]))
for ix in size_per_cluster.index:
    size_per_cluster.loc[ix,'my_count'] = len(save_cls.loc[save_cls[0]==ix])
    
    members_of_current = copy_taxa_df_fun.loc[save_cls.loc[save_cls[0]==ix].index,:]
    for t_ix in np.unique(members_of_current.c):
        size_per_cluster.loc[ix,t_ix] = len(members_of_current[members_of_current.c==t_ix])

size_per_cluster = size_per_cluster.fillna(0)


size_per_cluster = size_per_cluster.loc[size_per_cluster.my_count>24,:]
size_per_cluster = size_per_cluster.sort_values(by='my_count', ascending=False)
short_size_per_cluster = size_per_cluster.iloc[:,1:]


save_cls = pd.DataFrame(cls, index = taxa_df.index)


get_ages = pd.DataFrame(index=np.unique(save_cls[0]))
for ix in size_per_cluster.index:
    cur_ages = age_df_fun.loc[save_cls.loc[save_cls[0]==ix].index]
    get_ages.loc[ix,'mean_age'] = np.mean(cur_ages)
    get_ages.loc[ix,'std_age'] = np.std(cur_ages)
    

slim_ages = get_ages.loc[size_per_cluster.index,:]
slim_ages = slim_ages.sort_values(by='mean_age')
slim_ages = slim_ages.drop([-1])


short_size_per_cluster = short_size_per_cluster.loc[slim_ages.index,:]


short_size_per_cluster.plot(kind='bar', 
                            stacked=True, 
                            ax=tsne_fa,
                            color=np.array(color_table_fun.loc[short_size_per_cluster.columns,:]))
tsne_fa.get_legend().set_visible(False)
tsne_fa.set_xticklabels('')


tsne_fa2.errorbar(range(len(slim_ages)), slim_ages.mean_age, yerr=slim_ages.std_age, ls='', marker='o')
tsne_fa2.plot([-0.1,len(short_size_per_cluster)+0.1-1], [np.mean(age_df_bac), np.mean(age_df_bac)],
             ls='--',
             c = 'gray')

tsne_fa2.set_xlabel('Cluster ID')
#plt.savefig('new_clustering_analysis_bacteria.svg')




####### LEGENDS


dba_deliv_lgd.scatter([], [], s=dot_size,edgecolors='k',linewidth=0.2, color='orange', label='C-section')
dba_deliv_lgd.scatter([], [], s=dot_size,edgecolors='k',linewidth=0.2, color='green', label='Vaginal')
dba_deliv_lgd.legend(scatterpoints=1,
                frameon=True,
                labelspacing=1,
                title='Delivery Mode',
                loc = 'lower center',
                fontsize = 15,
                ncol = 2)
dba_deliv_lgd.xaxis.set_major_formatter(NullFormatter())
dba_deliv_lgd.yaxis.set_major_formatter(NullFormatter())
dba_deliv_lgd.spines['top'].set_visible(False)
dba_deliv_lgd.spines['right'].set_visible(False)
dba_deliv_lgd.spines['left'].set_visible(False)
dba_deliv_lgd.spines['bottom'].set_visible(False)


dba_sex_lgd.scatter([], [], s=dot_size,edgecolors='k',linewidth=0.2, color='goldenrod', label='Male')
dba_sex_lgd.scatter([], [], s=dot_size,edgecolors='k',linewidth=0.2, color='darkblue', label='Female')

dba_sex_lgd.legend(scatterpoints=1,
                frameon=True,
                labelspacing=1,
                title='Delivery Mode',
                loc = 'lower center',
                fontsize = 15,
                ncol = 2)

dba_sex_lgd.xaxis.set_major_formatter(NullFormatter())
dba_sex_lgd.yaxis.set_major_formatter(NullFormatter())
dba_sex_lgd.spines['top'].set_visible(False)
dba_sex_lgd.spines['right'].set_visible(False)
dba_sex_lgd.spines['left'].set_visible(False)
dba_sex_lgd.spines['bottom'].set_visible(False)



dba_food_lgd.scatter([], [], s=dot_size,edgecolors='k',linewidth=0.2, color='blue', label='Formula')
dba_food_lgd.scatter([], [], s=dot_size,edgecolors='k',linewidth=0.2, color='pink', label='Breast-milk')
dba_food_lgd.scatter([], [], s=dot_size,edgecolors='k',linewidth=0.2, color='yellow', label='NPO')
dba_food_lgd.scatter([], [], s=dot_size,edgecolors='k',linewidth=0.2, color='orange', label='Mix')
dba_food_lgd.scatter([], [], s=dot_size,edgecolors='k',linewidth=0.2, color='black', label='Unknown')

dba_food_lgd.legend(scatterpoints=1,
                frameon=True,
                labelspacing=1,
                title='Delivery Mode',
                loc = 'upper center',
                fontsize = 15,
                ncol = 2)

dba_food_lgd.xaxis.set_major_formatter(NullFormatter())
dba_food_lgd.yaxis.set_major_formatter(NullFormatter())
dba_food_lgd.spines['top'].set_visible(False)
dba_food_lgd.spines['right'].set_visible(False)
dba_food_lgd.spines['left'].set_visible(False)
dba_food_lgd.spines['bottom'].set_visible(False)


col_map = plt.get_cmap('gray_r')
cbar = mpl.colorbar.ColorbarBase(color_axis, cmap=col_map, orientation = 'horizontal')
color_axis.set_xlabel('Infant age')

#plt.savefig('SI_fig_fungal_composition.svg')
plt.show()
