## Preamble

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import time
from itertools import chain, product
from tempfile import mkstemp
from warnings import filterwarnings

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable

# from fastcluster import linkage
from scipy.cluster.hierarchy import linkage
from scipy.spatial.distance import pdist, squareform
from statsmodels.graphics.regressionplots import influence_plot
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
import lib.thisproject.data
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
sns.set_context('notebook')
plt.rcParams["figure.dpi"] = 100

In [None]:
def _calculate_2tailed_pvalue_from_perm(obs, perms):
    hypoth_left = perms > obs
    hypoth_right = perms < obs
    null_p_left = (hypoth_left.sum() + 1) / (len(hypoth_left) + 1)
    null_p_right = (hypoth_right.sum() + 1) / (len(hypoth_right) + 1)
    return np.minimum(null_p_left, null_p_right) * 2

In [None]:
def linkage_order(linkage, labels):
    return labels[sp.cluster.hierarchy.to_tree(linkage).pre_order(lambda x: x.id)]


def is_prime(n):
    if n <= 1:
        return False
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True


def iterate_primes_up_to(n, return_index=False):
    n = int(np.ceil(n))
    idx = 0
    for i in range(n):
        if is_prime(i):
            if return_index:
                yield (idx, i)
            else:
                yield i
            idx += 1


def maximally_shuffled_order(sorted_order):
    n = len(sorted_order)
    primes_list = list(iterate_primes_up_to(np.sqrt(n)))
    table = pd.DataFrame(np.arange(n), index=sorted_order, columns=["original_order"])
    for prime in primes_list:
        table[prime] = table.original_order % prime
    table.sort_values(primes_list).original_order.values
    table = table.assign(new_order=table.sort_values(primes_list).original_order.values)
    z = table.sort_values("new_order").original_order.values
    table["delta"] = [np.nan] + list(z[1:] - z[:-1])
    return table.sort_values("new_order").index.to_list()

In [None]:
from sklearn.cross_decomposition import PLSCanonical, PLSRegression


def cross_correlation(x, y, metric='correlation'):
    cc = pd.DataFrame(
        1
        - sp.spatial.distance.cdist(
            x, y, metric=metric
        ),
        index=x.index,
        columns=y.index,
    )
    return cc

def pls_pseudo_mapping(x, y, scale=False, tol=1e-7, **kwargs):
    pls = PLSCanonical(n_components=min(*x.shape, *y.shape), scale=scale, tol=tol).fit(
        x.T, y.T
    )
    pseudo_map = pd.DataFrame(
        (pls.y_rotations_ @ pls.x_rotations_.T), index=y.index, columns=x.index
    )
    return pseudo_map.T

def reciprocal_hits(x, y, cc_func=lambda x, y: cross_correlation(x, y, metric='correlation'), alpha=1, rank_thresh=1, corr_thresh=0.0):
    x, y = align_indexes(x, y)
    
    cc = cc_func(x.T, y.T)
    
    cols = cc.columns
    idxs = cc.index

    x_weight = (x.mean().loc[idxs]) ** alpha
    y_weight = (y.mean().loc[cols]) ** alpha
    score = cc.multiply(x_weight, axis=0).multiply(y_weight, axis=1)
    
    cols_rank = score.rank(1, ascending=False)
    idxs_rank = score.rank(0, ascending=False)

    result = (cols_rank * idxs_rank).stack().to_frame('rank_product').assign(corr=cc.stack(), score=score.stack())[lambda x: (x.rank_product <= rank_thresh) & (x['corr'] >= corr_thresh)]
    
    return result

## Construct Metadata

In [None]:
pair_type_palette={'Transition': 'plum', 'EEN': 'pink', 'PostEEN': 'lightblue'}

diet_palette = {
    "EEN": "lightgreen",
    "PostEEN": "lightblue",
    "InVitro": "plum",
    "PreEEN": "lightpink",
}

subject_order = [
    "A",
    "B",
    "H",
    "C",
    "D",
    "E",
    "F",
    "G",
    "K",
    "L",
    "M",
    "N",
    "O",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "U",
]

# NOTE: Requires a dummy value because I want exactly 20 items.
subject_palette = lib.plot.construct_ordered_palette(
    subject_order + [f"dummy{i}" for i in range(20 - len(subject_order))], cm="tab20"
)
subject_palette["X"] = "black"
pair_type_order = ["EEN", "Transition", "PostEEN"]
pair_type_marker_palette = {"EEN": "s", "Transition": ">", "PostEEN": "o"}
pair_type_linestyle_palette = {"EEN": ":", "Transition": "-.", "PostEEN": "-"}

In [None]:
sample = (
    pd.read_table("meta/een-mgen/sample.tsv")
    .assign(
        label=lambda x: x[
            ["collection_date_relative_een_end", "diet_or_media", "sample_id"]
        ].apply(tuple, axis=1)
    )
    .set_index("sample_id")
)
subject = pd.read_table("meta/een-mgen/subject.tsv", index_col="subject_id")

In [None]:
rotu_counts = pd.read_table(
    "data/group/een/a.proc.zotu_counts.tsv", index_col="#OTU ID"
).rename_axis(index="zotu", columns="sample_id")
rotu_taxonomy0 = rotu_counts.taxonomy
rotu_counts = rotu_counts.drop(columns=["taxonomy"]).T
rotu_rabund = rotu_counts.divide(rotu_counts.sum(1), axis=0)

sample_rotu_bc_linkage = sp.cluster.hierarchy.linkage(
    rotu_rabund, method="average", metric="braycurtis", optimal_ordering=True
)

In [None]:
rotu_taxonomy = rotu_taxonomy0.str.split(';').apply(lambda x: pd.Series(x, index=['d__', 'p__', 'c__', 'o__', 'f__', 'g__', 's__']))

In [None]:
missing_samples = sorted(idxwhere(~rotu_counts.index.to_series().isin(sample.index)))
print(len(missing_samples), ", ".join(missing_samples))

In [None]:
x = rotu_rabund
row_colors = pd.DataFrame(
    dict(
        subj=sample.subject_id.map(subject_palette),
        swap=sample.index.to_series()
        .isin(["CF_11", "CF_15"])
        .replace({False: "grey", True: "black"}),
    )
)
row_linkage = sample_rotu_bc_linkage

sns.clustermap(
    rotu_rabund,
    norm=mpl.colors.PowerNorm(1 / 5),
    row_colors=row_colors,
    row_linkage=row_linkage,
)

In [None]:
pd.read_table(
    "data/group/een/r.proc.gtpro.species_depth.tsv")

In [None]:
gtpro_depth = (pd.read_table(
    "data/group/een/r.proc.gtpro.species_depth.tsv",
    index_col=['sample', "species_id"],
    )
    .depth.unstack(fill_value=0)
    .rename(columns=str, index=lambda x: "CF_" + str(int(x.split("_")[1])))
    .rename({'CF_15': 'CF_11', 'CF_11': 'CF_15'})  # Sample swap
)
gtpro_rabund = gtpro_depth.divide(gtpro_depth.sum(1), axis=0)

