## Preamble

### Project Template

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import sys
import time
from datetime import datetime
from glob import glob
from itertools import chain, product
from tempfile import mkstemp

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import mpltern
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels.api as sm
import statsmodels.formula.api as smf
import xarray as xr
from matplotlib_venn import venn2
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.spatial.distance import pdist, squareform
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
from lib.dissimilarity import load_dmat_as_pickle
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
def linkage_order(linkage, labels):
    return labels[sp.cluster.hierarchy.to_tree(linkage).pre_order(lambda x: x.id)]


def plot_stacked_barplot(data, x_var, order, palette=None, ax=None, **kwargs):
    if ax is None:
        ax = plt.subplot()
    if palette is None:
        palette = lib.plot.construct_ordered_palette(order)

    # Bar styles
    bar_kwargs = dict(
        width=1.0,
        alpha=1.0,
        edgecolor="k",
        lw=1,
    )
    bar_kwargs.update(kwargs)

    # Plot each bar segment
    _last_top = 0
    for y_var in order:
        ax.bar(
            x=data[x_var],
            height=data[y_var],
            bottom=_last_top,
            label=y_var,
            color=palette[y_var],
            **bar_kwargs,
        )
        _last_top += data[y_var]
    ax.set_xticks(data[x_var].values)
    return ax


plot_stacked_barplot(
    pd.DataFrame(dict(t=[0, 1, 2], y1=[0.0, 0.5, 1.0], y2=[1.0, 0.5, 0.0])),
    x_var="t",
    order=["y1", "y2"],
)

In [None]:
import lib.thisproject.data

### Set Style

In [None]:
sns.set_context("talk")
plt.rcParams["figure.dpi"] = 100

## Metadata

In [None]:
mgen_list = list(
    pd.read_table("meta/mgen_group.tsv")[
        lambda x: x.mgen_group_id == "ucfmt"
    ].mgen_id.values
)
len(mgen_list)

In [None]:
mgen = pd.read_table("meta/ucfmt/mgen.tsv", index_col="mgen_id")
sample = pd.read_table("meta/ucfmt/sample.tsv", index_col="sample_id")
subject = pd.read_table("meta/ucfmt/subject.tsv", index_col="subject_id")
assert mgen.sample_id.isin(sample.index).all()

mgen_meta = mgen.join(sample, on="sample_id").join(subject, on="subject_id")

In [None]:
sample.loc[mgen_meta.loc[mgen_list].sample_id].sample_type.value_counts()

In [None]:
mgen_list = list(mgen_meta.index)

In [None]:
subject[lambda x: ~x.remission.isna() & x.recipient].donor_subject_id.value_counts()

In [None]:
d97_mgen_list = idxwhere(mgen_meta.subject_id == "D0097")
d44_mgen_list = idxwhere(mgen_meta.subject_id == "D0044")

(len(d97_mgen_list), len(d44_mgen_list))

### StrainPGC Post-filtering Strains

In [None]:
spgc_meta = pd.read_table(
    f"data/group/ucfmt/species/sp-102506/r.proc.gtpro.sfacts-fit.gene99_v20-v23-agg75.spgc-fit.strain_meta_spgc_and_ref.tsv",
    index_col="genome_id",
)
spgc_strain_list = spgc_meta[
    lambda x: x.passes_filter & x.genome_type.isin(["SPGC"])
].index.values  # .astype(str)
len(spgc_strain_list)

In [None]:
gene_x_eggnog = pd.read_table(
    "data/species/sp-102506/midasdb_v20.emapper.gene75_x_eggnog.tsv"
)

## Strain tracking

In [None]:
strain_fit = sf.data.World.load(
    "data/group/ucfmt/species/sp-102506/r.proc.gtpro.sfacts-fit.world.nc"
).rename_coords(strain=str)

np.random.seed(0)
position_ss = strain_fit.random_sample(position=1000).position
sample_linkage = strain_fit.unifrac_linkage(optimal_ordering=True)
sample_linkage_alt = strain_fit.metagenotype.linkage(optimal_ordering=True)
strain_linkage = strain_fit.genotype.linkage(optimal_ordering=True)
position_linkage = strain_fit.metagenotype.sel(position=position_ss).linkage("position")

print(sf.evaluation.metagenotype_error2(strain_fit, discretized=True)[0])

### Visualize strain tracking

In [None]:
sf.plot.plot_metagenotype(strain_fit, col_linkage_func=lambda w: sample_linkage_alt)
sf.plot.plot_community(strain_fit, col_linkage_func=lambda w: sample_linkage_alt)

In [None]:
w_all = strain_fit.drop_low_abundance_strains(0.01)
_mgen_list = list(set(mgen_list) & set(w_all.sample.values))

w = w_all.sel(sample=_mgen_list).drop_low_abundance_strains(0.01)

sample_linkage = w.unifrac_linkage()
w_ss = w.random_sample(position=min(500, w.sizes["position"]))

d = (
    w.community.to_series()[lambda x: x > 0.1]
    .to_frame()
    .reset_index()
    .rename(columns=dict(sample="mgen_id"))
    .join(mgen_meta, on="mgen_id")
    .assign(
        sample_class=lambda x: x.sample_type.replace(
            {
                # Padding and renaming
                "baseline": "baseline",
                "donor": "    donor",
                "maintenance": "     other",
                "followup": "     other",
                "post_antibiotic": "     other",
            }
        )
    )
    .groupby(["donor_subject_id", "subject_id", "sample_class"])
    .strain.value_counts()
    .unstack("strain", fill_value=0)
    .sort_index(ascending=[True, True, False])
)

strain_order = d.sum().sort_values(ascending=False).index
d = d.loc[:, strain_order]

nrow, ncol = d.shape
fig, ax = plt.subplots(figsize=(0.4 * ncol + 0.5, 0.5 * nrow + 1))
sns.heatmap(d, norm=mpl.colors.PowerNorm(1 / 3), annot=True, cbar=False, ax=ax)

In [None]:
spgc_meta[lambda x: x.genome_type.isin(["SPGC"])].dropna(axis="columns")

In [None]:
ucfmt_strains = list(
    spgc_meta[lambda x: x.genome_type.isin(["SPGC"]) & x.passes_filter].index
)

In [None]:
len(ucfmt_strains)

In [None]:
focal_strains = [
    "6",
    "9",  #'33',
    # '38',  # Unfortunately low-quality gene content inferences. :-/
]  # Defined based on dominant, donor strains in followup samples.

In [None]:
strain_order = [
    s
    for s in linkage_order(
        strain_linkage,
        strain_fit.strain.values,
    )
]
# ) if s in spgc_strain_list]

if "-1" in strain_order:
    strain_order.remove("-1")  # Drop "other" strain.
strain_order.append("-1")  # Add to end of list

strain_palette = lib.plot.construct_ordered_palette(
    strain_order,  # Linkage order, I believe
    cm="turbo",
    vmin=0.05,
    vmax=0.95,
    extend={"-1": "silver"},
    desaturate_levels=[1.0],
)

# Desaturate non-focal strains
for strain in strain_palette:
    if strain == "-1":
        continue
    if strain not in focal_strains:
        strain_palette[strain] = sns.set_hls_values(
            strain_palette[strain], l=0.35, s=0.7
        )
    else:
        strain_palette[strain] = sns.set_hls_values(
            strain_palette[strain], l=0.75, s=1.0
        )

In [None]:
strain_colors = (
    pd.Series(strain_order, index=strain_order).map(strain_palette).to_frame()
)

sf.plot.plot_metagenotype(
    strain_fit.sel(position=position_ss), col_linkage=sample_linkage_alt
)
sf.plot.plot_community(
    strain_fit.sel(position=position_ss),
    col_linkage=sample_linkage_alt,
    row_linkage=strain_linkage,
    row_colors=strain_colors,
)

In [None]:
sf.plot.plot_genotype(
    strain_fit.sel(position=position_ss),
    row_linkage=strain_linkage,
    row_colors=strain_colors,
)

In [None]:
strain_samples = pd.read_table(
    "data/group/ucfmt/species/sp-102506/r.proc.gtpro.sfacts-fit.spgc_ss-all.strain_samples.tsv",
    names=["sample", "strain"],
    dtype=str,
)
strain_samples[lambda x: x.strain.isin(focal_strains)].sort_values(
    "strain"
).strain.value_counts()

In [None]:
focal_strain_samples = idxwhere(
    (
        strain_fit.community.data.sel(strain=focal_strains).max("strain") > 0.75
    ).to_series()
)
w = strain_fit.sel(
    position=position_ss, sample=focal_strain_samples
).drop_low_abundance_strains(0.001)
sf.plot.plot_community(w, col_linkage_func=lambda w: w.metagenotype.linkage())
sf.plot.plot_metagenotype(
    w,
    row_linkage_func=lambda w: position_linkage,
    col_linkage_func=lambda w: w.metagenotype.linkage(),
)

