In [None]:
%matplotlib inline

In [None]:
import pandas as pd
from glob import glob
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.gridspec import GridSpec

In [None]:
sns.set_palette('deep')

## Load data and metadata

In [None]:
orgs01 = pd.concat(pd.read_csv(filename, sep='\t', header=None, names=['community', 'org_id'])
                  for filename in glob('../communities/bin_rnd_01/*.tsv'))
orgs01["size"] = orgs01["community"].apply(lambda x: int(x.split("_")[1]))

In [None]:
orgs001 = pd.concat(pd.read_csv(filename, sep='\t', header=None, names=['community', 'org_id'])
                  for filename in glob('../communities/bin_rnd_001/*.tsv'))
orgs001["size"] = orgs001["community"].apply(lambda x: int(x.split("_")[1]))

In [None]:
samples = pd.read_csv('../data/emp_150bp_filtered.tsv', sep='\t')
samples['value'] = 1
samples_wide = pd.pivot_table(samples, index='org_id', columns='sample', values='value', fill_value=0)

In [None]:
metadata = pd.read_csv('../data/emp_qiime_mapping_qc_filtered.tsv', sep='\t')
metadata.rename(columns={'#SampleID': 'sample'}, inplace=True)

## Merge co-ocurrence and metadata

In [None]:
def merge_metadata(cooc):

    cooc['value'] = 1
    cooc_wide = pd.pivot_table(cooc, index='org_id', columns='community', values='value', fill_value=0)

    common = sorted(set(samples_wide.index) & set(cooc_wide.index))
    samples_common = samples_wide.loc[common,:]
    cooc_common = cooc_wide.loc[common,:]

    cooc_comms_wide = cooc_common.T.dot(samples_common)
    cooc_comms_bin = cooc_comms_wide.eq(cooc_wide.sum(axis=0), axis=0).astype(int)
    cooc_comms = cooc_comms_bin.unstack().reset_index()
    cooc_comms = cooc_comms[cooc_comms[0] > 0]
    cooc_comms.drop(columns=[0], inplace=True)
    
    col_funcs = {
        'empo_1': lambda x: ((x == "Host-associated").sum(), (x == "Free-living").sum()),
        'empo_3': lambda x: len(set(x)),
        'title': lambda x: len(set(x)),
        'sample': len
    }
    columns = ["sample", "empo_1", "empo_3", "title"]
    cooc_meta = pd.merge(cooc_comms, metadata[columns], on='sample')
    cooc_grouped = cooc_meta.groupby("community", as_index=False).agg(col_funcs)
    cooc_grouped["size"] = cooc_grouped["community"].apply(lambda x: int(x.split("_")[1]))
    
    return cooc_grouped

In [None]:
%time meta01 = merge_metadata(orgs01)

In [None]:
%time meta001 = merge_metadata(orgs001)

In [None]:
def func(x):
    a, b = list(zip(*x))
    ratio = np.log10(sum(a) / sum(b))
    return min(max(ratio, -1), 1)

empo3_01 = meta01.groupby(["size", "empo_3"], as_index=False).agg({"community": len, "empo_1": func})
samples_01 = meta01.groupby("size", as_index=False).agg({"sample": np.mean})

empo3_001 = meta001.groupby(["size", "empo_3"], as_index=False).agg({"community": len, "empo_1": func})
samples_001 = meta001.groupby("size", as_index=False).agg({"sample": np.mean})

empo3_01 = empo3_01.query("size % 2 == 0 and size <= 40")
empo3_001 = empo3_001.query("size % 2 == 0 and size <= 40")

samples_01 = samples_01.query(" size <= 40")
samples_001 = samples_001.query(" size <= 40")

In [None]:
#f, axs = plt.subplots(2,1, figsize=(10,6))#, sharex=True, sharey=True)

f = plt.figure(figsize=(12,6))

gs1 = GridSpec(2, 20, hspace=0.4, wspace=2.0)
ax1 = plt.subplot(gs1[0, :-1])
ax2 = plt.subplot(gs1[1, :-1])
ax3 = plt.subplot(gs1[:, -1])
axs = [ax1, ax2]

sns.scatterplot(data=empo3_01, x="size", y="empo_3", hue="empo_1", hue_norm=(-1, 0.6), size="community", 
                    sizes=(20,800), palette="BrBG", legend=False, ax=axs[0])

sns.scatterplot(data=empo3_001, x="size", y="empo_3", hue="empo_1", hue_norm=(-1, 0.6), size="community", 
                    sizes=(20,800), palette="BrBG", legend=False, ax=axs[1])


axs[0].set_ylabel("")
axs[0].set_xlabel("")
axs[0].set_title("Cooperative")
axs[0].set_ylim(-1, 16)
axs[0].set_xlim(1, 41.3)
axs[0].set_yticks([0, 5, 10, 15])
axs[0].set_ylabel('Number of habitats', fontsize=12)

axs[1].set_xlabel("Community size", fontsize=12)
axs[1].set_ylabel("")
axs[1].set_title("Competitive")
axs[1].set_ylim(-1, 16)
axs[1].set_xlim(1, 41.3)
axs[1].set_yticks([0, 5, 10, 15])
axs[1].set_ylabel('Number of habitats', fontsize=12)

n = 100
pal = sns.color_palette("BrBG_r", n)
ax3.imshow(np.arange(n).reshape(n, 1), cmap=mpl.colors.ListedColormap(list(pal)),
          interpolation="nearest", aspect="auto")

ax3.set_xticks([])
ax3.set_yticks([])

for spine in ["top", "left", "right", "bottom"]:
    ax3.spines[spine].set_alpha(0.5)

f.text(0.91, 0.85, "Host-associated", fontsize=12, rotation="vertical")
f.text(0.91, 0.25, "Free-living", fontsize=12, rotation="vertical")

plt.savefig("../figures/fig_3.png", dpi=300)