In [None]:
%load_ext autoreload

In [None]:
import os as _os
_os.chdir(_os.environ['PROJECT_ROOT'])
_os.path.realpath(_os.path.curdir)

#### Imports

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from lib.pandas_util import idxwhere, aligned_index, align_indexes, invert_mapping
import lib.thisproject.data
import matplotlib as mpl
import lib.plot
import statsmodels as sm
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm
import subprocess
from tempfile import mkstemp
import time
import subprocess
from itertools import chain
import os
from itertools import product
from mpl_toolkits.axes_grid1 import make_axes_locatable
import sfacts as sf

In [None]:
een_metaphlan_rabund = pd.read_table('raw/een-mgen/2023-06-13_aritra.mahapatra@tum.de/6_species.tab').set_index('clade_name').T / 100
een_metaphlan_rabund_ecoli = een_metaphlan_rabund['k__Bacteria|p__Proteobacteria|c__Gammaproteobacteria|o__Enterobacterales|f__Enterobacteriaceae|g__Escherichia|s__Escherichia_coli']
een_metaphlan_rabund_ecoli

In [None]:
# Recreation of one of Aritra's plots. Looks identical.
plt.hist((een_metaphlan_rabund > 0.001).sum(1), bins=20)

In [None]:
pd.read_table('data/group/een/r.proc.gtpro.species_depth.tsv', index_col=["sample", "species_id"]).depth.unstack(fill_value=0).sum(1).mean()

In [None]:
hmp2_ecoli_depth = pd.read_table('data/group/xjin_hmp2/r.proc.gtpro.species_depth.tsv', index_col=["sample", "species_id"]).depth.unstack(fill_value=0)[102506]
een_ecoli_depth = pd.read_table('data/group/een/r.proc.gtpro.species_depth.tsv', index_col=["sample", "species_id"]).depth.unstack(fill_value=0)[102506]

fig, ax = plt.subplots()
bins = [0] + list(np.logspace(-3, 3))

for (label, (x, color)) in dict(hmp2=(hmp2_ecoli_depth, 'tab:blue'), een=(een_ecoli_depth, 'tab:orange')).items():
    ax.hist(x, bins=bins, label=label, alpha=0.6, color=color)
    
ax.legend()
ax.set_xscale('symlog', linthresh=1e-3, linscale=0.1)
ax.set_yscale('log')
ax.set_ylabel(f'{label} samples')    
ax.set_xlabel('depth')
None

In [None]:
hmp2_ecoli_rabund = pd.read_table('data/group/xjin_hmp2/r.proc.gtpro.species_depth.tsv', index_col=["sample", "species_id"]).depth.unstack(fill_value=0).apply(lambda x: x / x.sum(), axis=1)[102506]
een_ecoli_rabund = pd.read_table('data/group/een/r.proc.gtpro.species_depth.tsv', index_col=["sample", "species_id"]).depth.unstack(fill_value=0).apply(lambda x: x / x.sum(), axis=1)[102506]

fig, ax = plt.subplots()
bins = [0] + list(np.logspace(-7, 1))

for (label, (x, color)) in dict(hmp2=(hmp2_ecoli_rabund, 'grey'), een=(een_ecoli_rabund, 'tab:blue'), metaphlan=(een_metaphlan_rabund_ecoli, 'tab:orange')).items():
    ax.hist(x, bins=bins, label=label, alpha=0.6, color=color)
    
ax.legend()
ax.set_xscale('symlog', linthresh=1e-7, linscale=0.1)
ax.set_yscale('log')
ax.set_ylabel(f'{label} samples')    
ax.set_xlabel('relative abundance')
None

In [None]:
x, y = align_indexes(een_ecoli_rabund, een_metaphlan_rabund_ecoli)



left_bound = 0.0
bins = [0] + list(np.logspace(-7, 0, num=20))

fig, ax = plt.subplots()
*_, cbar_artist = ax.hist2d(x, y, bins=bins, norm=mpl.colors.PowerNorm(1/2), cmap='magma_r')
ax.set_aspect(1)

ax.set_xscale('symlog', linthresh=1e-7, linscale=1)
ax.set_yscale('symlog', linthresh=1e-7, linscale=1)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)

fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.025, 0.67])
fig.colorbar(cbar_artist, cax=cbar_ax, label="count samples")

In [None]:
x, y = align_indexes(een_ecoli_rabund, een_metaphlan_rabund_ecoli)


fig, ax = plt.subplots()

ax.scatter(x, y, alpha=0.5)
ax.set_xlabel('relative abundance (GT-Pro)')
ax.set_ylabel('relative abundance (MetaPhlAn)')
ax.set_xscale('symlog', linthresh=1e-7, linscale=0.1)
ax.set_yscale('symlog', linthresh=1e-7, linscale=0.1)
ax.axvline(1e-3, linestyle=':', color='k', lw=1)
ax.axhline(1e-3, linestyle=':', color='k', lw=1)