In [None]:
sample_type_specific_relabel = {
    "donor_enema": "D",
    "donor_initial": "D",
    "donor_capsule": "D",
    "baseline": "B",
    "post_antibiotic": "pA",
    "pre_maintenance_1": "M1",
    "pre_maintenance_2": "M2",
    "pre_maintenance_3": "M3",
    "pre_maintenance_4": "M4",
    "pre_maintenance_5": "M5",
    "pre_maintenance_6": "M6",
    "followup_1": "F1",
    "followup_2": "F2",
    "followup_3": "F3",
}

In [None]:
for strain, c in strain_colors.squeeze().items():
    plt.scatter([], [], c=c, label=strain)
plt.legend(ncols=3)
lib.plot.hide_axes_and_spines()

#### Figure 5A

In [None]:
_meta = mgen_meta
subject_order = ["D0044", "D0097", "D0485"]

fig, axs = plt.subplots(
    1, len(subject_order), figsize=(5 * len(subject_order), 2), squeeze=False
)
for subject_id, ax in zip(subject_order, axs.flatten()):
    sample_list = (
        _meta[
            lambda x: (x.subject_id == subject_id)
            & (x.index.isin(strain_fit.sample.values))
        ]
        # .sort_values('collection_days_post_fmt')  # This will be useful for subjects, but not donors
        .index
    )

    if len(sample_list) < 2:
        sample_order = sample_list
    else:
        sample_order = list(
            linkage_order(
                strain_fit.sel(sample=sample_list).metagenotype.linkage(
                    optimal_ordering=True
                ),
                sample_list,
            )
        )

    if len(sample_list) < 1:
        subject_comm = pd.DataFrame([], columns=[-1])
    else:
        subject_comm = (
            strain_fit.sel(sample=sample_order).keep_only_strain_list(strain_order)
            # .drop_low_abundance_strains(
            #     0.0, agg_strain_coord=-1
            # )  # TODO: Check that this adds to the alread-existing strain_-1
            .community.to_pandas()
        )

    d = (
        _meta.reindex(sample_order)
        # .dropna(subset=["collection_date_relative_een_end"])
        # .sort_values("collection_date_relative_een_end")
        .assign(
            t=lambda x: range(len(x)),
        )
    ).join(subject_comm)
    # d.loc[d.index[:num_offset_samples], 't'] -= 0.7  # Offset width

    plot_stacked_barplot(
        data=d,
        x_var="t",
        order=[s for s in strain_order if s in subject_comm.columns],
        palette=strain_palette,
        ax=ax,
        width=0.8,
        lw=0.5,
    )

    ax.set_title(subject_id)
    ax.set_xticklabels(
        d.sample_type_specific.map(sample_type_specific_relabel),
        fontsize=12,
    )
    ax.set_aspect(9, anchor="NW")
    ax.set_ylim(0, 1.0)
    lib.plot.rotate_xticklabels(rotation=90, ax=ax, ha="center")
    ax.set_yticks(np.linspace(0, 1.0, num=3))
    ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, symbol="%"))
    if not d.empty:
        ax.set_xlim(d.t.min() - 0.5, d.t.max() + 0.5)
    ax.spines[["right", "top"]].set_visible(False)

fig.savefig('fig/fig5a_donor_panels.pdf', bbox_inches='tight')

In [None]:
donor_subject_order = ["D0044", "D0097", "D0485"]
_meta = mgen_meta.loc[strain_fit.sample.values][
    lambda x: x.recipient
    # & ~x.sra_accession.isna()
]

for donor in donor_subject_order:
    subject_order = (
        _meta[lambda x: x.donor_subject_id == donor].subject_id.value_counts().index
    )
    fig, axs = plt.subplots(
        1,
        len(subject_order),
        figsize=(3 * len(subject_order), 2),
        squeeze=False,
        sharey=True,
    )
    for subject_id, ax in zip(subject_order, axs.flatten()):
        sample_order = (
            _meta[
                lambda x: (x.subject_id == subject_id)
                & (x.index.isin(strain_fit.sample.values))
            ]
            .sort_values("collection_days_post_fmt")
            .index
        )

        if len(sample_order) < 1:
            subject_comm = pd.DataFrame([], columns=[-1])
        else:
            subject_comm = (
                strain_fit.sel(sample=sample_order).keep_only_strain_list(strain_order)
                # .drop_low_abundance_strains(
                #     0.0, agg_strain_coord=-1
                # )  # TODO: Check that this adds to the alread-existing strain_-1
                .community.to_pandas()
            )

        d = (
            _meta.reindex(sample_order)
            # .dropna(subset=["collection_date_relative_een_end"])
            # .sort_values("collection_date_relative_een_end")
            .assign(
                t=lambda x: range(len(x)),
            )
        ).join(subject_comm)
        # d.loc[d.index[:num_offset_samples], 't'] -= 0.7  # Offset width

        plot_stacked_barplot(
            data=d,
            x_var="t",
            order=[s for s in strain_order if s in subject_comm.columns],
            palette=strain_palette,
            ax=ax,
            width=0.8,
            lw=0.5,
        )

        ax.set_title(subject_id)
        ax.set_xticklabels(
            d.sample_type_specific.map(sample_type_specific_relabel),
            fontsize=12,
        )
        ax.set_aspect(10, anchor="NW")
        ax.set_ylim(0, 1.0)
        lib.plot.rotate_xticklabels(rotation=90, ax=ax, ha="center")
        ax.set_yticks(np.linspace(0, 1.0, num=3))
        ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, symbol="%"))
        if not d.empty:
            ax.set_xlim(d.t.min() - 0.5, d.t.max() + 0.5)
        ax.spines[["right", "top"]].set_visible(False)
    fig.savefig(f'fig/fig5a_recipient_{donor}_panels.pdf', bbox_inches='tight')

## Dominant strains from donors / recipients

In [None]:
spgc_meta[lambda x: x.passes_filter & (x.genome_type == "SPGC")]

In [None]:
spgc_meta.loc[np.array(focal_strains).astype(str)]

## Compare to reference database

In [None]:
pd.read_table("ref/midasdb_uhgg_v20_all/metadata/genomes-all_metadata.tsv").columns

In [None]:
ref_meta = pd.read_table(
    "ref/midasdb_uhgg_v20_all/metadata/genomes-all_metadata.tsv",
).set_index("New_Genome_accession")
ref_meta

In [None]:
geno_pdmat = lib.dissimilarity.load_dmat_as_pickle(
    "data/group/ucfmt/species/sp-102506/r.proc.gtpro.sfacts-fit.spgc_ss-all.geno_uhgg-v20_pdist-mask10-pseudo10.pkl"
)
d0 = geno_pdmat.loc[spgc_meta.genome_type.isin(["Isolate"]), focal_strains]
d1 = pd.DataFrame(dict(
    min_diss=d0.min(), idxmin_isolate=d0.idxmin()
))

d1.join(ref_meta, on='idxmin_isolate')

In [None]:
spgc_meta.loc[['6', '9', 'GUT_GENOME288864', 'GUT_GENOME140932']]
# TODO: Figure. out

In [None]:
midas_mgtp_inpath = f"data/species/sp-102506/midasdb_v15.gtpro.mgtp.nc"
midas_mgtp_geno = (
    sf.Metagenotype.load(midas_mgtp_inpath).to_estimated_genotype()
)

In [None]:
ucfmt_infer_inpath = f"data/group/ucfmt/species/sp-102506/r.proc.gtpro.sfacts-fit.spgc_ss-all.mgtp.nc"
ucfmt_infer_geno = (
    sf.Metagenotype.load(ucfmt_infer_inpath).to_estimated_genotype()
)

In [None]:
ambiguity_threshold = 0.1

g_ucfmt = ucfmt_infer_geno.discretized(max_ambiguity=ambiguity_threshold)
g_midas = midas_mgtp_geno.discretized(max_ambiguity=ambiguity_threshold)

geno_ucfmt_and_midas = sf.data.Genotype.concat(
    {
        "ucfmt": g_ucfmt,
        "midas": g_midas,
    },
    dim="strain",
    # rename=False,
)

In [None]:
_strain_list = ['ucfmt_6', 'midas_GUT_GENOME288864']

assert len(_strain_list) == 2
g = geno_ucfmt_and_midas.sel(strain=_strain_list).data
shared_positions = idxwhere(~g.pipe(np.isnan).any("strain").to_series())

print(
    (g.sel(strain=_strain_list[0]) != g.sel(strain=_strain_list[1])).sel(position=shared_positions).sum(),
    len(shared_positions)
)

In [None]:
_strain_list = ['ucfmt_9', 'midas_GUT_GENOME140932']

