## Preamble

### Project Template

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import sys
import time
from datetime import datetime
from glob import glob
from itertools import chain, product
from tempfile import mkstemp

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels.api as sm
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.spatial.distance import squareform
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping
from lib.thisproject.genotype_dissimilarity import native_masked_hamming_distance_cdist

## Style

In [None]:
sns.set_context("paper")
plt.rcParams["figure.dpi"] = 100

## Analysis Parameters

In [None]:
ambiguity_threshold = 0.1

# Load and Check Data

## Select some GTDB genomes missing from UHGG

### Pick some examples genomes and generate the necessary metadata files.

In [None]:
prjna938932_genomes = pd.read_table(
    "raw/PRJNA938932_AssemblyDetails.txt",
    skiprows=2,
    names=["assembly_id", "level", "wgs", "biosample", "strain", "taxonomy", "_"],
)
prjna938932_genomes

In [None]:
def table_description(df):
    return pd.DataFrame(
        dict(col_idx=range(df.shape[1]), name=df.columns, example_val=df.iloc[0].values)
    ).set_index("col_idx")


table_description(prjna938932_genomes)

In [None]:
genome_meta = pd.read_table("meta/genome.tsv")

table_description(genome_meta)

# for i, (c, v) in enumerate(zip(genome_meta.columns, genome_meta.iloc[0].values)):
#     print(i, c, v, sep="\t")

In [None]:
genome_group_meta = pd.read_table("meta/genome_group.tsv")


table_description(genome_group_meta)

# for i, (c, v) in enumerate(
#     zip(genome_group_meta.columns, genome_group_meta.iloc[0].values)
# ):
#     print(i, c, v, sep="\t")

In [None]:
mgen_meta = pd.read_table("meta/mgen_to_reads.tsv")


table_description(mgen_meta)

# for i, (c, v) in enumerate(zip(mgen_meta.columns, mgen_meta.iloc[0].values)):
#     print(i, c, v, sep="\t")

In [None]:
mgen_group_meta = pd.read_table("meta/mgen_group.tsv")


table_description(mgen_group_meta)

# for i, (c, v) in enumerate(
#     zip(mgen_group_meta.columns, mgen_group_meta.iloc[0].values)
# ):
#     print(i, c, v, sep="\t")

In [None]:
prjna938932_genomes

In [None]:
prjna938932_genomes

d = prjna938932_genomes.assign(
    genome_id=lambda x: "Escherichia-coli-"
    + x.assembly_id.str.replace("GCA", "GCF").str.replace(".1$", "-1", regex=True),
    mgen_id=lambda x: x.genome_id,
    species_id="102506",
    genome_path=lambda x: "raw/genomes/ncbi/" + x.genome_id + "/assembly.fa",
    r1_path=lambda x: "raw/genomes/ncbi/" + x.genome_id + "/r1.fq.gz",
    r2_path=lambda x: "raw/genomes/ncbi/" + x.genome_id + "/r2.fq.gz",
    genome_group_id="potential_spikein_benchmark",
    mgen_group_id="potential_spikein_benchmark",
    preprocessing="noop",
    _old_genome_id="",
    ncbi_assembly_name=lambda x: x.assembly_id,
    ncbi_assembly_biosample=lambda x: x.biosample,
    comments=lambda x: "From PRJNA938932 " + x.assembly_id + ";",
).sort_values("genome_id")


d[genome_meta.columns].to_csv(
    "meta/prjna938932_ecoli_genome.tsv", index=False, sep="\t"
)
d[genome_group_meta.columns].to_csv(
    "meta/prjna938932_ecoli_genome_group.tsv", index=False, sep="\t"
)
d[mgen_group_meta.columns].to_csv(
    "meta/prjna938932_ecoli_mgen_group.tsv", index=False, sep="\t"
)
d[mgen_meta.columns].to_csv(
    "meta/prjna938932_ecoli_mgen_to_reads.tsv", index=False, sep="\t"
)

## Load metadata

In [None]:
uhgg_genome_meta = pd.read_table('ref/midasdb_uhgg_v20/metadata/genomes-all_metadata.tsv', index_col='New_Genome_accession')