gtpro_rabund

In [None]:
motu_depth = (pd.read_table(
    "data/group/een/r.proc.gene99_new-v22-agg75.spgc_specgene-ref-t25-p95.species_depth.tsv",
    names=['sample', "species_id", 'depth'], index_col=['sample', "species_id"],
    )
    .depth.unstack(fill_value=0)
    .rename(columns=str, index=lambda x: "CF_" + str(int(x.split("_")[1])))
    .rename({'CF_15': 'CF_11', 'CF_11': 'CF_15'})  # Sample swap
)
motu_rabund = motu_depth.divide(motu_depth.sum(1), axis=0)

motu_rabund

In [None]:
x, y = align_indexes(motu_rabund, rotu_rabund)


x_linkage = linkage(x, method="average", metric="braycurtis", optimal_ordering=True)
y_linkage = linkage(y, method="average", metric="braycurtis", optimal_ordering=True)
colors = pd.DataFrame(
    dict(
        subj=sample.subject_id.map(subject_palette),
        swap=sample.index.to_series()
        .isin(["CF_11", "CF_15"])
        .replace({False: "grey", True: "black"}),
    )
)

x_pdist = pd.DataFrame(
    squareform(pdist(x, metric="braycurtis")), index=x.index, columns=x.index
)
sns.clustermap(
    x_pdist,
    row_linkage=y_linkage,
    col_linkage=x_linkage,
    row_colors=colors,
    col_colors=colors,
)

In [None]:
x, y = align_indexes(motu_rabund, gtpro_rabund)


x_linkage = linkage(x, method="average", metric="braycurtis", optimal_ordering=True)
y_linkage = linkage(y, method="average", metric="braycurtis", optimal_ordering=True)
colors = pd.DataFrame(
    dict(
        subj=sample.subject_id.map(subject_palette),
        swap=sample.index.to_series()
        .isin(["CF_11", "CF_15"])
        .replace({False: "grey", True: "black"}),
    )
)

x_pdist = pd.DataFrame(
    squareform(pdist(x, metric="braycurtis")), index=x.index, columns=x.index
)
sns.clustermap(
    x_pdist,
    row_linkage=y_linkage,
    col_linkage=x_linkage,
    row_colors=colors,
    col_colors=colors,
)

In [None]:
bins = np.linspace(0, 30_000, num=200)

fig, axs = plt.subplots(2, sharex=True)

for (title, x), ax in zip(
    dict(
        total_depth_by_sample=motu_depth.sum(1),
        total_depth_by_species=motu_depth.sum(0),
    ).items(),
    axs.flatten(),
):
    ax.hist(x, bins=np.logspace(-1, 5, num=100))
    ax.set_title(title)
    ax.set_xscale("log")
fig.tight_layout()

In [None]:
motu_rabund.mean().sort_values(ascending=False).head(20)

In [None]:
n_species = 10
top_motus = (
    (motu_rabund > 1e-5).sum().sort_values(ascending=False).head(n_species).index
)

fig, axs = plt.subplots(
    n_species, figsize=(5, 0.3 * n_species), sharex=True, sharey=True
)

bins = np.logspace(-8, 1, num=51)

for species_id, ax in zip(top_motus, axs):
    # ax.hist(rabund_subset[species_id], bins=bins, alpha=0.7)
    ax.hist(motu_rabund[species_id], bins=bins, alpha=0.7)
    ax.set_xscale("log")
    prevalence = (motu_rabund[species_id] > 1e-5).mean()
    ax.set_title("")
    # ax.set_xticks()
    # ax.set_yticks()
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.patch.set_alpha(0.0)
    for spine in ["left", "right", "top", "bottom"]:
        ax.spines[spine].set_visible(False)
    ax.annotate(
        f"{species_id} ({prevalence:0.0%})",
        xy=(0.05, 0.1),
        ha="left",
        xycoords="axes fraction",
    )
    ax.set_xlim(left=1e-9)
    ax.set_ylim(top=20)
    ax.axvline(1e-5, lw=1, linestyle=":", color="k")

ax.xaxis.set_visible(True)
ax.spines["bottom"].set_visible(True)
ax.set_xticks([1e-4, 1e-2, 1e-0])
ax.set_xticklabels(["0.01%", "1%", "100%"])
ax.set_xlabel("Relative Abundance")

# fig.subplots_adjust(hspace=-0.75)

In [None]:
def parse_taxonomy_string(taxonomy_string):
    values = taxonomy_string.split(";")
    return pd.Series(values, index=["d__", "p__", "c__", "o__", "f__", "g__", "s__"])

In [None]:
motu_taxonomy_inpath = "ref/uhgg_genomes_all_v2.tsv"

_motu_taxonomy = (
    pd.read_table(motu_taxonomy_inpath)[lambda x: x.Genome == x.Species_rep]
    .assign(species_id=lambda x: "1" + x.MGnify_accession.str.split("-").str[2])
    .set_index("species_id")
)

# motu_lineage_string = _motu_taxonomy.Lineage

motu_taxonomy = _motu_taxonomy.Lineage.apply(
    parse_taxonomy_string
)  # .assign(taxonomy_string=motu_lineage_string)
motu_taxonomy

In [None]:
for _species_id in top_motus.astype(str):
    print(_species_id, ":", ";".join(motu_taxonomy.loc[_species_id].values))

## Statistical Linkage Between zOTUs and GT-Pro Species

### "Enterobacteriaceae"

In [None]:
focal_family_list = ["Enterobacteriaceae"]
# rotu_taxonomy0.loc[focal_zotu].str.split(";").str[-3].unique()
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]

In [None]:
x_taxa = idxwhere(rotu_taxonomy0.str.contains(f";{focal_family};"))
y_taxa = idxwhere(
    motu_taxonomy.apply(lambda x: ";".join(x), axis=1).str.contains(
        f"f__{focal_family};"
    )
    & motu_taxonomy.index.to_series().isin(motu_rabund.columns)
)

In [None]:
rotu_taxonomy0.loc[x_taxa].reset_index().values

In [None]:
motu_taxonomy.loc[y_taxa]

In [None]:
from sklearn.cross_decomposition import PLSCanonical, PLSRegression

x, y = align_indexes(rotu_rabund[x_taxa], motu_rabund[y_taxa])
pls = PLSCanonical(n_components=min(*x.shape, *y.shape), scale=False, tol=1e-7).fit(
    x, y
)

In [None]:
pseudo_map = pd.DataFrame(
    (pls.y_rotations_ @ pls.x_rotations_.T), index=y.columns, columns=x.columns
)
pseudo_map.shape

In [None]:
plt.hist(pseudo_map.values.flatten(), bins=1000)
plt.yscale("log")
None