assert len(_strain_list) == 2
g = geno_ucfmt_and_midas.sel(strain=_strain_list).data
shared_positions = idxwhere(~g.pipe(np.isnan).any("strain").to_series())

print(
    (g.sel(strain=_strain_list[0]) != g.sel(strain=_strain_list[1])).sel(position=shared_positions).sum(),
    len(shared_positions)
)

In [None]:
dmat = geno_pdmat.loc[spgc_meta.genome_type.isin(["Isolate", "MAG"]), spgc_meta.genome_type.isin(["Isolate", "MAG"])]
dmat.values[np.diag_indices_from(dmat)] = 1
d = dmat.min()


linthresh = 1e-5
bins = [0] + list(np.logspace(np.log10(linthresh), 0, num=50))
plt.hist(d, bins=bins)
plt.xscale('symlog', linthresh=linthresh, linscale=0.1)
plt.xlabel('Closest Match Genotype Dissimilarity')

In [None]:
(d > 0.030).mean(), (d > 0.077).mean()

## Gene Content Comparison

### Gene Annotations

In [None]:
gene_content_uhgg = pd.read_table(
    "data/group/ucfmt/species/sp-102506/r.proc.gtpro.sfacts-fit.gene99_v20-v23-agg75.spgc-fit.uhgg-strain_gene.tsv",
    index_col="gene_id",
).astype(bool)
gene_content_uhgg = gene_content_uhgg.drop(idxwhere(gene_content_uhgg.sum(1) == 0))

In [None]:
gene_content_uhgg[focal_strains][lambda x: x.sum(1) > 0].value_counts().sort_index()

In [None]:
gene_prevalence_ref_uhgg = pd.read_table(
    "data/species/sp-102506/midasdb.gene75_v20.uhgg-strain_gene.ref_prevalence.tsv",
    names=["gene_id", "prevalence"],
    index_col="gene_id",
).prevalence
gene_prevalence_hmp2_uhgg = pd.read_table(
    "data/group/hmp2/species/sp-102506/r.proc.gtpro.sfacts-fit.gene99_v20-v23-agg75.spgc-fit.uhgg-strain_gene.prevalence.tsv",
    names=["gene_id", "prevalence"],
    index_col="gene_id",
).prevalence

In [None]:
d0 = (
    gene_prevalence_hmp2_uhgg.to_frame()
    .join(gene_content_uhgg)
    .groupby(["6", "9"])
    .prevalence
)

fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)
bins = np.linspace(0, 1, num=11)
for (focal_strain_gene_type, (k, c)), ax in zip(
    {
        "neither": ((False, False), "silver"),
        "D44": ((True, False), strain_palette["6"]),
        "D97": ((False, True), strain_palette["9"]),
        "both": ((True, True), "black"),
    }.items(),
    axs.T.flatten(),
):
    d1 = d0.get_group(k)
    ax.set_title(focal_strain_gene_type)
    ax.hist(
        d1,
        bins=bins,
        density=True,
        label=focal_strain_gene_type,
        color=c,
        alpha=0.5,
        histtype="stepfilled",
    )
    ax.hist(
        d1,
        bins=bins,
        density=True,
        label=focal_strain_gene_type,
        color=c,
        alpha=0.85,
        histtype="step",
    )
    if not focal_strain_gene_type == "neither":
        ax.annotate(len(d1), xy=(0.5, 0.5), xycoords="axes fraction")

fig.tight_layout()

In [None]:
gene_clust = pd.read_table(
    "data/group/hmp2/species/sp-102506/r.proc.gtpro.sfacts-fit.gene99_v20-v23-agg75.spgc-fit.uhgg-strain_gene.gene_clust-t10.tsv",
    names=["gene_id", "clust"],
    index_col="gene_id",
).clust
gene_clust.value_counts()

In [None]:
gene_annotation = pd.read_table(
    "data/species/sp-102506/midasdb_v20.gene75_meta.tsv", index_col="centroid_75"
)

In [None]:
gene_x_cog = pd.read_table(
    "data/species/sp-102506/midasdb_v20.emapper.gene75_x_cog.tsv"
).drop_duplicates()
cog_x_cog_pathway = pd.read_table(
    "ref/cog-20.meta.tsv",
    encoding="latin1",
    names=[
        "cog",
        "cog_categories",
        "description",
        "preferred_name",
        "cog_pathway",
        "_5",
        "color_hex",
    ],
    index_col="cog",
).cog_pathway

In [None]:
hmp2_strain_meta = pd.read_table(
    "data/group/hmp2/species/sp-102506/r.proc.gtpro.sfacts-fit.gene99_v20-v23-agg75.spgc-fit.strain_meta_spgc_and_ref.tsv",
    index_col="genome_id",
)
hmp2_strain_list = idxwhere(
    (hmp2_strain_meta.genome_type == "SPGC") & hmp2_strain_meta.passes_filter
)
hmp2_strain_rename_list = [f"hmp2_{s}" for s in hmp2_strain_list]
len(hmp2_strain_list)

In [None]:
hmp2_geno = (
    sf.Metagenotype.load(
        "data/group/hmp2/species/sp-102506/r.proc.gtpro.sfacts-fit.spgc_ss-all.mgtp.nc"
    )
    .rename_coords(sample=str)
    .sel(sample=hmp2_strain_list)
    .to_estimated_genotype()
    .discretized(max_ambiguity=0.1)
    .rename_coords(strain=lambda s: f"hmp2_{s}")
)
ucfmt_geno = (
    sf.Metagenotype.load(
        "data/group/ucfmt/species/sp-102506/r.proc.gtpro.sfacts-fit.spgc_ss-all.mgtp.nc"
    )
    .rename_coords(sample=str)
    .to_estimated_genotype()
    .discretized(max_ambiguity=0.1)
)

combined_geno = sf.Genotype.concat(
    dict(hmp2=hmp2_geno, ucfmt=ucfmt_geno.sel(strain=ucfmt_strains)),
    dim="strain",
    rename=False,
)

In [None]:
from scripts.calculate_pairwise_genotype_masked_hamming_distance import (
    native_masked_hamming_distance_pdist,
)

combined_mgtp_geno_pdist = pd.DataFrame(
    squareform(native_masked_hamming_distance_pdist(combined_geno.values, pseudo=1)),
    index=combined_geno.strain,
    columns=combined_geno.strain,
)
combined_mgtp_geno_linkage = sp.cluster.hierarchy.linkage(
    squareform(combined_mgtp_geno_pdist), method="average", optimal_ordering=True
)

In [None]:
sf.plot.plot_genotype(
    combined_geno.sel(position=position_ss),
    row_linkage_func=lambda w: combined_mgtp_geno_linkage,
    col_linkage_func=lambda w: position_linkage,
    row_colors=strain_colors,
)

In [None]:
gene_content_hmp2_uhgg = (
    pd.read_table(
        "data/group/hmp2/species/sp-102506/r.proc.gtpro.sfacts-fit.gene99_v20-v23-agg75.spgc-fit.uhgg-strain_gene.tsv",
        index_col="gene_id",
    )[hmp2_strain_list]
    .astype(bool)
    .rename(columns=lambda s: f"hmp2_{s}")
)
gene_content_ucfmt_uhgg = gene_content_uhgg[
    ucfmt_strains
]  # pd.read_table('data/group/ucfmt/species/sp-102506/r.proc.gtpro.sfacts-fit.gene99_v20-v23-agg75.spgc-fit.uhgg-strain_gene.tsv', index_col='gene_id').astype(bool)
gene_content_combined = pd.concat(
    [gene_content_hmp2_uhgg, gene_content_ucfmt_uhgg], axis="columns"
).fillna(False)

#### Figure 5B

In [None]:
d = gene_content_combined[
    lambda x: (gene_prevalence_hmp2_uhgg > 0.15)
    & (gene_prevalence_hmp2_uhgg < 0.9)
    & (gene_content_combined.mean(1) > 0)
]

cg = sns.clustermap(
    d,
    col_linkage=combined_mgtp_geno_linkage,
    metric="cosine",
    cmap="Grays",
    xticklabels=False,
    yticklabels=False,
    figsize=(10, 5),
    tree_kws=dict(lw=1),
    dendrogram_ratio=0.4,
    rasterized=True,
)
cg.ax_heatmap.set_ylabel("")
cg.ax_row_dendrogram.set_visible(False)
cg.ax_cbar.set_visible(False)

d.shape

plt.savefig('fig/fig5b_heatmap.pdf', bbox_inches='tight')

In [None]:
gene_content_combined[
    lambda x: (gene_prevalence_hmp2_uhgg > 0.15)
    & (gene_prevalence_hmp2_uhgg < 0.9)
    & (gene_content_combined.mean(1) > 0)
].dropna(axis="columns").dropna(axis="rows").shape

