## Preamble

### Project Template

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import sys
import time
from datetime import datetime
from glob import glob
from itertools import chain, product
from tempfile import mkstemp

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels.api as sm
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
from lib.dissimilarity import load_dmat_as_pickle
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
import lib.thisproject.data

### Set Style

In [None]:
sns.set_context("talk")
plt.rcParams["figure.dpi"] = 50

In [None]:
genome_type_palette = {"SPGC": "tab:green", "MAG": "tab:orange", "Isolate": "tab:blue"}

## Parameters

In [None]:
group_id = "xjin_ucfmt_hmp2"
# species = '101337'  # 101433, 101337, 100236

## Data Setup

### Metadata

#### Taxonomy

In [None]:
species_list = (
    pd.read_table("meta/species_group.tsv")[lambda x: x.species_group_id == group_id]
    .species_id.astype(str)
    .unique()
)

In [None]:
def parse_taxonomy_string(taxonomy_string):
    values = taxonomy_string.split(";")
    return pd.Series(values, index=["d__", "p__", "c__", "o__", "f__", "g__", "s__"])

In [None]:
species_taxonomy_inpath = "ref/uhgg_genomes_all_v2.tsv"

species_taxonomy = (
    pd.read_table(species_taxonomy_inpath)[lambda x: x.Genome == x.Species_rep]
    .assign(species_id=lambda x: "1" + x.MGnify_accession.str.split("-").str[2])
    .set_index("species_id")
    .Lineage.apply(parse_taxonomy_string)
)

#### Samples

In [None]:
mgen = pd.read_table("meta/ucfmt/mgen.tsv", index_col="mgen_id")
sample = pd.read_table("meta/ucfmt/sample.tsv", index_col="sample_id")
subject = pd.read_table("meta/ucfmt/subject.tsv", index_col="subject_id")
assert mgen.sample_id.isin(sample.index).all()

mgen_meta = mgen.join(sample, on="sample_id").join(subject, on="subject_id")

In [None]:
subject_and_donor_list = list(
    mgen_meta[lambda x: x.donor_subject_id.isin(["D0044", "D0097"])]
    .sort_values(["donor_subject_id"])
    .subject_id.unique()
)
subject_list = list(
    mgen_meta[lambda x: x.recipient & x.donor_subject_id.isin(["D0044", "D0097"])]
    .sort_values(["donor_subject_id"])
    .subject_id.unique()
)
subject_palette = lib.plot.construct_ordered_palette(subject_and_donor_list)

In [None]:
mgen_list = list(mgen_meta.index)

In [None]:
d97_mgen_list = idxwhere(mgen_meta.subject_id == "D0097")
d44_mgen_list = idxwhere(mgen_meta.subject_id == "D0044")

(len(d97_mgen_list), len(d44_mgen_list))

### Quick result for Jacqueline Moltzau

In [None]:
for _species in ["101337", "101433", "100236"]:
    world_all = sf.data.World.load(
        f"data/group/{group_id}/species/sp-{_species}/r.proc.gtpro.sfacts-fit.world.nc"
    ).drop_low_abundance_strains(0.01)
    _mgen_list = list(set(mgen_list) & set(world_all.sample.values))
    if not _mgen_list:
        print(f"Missing data for {_species}.")
        continue

    world = world_all.sel(sample=_mgen_list).drop_low_abundance_strains(0.01)

    sample_linkage = world.unifrac_linkage()
    world_ss = world.random_sample(position=500)

    d = (
        world.community.to_series()[lambda x: x > 0.1]
        .to_frame()
        .reset_index()
        .rename(columns=dict(sample="mgen_id"))
        .join(mgen_meta, on="mgen_id")
        .assign(
            sample_class=lambda x: x.sample_type.replace(
                {
                    # Padding and renaming
                    "baseline": "baseline",
                    "donor": "    donor",
                    "maintenance": "     other",
                    "followup": "     other",
                    "post_antibiotic": "     other",
                }
            )
        )
        .groupby(["donor_subject_id", "subject_id", "sample_class"])
        .strain.value_counts()
        .unstack("strain", fill_value=0)
        .sort_index(ascending=[True, True, False])
    )

    strain_order = d.sum().sort_values(ascending=False).index
    d = d.loc[:, strain_order]

    nrow, ncol = d.shape
    fig, ax = plt.subplots(figsize=(0.4 * ncol + 0.5, 0.5 * nrow + 1))
    sns.heatmap(d, norm=mpl.colors.PowerNorm(1 / 3), annot=True, cbar=False, ax=ax)
    ax.set_title((_species, species_taxonomy.loc[_species].s__))

### Closest references

In [None]:
_species = "101337"

world_all = sf.data.World.load(
    f"data/group/{group_id}/species/sp-{_species}/r.proc.gtpro.sfacts-fit.world.nc"
).drop_low_abundance_strains(0.01)
_mgen_list = list(set(mgen_list) & set(world_all.sample.values))

assert _mgen_list

world = world_all.sel(sample=_mgen_list).drop_low_abundance_strains(0.01)