## Compare Genotypes

In [None]:
species = "102506"

In [None]:
midas_assembly_inpath = f"data/species/sp-{species}/midasdb_v15.gtpro.mgtp.nc"
midas_assembly_geno = (
    sf.Metagenotype.load(midas_assembly_inpath).to_estimated_genotype()
    # .discretized(max_ambiguity=ambiguity_threshold)
)

In [None]:
spikein_assembly_inpath = f"data/group/potential_spikein_benchmark/species/sp-{species}/strain_genomes.gtpro.mgtp.nc"
spikein_assembly_geno = (
    sf.Metagenotype.load(spikein_assembly_inpath).to_estimated_genotype()
    # .discretized(max_ambiguity=ambiguity_threshold)
)

In [None]:
hmp2_inferred_inpath = (
    f"data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.spgc_ss-all.mgtp.nc"
)
hmp2_inferred_geno = (
    sf.Metagenotype.load(hmp2_inferred_inpath).to_estimated_genotype()
    # .discretized(max_ambiguity=ambiguity_threshold)
)

In [None]:
g1 = spikein_assembly_geno.discretized(max_ambiguity=ambiguity_threshold)
g2 = hmp2_inferred_geno.discretized(max_ambiguity=ambiguity_threshold)

g0 = sf.data.Genotype.concat(
    {
        "spike": g1,
        "hmp2": g2,
    },
    dim="strain",
    rename=False,
)

spikein_and_hmp2_sample_geno = g0

g0 = spikein_and_hmp2_sample_geno_ss = g0.random_sample(position=10_000)

spikein_to_hmp2_gdist = native_masked_hamming_distance_cdist(
    g0.sel(strain=g1.strain).values,
    g0.sel(strain=g2.strain).values,
)
spikein_to_hmp2_gdist = pd.DataFrame(
    spikein_to_hmp2_gdist,
    index=g1.strain,
    columns=g2.strain,
)

In [None]:
g1 = spikein_assembly_geno.discretized(max_ambiguity=ambiguity_threshold)
g2 = midas_assembly_geno.discretized(max_ambiguity=ambiguity_threshold)

g0 = sf.data.Genotype.concat(
    {
        "spike": g1,
        "midas": g2,
    },
    dim="strain",
    rename=False,
)

spikein_and_midas_geno = g0
spikein_and_midas_geno_ss = g0.sel(position=spikein_and_hmp2_sample_geno_ss.position)


spikein_to_midas_gdist = native_masked_hamming_distance_cdist(
    g0.sel(strain=g1.strain).values,
    g0.sel(strain=g2.strain).values,
)
spikein_to_midas_gdist = pd.DataFrame(
    spikein_to_midas_gdist,
    index=g1.strain,
    columns=g2.strain,
)

In [None]:
midas_ecoli_isolate_list = idxwhere((uhgg_genome_meta.Genome_type == 'Isolate') & uhgg_genome_meta.Lineage.str.endswith('s__Escherichia coli_D'))

spikein_to_midas_ecoli_isolate_gdist = spikein_to_midas_gdist.loc[:, midas_ecoli_isolate_list]

In [None]:
spikein_closest_match = pd.DataFrame(
    dict(
        idxmin_midas=spikein_to_midas_gdist.idxmin(axis=1),
        min_dist_midas=spikein_to_midas_gdist.min(axis=1),
        idxmin_isolate=spikein_to_midas_ecoli_isolate_gdist.idxmin(axis=1),
        min_dist_isolate=spikein_to_midas_ecoli_isolate_gdist.min(axis=1),
        idxmin_hmp2=spikein_to_hmp2_gdist.idxmin(axis=1),
        min_dist_hmp2=spikein_to_hmp2_gdist.min(axis=1),
    )
)

plt.hist(spikein_closest_match.min_dist_midas)

spikein_closest_match.sort_values("min_dist_midas", ascending=False)

In [None]:
unique_closest_match_spikein = spikein_closest_match.sort_values(
    "min_dist_midas", ascending=False
).drop_duplicates(subset=["idxmin_midas"])


