In [None]:
%matplotlib inline

In [None]:
import pandas as pd
from glob import glob
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from collections import Counter

In [None]:
sns.set_palette('deep')

### Load simulation results

In [None]:
types = ["bin_rnd_01", "bin_rnd_001", "random"]

In [None]:
dfs = []
for commtype in types:
    dfi = pd.concat(pd.read_csv(filename, sep='\t')
                    for filename in glob(f"../simulation/{commtype}/*_debug.tsv"))
    dfi['type'] = commtype
    dfs.append(dfi)
df = pd.concat(dfs)
df.reset_index(inplace=True)
df.dropna(inplace=True)

In [None]:
df["size"] = df["community"].apply(lambda x: int(x.split("_")[1]))
df["data"] = df["data"].apply(lambda x: set(x.split(",")))

## load classifier

In [None]:
classes = pd.read_csv("../data/bigg_classes.tsv", sep="\t")
nan = classes["sub_class"].isnull()
classes.loc[nan, "sub_class"] = classes.loc[nan, "class"]
classes_dict = dict(classes[["bigg", "sub_class"]].values)

In [None]:
most_common = [
    'Amino acids, peptides, and analogues',
    'Carbohydrates and carbohydrate conjugates',
    'Pyrimidines and pyrimidine derivatives',
    'Benzoic acids and derivatives', 
    'Other non-metal oxides',
    'Monosaccharides',
    'Sugar acids and derivatives',
    'Purine nucleosides',
    'Tricarboxylic acids and derivatives',
    'Homogeneous other non-metal compounds'
]

most_common = [
    'Amino acids, peptides, and analogues',
    'Carbohydrates and carbohydrate conjugates',
    'Pyrimidines and pyrimidine derivatives',
    'Benzoic acids and derivatives',
    'Homogeneous other non-metal compounds', 
    'Other non-metal oxides',
    'Tricarboxylic acids and derivatives', 
    'Sugar acids and derivatives',
    'Alcohols and polyols', 
    'Carbonyl compounds'
]


colors = dict(zip(most_common, sns.color_palette("deep", 10)))

### Plotting function

In [None]:
def plot_pie(values, ax=None):
    if ax is None:
        f, ax = plt.subplots(1,1)
    ct1 = [classes_dict.get(y, y) for x in values for y in x]
    ct1 = Counter(ct1)
    ct1 = pd.DataFrame(list(ct1.items()), columns=["class", "freq"])
    ct1.sort_values("freq", ascending=False, inplace=True)
    ct1["freq"] = ct1["freq"] / ct1["freq"].sum()
    ct1.loc[ct1["freq"] < 0.02, "class"] = ''
    ct1["color"] = ct1["class"].apply(lambda x: colors.get(x, [0.8,0.8,0.8,1]))
    missing = ct1["color"] == 0
    ax.pie(ct1["freq"], labels=ct1["class"], rotatelabels=False, startangle=90, labeldistance=1.15,
           colors=ct1["color"], wedgeprops=dict(width=0.6, edgecolor='w', linewidth=1));  

In [None]:
def count_classes(values, ax=None):
    ct1 = [classes_dict.get(y, y) for x in values for y in x]
    ct1 = Counter(ct1)
    ct1 = pd.DataFrame(list(ct1.items()), columns=["class", "freq"])
    ct1.sort_values("freq", ascending=False, inplace=True)
    ct1["freq"] = ct1["freq"] / ct1["freq"].sum()
    return ct1

### Compound sharing

In [None]:
df_ni = df.query("key1 == 'mip' and key2 == 'ni'")[["community", "type", "size", "data"]]
df_i = df.query("key1 == 'mip' and key2 == 'i'")[["community", "type", "size", "data"]]
df_mip = pd.merge(df_ni, df_i, on=["community", "size", "type"])
df_mip["shared"] = df_mip["data_x"] - df_mip["data_y"]

### Compound competition

In [None]:
df_tro = df.query("key1 == 'mro' and key2 != 'comm'")

def func(xs):
    ct = Counter()
    for x in xs:
        ct.update(x)
    return {a for a, b in ct.items() if b > 1}

df_tro = df_tro.groupby(["community", "medium", "type", "size"], as_index=False).agg({"data": func})

### Plot all by class

In [None]:
fig, axs = plt.subplots(2,2, figsize=(18,8))
competed1 = df_tro.query("type == 'bin_rnd_01'")["data"]
competed2 = df_tro.query("type == 'bin_rnd_001'")["data"]
shared1 = df_mip.query("type == 'bin_rnd_01'")["shared"]
shared2 = df_mip.query("type == 'bin_rnd_001'")["shared"]

plot_pie(competed1, ax=axs[0,0])
plot_pie(competed2, ax=axs[0,1])
plot_pie(shared1, ax=axs[1,0])
plot_pie(shared2, ax=axs[1,1])

axs[0,0].set_title("a", fontdict={'fontsize':16})
axs[0,1].set_title("b", fontdict={'fontsize':16})
axs[1,0].set_title("c", fontdict={'fontsize':16})
axs[1,1].set_title("d", fontdict={'fontsize':16})

plt.tight_layout()
plt.savefig("../figures/supp_fig_4.png", dpi=300)