In [None]:
# Helper figure to place pointers.

d = gene_content_combined[
    lambda x: (gene_prevalence_hmp2_uhgg > 0.15)
    & (gene_prevalence_hmp2_uhgg < 0.9)
    & (gene_content_combined.mean(1) > 0)
]

cg = sns.clustermap(
    d,
    col_linkage=combined_mgtp_geno_linkage,
    col_colors=strain_colors,
    metric="cosine",
    xticklabels=False,
    yticklabels=False,
    figsize=(10, 4),
    tree_kws=dict(lw=1),
    dendrogram_ratio=0.3,
)
cg.ax_row_dendrogram.set_visible(False)
cg.ax_cbar.set_visible(False)

In [None]:
combined_gene_content_pdist = lib.dissimilarity.dmatrix(
    gene_content_combined[
        lambda x: (gene_prevalence_hmp2_uhgg > 0.15) & (gene_prevalence_hmp2_uhgg < 0.9)
    ].T,
    metric="cosine",
)

_strain_list = combined_gene_content_pdist.index

strain_pair_coded = lib.dissimilarity.dmatrix(
    _strain_list.to_frame().isin(_strain_list).replace({True: 2, False: 1}),
    metric=lambda x, y: x * y,
)
print(strain_pair_coded.stack().value_counts())

d = pd.DataFrame(
    dict(
        geno_diss=squareform(combined_mgtp_geno_pdist.loc[_strain_list, _strain_list]),
        gene_diss=squareform(
            combined_gene_content_pdist.loc[_strain_list, _strain_list]
        ),
        pair_type_coded=squareform(strain_pair_coded.loc[_strain_list, _strain_list]),
    )
).assign(
    pair_type=lambda x: x.pair_type_coded.map(
        {0.0: "self", 1.0: "hmp2", 2.0: "inter", 4.0: "ucfmt"}
    )
)
print(sp.stats.pearsonr(d["geno_diss"], d["gene_diss"]))

# Plot hmp2 and inter-set pairs
plt.scatter(
    "geno_diss",
    "gene_diss",
    data=d[lambda x: x.pair_type.isin(["hmp2", "inter"])],
    s=30,
    facecolor="none",
    edgecolor="darkgrey",
    label="Other pairs",
)

# Plot UCFMT pairs
plt.scatter(
    "geno_diss",
    "gene_diss",
    data=d[lambda x: x.pair_type.isin(["ucfmt"])],
    s=30,
    facecolor="none",
    edgecolor="tab:blue",
    label="UCFMT pairs",
)


# ucfmt-by-ucfmt pairs# Highlight
# _x = squareform(combined_mgtp_geno_pdist.loc[spgc_strain_list, spgc_strain_list])
# _y = squareform(combined_gene_content_pdist.loc[spgc_strain_list, spgc_strain_list])
# plt.scatter(_x, _y, s=5)

# Highlight focal strain comparison#
_x = squareform(combined_mgtp_geno_pdist.loc[focal_strains, focal_strains])
_y = squareform(combined_gene_content_pdist.loc[focal_strains, focal_strains])
plt.scatter(_x, _y, s=40, c="red", label="9 and 6")

plt.legend()
plt.xlabel("Genotype Dissimilarity", fontsize=14)
plt.ylabel("Shell Gene Content Dissimilarity", fontsize=14)

In [None]:
plt.hist(
    squareform(combined_mgtp_geno_pdist.loc[spgc_strain_list, spgc_strain_list]),
    bins=20,
)
plt.axvline(
    squareform(combined_mgtp_geno_pdist.loc[focal_strains, focal_strains]), color="r"
)
print(squareform(combined_mgtp_geno_pdist.loc[focal_strains, focal_strains]))
print(
    (
        squareform(combined_mgtp_geno_pdist.loc[spgc_strain_list, spgc_strain_list])
        < squareform(combined_mgtp_geno_pdist.loc[focal_strains, focal_strains])
    ).mean()
)
np.quantile(
    squareform(combined_mgtp_geno_pdist.loc[spgc_strain_list, spgc_strain_list]),
    [0.25, 0.5, 0.75],
)

In [None]:
sns.clustermap(
    combined_mgtp_geno_pdist,
    row_linkage=combined_mgtp_geno_linkage,
    col_linkage=combined_mgtp_geno_linkage,
    col_colors=strain_colors,
    row_colors=strain_colors,
)

In [None]:
d0 = gene_content_combined[
    lambda x: (gene_prevalence_hmp2_uhgg > 0.15) & (gene_prevalence_hmp2_uhgg < 0.9)
]

pdist = lib.dissimilarity.dmatrix(d0.T, metric="jaccard")

sns.clustermap(
    pdist,
    row_linkage=combined_mgtp_geno_linkage,
    col_linkage=combined_mgtp_geno_linkage,
    col_colors=strain_colors,
    row_colors=strain_colors,
)

In [None]:
gene_x_cog_category_matrix = (
    pd.read_table(
        "data/species/sp-102506/midasdb_v20.emapper.gene75_x_cog_category.tsv"
    )
    .assign(flag=True)
    .set_index(["centroid_75", "cog_category"])
    .flag.unstack("cog_category", fill_value=False)
)

#### Figure 5C

In [None]:
d = gene_content_ucfmt_uhgg

v = venn2(
    [set(idxwhere(d["6"])), set(idxwhere(d["9"]))],
    set_labels=["", ""],
    set_colors=[strain_palette["6"], strain_palette["9"]],
)

for text in v.subset_labels:
    text.set_fontsize(30)

plt.savefig('fig/fig5c_venn.pdf', bbox_inches='tight')

#### Figure 5D

In [None]:
# Figure 5D

drop_clusters = set(idxwhere(gene_clust.value_counts() < 2) + [-4, -3, -2, -1])
d = (gene_content_ucfmt_uhgg.groupby(gene_clust).mean() >= 0.75).drop(
    drop_clusters, errors="ignore"
)

v = venn2(
    [set(idxwhere(d["6"])), set(idxwhere(d["9"]))],
    set_labels=["", ""],
    set_colors=[strain_palette["6"], strain_palette["9"]],
)

for text in v.subset_labels:
    text.set_fontsize(30)

plt.savefig('fig/fig5d_venn.pdf', bbox_inches='tight')

### Focal Strain Shared Gene Content

In [None]:
_strain_status = pd.DataFrame(
    dict(
        shared=gene_content_combined[focal_strains].all(1),
        ucfmt_prevalence=gene_content_combined[ucfmt_strains].mean(1),
        hmp2_prevalence=gene_content_combined[hmp2_strain_rename_list].mean(1),
    )
)[lambda x: x.ucfmt_prevalence > 0]

_strain_status

In [None]:
cluster_content = (gene_content_combined.groupby(gene_clust).mean() >= 0.75).drop(
    drop_clusters, errors="ignore"
)

_strain_status = pd.DataFrame(
    dict(
        shared=cluster_content[focal_strains].all(1),
        ucfmt_prevalence=cluster_content[ucfmt_strains].mean(1),
        hmp2_prevalence=cluster_content[hmp2_strain_rename_list].mean(1),
        cluster_size=gene_clust.value_counts(),
    )
)[lambda x: x.ucfmt_prevalence > 0]

shared_low_prevalence_clusters = idxwhere(
    (_strain_status.ucfmt_prevalence < 5 / 18) & (_strain_status.shared)
)

_strain_status.loc[shared_low_prevalence_clusters].sort_values(
    ["cluster_size", "ucfmt_prevalence"], ascending=False
)

In [None]:
clust = shared_low_prevalence_clusters
_clust_annotations = (
    gene_annotation.loc[idxwhere(gene_clust.isin(clust))]
    .assign(ucfmt_prevalence=gene_content_combined[ucfmt_strains].mean(1))
    .join(gene_clust)
    .join(gene_content_uhgg[focal_strains])
    .sort_values(["clust", "Preferred_name"])
)

# _clust_annotations[['eggNOG_OGs', 'COG_category', 'Description', 'Preferred_name', 'PFAMs', 'ucfmt_prevalence', 'clust', '6', '9']]

gene_x_cog_category_matrix.reindex(_clust_annotations.index).fillna(False).assign(
    no_category=lambda d: (d.drop(columns=["no_category"]).sum(1) == 0)
).sum()[lambda x: x > 0]

In [None]:
gene_x_cog_category_matrix.reindex(_clust_annotations.index).fillna(False).assign(
    no_category=lambda d: (d.drop(columns=["no_category"]).sum(1) == 0)
).groupby(gene_clust).sum().loc[:, lambda x: x.sum() > 0]

