In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from copy import deepcopy
from sklearn.metrics import pairwise_distances
from collections import Counter
import scipy.cluster.hierarchy as sch
from scipy.spatial.distance import squareform
from scipy.stats import spearmanr,kendalltau
from scipy.cluster.hierarchy import leaves_list
import warnings
warnings.filterwarnings('ignore')

# Load microbiome data and metadata

In [2]:
# read sample meta data and eliminate samples without transplant day
df_sample = pd.read_csv('tblASVsamples.csv', index_col=0)
df_sample = df_sample[df_sample.DayRelativeToNearestHCT.notnull()]

# read count data
df_count_stacked = pd.read_csv('tblcounts_asv_melt.csv')
df_count_stacked = pd.pivot_table(df_count_stacked, index='SampleID', columns='ASV', values='Count', aggfunc=np.sum).fillna(0)
df_count_stacked = df_count_stacked[df_count_stacked.sum(axis=1)>=1000]
df_count_stacked = df_count_stacked.loc[:, (df_count_stacked != 0).any(axis=0)]
df_relab_asv =  df_count_stacked.div(df_count_stacked.sum(axis=1), axis=0)

# find commmon samples
common_samples = set(df_sample.index).intersection(set(df_relab_asv.index))
df_sample = df_sample.loc[common_samples]
df_relab_asv = df_relab_asv.loc[common_samples]

# pairwise distance
df_pdist_asv = pd.DataFrame(
    pairwise_distances(df_relab_asv.values, metric="braycurtis", n_jobs=-1),
    index=df_relab_asv.index,
    columns=df_relab_asv.index)

# get oral bacterial fraction
df_blast_100 = pd.read_csv("blast_HMPv35oral/blast_HMPv35oral_p100.txt", sep="\t", comment="#", header=None)
df_blast_100.columns = ['query_accver', 'subject_accver', 'perc_identity', 'alignment_length', 'mismatches', 'gap_opens', 'qstart', 'qend', 'sstart', 'send', 'evalue', 'bitscore']#
df_oral_total = df_relab_asv[set(df_blast_100.query_accver).intersection(df_relab_asv.columns)].sum(axis=1).to_frame()
df_oral_total.columns = ['OralFrac_HMPv35oral']
df_oral_total = df_oral_total.reset_index('SampleID').sort_values(['OralFrac_HMPv35oral','SampleID']).set_index('SampleID')

# read taxonomy
df_tax = pd.read_csv('tblASVtaxonomy_silva138_v4v5_filter.csv', index_col=0)
df_tax = df_tax.loc[df_relab_asv.columns]
df_tax.index.name = 'ASV'

# find taxonomy color
unique_color = df_tax[['TaxonomyColor','TaxonomyColorOrder']].drop_duplicates().sort_values(by='TaxonomyColorOrder').reset_index(drop=True)
relab_asv_grouped = np.zeros((len(df_relab_asv.index),len(unique_color.index)))
for k,o in enumerate(unique_color.TaxonomyColorOrder):
    relab_asv_grouped[:,k] = df_relab_asv[set(df_relab_asv.columns).intersection(set(df_tax[df_tax.TaxonomyColorOrder==o].index))].sum(axis=1).values
df_relab_asv_grouped = pd.DataFrame(relab_asv_grouped, index=df_relab_asv.index, columns=unique_color.TaxonomyColor)
df_relab_asv_grouped = df_relab_asv_grouped.loc[df_oral_total.index]

# Fig. 4a

In [3]:
%%capture

fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(30,3))

# set n_samples_to_plot a small number for debugging
n_samples_to_plot = int(1e10)
total_sample_number = len(df_relab_asv_grouped)

# find clusters
Y = sch.linkage(squareform(df_pdist_asv.loc[df_oral_total.index,df_oral_total.index].values), method='average')

# plot stacked bars of microbiota composition
df_relab_asv_grouped_reordered = df_relab_asv_grouped.loc[df_oral_total.index].iloc[leaves_list(Y)]
_ = df_relab_asv_grouped_reordered.iloc[0:np.min([n_samples_to_plot, total_sample_number])].plot.bar(
    stacked=True, 
    color=df_relab_asv_grouped_reordered.columns, 
    legend=None, 
    width=1.0, 
    ax=ax[0], 
    ylim=[0,1]
)
_ = ax[0].set_ylabel('')
_ = ax[0].set_yticks([])
_ = ax[0].set_yticks([], minor=True)
_ = ax[0].set_xlabel('')
_ = ax[0].set_xticks([])
_ = ax[0].set_xticks([], minor=True)

# plot stacked bars of oral bacterial fraction
_ = df_oral_total.iloc[leaves_list(Y)].iloc[0:np.min([n_samples_to_plot, total_sample_number])].OralFrac_HMPv35oral.plot.bar(
    color=(0.4980392156862745, 0.4980392156862745, 0.4980392156862745), 
    legend=None, 
    width=1.0, 
    ax=ax[1], 
    ylim=[0,1]
)
_ = ax[1].set_ylabel('')
_ = ax[1].set_yticks([])
_ = ax[1].set_yticks([], minor=True)
_ = ax[1].set_xlabel('')
_ = ax[1].set_xticks([])
_ = ax[1].set_xticks([], minor=True)

plt.tight_layout()
plt.savefig('fig4a.png', dpi=600, bbox_inches='tight')
plt.close()