## Preamble

### Project Template

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os

import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from tqdm import tqdm

import lib.plot

### Set Style

In [None]:
sns.set_context("talk")
plt.rcParams["figure.dpi"] = 50

In [None]:
genome_type_palette = {"SPGC": "tab:green", "MAG": "tab:orange", "Isolate": "tab:blue"}

## Data Setup

### Metadata

In [None]:
species_list = (
    pd.read_table("meta/species_group.tsv")[
        lambda x: x.species_group_id == "xjin_ucfmt_hmp2"
    ]
    .species_id.astype(str)
    .unique()
)

In [None]:
def parse_taxonomy_string(taxonomy_string):
    values = taxonomy_string.split(";")
    return pd.Series(values, index=["d__", "p__", "c__", "o__", "f__", "g__", "s__"])

In [None]:
species_taxonomy_inpath = "ref/uhgg_genomes_all_v2.tsv"

species_taxonomy = (
    pd.read_table(species_taxonomy_inpath)[lambda x: x.Genome == x.Species_rep]
    .assign(species_id=lambda x: "1" + x.MGnify_accession.str.split("-").str[2])
    .set_index("species_id")[["Lineage"]]
)
species_taxonomy = species_taxonomy.join(
    species_taxonomy.Lineage.apply(parse_taxonomy_string)
)
species_taxonomy

In [None]:
phylum_palette = lib.plot.construct_ordered_palette(
    sorted(species_taxonomy.p__.unique()),
    cm="tab10",
)

### Strain Statistics

In [None]:
def classify_genome(x):
    """Re-label genomes to a mutually exclusive vocabulary for manuscript."""
    if (x.genome_type == "Isolate") & x.passes_filter:
        return "isolate"
    elif (x.genome_type == "Isolate") & ~x.passes_filter:
        return "isolate_fails_qc"
    elif (x.genome_type == "MAG") & x.passes_filter:
        return "mag"
    elif (x.genome_type == "MAG") & ~x.passes_filter:
        return "mag_fails_qc"
    elif (x.genome_type == "SPGC") & x.passes_filter:
        return "spgc"
    elif (x.genome_type == "SPGC") & x.passes_geno_positions & x.passes_in_sample_list:
        return "sfacts"
    elif (x.genome_type == "SPGC") & ~(
        x.passes_geno_positions & x.passes_in_sample_list
    ):
        return "sfacts_fails_qc"
    else:
        raise ValueError("Genome did not match classification criteria:", x)

In [None]:
filt_stats = []
missing_species = []

_species_list = species_list

for species in tqdm(_species_list):
    # NOTE: All paths are relative to /pollard/data/projects/bsmith/strain-corr
    # These tables are for each species individually.
    # Here's I'm concatenating them all together, 'cause it's easy enough to work with the whole set.
    inpath = f"data/group/xjin_ucfmt_hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_new-v22-agg75.spgc-fit.uhgg-strain_gene.strain_meta_for_analysis.tsv"
    if not os.path.exists(inpath):
        missing_species.append(species)
        continue
    data = pd.read_table(inpath, index_col="genome_id").assign(species=species)
    filt_stats.append(data)


print(
    len(missing_species),
    "out of",
    len(_species_list),
    "species are missing stats.",
)

filt_stats = pd.concat(filt_stats).assign(
    genome_class=lambda x: x.apply(classify_genome, axis=1)
)

In [None]:
d0 = (
    # For each dereplication cluster, tally number of genomes in each "genome class".
    filt_stats.groupby(["species", "clust"])
    .genome_class.value_counts()
    .unstack("genome_class", fill_value=0)
    .assign(
        # For each cluster, find the "highest" genome class present.
        highest_genome_class=lambda x: (
            x[
                [
                    "isolate",
                    "mag",
                    "spgc",
                    "sfacts",
                    "isolate_fails_qc",
                    "mag_fails_qc",
                    "sfacts_fails_qc",
                ]
            ]
            > 0
        ).idxmax(1)
    )
)
d1 = (
    # For each species, tally the number of clusters with with each highest genome class.
    d0.groupby("species")
    .highest_genome_class.value_counts()
    .unstack(fill_value=0)
    # Order columns to match the genome class priority order.
    [
        [
            "isolate",
            "mag",
            "spgc",
            "sfacts",
            "isolate_fails_qc",
            "mag_fails_qc",
            "sfacts_fails_qc",
        ]
    ]
)

# Append taxonomy.
d2 = d1.join(species_taxonomy).sort_values(list(species_taxonomy.columns.values))

# Plot a seaborn heatmap of these results.
row_colors = d2.p__.map(phylum_palette)
sns.clustermap(
    d2[["isolate", "mag", "spgc", "sfacts"]],
    norm=mpl.colors.SymLogNorm(1),
    col_cluster=False,
    row_cluster=False,
    row_colors=row_colors,
)

In [None]:
d3 = d1.assign(taxonomy=species_taxonomy.Lineage).sort_values("taxonomy")

d3.to_csv(
    "data/group/xjin_ucfmt_hmp2/r.proc.gtpro.sfacts-fit.gene99_new-v22-agg75.spgc-fit.uhgg-strain_gene.strain_meta_for_analysis.tsv",
    sep="\t",
)
d3