In [None]:
clust = [624]
_clust_annotations = (
    gene_annotation.loc[idxwhere(gene_clust.isin(clust))]
    .assign(ucfmt_prevalence=gene_content_combined[ucfmt_strains].mean(1))
    .join(gene_clust)
    .join(gene_content_uhgg[focal_strains])
    .sort_values("Preferred_name")
)

_clust_annotations[
    [
        "eggNOG_OGs",
        "COG_category",
        "Description",
        "Preferred_name",
        "PFAMs",
        "ucfmt_prevalence",
        "clust",
        "6",
        "9",
    ]
]

### Focal Strain Differences

In [None]:
gene_x_amr = pd.read_table("data/species/sp-102506/midasdb_v20.gene75_x_amr.tsv")
amr_gene_list = gene_x_amr.centroid_75.unique()
len(amr_gene_list)

In [None]:
gene_content_ucfmt_uhgg.reindex(
    index=amr_gene_list, columns=focal_strains, fill_value=False
).sum()

In [None]:
gene_content_ucfmt_uhgg.reindex(
    index=amr_gene_list, columns=focal_strains, fill_value=False
)[lambda x: x.sum(1) > 0]

In [None]:
focal_strain_amr_hits = gene_x_amr[
    lambda x: x.centroid_75.isin(
        idxwhere(gene_content_ucfmt_uhgg[focal_strains].any(axis=1))
    )
]
focal_strain_amr_list = focal_strain_amr_hits.accession_no.values
focal_strain_amr_hits.sort_values("centroid_75")

In [None]:
resfinder = pd.read_table(
    "ref/midasdb_uhgg_v20/pangenomes/102506/annotation/resfinder.tsv"
)[["resistance_gene", "phenotype", "accession_no"]].drop_duplicates()
resfinder

In [None]:
resfinder[lambda x: x.accession_no.isin(focal_strain_amr_list)].sort_values(
    "accession_no"
)

In [None]:
resfinder[lambda x: x.accession_no.isin(focal_strain_amr_list)].sort_values(
    "accession_no"
).phenotype.str.split(", ").explode().value_counts()

In [None]:
d = gene_content_combined.loc[amr_gene_list][lambda x: x.sum(1) > 0]

cg = sns.clustermap(
    d,
    col_linkage=combined_mgtp_geno_linkage,
    col_colors=strain_colors,
    metric="cosine",
    xticklabels=False,
    yticklabels=False,
    figsize=(10, 4),
    tree_kws=dict(lw=1),
    dendrogram_ratio=0.3,
)
cg.ax_row_dendrogram.set_visible(False)
cg.ax_cbar.set_visible(False)

In [None]:
gene_content_combined

In [None]:
drop_clusters = set(idxwhere(gene_clust.value_counts() < 2) + [-4, -3, -2, -1])
clust_content_combined = (
    gene_content_combined.reindex(gene_clust.index, fill_value=False)
    .groupby(gene_clust)
    .mean()
    # >= 0.75
).drop(drop_clusters, errors="ignore")
d = clust_content_combined

cg = sns.clustermap(
    d,
    col_linkage=combined_mgtp_geno_linkage,
    col_colors=strain_colors,
    metric="cosine",
    xticklabels=False,
    yticklabels=False,
    figsize=(10, 4),
    tree_kws=dict(lw=1),
    dendrogram_ratio=0.3,
)
cg.ax_row_dendrogram.set_visible(False)
cg.ax_cbar.set_visible(False)

## Comprehensive Results Table (Supplementary Table 2)

In [None]:
d = (
    gene_annotation[["nlength", "Description", "Preferred_name"]]
    .reindex(idxwhere(gene_content_uhgg[focal_strains].any(axis=1)))
    .join(
        gene_content_combined[focal_strains].rename(
            columns=lambda s: f"{s}_gene_present"
        )
    )
    .assign(
        eggnog_og=gene_x_eggnog.groupby("centroid_75").eggnog.apply(
            lambda x: ",".join(x)
        ),
        gene_clust=gene_clust.replace({-4: np.nan, -3: np.nan, -2: np.nan, -1: np.nan}),
        amr_hit=gene_x_amr.groupby("centroid_75").accession_no.apply(
            lambda s: ",".join(s)
        ),
        cog_category=gene_x_cog_category_matrix.stack()
        .rename("is_category")
        .reset_index()[lambda x: x.is_category & (x.cog_category != "no_category")]
        .groupby("centroid_75")
        .cog_category.apply(lambda s: ",".join(s)),
    )
    .join(
        gene_content_combined[focal_strains]
        .reindex(gene_clust.index, fill_value=False)
        .groupby(gene_clust)
        .mean()
        .rename(columns=lambda s: f"{s}_clust_frac")
        .drop([-4, -3, -2, -1]),
        on="gene_clust",
    )
)
d.to_csv("fig/ucfmt_focal_strain_genes_supplementary_table2.tsv", sep='\t')
d

In [None]:
d = gene_annotation.join(gene_clust).join(gene_content_uhgg[focal_strains])

d

## Strain-9

In [None]:
d = clust_content_combined > 0.75

strain9_only_clust_list = idxwhere(
    (~d["6"] & d["9"])
)
gene_clust.value_counts().to_frame("total_genes").assign(
    cat=gene_x_cog_category_matrix.groupby(gene_clust)
    .sum()
    .apply(lambda x: x[lambda x: x > 0].sort_values(ascending=False).to_dict(), axis=1)
).loc[strain9_only_clust_list].sort_values("total_genes", ascending=False).head(20)

### Clust-37

In [None]:
clust = [37]
_clust_annotations = (
    gene_annotation.loc[idxwhere(gene_clust.isin(clust))]
    .join(gene_clust)
    .join(
        gene_content_uhgg[focal_strains].reindex(
            idxwhere(gene_clust.isin(clust)), fill_value=False
        )
    )
    .sort_values("Preferred_name")
)
print(gene_clust[lambda x: x.isin(clust)].value_counts())
print(
    gene_content_uhgg.reindex(_clust_annotations.index)[focal_strains]
    .value_counts()
    .unstack(fill_value=0)
)
print(
    gene_x_cog_category_matrix.reindex(idxwhere(gene_clust.isin(clust)))
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
)
# _clust_annotations

In [None]:
for gene_id, x in _clust_annotations.iterrows():
    print(
        gene_id,
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
        idxwhere(x[focal_strains]),
    )
    print()

#### type iv / pil / pilin

In [None]:
_related_annotations = (
    gene_annotation[
        lambda x: (x.Preferred_name.fillna("").str.lower().str.startswith("pil"))
        | (x.Description.fillna("").str.lower().str.contains("pilin"))
        | (x.Description.fillna("").str.lower().str.contains("type iv secre"))
    ]
    .join(gene_clust)
    .join(gene_content_uhgg[focal_strains])[
        lambda x: x[["6", "9"]].fillna(False).any(axis=1)
    ]
    .sort_values("Preferred_name")
)

print(
    _related_annotations[["clust", *focal_strains]]
    .value_counts()
    .unstack("clust", fill_value=0)
    .T
)

In [None]:
for gene_id, x in _related_annotations.iterrows():
    print(
        gene_id,
        x.clust,
        idxwhere(x[focal_strains]),
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
    )
    print()

### Clust-418

In [None]:
clust = [418]
_clust_annotations = (
    gene_annotation.loc[idxwhere(gene_clust.isin(clust))]
    .join(gene_clust)
    .join(
        gene_content_uhgg[focal_strains].reindex(
            idxwhere(gene_clust.isin(clust)), fill_value=False
        )
    )
    .sort_values("Preferred_name")
)
print(gene_clust[lambda x: x.isin(clust)].value_counts())
print(
    gene_content_uhgg.reindex(_clust_annotations.index)[focal_strains]
    .value_counts()
    .unstack(fill_value=0)
)
print(
    gene_x_cog_category_matrix.reindex(idxwhere(gene_clust.isin(clust)))
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
)
# _clust_annotations

In [None]:
for gene_id, x in _clust_annotations.iterrows():
    print(
        gene_id,
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
        idxwhere(x[focal_strains]),
    )
    print()

#### sgc / yjh operons

In [None]:
_related_annotations = (
    gene_annotation[
        lambda x: (x.Preferred_name.fillna("").str.lower().str.startswith("sgc"))
        | (x.Preferred_name.fillna("").str.lower().str.startswith("yjh"))
        # | (x.Description.fillna('').str.lower().str.contains('pilin'))
        # | (x.Description.fillna('').str.lower().str.contains('type iv secre'))
    ]
    .join(gene_clust)
    .join(gene_content_uhgg[focal_strains])[
        lambda x: x[["6", "9"]].fillna(False).any(axis=1)
    ]
    .sort_values("Preferred_name")
)

print(
    _related_annotations[["clust", *focal_strains]]
    .value_counts()
    .unstack("clust", fill_value=0)
    .T
)