ax.set_aspect(1)
ax.plot([0, 1], [0, 1], linestyle='--', color='k')

In [None]:
species_depth = pd.read_table('data/group/een/r.proc.gtpro.species_depth.tsv', index_col=["sample", "species_id"]).depth.unstack(fill_value=0)

bins = np.linspace(0, 30_000, num=200)

fig, axs = plt.subplots(2, sharex=True)

for (title, x), ax in zip(dict(total_depth_by_sample=species_depth.sum(1), total_depth_by_species=species_depth.sum(0)).items(), axs.flatten()):
    ax.hist(x, bins=np.logspace(-1, 5, num=100))
    ax.set_title(title)
    ax.set_xscale('log')
fig.tight_layout()

In [None]:
species_depth.loc[["CF_1", "CF_11", "CF_15", "CF_89"]].sum(1)

In [None]:
species_rabund = species_depth.divide(species_depth.sum(1), axis=0)
species_rabund.mean().sort_values(ascending=False).head(20)

In [None]:
bins = [0] + list(np.logspace(-6, 4, num=100))
plt.hist(species_depth.values.flatten(), bins=bins)
plt.xscale('symlog', linthresh=1e-6, linscale=0.1)
plt.yscale('log')

In [None]:
x = (species_depth > 1e-1).sum(1)
print(x.quantile([0.05, 0.25, 0.5, 0.75, 0.95]))
plt.hist(x, bins=10)
plt.xlabel("Number of species with depth >0.1x")
plt.ylabel("Number of samples")

In [None]:
species_prevalence = (species_depth > 1e-1).mean()

print((species_prevalence > 0.5).sum())
print(((species_depth > 1e-1).sum() >= 2).sum())

plt.hist(species_prevalence, bins=np.linspace(0, 1, num=51))
plt.xlabel("Fraction of samples with depth >0.1x")
plt.ylabel("Number of species")
None

In [None]:
species_rabund = species_depth.divide(species_depth.sum(1), axis=0)

In [None]:
plt.hist((species_rabund > 0.001).sum(1), bins=20)
plt.xlabel("Number of species with relative abundance >0.1%")
plt.ylabel("Number of samples")

In [None]:
(species_rabund > 0.001).sum(1).median()

In [None]:
(species_depth > 1e-1).sum(1).median()

In [None]:
(een_metaphlan_rabund > 0.001).sum(1).median()

In [None]:
n_species = 10
top_species = (species_rabund > 1e-5).sum().sort_values(ascending=False).head(n_species).index

fig, axs = plt.subplots(n_species, figsize=(5, 0.3 * n_species), sharex=True, sharey=True)

bins = np.logspace(-8, 1, num=51)

for species_id, ax in zip(top_species, axs):
    # ax.hist(rabund_subset[species_id], bins=bins, alpha=0.7)
    ax.hist(species_rabund[species_id], bins=bins, alpha=0.7)
    ax.set_xscale('log')
    prevalence = (species_rabund[species_id] > 1e-5).mean()
    ax.set_title("")
    # ax.set_xticks()
    # ax.set_yticks()
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.patch.set_alpha(0.0)
    for spine in ['left', 'right', 'top', 'bottom']:
        ax.spines[spine].set_visible(False)
    ax.annotate(f'{species_id} ({prevalence:0.0%})', xy=(0.05, 0.1), ha='left', xycoords="axes fraction")
    ax.set_xlim(left=1e-9)
    ax.set_ylim(top=20)
    ax.axvline(1e-5, lw=1, linestyle=':', color='k')
    
ax.xaxis.set_visible(True)
ax.spines['bottom'].set_visible(True)
ax.set_xticks([1e-4, 1e-2, 1e-0])
ax.set_xticklabels(["0.01%", "1%", "100%"])
ax.set_xlabel("Relative Abundance")

# fig.subplots_adjust(hspace=-0.75)

In [None]:
species_taxonomy = lib.thisproject.data.load_species_taxonomy("ref/gtpro/species_taxonomy_ext.tsv")

In [None]:
for _species_id in top_species.astype(str):
    print(_species_id, ":", species_taxonomy.taxonomy_string.loc[_species_id])

In [None]:
(een_metaphlan_rabund_ecoli > 1e-3).mean()

In [None]:
n_species = 20
top_species = (species_rabund > 1e-3).sum().sort_values(ascending=False).head(n_species).index

fig, axs = plt.subplots(n_species, figsize=(5, 0.3 * n_species), sharex=True, sharey=True)