In [None]:
cross_correlation = pd.DataFrame(
    1
    - sp.spatial.distance.cdist(
        x.T, y.T, metric=lambda x, y: 1 - sp.stats.pearsonr(x, y)[0]
    ),
    columns=y.columns,
    index=x.columns,
)

In [None]:
sns.clustermap(
    pseudo_map.fillna(0).T, center=0, figsize=(5, 5), xticklabels=1, yticklabels=1,
)
sns.clustermap(
    cross_correlation.fillna(0), center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)

In [None]:
d = pd.DataFrame(
    dict(
        pls=pseudo_map.stack(),
        cc=cross_correlation.T.stack(),
        pls_rank_A=pseudo_map.rank(axis=0, ascending=False).stack(),
        pls_rank_B=pseudo_map.rank(axis=1, ascending=False).stack(),
        cc_rank_A=cross_correlation.T.rank(axis=0, ascending=False).stack(),
        cc_rank_B=cross_correlation.T.rank(axis=1, ascending=False).stack(),
    )
).join(motu_rabund.mean().rename("motu_mean")).join(rotu_rabund.mean().rename("rotu_mean"))

In [None]:
# Reciprocal best hits

motu_list = cross_correlation.columns
zotu_list = cross_correlation.index

motu_rank = cross_correlation.multiply(motu_rabund.mean().loc[motu_list], axis=1).rank(1, ascending=False)
zotu_rank = cross_correlation.multiply(rotu_rabund.mean().loc[zotu_list], axis=0).rank(0, ascending=False)

(motu_rank * zotu_rank).unstack()[lambda x: x==1]

In [None]:
z, m = ('Zotu4', '102506')

plt.scatter('x', 'y', data=pd.DataFrame(dict(x=rotu_rabund[z], y=motu_rabund[m])).dropna())
cross_correlation.loc[z, m]

In [None]:
reciprocal_hits(
    motu_rabund[y_taxa],
    rotu_rabund[x_taxa],
    metric='correlation',
    # corr_func=sp.stats.pearsonr,
)

In [None]:
focal_family_list = ["Enterobacteriaceae"]
# rotu_taxonomy.loc[focal_zotu].str.split(";").str[-3].unique()
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]

In [None]:
x_taxa = idxwhere(rotu_taxonomy0.str.contains(f";{focal_family};"))
y_taxa = idxwhere(
    motu_taxonomy.apply(lambda x: ";".join(x), axis=1).str.contains(
        f"f__{focal_family};"
    )
    & motu_taxonomy.index.to_series().isin(motu_rabund.columns)
)

#### Zotu4

In [None]:
focal_zotu = "Zotu4"

_selection_data = (
    d.reorder_levels(["zotu", "species_id"])
    .loc[[focal_zotu]]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

_selection_data

In [None]:
d.assign(score=lambda x: np.log10(x.pls * x.motu_mean * x.rotu_mean)).score.groupby(level='zotu').apply(lambda x: {'max_score': x.max(), 'motu': x.idxmax()[0]}).unstack().sort_values('max_score', ascending=False)

In [None]:
focal_motu = _selection_data.pls_score.idxmax()[1]
focal_motu

In [None]:
plt.scatter(x[focal_zotu], y[focal_motu])
print(sp.stats.pearsonr(x[focal_zotu], y[focal_motu]))
plt.plot([0, 0.5], [0, 0.5])

#### 102351

In [None]:
focal_motu = '102351'

(
    d.reorder_levels(["species_id", "zotu"])
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .join(rotu_taxonomy0.str[-20:])
    .loc[focal_motu]
)

### "Eggerthellaceae" (Zotu172)

In [None]:
focal_zotu = ["Zotu172"]
focal_family_list = rotu_taxonomy0.loc[focal_zotu].str.split(";").str[-3].unique()
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]
print(focal_family)

In [None]:
x_taxa = idxwhere(rotu_taxonomy0.str.contains(f";{focal_family};"))
y_taxa = idxwhere(
    motu_taxonomy.apply(lambda x: ";".join(x), axis=1).str.contains(
        f"f__{focal_family};"
    )
    & motu_taxonomy.index.to_series().isin(motu_rabund.columns)
)

In [None]:
rotu_taxonomy0.loc[x_taxa].reset_index().values

In [None]:
motu_taxonomy.loc[y_taxa]

In [None]:
from sklearn.cross_decomposition import PLSCanonical, PLSRegression

x, y = align_indexes(rotu_rabund[x_taxa], motu_rabund[y_taxa])
pls = PLSCanonical(n_components=min(*x.shape, *y.shape), scale=False, tol=1e-7).fit(
    x, y
)

In [None]:
pseudo_map = pd.DataFrame(
    (pls.y_rotations_ @ pls.x_rotations_.T), index=y.columns, columns=x.columns
)
pseudo_map.shape

In [None]:
plt.hist(pseudo_map.values.flatten(), bins=1000)
plt.yscale("log")
None

In [None]:
cross_correlation = pd.DataFrame(
    1
    - sp.spatial.distance.cdist(
        x.T, y.T, metric=lambda x, y: 1 - sp.stats.pearsonr(x, y)[0]
    ),
    columns=y.columns,
    index=x.columns,
)

In [None]:
sns.clustermap(
    pseudo_map.fillna(0).T, center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)
sns.clustermap(
    cross_correlation.fillna(0), center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)

In [None]:
d = pd.DataFrame(
    dict(
        pls=pseudo_map.stack(),
        cc=cross_correlation.T.stack(),
        pls_rank_A=pseudo_map.rank(axis=0, ascending=False).stack(),
        pls_rank_B=pseudo_map.rank(axis=1, ascending=False).stack(),
        cc_rank_A=cross_correlation.T.rank(axis=0, ascending=False).stack(),
        cc_rank_B=cross_correlation.T.rank(axis=1, ascending=False).stack(),
    )
)

In [None]:
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

### "Lachnospiraceae"

#### Zotu114

In [None]:
focal_zotu = ["Zotu114"]
focal_family_list = rotu_taxonomy0.loc[focal_zotu].str.split(";").str[-3].unique()
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]
print(focal_family)

In [None]:
x_taxa = idxwhere(rotu_taxonomy0.str.contains(f";{focal_family};"))
y_taxa = idxwhere(
    motu_taxonomy.apply(lambda x: ";".join(x), axis=1).str.contains(
        f"f__{focal_family};"
    )
    & motu_taxonomy.index.to_series().isin(motu_rabund.columns)
)

In [None]:
rotu_taxonomy0.loc[x_taxa].reset_index().values

In [None]:
motu_taxonomy.loc[y_taxa]

In [None]:
from sklearn.cross_decomposition import PLSCanonical, PLSRegression