spikein_benchmark_isolate_genomes = list(
    unique_closest_match_spikein.head(20).index
) + list(unique_closest_match_spikein.head(50).tail(5).index)

spikein_closest_match.loc[spikein_benchmark_isolate_genomes]

In [None]:
spikein_to_spikein_gdist = pd.DataFrame(
    native_masked_hamming_distance_cdist(
        spikein_and_midas_geno.sel(strain=spikein_assembly_geno.strain).values,
        spikein_and_midas_geno.sel(strain=spikein_assembly_geno.strain).values,
    ),
    index=spikein_assembly_geno.strain,
    columns=spikein_assembly_geno.strain,
)

In [None]:
_colors = pd.DataFrame(
    dict(
        min_dist_midas=spikein_closest_match.min_dist_midas.map(
            lambda x: mpl.cm.viridis(x * 20)
        ),
        min_dist_hmp2=spikein_closest_match.min_dist_hmp2.map(
            lambda x: mpl.cm.viridis(x * 20)
        ),
        is_selected=spikein_closest_match.index.to_series()
        .isin(spikein_benchmark_isolate_genomes)
        .map({False: "grey", True: "black"}),
    )
)

sns.clustermap(spikein_to_spikein_gdist, col_colors=_colors, row_colors=_colors)

In [None]:
d = prjna938932_genomes.assign(
    genome_id=lambda x: "Escherichia-coli-"
    + x.assembly_id.str.replace("GCA", "GCF").str.replace(".1$", "-1", regex=True),
    mgen_id=lambda x: x.genome_id,
    species_id="102506",
    genome_path=lambda x: "raw/genomes/ncbi/" + x.genome_id + "/assembly.fa",
    r1_path=lambda x: "raw/genomes/ncbi/" + x.genome_id + "/r1.fq.gz",
    r2_path=lambda x: "raw/genomes/ncbi/" + x.genome_id + "/r2.fq.gz",
    genome_group_id="potential_spikein_benchmark",
    mgen_group_id="potential_spikein_benchmark",
    preprocessing="noop",
    _old_genome_id="",
    ncbi_assembly_name=lambda x: x.assembly_id,
    ncbi_assembly_biosample=lambda x: x.biosample,
    comments=lambda x: "From PRJNA938932 " + x.assembly_id + ";",
)[lambda x: x.genome_id.isin(spikein_benchmark_isolate_genomes)].sort_values(
    "genome_id"
)


d[genome_meta.columns].to_csv("meta/bench_ecoli_genome.tsv", index=False, sep="\t")
d[genome_group_meta.columns].to_csv(
    "meta/bench_ecoli_genome_group.tsv", index=False, sep="\t"
)
d[mgen_group_meta.columns].to_csv(
    "meta/bench_ecoli_mgen_group.tsv", index=False, sep="\t"
)
d[mgen_meta.columns].to_csv("meta/bench_ecoli_mgen_to_reads.tsv", index=False, sep="\t")

## Select a subject with 5+ samples and no E. coli

In [None]:
species_list = (
    pd.read_table("meta/species_group.tsv")[lambda x: x.species_group_id == "hmp2"]
    .species_id.astype(str)
    .unique()
)

In [None]:
mgen = pd.read_table("meta/hmp2/mgen.tsv", index_col="library_id")
preparation = pd.read_table("meta/hmp2/preparation.tsv", index_col="preparation_id")
stool = pd.read_table("meta/hmp2/stool.tsv", index_col="stool_id")
visit = pd.read_table("meta/hmp2/visit.tsv", index_col="visit_id")
subject = pd.read_table("meta/hmp2/subject.tsv", index_col="subject_id")

meta_all = (
    mgen.join(preparation.drop(columns="library_type"), on="preparation_id")
    .join(stool, on="stool_id")
    .join(visit, on="visit_id", rsuffix="_")
    .join(subject, on="subject_id")
    .assign(
        new_name=lambda x: (
            x[["subject_id", "week_number"]]
            .assign(library_id=x.index)
            .assign(week_number=lambda x: x.week_number.fillna(999).astype(int))
            .apply(lambda x: "_".join(x.astype(str)), axis=1)
        )
    )
    # .reset_index()
    # .set_index('new_name')
)

