## Preamble

### Template Utils

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import time
from itertools import chain, product
from tempfile import mkstemp

import fastcluster
import matplotlib as mpl
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels as sm
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.spatial.distance import pdist, squareform
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import adjusted_mutual_info_score, adjusted_rand_score
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
import lib.thisproject.data
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
genome = pd.read_table('meta/genome.tsv', index_col='genome_id')

In [None]:
genome.species_id.value_counts().head(20)

In [None]:
species_id = "102478"

In [None]:
xjin_sample_inpath = "meta/XJIN_BENCHMARK/mgen.tsv"
xjin_strain_geno_inpath = f"data/species/sp-{species_id}/strain_genomes.gtpro.mgtp.nc"
spgc_strain_geno_inpath = f"data/group/xjin_ucfmt_hmp2/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.world.nc"

In [None]:
xjin_sample_list = pd.read_table(xjin_sample_inpath).mgen_id.to_list()

In [None]:
# "Reference GT-Pro genotype"
xjin_geno = sf.Metagenotype.load(xjin_strain_geno_inpath).to_estimated_genotype()
xjin_geno.sizes

In [None]:
sfacts_fit = sf.World.load(spgc_strain_geno_inpath)
sfacts_fit.sizes

In [None]:
matched_sample_list = list(set(sfacts_fit.sample.values) & set(xjin_sample_list))
len(matched_sample_list)

In [None]:
accuracy_inpaths = {
    genome_id: f"data/group/XJIN_BENCHMARK/species/sp-{species_id}/r.proc.gene99_new-v22-agg75.spgc-fit.{genome_id}.uhggtiles-reconstruction_accuracy.tsv"
    for genome_id in genome[lambda x: x.species_id == species_id].index
}
accuracy_inpaths

In [None]:
strain_cdist = xjin_geno.cdist(sfacts_fit.genotype)

In [None]:
for genome_id in accuracy_inpaths:
    _spgc_accuracy = pd.read_table(accuracy_inpaths[genome_id], index_col="strain").f1
    _geno_diss = strain_cdist.loc[genome_id]
    top_strain = strain_cdist.loc[genome_id].idxmin()
    if top_strain not in _spgc_accuracy.index:
        print(f"ERROR: Best hit to {genome_id} ({top_strain}) not found in accuracy table.")
    d = pd.DataFrame(dict(geno_diss=_geno_diss, gene_f1=_spgc_accuracy)).assign(gene_f1=lambda x: x.gene_f1.fillna(0))
    plt.scatter('geno_diss', 'gene_f1', data=d.sort_values('geno_diss'), label=genome_id)


plt.xscale('symlog', linthresh=1e-3, linscale=0.1)
plt.ylim(-0.05, 1.05)
plt.legend(bbox_to_anchor=(1, 1))

In [None]:
for genome_id in accuracy_inpaths:
    print(genome_id)
    spgc_accuracy = pd.read_table(accuracy_inpaths[genome_id], index_col="strain")
    print(strain_cdist.loc[genome_id].sort_values().head(10))
    print()
    top_strain = strain_cdist.loc[genome_id].idxmin()
    print(spgc_accuracy.head(5))
    print()
    if top_strain not in spgc_accuracy.index:
        print("ERROR: Best hit to {genome_id} ({top_strain}) not found in accuracy table.")
    else:
        print(spgc_accuracy.loc[top_strain])
    print("\n")