sample_linkage = world.unifrac_linkage()
world_ss = world.random_sample(position=500)

d = (
    world.community.to_series()[lambda x: x > 0.1]
    .to_frame()
    .reset_index()
    .rename(columns=dict(sample="mgen_id"))
    .join(mgen_meta, on="mgen_id")
    .assign(
        sample_class=lambda x: x.sample_type.replace(
            {
                # Padding and renaming
                "baseline": "baseline",
                "donor": "    donor",
                "maintenance": "     other",
                "followup": "     other",
                "post_antibiotic": "     other",
            }
        )
    )
    .groupby(["donor_subject_id", "subject_id", "sample_class"])
    .strain.value_counts()
    .unstack("strain", fill_value=0)
    .sort_index(ascending=[True, True, False])
)

strain_order = d.sum().sort_values(ascending=False).index
d = d.loc[:, strain_order]

nrow, ncol = d.shape
fig, ax = plt.subplots(figsize=(0.4 * ncol + 0.5, 0.5 * nrow + 1))
sns.heatmap(d, norm=mpl.colors.PowerNorm(1 / 3), annot=True, cbar=False, ax=ax)
ax.set_title((_species, species_taxonomy.loc[_species].s__))

In [None]:
species_id = "101337"

In [None]:
ref_mgtp = sf.data.Metagenotype.load(
    f"data/species/sp-{species_id}/midasdb_v15.gtpro.mgtp.nc"
)  # .rename_coords(sample=lambda s: 'UHGG' + s[len('GUT_GENOME'):])#.to_estimated_genotype(fillna=False)
# data/species/sp-101337/midasdb_v15.gtpro.mgtp.nc
spgc_mgtp = sf.data.Metagenotype.load(
    f"data/group/{group_id}/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.spgc_ss-all.mgtp.nc"
).rename_coords(
    sample=str
)  # .to_estimated_genotype(fillna=False)
isolate_mgtp = sf.data.Metagenotype.load(
    f"data/group/ucfmt/species/sp-{species_id}/strain_genomes.gtpro.mgtp.nc"
)

all_mgtp = sf.data.Metagenotype.concat(
    {"ref": ref_mgtp, "spgc": spgc_mgtp, "iso": isolate_mgtp},
    dim="sample",
    rename=False,
)
all_geno = all_mgtp.to_estimated_genotype().discretized(max_ambiguity=0.1)

In [None]:
from scripts.calculate_ref_and_spgc_pairwise_genotype_masked_hamming_distance import (
    native_masked_hamming_distance_pdist,
)

geno_cdmat = native_masked_hamming_distance_pdist(all_geno.values, pseudo=1.0)

all_geno_pdist = pd.DataFrame(
    squareform(geno_cdmat), index=all_geno.strain, columns=all_geno.strain
)
all_geno_pdist

In [None]:
all_geno_pdist["Bacteroides-fragilis-VPI-2553"].sort_values().head(50)

In [None]:
all_geno_pdist["Bacteroides-fragilis-VPI-2553"].loc[["1", "3"]]

In [None]:
from scipy.spatial.distance import squareform

plt.hist(squareform(all_geno_pdist), bins=100)
None

In [None]:
all_geno_linkage = sp.cluster.hierarchy.linkage(
    geno_cdmat, method="average", optimal_ordering=True
)

In [None]:
np.random.seed(0)
all_geno_ss = all_geno.random_sample(position=1000)
all_geno_ss_pos_linkage = sp.cluster.hierarchy.linkage(
    np.nan_to_num(all_geno_ss.values, nan=0.5).T,
    method="average",
    metric="correlation",
    optimal_ordering=True,
)

sf.plot.plot_genotype(
    all_geno_ss,
    row_linkage_func=lambda w: all_geno_linkage,
    col_linkage_func=lambda w: all_geno_ss_pos_linkage,
)

In [None]:
all_geno_pdist[ucfmt_strain_list].stack()[lambda x: x < 0.01].sort_values().head(50)

In [None]:
all_geno_pdist["3"][lambda x: (x < 0.01) & x.index.str.startswith("GUT_GENOME")]

In [None]:
reference_genome_meta = (
    pd.read_table(
        "ref/midasdb_uhgg_v15/metadata/2023-11-11-genomes-all_metadata.tsv",
        names=[
            "Genome",
            "Genome_type",
            "Length",
            "N_contigs",
            "N50",
            "GC_content",
            "Completeness",
            "Contamination",
            "rRNA_5S",
            "rRNA_16S",
            "rRNA_23S",
            "tRNAs",
            "Genome_accession",
            "Species_rep",
            "Lineage",
            "Sample_accession",
            "Study_accession",
            "Country",
            "Continent",
            "FTP_download",
            "_20",
            "_21",
        ],
        index_col=["Genome_accession"],
    )
    .assign(
        species_id=lambda x: "1" + x.Species_rep.str[len("MGYG0000") :],
    )[lambda x: x.species_id == species_id]
    .rename_axis(index="genome_id")
    # .rename(index=lambda x: "UHGG" + x[len("GUT_GENOME") :])
    .rename(
        columns={
            "Genome_type": "genome_type",
            "Completeness": "completeness",
            "Contamination": "contamination",
        }
    )
)