In [None]:
for _, x in _related_annotations.iterrows():
    print(
        gene_id,
        x.clust,
        idxwhere(x[focal_strains]),
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
    )
    print()

### Clust-1485

In [None]:
clust = [1485]
_clust_annotations = (
    gene_annotation.loc[idxwhere(gene_clust.isin(clust))]
    .join(gene_clust)
    .join(
        gene_content_uhgg[focal_strains].reindex(
            idxwhere(gene_clust.isin(clust)), fill_value=False
        )
    )
    .sort_values("Preferred_name")
)
print(gene_clust[lambda x: x.isin(clust)].value_counts())
print(
    gene_content_uhgg.reindex(_clust_annotations.index)[focal_strains]
    .value_counts()
    .unstack(fill_value=0)
)
print(
    gene_x_cog_category_matrix.reindex(idxwhere(gene_clust.isin(clust)))
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
)
# _clust_annotations

In [None]:
for gene_id, x in _clust_annotations.iterrows():
    print(
        gene_id,
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
        idxwhere(x[focal_strains]),
    )
    print()

### Clust-861

In [None]:
clust = [861]
_clust_annotations = (
    gene_annotation.loc[idxwhere(gene_clust.isin(clust))]
    .join(gene_clust)
    .join(
        gene_content_uhgg[focal_strains].reindex(
            idxwhere(gene_clust.isin(clust)), fill_value=False
        )
    )
    .sort_values("Preferred_name")
)
print(gene_clust[lambda x: x.isin(clust)].value_counts())
print(
    gene_content_uhgg.reindex(_clust_annotations.index)[focal_strains]
    .value_counts()
    .unstack(fill_value=0)
)
print(
    gene_x_cog_category_matrix.reindex(idxwhere(gene_clust.isin(clust)))
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
)
# _clust_annotations

In [None]:
for gene_id, x in _clust_annotations.iterrows():
    print(
        gene_id,
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
        idxwhere(x[focal_strains]),
    )
    print()

#### T6SS

In [None]:
_related_annotations = (
    gene_annotation[
        lambda x: (x.Preferred_name.fillna("").str.lower().str.startswith("imp"))
        | (x.Preferred_name.fillna("").str.lower().str.startswith("vas"))
        # | (x.Preferred_name.fillna('').str.lower().str.startswith('yjh'))
        | (x.Description.fillna("").str.lower().str.contains("type vi secret"))
        # | (x.Description.fillna('').str.lower().str.contains('Capsule polysaccharide'))
    ]
    .join(gene_clust)
    .join(gene_content_uhgg[focal_strains])[
        lambda x: x[["6", "9"]].fillna(False).any(axis=1)
    ]
    .sort_values("Preferred_name")
)

print(
    _related_annotations[["clust", *focal_strains]]
    .value_counts()
    .unstack("clust", fill_value=0)
    .T
)

In [None]:
for gene_id, x in _related_annotations.iterrows():
    print(
        gene_id,
        x.clust,
        idxwhere(x[focal_strains]),
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
    )
    print()

### Clust-789

In [None]:
clust = [789]
_clust_annotations = (
    gene_annotation.loc[idxwhere(gene_clust.isin(clust))]
    .join(gene_clust)
    .join(
        gene_content_uhgg[focal_strains].reindex(
            idxwhere(gene_clust.isin(clust)), fill_value=False
        )
    )
    .sort_values("Preferred_name")
)
print(gene_clust[lambda x: x.isin(clust)].value_counts())
print(
    gene_content_uhgg.reindex(_clust_annotations.index)[focal_strains]
    .value_counts()
    .unstack(fill_value=0)
)
print(
    gene_x_cog_category_matrix.reindex(idxwhere(gene_clust.isin(clust)))
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
)
# _clust_annotations

In [None]:
for gene_id, x in _clust_annotations.iterrows():
    print(
        gene_id,
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
        idxwhere(x[focal_strains]),
    )
    print()

#### fab operon

In [None]:
_related_annotations = (
    gene_annotation[
        lambda x: False
        | (x.Preferred_name.fillna("").str.lower().str.startswith("fab"))
        # | (x.Preferred_name.fillna('').str.lower().str.startswith('yjh'))
        # | (x.Description.fillna('').str.lower().str.contains('lipoprotein'))
        # | (x.Description.fillna('').str.lower().str.contains('Capsule polysaccharide'))
    ]
    .join(gene_clust)
    .join(gene_content_uhgg[focal_strains])[
        lambda x: x[["6", "9"]].fillna(False).any(axis=1)
    ]
    .sort_values("Preferred_name")
)

print(_related_annotations[["clust", *focal_strains]].sum())
print(
    _related_annotations[["clust", *focal_strains]]
    .value_counts()
    .unstack("clust", fill_value=0)
    .T
)

In [None]:
for gene_id, x in _related_annotations.iterrows():
    print(
        gene_id,
        x.clust,
        idxwhere(x[focal_strains]),
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
    )
    print()

### Clust-129

In [None]:
clust = [129]
_clust_annotations = (
    gene_annotation.loc[idxwhere(gene_clust.isin(clust))]
    .join(gene_clust)
    .join(
        gene_content_uhgg[focal_strains].reindex(
            idxwhere(gene_clust.isin(clust)), fill_value=False
        )
    )
    .sort_values("Preferred_name")
)
print(gene_clust[lambda x: x.isin(clust)].value_counts())
print(
    gene_content_uhgg.reindex(_clust_annotations.index)[focal_strains]
    .value_counts()
    .unstack(fill_value=0)
)
print(
    gene_x_cog_category_matrix.reindex(idxwhere(gene_clust.isin(clust)))
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
)
# _clust_annotations

In [None]:
for gene_id, x in _clust_annotations.iterrows():
    print(
        gene_id,
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
        idxwhere(x[focal_strains]),
    )
    print()

#### glc

In [None]:
_related_annotations = (
    gene_annotation[
        lambda x: (x.Preferred_name.fillna("").str.lower().str.startswith("glc"))
        # | (x.Preferred_name.fillna('').str.lower().str.startswith('JD73_'))
        # | (x.Preferred_name.fillna('').str.lower().str.startswith('yjh'))
        # | (x.Description.fillna('').str.lower().str.contains('type vi secret'))
        # | (x.Description.fillna('').str.lower().str.contains('Capsule polysaccharide'))
    ]
    .join(gene_clust)
    .join(gene_content_uhgg[focal_strains])[
        lambda x: x[["6", "9"]].fillna(False).any(axis=1)
    ]
    .sort_values("Preferred_name")
)

print(
    _related_annotations[["clust", *focal_strains]]
    .value_counts()
    .unstack("clust", fill_value=0)
    .T
)

In [None]:
for gene_id, x in _related_annotations.iterrows():
    print(
        gene_id,
        x.clust,
        idxwhere(x[focal_strains]),
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
    )
    print()

### Clust-351

In [None]:
clust = [351]
_clust_annotations = (
    gene_annotation.loc[idxwhere(gene_clust.isin(clust))]
    .join(gene_clust)
    .join(
        gene_content_uhgg[focal_strains].reindex(
            idxwhere(gene_clust.isin(clust)), fill_value=False
        )
    )
    .sort_values("Preferred_name")
)
print(gene_clust[lambda x: x.isin(clust)].value_counts())
print(
    gene_content_uhgg.reindex(_clust_annotations.index)[focal_strains]
    .value_counts()
    .unstack(fill_value=0)
)
print(
    gene_x_cog_category_matrix.reindex(idxwhere(gene_clust.isin(clust)))
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
)
# _clust_annotations

In [None]:
for gene_id, x in _clust_annotations.iterrows():
    print(
        gene_id,
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
        idxwhere(x[focal_strains]),
    )
    print()

#### pdu*

In [None]:
_related_annotations = (
    gene_annotation[
        lambda x: (x.Preferred_name.fillna("").str.lower().str.startswith("pdu"))
        # | (x.Preferred_name.fillna('').str.lower().str.startswith('jd73_'))
        # | (x.Preferred_name.fillna('').str.lower().str.startswith('yjh'))
        # | (x.Description.fillna('').str.lower().str.contains('glycolate '))
        # | (x.Description.fillna('').str.lower().str.contains('Capsule polysaccharide'))
    ]
    .join(gene_clust)
    .join(gene_content_uhgg[focal_strains])[
        lambda x: x[["6", "9"]].fillna(False).any(axis=1)
    ]
    .sort_values("Preferred_name")
)

print(
    _related_annotations[["clust", *focal_strains]]
    .value_counts()
    .unstack("clust", fill_value=0)
    .T
)

In [None]:
for gene_id, x in _related_annotations.iterrows():
    print(
        gene_id,
        x.clust,
        idxwhere(x[focal_strains]),
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
    )
    print()

### Clust-2275

