In [None]:
%matplotlib inline

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import os
import matplotlib.pyplot as plt

In [None]:
sns.set_palette('deep')

In [None]:
meta = pd.read_csv("../data/elinav_patients.tsv", sep="\t", dtype=str)
meta["patient"] = meta["patient"].apply(lambda x: "p_" + x)

## Load shotgun

In [None]:
dshot_wide = pd.read_excel("../data/Probiotics_Elinav_Shotgun.xlsx")

In [None]:
dshot = dshot_wide.melt(id_vars='Unnamed: 0').query("value > 0")
dshot["strain"] = dshot['Unnamed: 0'].apply(lambda x: x[3:])
dshot["species"] = dshot['strain'].apply(lambda x: "_".join(x.split("_")[:2]))
dshot["genus"] = dshot['species'].apply(lambda x: x.split("_")[0])
dshot["patient"] = dshot["variable"].apply(lambda x: "p_" + x.split(".")[0])
dshot["loc"] = dshot["variable"].apply(lambda x: "l_" + x.split(".")[2].rjust(2, '0'))
dshot["value"] = dshot["value"] / 100
dshot = dshot.query("genus != 'unclassified'")

del dshot["Unnamed: 0"]
del dshot["variable"]

dshot = dshot.groupby(["species", "genus", "patient", "loc"], as_index=False).agg({"value": sum})
dshot = dshot.groupby(["patient", "loc"]).filter(lambda x: len(x) > 1)

dshot = pd.merge(dshot, meta, on="patient")
dshot["log_value"] = np.log10(dshot["value"])
dshot.sort_values("value", ascending=False, inplace=True)

## Map communities to models

In [None]:
models = pd.read_csv('../data/model_list.tsv', sep='\t', usecols=[4])
models['strain'] = models['file_path'].apply(lambda x: os.path.basename(x)[:-7])
models["species"] = models['strain'].apply(lambda x: "_".join(x.split("_")[:2]))
del models["file_path"]

In [None]:
gut_strains = set(dshot["species"])
species = models[models["species"].isin(gut_strains)].groupby(
    "species", as_index=False).agg({"strain": lambda x: x.iloc[0]})

In [None]:
dshot = pd.merge(dshot, species, on="species")

In [None]:
bq = pd.read_csv("../communities/top/bq_50.tsv", sep="\t", header=None)
bq["species"] = bq[1].apply(lambda x: "_".join(x.split("_")[:2]))
dshot["bq"] = dshot["species"].isin(bq["species"])

In [None]:
grouped = dshot.query("bq == True").groupby(["type", "loc", "patient"],as_index=False).agg(
    {"value": sum, "species": len}).groupby(["type", "loc"], as_index=False).agg(np.mean)

abundance = grouped.pivot_table(index="loc", columns="type", values="value", fill_value=0) * 100
counts = grouped.pivot_table(index="loc", columns="type", values="species", fill_value=0)

## Plot

In [None]:
loc_order = list(reversed(abundance.index))
col_order = ['Permissive', 'Resistant', 'Placebo']

abundance = abundance.loc[loc_order, col_order]
counts = counts.loc[loc_order, col_order]

index=['TI', 'Ce', 'AC', 'TC', 'DC', 'SC', 'Re']
columns = ['P', 'R', 'C']

abundance.index = index
counts.index = index

abundance.columns = columns
counts.columns = columns

In [None]:
plt.subplot(1,2,1)
ax1 = sns.heatmap(counts, cmap="BuGn", cbar_kws={"format": '%i        '})
ax1.set_xlabel('# Species')
plt.yticks(rotation=0)
ax1.set_ylabel('LGI location')

ax1.annotate("___________", (0.1, 1.06), xycoords='axes fraction')
ax1.annotate("*", (0.33, 1.06), xycoords='axes fraction')

plt.subplot(1,2,2)
ax2 = sns.heatmap(abundance, cmap="OrRd", vmin=0, vmax=1, cbar_kws={"format": '%.1f %%'})
plt.yticks(rotation=0)
ax2.set_xlabel('Abundance')

ax2.annotate("___________", (0.1, 1.06), xycoords='axes fraction')
ax2.annotate("*", (0.33, 1.06), xycoords='axes fraction')
ax2.annotate("____________________", (0.1, 1.15), xycoords='axes fraction')
ax2.annotate("*", (0.5, 1.15), xycoords='axes fraction')

plt.tight_layout()
plt.savefig("../figures/supp_fig_8.png", dpi=300)