library_id_to_new_name = meta_all.new_name

assert not any(meta_all.subject_id.isna())

# TODO: Rename samples based on subject and visit number
# TODO: Drop duplicate stools

In [None]:
species_depth_inpath = f"data/group/hmp2/species/sp-{species}/r.proc.gene99_v20-v23-agg75.spgc_specgene-ref-filt-p95.species_depth.tsv"
species_depth = pd.read_table(
    species_depth_inpath, names=["library_id", "species_depth"], index_col="library_id"
).squeeze()

d = (
    meta_all.assign(
        species_depth=species_depth, species_missing=lambda x: x.species_depth == 0
    )
    .groupby("subject_id")
    .species_missing.agg(["mean", "sum"])
)

plt.scatter("sum", "mean", data=d)

In [None]:
d.sort_values("sum", ascending=False)

In [None]:
meta_all.assign(
    species_depth=species_depth, species_missing=lambda x: x.species_depth == 0
)[lambda x: (x.subject_id == "C3022") & (x.species_depth == 0)]

In [None]:
spikein_mapping = pd.read_table('data/group/hmp2_spikein_benchmark/species/sp-102506/ecoli-spiked.strain_samples.tsv', names=['sample', 'strain'], index_col='sample').strain
spikein_mapping

In [None]:
spikein_closest_match.loc[spikein_mapping.unique()]

In [None]:
uhgg_genome_meta.loc[spikein_closest_match.loc[spikein_mapping.unique()].idxmin_midas]
# uhgg_genome_meta.loc[spikein_closest_match.loc[spikein_mapping.unique()].idxmin_midas]

In [None]:
uhgg_genome_meta.loc[spikein_closest_match.loc[spikein_mapping.unique()].idxmin_midas]

In [None]:
w = sf.data.World.load(
    "data/group/hmp2_spikein_benchmark/species/sp-102506/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0.world.nc"
)

genome_to_strain = w.sel(sample=spikein_mapping.index).community.to_pandas().groupby(spikein_mapping).mean().idxmax(1)
genome_to_strain

In [None]:
genome_to_strain.index.values

In [None]:
spikein_performance = []

for genome, strain in genome_to_strain.items():
    inpath = f'data/group/hmp2_spikein_benchmark/species/sp-102506/r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0.gene99_v20-v23-agg75.spgc_specgene-ref-filt-p95_ss-all_t-10_thresh-corr400-depth200.{genome}.eggnog-reconstruction_accuracy.tsv'
    benchmark = pd.read_table(inpath, index_col='strain').assign(genome_id=genome)
    print(inpath)
    # print(benchmark.loc[strain])
    # print(spikein_closest_match.loc[genome])
    # print()
    spikein_performance.append(benchmark.loc[strain].to_dict() | spikein_closest_match.loc[genome].to_dict())

spikein_performance = pd.DataFrame(spikein_performance)

print(spikein_performance[['precision', 'recall', 'f1']].mean())
spikein_performance

In [None]:
print(spikein_performance[['genome_id', 'idxmin_midas', 'min_dist_midas', 'precision', 'recall', 'f1']].to_markdown())

In [None]:
spikein_performance[['precision', 'recall', 'f1']].median()

In [None]:
uhgg_genome_meta.loc[spikein_performance.idxmin_isolate]

In [None]:
sns.pairplot(spikein_performance.assign(sqrt_min_dist_midas=lambda x: np.sqrt(x.min_dist_midas))[['f1', 'precision', 'recall', 'sqrt_min_dist_midas']])

In [None]:
for genome_id, midas_id in spikein_performance[['genome_id', 'idxmin_midas']].values:
    g = spikein_and_midas_geno.sel(strain=[genome_id, midas_id]).data
    shared_positions = idxwhere(~g.pipe(np.isnan).any("strain").to_series())
    num_mismatched, num_shared = (
        (g.sel(strain=genome_id) != g.sel(strain=midas_id)).sel(position=shared_positions).sum(),
        len(shared_positions)
    )
    print(
        genome_id,
        midas_id,
        num_mismatched,
        num_shared,
        (num_mismatched + 1) / (num_shared + 1)
    )