In [None]:
clust = [2275]
_clust_annotations = (
    gene_annotation.loc[idxwhere(gene_clust.isin(clust))]
    .join(gene_clust)
    .join(
        gene_content_uhgg[focal_strains].reindex(
            idxwhere(gene_clust.isin(clust)), fill_value=False
        )
    )
    .sort_values("Preferred_name")
)
print(gene_clust[lambda x: x.isin(clust)].value_counts())
print(
    gene_content_uhgg.reindex(_clust_annotations.index)[focal_strains]
    .value_counts()
    .unstack(fill_value=0)
)
print(
    gene_x_cog_category_matrix.reindex(idxwhere(gene_clust.isin(clust)))
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
)
# _clust_annotations

In [None]:
for gene_id, x in _clust_annotations.iterrows():
    print(
        gene_id,
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
        idxwhere(x[focal_strains]),
    )
    print()

#### rfb*

In [None]:
_related_annotations = (
    gene_annotation[
        lambda x: (x.Preferred_name.fillna("").str.lower().str.startswith("rfb"))
        # | (x.Preferred_name.fillna('').str.lower().str.startswith('jd73_'))
        # | (x.Preferred_name.fillna('').str.lower().str.startswith('yjh'))
        # | (x.Description.fillna('').str.lower().str.contains('glycolate '))
        # | (x.Description.fillna('').str.lower().str.contains('Capsule polysaccharide'))
    ]
    .join(gene_clust)
    .join(gene_content_uhgg[focal_strains])[
        lambda x: x[["6", "9"]].fillna(False).any(axis=1)
    ]
    .sort_values("Preferred_name")
)

print(
    _related_annotations[["clust", *focal_strains]]
    .value_counts()
    .unstack("clust", fill_value=0)
    .T
)

In [None]:
for gene_id, x in _related_annotations.iterrows():
    print(
        gene_id,
        x.clust,
        idxwhere(x[focal_strains]),
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
    )
    print()

### Clust-172

In [None]:
clust = [172]
_clust_annotations = (
    gene_annotation.loc[idxwhere(gene_clust.isin(clust))]
    .join(gene_clust)
    .join(
        gene_content_uhgg[focal_strains].reindex(
            idxwhere(gene_clust.isin(clust)), fill_value=False
        )
    )
    .sort_values("Preferred_name")
)
print(gene_clust[lambda x: x.isin(clust)].value_counts())
print(
    gene_content_uhgg.reindex(_clust_annotations.index)[focal_strains]
    .value_counts()
    .unstack(fill_value=0)
)
print(
    gene_x_cog_category_matrix.reindex(idxwhere(gene_clust.isin(clust)))
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
)
# _clust_annotations

In [None]:
for gene_id, x in _clust_annotations.iterrows():
    print(
        gene_id,
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
        idxwhere(x[focal_strains]),
    )
    print()

#### ecp*/mat*/common pilus

In [None]:
_related_annotations = (
    gene_annotation[
        lambda x: (x.Preferred_name.fillna("").str.lower().str.startswith("ecp"))
        | (x.Preferred_name.fillna("").str.lower().str.startswith("mat"))
        # | (x.Preferred_name.fillna('').str.lower().str.startswith('yjh'))
        | (x.Description.fillna("").str.lower().str.contains("common pilus"))
        # | (x.Description.fillna('').str.lower().str.contains('Capsule polysaccharide'))
    ]
    .join(gene_clust)
    .join(gene_content_uhgg[focal_strains])[
        lambda x: x[["6", "9"]].fillna(False).any(axis=1)
    ]
    .sort_values("Preferred_name")
)

print(
    _related_annotations[["clust", *focal_strains]]
    .value_counts()
    .unstack("clust", fill_value=0)
    .T
)

In [None]:
for gene_id, x in _related_annotations.iterrows():
    print(
        gene_id,
        x.clust,
        idxwhere(x[focal_strains]),
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
    )
    print()

#### fim* / fimbrial / pap

In [None]:
_related_annotations = (
    gene_annotation[
        lambda x: (x.Preferred_name.fillna("").str.lower().str.startswith("fim"))
        | (x.Preferred_name.fillna("").str.lower().str.startswith("pap"))
        # | (x.Preferred_name.fillna('').str.lower().str.startswith('jd73_'))
        # | (x.Preferred_name.fillna('').str.lower().str.startswith('yjh'))
        | (x.Description.fillna("").str.lower().str.contains("fimbrial"))
        # | (x.Description.fillna('').str.lower().str.contains('Capsule polysaccharide'))
    ]
    .join(gene_clust)
    .join(gene_content_uhgg[focal_strains])[
        lambda x: x[["6", "9"]].fillna(False).any(axis=1)
    ]
    .sort_values("Preferred_name")
)

print(
    _related_annotations[["clust", *focal_strains]]
    .value_counts()
    .unstack("clust", fill_value=0)
    .T
)

In [None]:
for gene_id, x in _related_annotations.iterrows():
    print(
        gene_id,
        x.clust,
        idxwhere(x[focal_strains]),
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
    )
    print()

### Clust-2618

In [None]:
clust = [2618]
_clust_annotations = (
    gene_annotation.loc[idxwhere(gene_clust.isin(clust))]
    .join(gene_clust)
    .join(
        gene_content_uhgg[focal_strains].reindex(
            idxwhere(gene_clust.isin(clust)), fill_value=False
        )
    )
    .sort_values("Preferred_name")
)
print(gene_clust[lambda x: x.isin(clust)].value_counts())
print(
    gene_content_uhgg.reindex(_clust_annotations.index)[focal_strains]
    .value_counts()
    .unstack(fill_value=0)
)
print(
    gene_x_cog_category_matrix.reindex(idxwhere(gene_clust.isin(clust)))
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
)
# _clust_annotations

In [None]:
for gene_id, x in _clust_annotations.iterrows():
    print(
        gene_id,
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
        idxwhere(x[focal_strains]),
    )
    print()

#### hpa meta-operon

In [None]:
_related_annotations = (
    gene_annotation[
        lambda x: (x.Preferred_name.fillna("").str.lower().str.startswith("hpa"))
        # | (x.Preferred_name.fillna('').str.lower().str.startswith('jd73_'))
        # | (x.Preferred_name.fillna('').str.lower().str.startswith('yjh'))
        # | (x.Description.fillna('').str.lower().str.contains('glycolate '))
        # | (x.Description.fillna('').str.lower().str.contains('Capsule polysaccharide'))
    ]
    .join(gene_clust)
    .join(gene_content_uhgg[focal_strains])[
        lambda x: x[["6", "9"]].fillna(False).any(axis=1)
    ]
    .sort_values("Preferred_name")
)

print(
    _related_annotations[["clust", *focal_strains]]
    .value_counts()
    .unstack("clust", fill_value=0)
    .T
)

In [None]:
for gene_id, x in _related_annotations.iterrows():
    print(
        gene_id,
        x.clust,
        idxwhere(x[focal_strains]),
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
    )
    print()

### Clust-2478

In [None]:
clust = [2478]
_clust_annotations = (
    gene_annotation.loc[idxwhere(gene_clust.isin(clust))]
    .join(gene_clust)
    .join(
        gene_content_uhgg[focal_strains].reindex(
            idxwhere(gene_clust.isin(clust)), fill_value=False
        )
    )
    .sort_values("Preferred_name")
)
print(gene_clust[lambda x: x.isin(clust)].value_counts())
print(
    gene_content_uhgg.reindex(_clust_annotations.index)[focal_strains]
    .value_counts()
    .unstack(fill_value=0)
)
print(
    gene_x_cog_category_matrix.reindex(idxwhere(gene_clust.isin(clust)))
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
)
# _clust_annotations

In [None]:
for gene_id, x in _clust_annotations.iterrows():
    print(
        gene_id,
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
        idxwhere(x[focal_strains]),
    )
    print()

#### yih*

In [None]:
_related_annotations = (
    gene_annotation[
        lambda x: (x.Preferred_name.fillna("").str.lower().str.startswith("yih"))
        # | (x.Preferred_name.fillna('').str.lower().str.startswith('jd73_'))
        # | (x.Preferred_name.fillna('').str.lower().str.startswith('yjh'))
        # | (x.Description.fillna('').str.lower().str.contains('glycolate '))
        # | (x.Description.fillna('').str.lower().str.contains('Capsule polysaccharide'))
    ]
    .join(gene_clust)
    .join(gene_content_uhgg[focal_strains])[
        lambda x: x[["6", "9"]].fillna(False).any(axis=1)
    ]
    .sort_values("Preferred_name")
)

print(
    _related_annotations[["clust", *focal_strains]]
    .value_counts()
    .unstack("clust", fill_value=0)
    .T
)