bins = np.logspace(-8, 1, num=51)

for species_id, ax in zip(top_species, axs):
    # ax.hist(rabund_subset[species_id], bins=bins, alpha=0.7)
    ax.hist(species_rabund[species_id], bins=bins, alpha=0.7)
    ax.set_xscale('log')
    prevalence = (species_rabund[species_id] > 1e-3).mean()
    ax.set_title("")
    # ax.set_xticks()
    # ax.set_yticks()
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.patch.set_alpha(0.0)
    for spine in ['left', 'right', 'top', 'bottom']:
        ax.spines[spine].set_visible(False)
    ax.annotate(f'{species_id} ({prevalence:0.0%})', xy=(0.05, 0.1), ha='left', xycoords="axes fraction")
    ax.set_xlim(left=1e-9)
    ax.set_ylim(top=20)
    ax.axvline(1e-5, lw=1, linestyle=':', color='k')
    
ax.xaxis.set_visible(True)
ax.spines['bottom'].set_visible(True)
ax.set_xticks([1e-4, 1e-2, 1e-0])
ax.set_xticklabels(["0.01%", "1%", "100%"])
ax.set_xlabel("Relative Abundance")

# fig.subplots_adjust(hspace=-0.75)

In [None]:
for _species_id in top_species.astype(str):
    print(_species_id, ":", species_taxonomy.taxonomy_string.loc[_species_id])

In [None]:
species_id = "102506"
species_taxonomy.loc[species_id]

In [None]:
np.random.seed(0)

mgtp_all = sf.data.Metagenotype.load(f"data/group/een/species/sp-{species_id}/r.proc.gtpro.mgtp.nc")
world = sf.data.World.load(f"data/group/een/species/sp-{species_id}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts46-s85-seed0.world.nc")
position_ss = world.random_sample(position=min(1000, world.sizes["position"])).position

In [None]:
sf.plot.plot_metagenotype(world.sel(position=position_ss), col_linkage_func=lambda w: w.metagenotype.linkage(), dwidth=0.01, scalex=0.147, col_colors_func=None)
sf.plot.plot_community(world.sel(position=position_ss), col_linkage_func=lambda w: w.metagenotype.linkage(), row_linkage_func=lambda w: w.genotype.linkage(), dwidth=0.01, row_colors_func=None)
sf.plot.plot_genotype(world.sel(position=position_ss), row_linkage_func=lambda w: w.genotype.linkage(), row_colors_func=lambda w: w.genotype.entropy())

In [None]:
sf.plot.plot_genotype(sf.Genotype.concat(dict(strain=world.genotype, mgen=world.metagenotype.to_estimated_genotype()), dim="strain").sel(position=position_ss), transpose=True)

In [None]:
sample = "CF_107"

In [None]:
world.metagenotype.mean_depth().sel(sample=[sample])

In [None]:
world.community.sel(sample=["CF_107"]).to_series().sort_values(ascending=False).head(4)

In [None]:
sf.plot.plot_metagenotype_frequency_spectrum(world, sample, bins=100)
plt.yscale('log')

In [None]:
sf.plot.plot_metagenotype_frequency_spectrum_compare_samples(world, sample_list=["CF_94", "CF_060"])

In [None]:
sf.plot.plot_genotype_frequency_spectrum(world, strain=3, bins=100)
plt.yscale('log')

In [None]:
sf.plot.plot_genotype_entropy(world.sel(position=position_ss), row_colors_func=lambda w: w.genotype.entropy(norm=2))
sf.plot.plot_genotype(world.sel(position=position_ss), row_colors_func=lambda w: w.genotype.entropy(norm=2))

In [None]:
plt.hist(world.genotype.entropy(norm=2).to_series().sort_values(ascending=False))

In [None]:
world.genotype.entropy(norm=2).to_series().sort_values(ascending=False).head()

In [None]:
high_entropy_genotype_list = idxwhere(world.genotype.entropy(norm=2).to_series() > 0.25)
high_entropy_genotype_list

In [None]:
low_representation_strain_list = idxwhere(world.community.max("sample").to_series() < 0.05)
low_representation_strain_list

In [None]:
replace_strain_list = list(set(low_representation_strain_list + high_entropy_genotype_list))
len(replace_strain_list)

In [None]:
geno_init = world.genotype.data.to_pandas().copy()
geno_init.loc[replace_strain_list] = 0.5
geno_init = sf.data.Genotype(geno_init.stack().to_xarray())
sf.plot.plot_genotype(geno_init.sel(position=position_ss), row_linkage_func=lambda w: world.sel(position=position_ss).genotype.linkage())