In [None]:
all_geno_pdist["1"][lambda x: (x < 0.01)]

In [None]:
reference_genome_meta.columns

In [None]:
reference_genome_meta.assign(distance_to_strain_1=all_geno_pdist["1"]).sort_values(
    "distance_to_strain_1"
)[
    [
        "genome_type",
        "Length",
        "N_contigs",
        "N50",
        "completeness",
        "contamination",
        "Sample_accession",
        "Study_accession",
        "distance_to_strain_1",
    ]
][lambda x: x.genome_type == 'Isolate']
# & x.index.str.startswith("UHGG")]
#     .rename(lambda s: "GUT_GENOME" + s[len("UHGG") :])
#     .index
# ]

In [None]:
reference_genome_meta.assign(distance_to_strain_1=all_geno_pdist["3"]).sort_values(
    "distance_to_strain_1"
)[
    [
        "genome_type",
        "Length",
        "N_contigs",
        "N50",
        "completeness",
        "contamination",
        "Sample_accession",
        "Study_accession",
        "distance_to_strain_1",
    ]
][lambda x: x.genome_type == 'Isolate']
# & x.index.str.startswith("UHGG")]
#     .rename(lambda s: "GUT_GENOME" + s[len("UHGG") :])
#     .index
# ]

In [None]:
dmat["3"].to_frame("snp_diss").rename(lambda s: "GUT_GENOME" + s[len("UHGG") :]).join(
    reference_genome_meta
).dropna()[lambda x: x.Genome_type == "Isolate"].sort_values("snp_diss")

In [None]:
dmat["1"].to_frame("snp_diss").rename(lambda s: "GUT_GENOME" + s[len("UHGG") :]).join(
    reference_genome_meta
).dropna()[lambda x: x.Genome_type == "Isolate"].sort_values("snp_diss")

In [None]:
reference_genome_meta.loc[
    dmat["3"][lambda x: (x < 0.01) & x.index.str.startswith("UHGG")]
    .rename(lambda s: "GUT_GENOME" + s[len("UHGG") :])
    .index
]

In [None]:
print(
    "\n".join(
        dmat["1"][lambda x: (x < 0.01) & x.index.str.startswith("UHGG")]
        .rename(lambda s: "MGYG-HGUT-0" + s[len("UHGG") :])
        .index
    )
)

In [None]:
sf.plot.plot_metagenotype(all_mgtp.sel(sample=strain_list), scalex=0.3)

In [None]:
w = sf.data.World.load(
    f"data/group/xjin_ucfmt_hmp2/species/sp-{_species}/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0.world.nc"
).rename_coords(strain=str)
w.community.to_pandas()["4"].sort_values(ascending=False).head(10)

### Strains

In [None]:
species = "102506"

In [None]:
species_taxonomy.loc[species]

In [None]:
world_all = sf.data.World.load(
    f"data/group/{group}/species/sp-{species}/r.proc.gtpro.sfacts-fit.world.nc"
).drop_low_abundance_strains(0.01)
_mgen_list = list(set(mgen_list) & set(world_all.sample.values))
assert mgen_list
world = world_all.sel(sample=_mgen_list).drop_low_abundance_strains(0.01)
print(world_all.sizes)
print(world.sizes)

sample_linkage = world.unifrac_linkage()
world_ss = world.random_sample(position=500)
mgen_colors = pd.DataFrame(
    dict(
        subject=mgen_meta.subject_id.map(subject_palette),
        donor=mgen_meta.donor_subject_id.map(subject_palette),
    )
)

sf.plot.plot_metagenotype(
    world_ss,
    col_linkage_func=lambda w: sample_linkage,
    row_linkage_func=lambda w: w.genotype.linkage("position"),
    col_colors=mgen_colors,
)
sf.plot.plot_community(
    world_ss,
    col_linkage_func=lambda w: sample_linkage,
    row_linkage_func=lambda w: w.genotype.linkage("strain"),
    col_colors=mgen_colors,
)

In [None]:
d = (
    world.community.to_pandas()
    .rename(index=dict(sample="mgen_id"))
    # Drop donor samples and all but follow-ups.
    .drop(
        idxwhere((~mgen_meta.recipient) | (mgen_meta.sample_type != "followup")),
        errors="ignore",
    )
    .gt(0.2)
    .groupby(mgen_meta.subject_id)
    # .join(mgen_meta, on="mgen_id")
    .any()
    .groupby(subject.donor_subject_id)
    .sum()
    .stack()
    .sort_values(ascending=False)
    # .groupby('donor_subject_id').head(1).to_frame("num_subjects")
    # .reset_index()
)

(
    d
    # .groupby('donor_subject_id').head(1)
    .head()
)

In [None]:
strain_gene = pd.read_table(
    f"data/group/{group}/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_new-v22-agg75.spgc-fit.uhgg-strain_gene.tsv"
)
strain_gene[["8", "1"]].value_counts()