In [None]:
for gene_id, x in _related_annotations.iterrows():
    print(
        gene_id,
        x.clust,
        idxwhere(x[focal_strains]),
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
    )
    print()

### Clust-1500

In [None]:
clust = [1500]
_clust_annotations = (
    gene_annotation.loc[idxwhere(gene_clust.isin(clust))]
    .join(gene_clust)
    .join(
        gene_content_uhgg[focal_strains].reindex(
            idxwhere(gene_clust.isin(clust)), fill_value=False
        )
    )
    .sort_values("Preferred_name")
)
print(gene_clust[lambda x: x.isin(clust)].value_counts())
print(
    gene_content_uhgg.reindex(_clust_annotations.index)[focal_strains]
    .value_counts()
    .unstack(fill_value=0)
)
print(
    gene_x_cog_category_matrix.reindex(idxwhere(gene_clust.isin(clust)))
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
)
# _clust_annotations

In [None]:
for gene_id, x in _clust_annotations.iterrows():
    print(
        gene_id,
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
        idxwhere(x[focal_strains]),
    )
    print()

## Strain-6

In [None]:
d = clust_content_combined > 0.75

strain6_only_clust_list = idxwhere(
    (d["6"] & ~d["9"])
)
gene_clust.value_counts().to_frame("total_genes").assign(
    cat=gene_x_cog_category_matrix.groupby(gene_clust)
    .sum()
    .apply(lambda x: x[lambda x: x > 0].sort_values(ascending=False).to_dict(), axis=1)
).loc[strain6_only_clust_list].sort_values("total_genes", ascending=False).head(20)

### Clust-1575

In [None]:
clust = [1575]
_clust_annotations = (
    gene_annotation.loc[idxwhere(gene_clust.isin(clust))]
    .join(gene_clust)
    .join(
        gene_content_uhgg[focal_strains].reindex(
            idxwhere(gene_clust.isin(clust)), fill_value=False
        )
    )
    .sort_values("Preferred_name")
)
print(gene_clust[lambda x: x.isin(clust)].value_counts())
print(
    gene_content_uhgg.reindex(_clust_annotations.index)[focal_strains]
    .value_counts()
    .unstack(fill_value=0)
)
print(
    gene_x_cog_category_matrix.reindex(idxwhere(gene_clust.isin(clust)))
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
)
# _clust_annotations

In [None]:
for gene_id, x in _clust_annotations.iterrows():
    print(
        gene_id,
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
        idxwhere(x[focal_strains]),
    )
    print()

### Clust-941

In [None]:
clust = [941]
_clust_annotations = (
    gene_annotation.loc[idxwhere(gene_clust.isin(clust))]
    .join(gene_clust)
    .join(
        gene_content_uhgg[focal_strains].reindex(
            idxwhere(gene_clust.isin(clust)), fill_value=False
        )
    )
    .sort_values("Preferred_name")
)
print(gene_clust[lambda x: x.isin(clust)].value_counts())
print(
    gene_content_uhgg.reindex(_clust_annotations.index)[focal_strains]
    .value_counts()
    .unstack(fill_value=0)
)
print(
    gene_x_cog_category_matrix.reindex(idxwhere(gene_clust.isin(clust)))
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
)
# _clust_annotations

In [None]:
for gene_id, x in _clust_annotations.iterrows():
    print(
        gene_id,
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
        idxwhere(x[focal_strains]),
    )
    print()

#### Thioredoxin (just a DEMO)

In [None]:
_related_annotations = (
    gene_annotation[
        lambda x: False
        # | (x.Preferred_name.fillna("").str.lower().str.startswith("fab"))
        # | (x.Preferred_name.fillna('').str.lower().str.startswith('yjh'))
        | (x.Description.fillna("").str.lower().str.contains("thioredoxin"))
        # | (x.Description.fillna('').str.lower().str.contains('Capsule polysaccharide'))
    ]
    .join(gene_clust)
    .join(gene_content_uhgg[focal_strains])[
        lambda x: x[["6", "9"]].fillna(False).any(axis=1)
    ]
    .sort_values("Preferred_name")
)

print(_related_annotations[["clust", *focal_strains]].sum())
print(
    _related_annotations[["clust", *focal_strains]]
    .value_counts()
    .unstack("clust", fill_value=0)
    .T
)

In [None]:
for gene_id, x in _related_annotations.iterrows():
    print(
        gene_id,
        x.clust,
        idxwhere(x[focal_strains]),
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
    )
    print()

### Clust-1322

In [None]:
clust = [1322]
_clust_annotations = (
    gene_annotation.loc[idxwhere(gene_clust.isin(clust))]
    .join(gene_clust)
    .join(
        gene_content_uhgg[focal_strains].reindex(
            idxwhere(gene_clust.isin(clust)), fill_value=False
        )
    )
    .sort_values("Preferred_name")
)
print(gene_clust[lambda x: x.isin(clust)].value_counts())
print(
    gene_content_uhgg.reindex(_clust_annotations.index)[focal_strains]
    .value_counts()
    .unstack(fill_value=0)
)
print(
    gene_x_cog_category_matrix.reindex(idxwhere(gene_clust.isin(clust)))
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
)
# _clust_annotations

In [None]:
for gene_id, x in _clust_annotations.iterrows():
    print(
        gene_id,
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
        idxwhere(x[focal_strains]),
    )
    print()

#### Ethanolamine

In [None]:
_related_annotations = (
    gene_annotation[
        lambda x:
        # (x.Preferred_name.fillna('').str.lower().str.startswith('TODO'))
        (x.Description.fillna("").str.lower().str.contains("ethanolamine"))
    ]
    .join(gene_clust)
    .join(gene_content_uhgg[focal_strains])[
        lambda x: x[["6", "9"]].fillna(False).any(axis=1)
    ]
    .sort_values("Preferred_name")
)

print(
    gene_content_uhgg.loc[_related_annotations.index, focal_strains]
    .value_counts()
    .unstack(fill_value=0)
)

In [None]:
for gene_id, x in _related_annotations.iterrows():
    print(
        gene_id,
        x.clust,
        idxwhere(x[focal_strains]),
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
    )
    print()

### Clust-1857

In [None]:
clust = [1857]
_clust_annotations = (
    gene_annotation.loc[idxwhere(gene_clust.isin(clust))]
    .join(gene_clust)
    .join(
        gene_content_uhgg[focal_strains].reindex(
            idxwhere(gene_clust.isin(clust)), fill_value=False
        )
    )
    .sort_values("Preferred_name")
)
print(gene_clust[lambda x: x.isin(clust)].value_counts())
print(
    gene_content_uhgg.reindex(_clust_annotations.index)[focal_strains]
    .value_counts()
    .unstack(fill_value=0)
)
print(
    gene_x_cog_category_matrix.reindex(idxwhere(gene_clust.isin(clust)))
    .sum()
    .sort_values(ascending=False)[lambda x: x > 0]
)
# _clust_annotations

In [None]:
for gene_id, x in _clust_annotations.iterrows():
    print(
        gene_id,
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
        idxwhere(x[focal_strains]),
    )
    print()

### Related annotations

#### Capsule polysaccharide

In [None]:
_related_annotations = (
    gene_annotation[
        lambda x: (x.Preferred_name.fillna("").str.lower().str.startswith("kps"))
        | (x.Description.fillna("").str.lower().str.contains("capsule polysac"))
    ]
    .join(gene_clust)
    .join(gene_content_uhgg[focal_strains])[
        lambda x: x[["6", "9"]].fillna(False).any(axis=1)
    ]
    .sort_values("Preferred_name")
)

print(
    _related_annotations[["clust", *focal_strains]]
    .value_counts()
    .unstack("clust", fill_value=0)
    .T
)

In [None]:
for gene_id, x in _related_annotations.iterrows():
    print(
        gene_id,
        x.clust,
        idxwhere(x[focal_strains]),
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
    )
    print()

#### frv*

In [None]:
_related_annotations = (
    gene_annotation[
        lambda x: (x.Preferred_name.fillna("").str.lower().str.startswith("frv"))
        | (
            x.Description.fillna("")
            .str.lower()
            .str.contains(
                "phosphoenolpyruvate-dependent sugar phosphotransferase system"
            )
        )
    ]
    .join(gene_clust)
    .join(gene_content_uhgg[focal_strains])[
        lambda x: x[["6", "9"]].fillna(False).any(axis=1)
    ]
    .sort_values("Preferred_name")
)

print(
    _related_annotations[["clust", *focal_strains]]
    .value_counts()
    .unstack("clust", fill_value=0)
    .T
)

In [None]:
for gene_id, x in _related_annotations.iterrows():
    print(
        gene_id,
        x.clust,
        idxwhere(x[focal_strains]),
        x.Preferred_name,
        x.Description,
        x.eggNOG_OGs,
    )
    print()