In [None]:
comm_init = world.community.data.to_pandas().copy()
comm_init[replace_strain_list] = pd.DataFrame({s: (comm_init[replace_strain_list].sum(1) / len(replace_strain_list)) for s in replace_strain_list})
comm_init = sf.data.Community(comm_init.stack().to_xarray()).fuzzed(eps=1e-3)
sf.plot.plot_community(comm_init, row_linkage_func=lambda w: world.sel(position=position_ss).genotype.linkage())

In [None]:
import logging
import torch

logging.getLogger().setLevel(logging.INFO)

In [None]:
sf.World.from_combined(geno_init, comm_init).validate_constraints()

In [None]:
world_init = sf.World.from_combined(geno_init, comm_init)
# world1.data["rho"] = np.ones(world.sizes["strain"]) / world.sizes["strain"]  # TODO: Determine if this is necessary.

world2, history = sf.workflow.fit_metagenotype_complex(
            structure=sf.model_zoo.NAMED_STRUCTURES["model8"],
            metagenotype=world.metagenotype,
            nstrain=world.sizes["strain"],
            init_from=world_init,
            init_vars=["genotype", "community"],
            hyperparameters=dict(
                gamma_hyper=1e-10,
                pi_hyper=0.01,
                pi_hyper2=0.01,
                rho_hyper=10.0,
                rho_hyper2=10.0,
                alpha=10.0,
            ),
            device="cuda",
            dtype=torch.float64,  # FIXME: Run with float32 and new model without constraint validation.
            estimation_kwargs=dict(
                seed=0,
                jit=True,
                ignore_jit_warnings=True,
                maxiter=1_000_000,
                lagA=50,
                lagB=100 ,
                optimizer_name="Adamax",
                optimizer_kwargs=dict(lr=0.05),
                optimizer_clip_kwargs=dict(clip_norm=0.001),
                minimum_lr=1e-2,
            ),
        )

In [None]:
sf.plot.plot_community(world.sel(position=position_ss), col_linkage_func=lambda w: world.metagenotype.linkage(), row_linkage_func=lambda w: world.genotype.linkage(), dwidth=0.01, row_colors_func=None)
sf.plot.plot_community(world_init.sel(position=position_ss), col_linkage_func=lambda w: world.metagenotype.linkage(), row_linkage_func=lambda w: world.genotype.linkage(), dwidth=0.01, row_colors_func=None)
sf.plot.plot_community(world2.sel(position=position_ss), col_linkage_func=lambda w: world.metagenotype.linkage(), row_linkage_func=lambda w: world.genotype.linkage(), dwidth=0.01, row_colors_func=None)

In [None]:
_sample_list = ["CF_057", "CF_059", "CF_060", "CF_039", "CF_045", "CF_044", "CF_058", "CF_96", "CF_94", "CF_95", "CF_018", "CF_082", "CF_019", "CF_087", "CF_092"]


# sf.plot.plot_community(world2.sel(position=position_ss, sample=_sample_list).drop_low_abundance_strains(0.05), row_linkage_func=lambda w: w.genotype.linkage())
sf.plot.plot_community(world2.sel(position=position_ss, strain=replace_strain_list, sample=_sample_list), row_linkage_func=lambda w: w.genotype.linkage(), col_linkage_func=lambda w: w.metagenotype.linkage())
sf.plot.plot_genotype(world2.sel(position=position_ss, strain=replace_strain_list, sample=_sample_list), col_linkage_func=lambda w: world.metagenotype.sel(position=position_ss).linkage("position"), row_linkage_func=lambda w: w.genotype.linkage())
sf.plot.plot_metagenotype(world.sel(position=position_ss, sample=_sample_list), row_linkage_func=lambda w: world.metagenotype.sel(position=position_ss).linkage("position"), transpose=True)

In [None]:
sf.plot.plot_genotype(world.sel(position=position_ss), col_linkage_func=lambda w: world.metagenotype.sel(position=position_ss).linkage("position"), row_linkage_func=lambda w: world.genotype.linkage())
sf.plot.plot_genotype(world_init.sel(position=position_ss), col_linkage_func=lambda w: world.metagenotype.sel(position=position_ss).linkage("position"), row_linkage_func=lambda w: world.genotype.linkage())
sf.plot.plot_genotype(world2.sel(position=position_ss), col_linkage_func=lambda w: world.metagenotype.sel(position=position_ss).linkage("position"), row_linkage_func=lambda w: world.genotype.linkage())

In [None]:
diff = world2.community.data - world.community.data

sns.clustermap(np.cbrt(diff.to_pandas().T), xticklabels=1, yticklabels=1, col_linkage=world.metagenotype.linkage(), row_linkage=world.genotype.linkage(), vmin=-1, vmax=1, center=0)