x, y = align_indexes(rotu_rabund[x_taxa], motu_rabund[y_taxa])
pls = PLSCanonical(n_components=min(*x.shape, *y.shape), scale=False, tol=1e-7).fit(
    x, y
)

In [None]:
pseudo_map = pd.DataFrame(
    (pls.y_rotations_ @ pls.x_rotations_.T), index=y.columns, columns=x.columns
)
pseudo_map.shape

In [None]:
plt.hist(pseudo_map.values.flatten(), bins=1000)
plt.yscale("log")
None

In [None]:
cross_correlation = pd.DataFrame(
    1
    - sp.spatial.distance.cdist(
        x.T, y.T, metric=lambda x, y: 1 - sp.stats.pearsonr(x, y)[0]
    ),
    columns=y.columns,
    index=x.columns,
)

In [None]:
sns.clustermap(
    pseudo_map.fillna(0).T, center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)
sns.clustermap(
    cross_correlation.fillna(0), center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)

In [None]:
d = pd.DataFrame(
    dict(
        pls=pseudo_map.stack(),
        cc=cross_correlation.T.stack(),
        pls_rank_A=pseudo_map.rank(axis=0, ascending=False).stack(),
        pls_rank_B=pseudo_map.rank(axis=1, ascending=False).stack(),
        cc_rank_A=cross_correlation.T.rank(axis=0, ascending=False).stack(),
        cc_rank_B=cross_correlation.T.rank(axis=1, ascending=False).stack(),
    )
)

In [None]:
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu25

In [None]:
focal_zotu = ["Zotu25"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu5

In [None]:
focal_zotu = ["Zotu5"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu10

In [None]:
focal_zotu = ["Zotu10"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu160

In [None]:
focal_zotu = ["Zotu160"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu31

In [None]:
focal_zotu = ["Zotu31"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu155

In [None]:
focal_zotu = ["Zotu155"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu75

In [None]:
focal_zotu = ["Zotu75"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu13

In [None]:
focal_zotu = ["Zotu13"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu155

In [None]:
focal_zotu = ["Zotu155"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu27

In [None]:
focal_zotu = ["Zotu27"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu104

In [None]:
focal_zotu = ["Zotu104"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

##### Alternative zOTUs

In [None]:
(
    d
    .loc[['100205']]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu167

In [None]:
focal_zotu = ["Zotu167"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("cc", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu26

In [None]:
focal_zotu = ["Zotu26"]

(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

In [None]:
focal_zotu = ["Zotu100"]
focal_family_list = rotu_taxonomy0.loc[focal_zotu].str.split(";").str[-3].unique()
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]
print(focal_family)

### "Erysipelatoclostridiaceae" (Zotu46)

In [None]:
focal_zotu = ["Zotu46"]
focal_family_list = rotu_taxonomy0.loc[focal_zotu].str.split(";").str[-3].unique()
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]
print(focal_family)

In [None]:
x_taxa = idxwhere(rotu_taxonomy0.str.contains(f";{focal_family};"))
y_taxa = idxwhere(
    motu_taxonomy.apply(lambda x: ";".join(x), axis=1).str.contains(
        f"f__{focal_family};"
    )
    & motu_taxonomy.index.to_series().isin(motu_rabund.columns)
)

In [None]:
rotu_taxonomy0.loc[x_taxa].reset_index().values

In [None]:
motu_taxonomy.loc[y_taxa]

In [None]:
from sklearn.cross_decomposition import PLSCanonical, PLSRegression

x, y = align_indexes(rotu_rabund[x_taxa], motu_rabund[y_taxa])
pls = PLSCanonical(n_components=min(*x.shape, *y.shape), scale=False, tol=1e-7).fit(
    x, y
)

In [None]:
pseudo_map = pd.DataFrame(
    (pls.y_rotations_ @ pls.x_rotations_.T), index=y.columns, columns=x.columns
)
pseudo_map.shape

In [None]:
plt.hist(pseudo_map.values.flatten(), bins=1000)
plt.yscale("log")
None

In [None]:
cross_correlation = pd.DataFrame(
    1
    - sp.spatial.distance.cdist(
        x.T, y.T, metric=lambda x, y: 1 - sp.stats.pearsonr(x, y)[0]
    ),
    columns=y.columns,
    index=x.columns,
)

In [None]:
sns.clustermap(
    pseudo_map.fillna(0).T, center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)
sns.clustermap(
    cross_correlation.fillna(0), center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)

In [None]:
d = pd.DataFrame(
    dict(
        pls=pseudo_map.stack(),
        cc=cross_correlation.T.stack(),
        pls_rank_A=pseudo_map.rank(axis=0, ascending=False).stack(),
        pls_rank_B=pseudo_map.rank(axis=1, ascending=False).stack(),
        cc_rank_A=cross_correlation.T.rank(axis=0, ascending=False).stack(),
        cc_rank_B=cross_correlation.T.rank(axis=1, ascending=False).stack(),
    )
)

In [None]:
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

### "Oscillospiraceae" (Zotu49)

In [None]:
focal_zotu = ["Zotu49"]
focal_family_list = rotu_taxonomy0.loc[focal_zotu].str.split(";").str[-3].unique()
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]
print(focal_family)

In [None]:
x_taxa = idxwhere(rotu_taxonomy0.str.contains(f";{focal_family};"))
y_taxa = idxwhere(
    motu_taxonomy.apply(lambda x: ";".join(x), axis=1).str.contains(
        f"f__{focal_family};"
    )
    & motu_taxonomy.index.to_series().isin(motu_rabund.columns)
)

In [None]:
rotu_taxonomy0.loc[x_taxa].reset_index().values

In [None]:
motu_taxonomy.loc[y_taxa]

In [None]:
from sklearn.cross_decomposition import PLSCanonical, PLSRegression

x, y = align_indexes(rotu_rabund[x_taxa], motu_rabund[y_taxa])
pls = PLSCanonical(n_components=min(*x.shape, *y.shape), scale=False, tol=1e-7).fit(
    x, y
)

In [None]:
pseudo_map = pd.DataFrame(
    (pls.y_rotations_ @ pls.x_rotations_.T), index=y.columns, columns=x.columns
)
pseudo_map.shape

In [None]:
plt.hist(pseudo_map.values.flatten(), bins=1000)
plt.yscale("log")
None

In [None]:
cross_correlation = pd.DataFrame(
    1
    - sp.spatial.distance.cdist(
        x.T, y.T, metric=lambda x, y: 1 - sp.stats.pearsonr(x, y)[0]
    ),
    columns=y.columns,
    index=x.columns,
)

In [None]:
sns.clustermap(
    pseudo_map.fillna(0).T, center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)
sns.clustermap(
    cross_correlation.fillna(0), center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)

In [None]:
d = pd.DataFrame(
    dict(
        pls=pseudo_map.stack(),
        cc=cross_correlation.T.stack(),
        pls_rank_A=pseudo_map.rank(axis=0, ascending=False).stack(),
        pls_rank_B=pseudo_map.rank(axis=1, ascending=False).stack(),
        cc_rank_A=cross_correlation.T.rank(axis=0, ascending=False).stack(),
        cc_rank_B=cross_correlation.T.rank(axis=1, ascending=False).stack(),
    )
)

In [None]:
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

### "Erysipelotrichaceae" (Zotu34)

In [None]:
focal_zotu = ["Zotu34"]
focal_family_list = rotu_taxonomy0.loc[focal_zotu].str.split(";").str[-3].unique()
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]
print(focal_family)

In [None]:
x_taxa = idxwhere(rotu_taxonomy0.str.contains(f";{focal_family};"))
y_taxa = idxwhere(
    motu_taxonomy.apply(lambda x: ";".join(x), axis=1).str.contains(
        f"f__{focal_family};"
    )
    & motu_taxonomy.index.to_series().isin(motu_rabund.columns)
)

In [None]:
rotu_taxonomy0.loc[x_taxa].reset_index().values

In [None]:
motu_taxonomy.loc[y_taxa]

In [None]:
from sklearn.cross_decomposition import PLSCanonical, PLSRegression

x, y = align_indexes(rotu_rabund[x_taxa], motu_rabund[y_taxa])
pls = PLSCanonical(n_components=min(*x.shape, *y.shape), scale=False, tol=1e-7).fit(
    x, y
)

In [None]:
pseudo_map = pd.DataFrame(
    (pls.y_rotations_ @ pls.x_rotations_.T), index=y.columns, columns=x.columns
)
pseudo_map.shape

In [None]:
plt.hist(pseudo_map.values.flatten(), bins=1000)
plt.yscale("log")
None

In [None]:
cross_correlation = pd.DataFrame(
    1
    - sp.spatial.distance.cdist(
        x.T, y.T, metric=lambda x, y: 1 - sp.stats.pearsonr(x, y)[0]
    ),
    columns=y.columns,
    index=x.columns,
)

In [None]:
sns.clustermap(
    pseudo_map.fillna(0).T, center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)
sns.clustermap(
    cross_correlation.fillna(0), center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)

In [None]:
d = pd.DataFrame(
    dict(
        pls=pseudo_map.stack(),
        cc=cross_correlation.T.stack(),
        pls_rank_A=pseudo_map.rank(axis=0, ascending=False).stack(),
        pls_rank_B=pseudo_map.rank(axis=1, ascending=False).stack(),
        cc_rank_A=cross_correlation.T.rank(axis=0, ascending=False).stack(),
        cc_rank_B=cross_correlation.T.rank(axis=1, ascending=False).stack(),
    )
)

In [None]:
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

### "Ruminococcaceae"

#### Zotu78

In [None]:
focal_zotu = ["Zotu78"]
focal_family_list = rotu_taxonomy0.loc[focal_zotu].str.split(";").str[-3].unique()
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]
print(focal_family)

In [None]:
x_taxa = idxwhere(rotu_taxonomy0.str.contains(f";{focal_family};"))
y_taxa = idxwhere(
    motu_taxonomy.apply(lambda x: ";".join(x), axis=1).str.contains(
        f"f__{focal_family};"
    )
    & motu_taxonomy.index.to_series().isin(motu_rabund.columns)
)

In [None]:
rotu_taxonomy0.loc[x_taxa].reset_index().values

In [None]:
motu_taxonomy.loc[y_taxa]

In [None]:
from sklearn.cross_decomposition import PLSCanonical, PLSRegression

x, y = align_indexes(rotu_rabund[x_taxa], motu_rabund[y_taxa])
pls = PLSCanonical(n_components=min(*x.shape, *y.shape), scale=False, tol=1e-7).fit(
    x, y
)

In [None]:
pseudo_map = pd.DataFrame(
    (pls.y_rotations_ @ pls.x_rotations_.T), index=y.columns, columns=x.columns
)
pseudo_map.shape

In [None]:
plt.hist(pseudo_map.values.flatten(), bins=1000)
plt.yscale("log")
None

In [None]:
cross_correlation = pd.DataFrame(
    1
    - sp.spatial.distance.cdist(
        x.T, y.T, metric=lambda x, y: 1 - sp.stats.pearsonr(x, y)[0]
    ),
    columns=y.columns,
    index=x.columns,
)

In [None]:
sns.clustermap(
    pseudo_map.fillna(0).T, center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)
sns.clustermap(
    cross_correlation.fillna(0), center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)

In [None]:
d = pd.DataFrame(
    dict(
        pls=pseudo_map.stack(),
        cc=cross_correlation.T.stack(),
        pls_rank_A=pseudo_map.rank(axis=0, ascending=False).stack(),
        pls_rank_B=pseudo_map.rank(axis=1, ascending=False).stack(),
        cc_rank_A=cross_correlation.T.rank(axis=0, ascending=False).stack(),
        cc_rank_B=cross_correlation.T.rank(axis=1, ascending=False).stack(),
    )
)

In [None]:
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu16

In [None]:
focal_zotu = ["Zotu16"]
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

In [None]:
motu_taxonomy.g__.loc['102040']

#### Zotu9

In [None]:
focal_zotu = ["Zotu9"]
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

In [None]:
motu_taxonomy.loc['103166']

### "Oscillospiraceae" (Zotu13)

In [None]:
focal_zotu = ["Zotu13"]
focal_family_list = rotu_taxonomy0.loc[focal_zotu].str.split(";").str[-3].unique()
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]
print(focal_family)

In [None]:
x_taxa = idxwhere(rotu_taxonomy0.str.contains(f";{focal_family};"))
y_taxa = idxwhere(
    motu_taxonomy.apply(lambda x: ";".join(x), axis=1).str.contains(
        f"f__{focal_family};"
    )
    & motu_taxonomy.index.to_series().isin(motu_rabund.columns)
)

In [None]:
rotu_taxonomy0.loc[x_taxa].reset_index().values

In [None]:
motu_taxonomy.loc[y_taxa]

In [None]:
from sklearn.cross_decomposition import PLSCanonical, PLSRegression

x, y = align_indexes(rotu_rabund[x_taxa], motu_rabund[y_taxa])
pls = PLSCanonical(n_components=min(*x.shape, *y.shape), scale=False, tol=1e-7).fit(
    x, y
)

In [None]:
pseudo_map = pd.DataFrame(
    (pls.y_rotations_ @ pls.x_rotations_.T), index=y.columns, columns=x.columns
)
pseudo_map.shape

In [None]:
plt.hist(pseudo_map.values.flatten(), bins=1000)
plt.yscale("log")
None

In [None]:
cross_correlation = pd.DataFrame(
    1
    - sp.spatial.distance.cdist(
        x.T, y.T, metric=lambda x, y: 1 - sp.stats.pearsonr(x, y)[0]
    ),
    columns=y.columns,
    index=x.columns,
)

In [None]:
sns.clustermap(
    pseudo_map.fillna(0).T, center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)
sns.clustermap(
    cross_correlation.fillna(0), center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)

In [None]:
d = pd.DataFrame(
    dict(
        pls=pseudo_map.stack(),
        cc=cross_correlation.T.stack(),
        pls_rank_A=pseudo_map.rank(axis=0, ascending=False).stack(),
        pls_rank_B=pseudo_map.rank(axis=1, ascending=False).stack(),
        cc_rank_A=cross_correlation.T.rank(axis=0, ascending=False).stack(),
        cc_rank_B=cross_correlation.T.rank(axis=1, ascending=False).stack(),
    )
)

In [None]:
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

### "Bacteroidaceae"

####  Zotu74

In [None]:
focal_zotu = ["Zotu74"]
focal_family_list = rotu_taxonomy0.loc[focal_zotu].str.split(";").str[-3].unique()
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]
print(focal_family)

In [None]:
x_taxa = idxwhere(rotu_taxonomy0.str.contains(f";{focal_family};"))
y_taxa = idxwhere(
    motu_taxonomy.apply(lambda x: ";".join(x), axis=1).str.contains(
        f"f__{focal_family};"
    )
    & motu_taxonomy.index.to_series().isin(motu_rabund.columns)
)

In [None]:
rotu_taxonomy0.loc[x_taxa].reset_index().values

In [None]:
motu_taxonomy.loc[y_taxa]

In [None]:
from sklearn.cross_decomposition import PLSCanonical, PLSRegression

x, y = align_indexes(rotu_rabund[x_taxa], motu_rabund[y_taxa])
pls = PLSCanonical(n_components=min(*x.shape, *y.shape), scale=False, tol=1e-7).fit(
    x, y
)

In [None]:
pseudo_map = pd.DataFrame(
    (pls.y_rotations_ @ pls.x_rotations_.T), index=y.columns, columns=x.columns
)
pseudo_map.shape

In [None]:
plt.hist(pseudo_map.values.flatten(), bins=1000)
plt.yscale("log")
None

In [None]:
cross_correlation = pd.DataFrame(
    1
    - sp.spatial.distance.cdist(
        x.T, y.T, metric=lambda x, y: 1 - sp.stats.pearsonr(x, y)[0]
    ),
    columns=y.columns,
    index=x.columns,
)

In [None]:
sns.clustermap(
    pseudo_map.fillna(0).T, center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)
sns.clustermap(
    cross_correlation.fillna(0), center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)

In [None]:
d = pd.DataFrame(
    dict(
        pls=pseudo_map.stack(),
        cc=cross_correlation.T.stack(),
        pls_rank_A=pseudo_map.rank(axis=0, ascending=False).stack(),
        pls_rank_B=pseudo_map.rank(axis=1, ascending=False).stack(),
        cc_rank_A=cross_correlation.T.rank(axis=0, ascending=False).stack(),
        cc_rank_B=cross_correlation.T.rank(axis=1, ascending=False).stack(),
    )
)

In [None]:
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

##### Deeper dive into alternative Zotus to the same species

In [None]:
(
    d
    .loc[['102549']]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

#### Zotu6

In [None]:
focal_zotu = ["Zotu6"]
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

In [None]:
motu_taxonomy[['g__', 's__']].loc['101346']

#### Zotu12

In [None]:
focal_zotu = ["Zotu12"]
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

In [None]:
motu_taxonomy[['g__', 's__']].loc['101337']

#### 102478

In [None]:
focal_motu = '102478'

print(motu_taxonomy.loc[focal_motu])

(
    d.reorder_levels(["species_id", "zotu"])
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .join(rotu_taxonomy0.str[-20:])
    .loc[focal_motu]
)

#### 101378

In [None]:
focal_motu = '101378'

print(motu_taxonomy.loc[focal_motu])

(
    d.reorder_levels(["species_id", "zotu"])
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .join(rotu_taxonomy0.str[-20:])
    .loc[focal_motu]
)

In [None]:
motu_taxonomy.loc[focal_motu]

### "Peptostreptococcaceae" (Zotu100)

In [None]:
focal_zotu = ["Zotu100"]
focal_family_list = rotu_taxonomy0.loc[focal_zotu].str.split(";").str[-3].unique()
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]
print(focal_family)

In [None]:
x_taxa = idxwhere(rotu_taxonomy0.str.contains(f";{focal_family};"))
y_taxa = idxwhere(
    motu_taxonomy.apply(lambda x: ";".join(x), axis=1).str.contains(
        f"f__{focal_family};"
    )
    & motu_taxonomy.index.to_series().isin(motu_rabund.columns)
)

In [None]:
rotu_taxonomy0.loc[x_taxa].reset_index().values

In [None]:
motu_taxonomy.loc[y_taxa]

In [None]:
from sklearn.cross_decomposition import PLSCanonical, PLSRegression

x, y = align_indexes(rotu_rabund[x_taxa], motu_rabund[y_taxa])
pls = PLSCanonical(n_components=min(*x.shape, *y.shape), scale=False, tol=1e-7).fit(
    x, y
)

In [None]:
pseudo_map = pd.DataFrame(
    (pls.y_rotations_ @ pls.x_rotations_.T), index=y.columns, columns=x.columns
)
pseudo_map.shape

In [None]:
plt.hist(pseudo_map.values.flatten(), bins=1000)
plt.yscale("log")
None

In [None]:
cross_correlation = pd.DataFrame(
    1
    - sp.spatial.distance.cdist(
        x.T, y.T, metric=lambda x, y: 1 - sp.stats.pearsonr(x, y)[0]
    ),
    columns=y.columns,
    index=x.columns,
)

In [None]:
sns.clustermap(
    pseudo_map.fillna(0).T, center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)
sns.clustermap(
    cross_correlation.fillna(0), center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)

In [None]:
d = pd.DataFrame(
    dict(
        pls=pseudo_map.stack(),
        cc=cross_correlation.T.stack(),
        pls_rank_A=pseudo_map.rank(axis=0, ascending=False).stack(),
        pls_rank_B=pseudo_map.rank(axis=1, ascending=False).stack(),
        cc_rank_A=cross_correlation.T.rank(axis=0, ascending=False).stack(),
        cc_rank_B=cross_correlation.T.rank(axis=1, ascending=False).stack(),
    )
)

In [None]:
(
    d.reorder_levels(["zotu", "species_id"])
    .loc[focal_zotu]
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .head(10)
    .join(motu_taxonomy.s__)
)

### "Enterococcaceae" (100323)

In [None]:
focal_motu = ["100323"]
focal_family_list = list(motu_taxonomy.loc[focal_motu].f__.str[3:].values)
focal_family_list
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]
print(focal_family)

In [None]:
x_taxa = idxwhere(rotu_taxonomy0.str.contains(f";{focal_family};"))
y_taxa = idxwhere(
    motu_taxonomy.apply(lambda x: ";".join(x), axis=1).str.contains(
        f"f__{focal_family};"
    )
    & motu_taxonomy.index.to_series().isin(motu_rabund.columns)
)

In [None]:
rotu_taxonomy0.loc[x_taxa].reset_index().values

In [None]:
motu_taxonomy.loc[y_taxa]

In [None]:
from sklearn.cross_decomposition import PLSCanonical, PLSRegression

x, y = align_indexes(rotu_rabund[x_taxa], motu_rabund[y_taxa])
pls = PLSCanonical(n_components=min(*x.shape, *y.shape), scale=False, tol=1e-7).fit(
    x, y
)

In [None]:
pseudo_map = pd.DataFrame(
    (pls.y_rotations_ @ pls.x_rotations_.T), index=y.columns, columns=x.columns
)
pseudo_map.shape

In [None]:
plt.hist(pseudo_map.values.flatten(), bins=1000)
plt.yscale("log")
None

In [None]:
cross_correlation = pd.DataFrame(
    1
    - sp.spatial.distance.cdist(
        x.T, y.T, metric=lambda x, y: 1 - sp.stats.pearsonr(x, y)[0]
    ),
    columns=y.columns,
    index=x.columns,
)

In [None]:
sns.clustermap(
    pseudo_map.fillna(0).T, center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)
sns.clustermap(
    cross_correlation.fillna(0), center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)

In [None]:
d = pd.DataFrame(
    dict(
        pls=pseudo_map.stack(),
        cc=cross_correlation.T.stack(),
        pls_rank_A=pseudo_map.rank(axis=0, ascending=False).stack(),
        pls_rank_B=pseudo_map.rank(axis=1, ascending=False).stack(),
        cc_rank_A=cross_correlation.T.rank(axis=0, ascending=False).stack(),
        cc_rank_B=cross_correlation.T.rank(axis=1, ascending=False).stack(),
    )
)

In [None]:
(
    d.reorder_levels(["species_id", "zotu"])
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .join(rotu_taxonomy0.str[-20:])
    .loc[focal_motu]
)

### "f__Desulfovibrionaceae" (101359)

In [None]:
focal_motu = ["101359"]
focal_family_list = list(motu_taxonomy.loc[focal_motu].f__.str[3:].values)
focal_family_list
assert len(focal_family_list) == 1
focal_family = focal_family_list[0]
print(focal_family)

In [None]:
x_taxa = idxwhere(rotu_taxonomy0.str.contains(f";{focal_family};"))
y_taxa = idxwhere(
    motu_taxonomy.apply(lambda x: ";".join(x), axis=1).str.contains(
        f"f__{focal_family};"
    )
    & motu_taxonomy.index.to_series().isin(motu_rabund.columns)
)

In [None]:
rotu_taxonomy0.loc[x_taxa].reset_index().values

In [None]:
motu_taxonomy.loc[y_taxa]

In [None]:
from sklearn.cross_decomposition import PLSCanonical, PLSRegression

x, y = align_indexes(rotu_rabund[x_taxa], motu_rabund[y_taxa])
pls = PLSCanonical(n_components=min(*x.shape, *y.shape), scale=False, tol=1e-7).fit(
    x, y
)

In [None]:
pseudo_map = pd.DataFrame(
    (pls.y_rotations_ @ pls.x_rotations_.T), index=y.columns, columns=x.columns
)
pseudo_map.shape

In [None]:
plt.hist(pseudo_map.values.flatten(), bins=1000)
plt.yscale("log")
None

In [None]:
cross_correlation = pd.DataFrame(
    1
    - sp.spatial.distance.cdist(
        x.T, y.T, metric=lambda x, y: 1 - sp.stats.pearsonr(x, y)[0]
    ),
    columns=y.columns,
    index=x.columns,
)

In [None]:
sns.clustermap(
    pseudo_map.fillna(0).T, center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)
sns.clustermap(
    cross_correlation.fillna(0), center=0, figsize=(5, 5), xticklabels=1, yticklabels=1
)

In [None]:
d = pd.DataFrame(
    dict(
        pls=pseudo_map.stack(),
        cc=cross_correlation.T.stack(),
        pls_rank_A=pseudo_map.rank(axis=0, ascending=False).stack(),
        pls_rank_B=pseudo_map.rank(axis=1, ascending=False).stack(),
        cc_rank_A=cross_correlation.T.rank(axis=0, ascending=False).stack(),
        cc_rank_B=cross_correlation.T.rank(axis=1, ascending=False).stack(),
    )
)

In [None]:
(
    d.reorder_levels(["species_id", "zotu"])
    .join(rotu_rabund[x_taxa].mean().rename("rotu_mean"))
    .join(motu_rabund[y_taxa].mean().rename("motu_mean"))
    .assign(pls_score=lambda x: x.pls * x.rotu_mean * x.motu_mean)
    .sort_values("pls_score", ascending=False)
    .join(rotu_taxonomy0.str[-20:])
    .loc[focal_motu]
)

## All-by-all

In [None]:
motu_list = motu_rabund.columns
rotu_list = rotu_rabund.columns
print(len(motu_list), len(rotu_list))

reciprocal_hits_results = reciprocal_hits(
    motu_rabund,
    rotu_rabund,
    alpha=1,
    rank_thresh=2,
    corr_thresh=0,
).join(motu_taxonomy[['f__', 'g__', 's__']], on='species_id').join(rotu_taxonomy[['f__', 'g__']], lsuffix='mgen', rsuffix='zotu').sort_values('zotu')

plt.hist(reciprocal_hits_results['corr'], bins=np.linspace(0, 1))
plt.xlabel('Pearson Correlation')
plt.ylabel('Count')
plt.title('Distribution of Correlations Between Matches')
plt.savefig('fig/een_species_matching_histogram.pdf')


reciprocal_hits_results.sort_values('score')

In [None]:
reciprocal_hits_results.to_csv('fig/een_zotu_species_matching.tsv', sep='\t')

In [None]:
d = pd.DataFrame(dict(rotu=rotu_rabund[['Zotu1', 'Zotu11']].sum(1), motu=motu_rabund[['102478']].sum(1))).dropna()
print(sp.stats.pearsonr(d['motu'], d['rotu']))
plt.scatter('rotu', 'motu', data=d)

In [None]:
d = pd.DataFrame(dict(rotu=rotu_rabund[['Zotu4']].sum(1), motu=motu_rabund[['102506']].sum(1))).dropna()
print(sp.stats.pearsonr(d['motu'], d['rotu']))
plt.scatter('rotu', 'motu', data=d)

In [None]:
d = pd.DataFrame(dict(
    rotu=rotu_rabund[['Zotu5']].sum(1),
    motu_A=motu_rabund[['101386']].sum(1),
    motu_B=motu_rabund[['101493']].sum(1),
    motu_both=motu_rabund[['101493', '101386']].sum(1),

)).dropna()
print(sp.stats.pearsonr(d['motu_A'], d['rotu']))
print(sp.stats.pearsonr(d['motu_B'], d['rotu']))
print(sp.stats.pearsonr(d['motu_both'], d['rotu']))
plt.scatter('rotu', 'motu_A', data=d, label='C. clostridioforme')
plt.scatter('rotu', 'motu_B', data=d, label='C. bolteae')
plt.scatter('rotu', 'motu_both', data=d, label='Combined')
plt.plot([0, 0.5], [0, 0.5])
plt.xlabel('zOTU5 relative abundance')
plt.ylabel('Species relative abundance')
plt.yscale('log')
plt.xscale('log')
plt.legend()

plt.savefig('fig/een_zotu5_species_matching.pdf')

In [None]:
d = pd.DataFrame(dict(
    rotu_A=rotu_rabund[['Zotu79']].sum(1),
    rotu_B=rotu_rabund[['Zotu80']].sum(1),
    rotu_both=rotu_rabund[['Zotu79', 'Zotu80']].sum(1),
    motu_A=motu_rabund[['100107']].sum(1),

)).dropna()
print(sp.stats.pearsonr(d['motu_A'], d['rotu_A']))
print(sp.stats.pearsonr(d['motu_A'], d['rotu_B']))
print(sp.stats.pearsonr(d['motu_A'], d['rotu_both']))
plt.scatter('motu_A', 'rotu_A', data=d, label='Zotu79')
plt.scatter('motu_A', 'rotu_B', data=d, label='Zotu80')
plt.scatter('motu_A', 'rotu_both', data=d, label='Combined')
plt.plot([0, 0.5], [0, 0.5])
plt.ylabel('zOTU relative abundance')
plt.xlabel('Species 100107 relative abundance')
# plt.yscale('log')
# plt.xscale('log')
plt.legend()

# plt.savefig('fig/een_zotu5_species_matching.pdf')

In [None]:
d = pd.DataFrame(dict(
    rotu_A=rotu_rabund[['Zotu79']].sum(1),
    rotu_B=rotu_rabund[['Zotu80']].sum(1),
    rotu_both=rotu_rabund[['Zotu79', 'Zotu80']].sum(1),
    motu_A=motu_rabund[['100107']].sum(1),

)).dropna()
print(sp.stats.pearsonr(d['motu_A'], d['rotu_A']))
print(sp.stats.pearsonr(d['rotu_A'], d['rotu_B']))
plt.scatter('rotu_A', 'rotu_B', data=d)
plt.plot([0, 0.5], [0, 0.5])
plt.ylabel('Zotu80 relative abundance')
plt.xlabel('Zotu79 relative abundance')
# plt.yscale('log')
# plt.xscale('log')

# plt.savefig('fig/een_zotu5_species_matching.pdf')

## PLS-based

In [None]:
motu_list = motu_rabund.columns
rotu_list = rotu_rabund.columns
print(len(motu_list), len(rotu_list))

reciprocal_hits_results2 = reciprocal_hits(
    motu_rabund,
    rotu_rabund,
    cc_func=pls_pseudo_mapping,
    rank_thresh=2,
).join(motu_taxonomy[['f__', 'g__', 's__']], on='species_id').join(rotu_taxonomy[['f__', 'g__']], lsuffix='mgen', rsuffix='zotu').sort_values('zotu')


reciprocal_hits_results2.sort_values('zotu')

In [None]:
reciprocal_hits_results2.sort_values('corr').head(50)

In [None]:
reciprocal_hits_results2.xs('100032', level='species_id')

In [None]:
reciprocal_hits_results2.xs('Zotu62', level='zotu')

In [None]:
d = pd.DataFrame(dict(
    x=motu_rabund[['100076', '100217']].sum(1),
    y=rotu_rabund[['Zotu35']].sum(1),
)).dropna()
print(sp.stats.pearsonr(d['x'], d['y']))
plt.scatter('x', 'y', data=d)
plt.plot([0, 0.5], [0, 0.5])
# plt.yscale('log')
# plt.xscale('log')

# plt.savefig('fig/een_zotu5_species_matching.pdf')

## Non-negative PLS?

In [None]:
motu_list = motu_rabund.columns
rotu_list = rotu_rabund.columns
print(len(motu_list), len(rotu_list))

x, y = lib.pandas_util.align_indexes(motu_rabund, rotu_rabund)

sns.clustermap((x.T @ y), norm=mpl.colors.SymLogNorm(1e-5))

In [None]:
from sklearn.decomposition import non_negative_factorization
z = (x.T @ y)
w, h, n_iter = non_negative_factorization(z, n_components=max(*z.shape), verbose=2, solver='mu', alpha_W=0.1, alpha_H=0.0, l1_ratio=1.0)

In [None]:
w.shape

In [None]:
pd.DataFrame(w, 

In [None]:
sns.clustermap(w, norm=mpl.colors.SymLogNorm(1e-10))

## Family-wise

In [None]:
family = 'Lachnospiraceae'

motu_list = idxwhere(motu_taxonomy.f__.str.startswith(f'f__{family}'))
rotu_list = idxwhere(rotu_taxonomy.f__.str.startswith(f'{family}'))

print(len(motu_list), len(rotu_list))

In [None]:
rotu_taxonomy.f__.unique()

In [None]:
reciprocal_hits_results3 = []
# for family in ['Clostridiaceae']:
for family in rotu_taxonomy.f__.unique():
    if family == '':
        continue
    motu_list = idxwhere(motu_taxonomy.f__.str.startswith(f'f__{family}'))
    rotu_list = idxwhere(rotu_taxonomy.f__.str.startswith(f'{family}'))
    print(family, len(motu_list), len(rotu_list))
    reciprocal_hits_results3.append(reciprocal_hits(
        motu_rabund.reindex(columns=motu_list).dropna(axis='columns'),
        rotu_rabund.reindex(columns=rotu_list).dropna(axis='columns'),
        cc_func=cross_correlation,
        rank_thresh=2,
    ))
    
    
reciprocal_hits_results3 = pd.concat(reciprocal_hits_results3).join(motu_taxonomy[['f__', 'g__', 's__']], on='species_id').join(rotu_taxonomy[['f__', 'g__']], lsuffix='mgen', rsuffix='zotu').sort_values('zotu')
reciprocal_hits_results3.sort_values('species_id')

In [None]:
d = pd.DataFrame(dict(rotu=rotu_rabund[['Zotu10', 'Zotu15']].sum(1), motu=motu_rabund[['100032']].sum(1))).dropna()
print(sp.stats.pearsonr(d['motu'], d['rotu']))
plt.scatter('rotu', 'motu', data=d)
plt.plot([0, 0.5], [0, 0.5])