# Preamble

## Project Template

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

## Imports

In [None]:
import os
import subprocess
import sys
import time
from datetime import datetime
from glob import glob
from itertools import chain, product
from tempfile import mkstemp

import matplotlib as mpl
import matplotlib.pyplot as plt
import mpltern
import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import sfacts as sf
import statsmodels.api as sm
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.spatial.distance import pdist, squareform
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
from lib.dissimilarity import load_dmat_as_pickle
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
import lib.thisproject.data

## Set Style

In [None]:
sns.set_context("paper")
plt.rcParams["figure.dpi"] = 200

In [None]:
genome_type_palette = {
    "SPGC": "tab:green",
    "MAG": "tab:orange",
    "Isolate": "tab:blue",
    "Ref": "black",
}

# Data Setup

## Metadata

In [None]:
species_list = (
    pd.read_table("meta/species_group.tsv")[lambda x: x.species_group_id == "hmp2"]
    .species_id.astype(str)
    .unique()
)

In [None]:
def parse_taxonomy_string(taxonomy_string):
    values = taxonomy_string.split(";")
    return pd.Series(values, index=["d__", "p__", "c__", "o__", "f__", "g__", "s__"])

In [None]:
species_taxonomy_inpath = "ref/uhgg_genomes_all_v2.tsv"

species_taxonomy = (
    pd.read_table(species_taxonomy_inpath)[lambda x: x.Genome == x.Species_rep]
    .assign(species_id=lambda x: "1" + x.MGnify_accession.str.split("-").str[2])
    .set_index("species_id")
    .Lineage.apply(parse_taxonomy_string)
)
species_taxonomy

In [None]:
phylum_order = [
    "p__Euryarchaeota",
    "p__Thermoplasmatota",
    "p__Firmicutes",
    "p__Firmicutes_A",
    "p__Firmicutes_C",
    # "p__Firmicutes_B", # None in species_list1
    # "p__Firmicutes_G", # B/G/I not sure how related to C or A
    # "p__Firmicutes_I", #
    # "p__Cyanobacteria", # None in species_list1
    "p__Actinobacteriota",
    "p__Synergistota",
    "p__Fusobacteriota",
    "p__Campylobacterota",
    "p__Proteobacteria",
    "p__Desulfobacterota_A",
    "p__Bacteroidota",
    "p__Verrucomicrobiota",
    # "dummy0", # 18
    # "dummy1", # 19
    # "dummy2", # 20
]

phylum_palette = lib.plot.construct_ordered_palette(
    phylum_order,
    cm="rainbow",
    desaturate_levels=[1.0, 0.5],
)

fig = plt.figure(figsize=(2, 3), facecolor="none")
for p__ in phylum_order:
    print(p__, phylum_palette[p__])
    plt.scatter(
        [],
        [],
        color=phylum_palette[p__],
        lw=0.5,
        edgecolor="k",
        label=p__.replace("p__", ""),
    )
plt.legend(ncols=1, markerscale=1.5, frameon=False)
lib.plot.hide_axes_and_spines()
fig.savefig('fig/fig2_phylum_legend.pdf')

fig = plt.figure(figsize=(2, 1.5), facecolor="none")
for p__ in phylum_order:
    print(p__, phylum_palette[p__])
    plt.scatter(
        [],
        [],
        color=phylum_palette[p__],
        lw=0.5,
        edgecolor="k",
        label=p__.replace("p__", ""),
    )
plt.legend(ncols=2, markerscale=1.5, frameon=False)
lib.plot.hide_axes_and_spines()
fig.savefig('fig/fig4_phylum_legend_2col.pdf')



# assert len(set(phylum_palette.values())) == len((phylum_palette.values()))

In [None]:
mgen = pd.read_table("meta/hmp2/mgen.tsv", index_col="library_id")
preparation = pd.read_table("meta/hmp2/preparation.tsv", index_col="preparation_id")
stool = pd.read_table("meta/hmp2/stool.tsv", index_col="stool_id")
visit = pd.read_table("meta/hmp2/visit.tsv", index_col="visit_id")
subject = pd.read_table("meta/hmp2/subject.tsv", index_col="subject_id")

meta_all = (
    mgen.join(preparation.drop(columns="library_type"), on="preparation_id")
    .join(stool, on="stool_id")
    .join(visit, on="visit_id", rsuffix="_")
    .join(subject, on="subject_id")
    .assign(
        new_name=lambda x: (
            x[["subject_id", "week_number"]]
            .assign(library_id=x.index)
            .assign(week_number=lambda x: x.week_number.fillna(999).astype(int))
            .apply(lambda x: "_".join(x.astype(str)), axis=1)
        )
    )
    # .reset_index()
    # .set_index('new_name')
)

library_id_to_new_name = meta_all.new_name

assert not any(meta_all.subject_id.isna())

# TODO: Rename samples based on subject and visit number
# TODO: Drop duplicate stools

## Species Depth

In [None]:
species_depth = []
_missing_species = []

for species in tqdm(species_list):
    inpath = f"data/group/hmp2/species/sp-{species}/r.proc.gene99_v20-v23-agg75.spgc_specgene-ref-filt-p95.species_depth.tsv"
    if not os.path.exists(inpath):
        _missing_species.append(species)
        continue
    data = pd.read_table(inpath, names=["sample", "depth"]).assign(species=species)
    species_depth.append(data)
species_depth = (
    pd.concat(species_depth)
    .set_index(["sample", "species"])
    .depth.unstack(fill_value=0)
)

print(
    len(_missing_species),
    "out of",
    len(species_list),
    "species are missing.",
)

In [None]:
depth_thresh = 0.2

species_found = species_depth > depth_thresh
species_prevalence = (
    species_found.groupby(meta_all.subject_id).any().mean().sort_values(ascending=False)
)

species_rabund_when_found = species_depth.divide(species_depth.sum(1), axis=0).where(
    species_found, np.nan
)
species_mean_rabund_when_found = (
    species_rabund_when_found.groupby(meta_all.subject_id)
    .mean()
    .mean()
    .sort_values(ascending=False)
)
species_median_rabund_when_found = (
    species_rabund_when_found.groupby(meta_all.subject_id)
    .median()
    .median()
    .sort_values(ascending=False)
)

species_prevalence.to_frame("prevalence").assign(
    mean_rabund=species_mean_rabund_when_found,
    median_rabund=species_median_rabund_when_found,
).join(species_taxonomy).head(20)

## Strain Statistics

In [None]:
def classify_genome(x):
    if (x.genome_type == "Isolate") & x.passes_filter:
        return "isolate"
    elif (x.genome_type == "Isolate") & ~x.passes_filter:
        return "isolate_fails_qc"
    elif (x.genome_type == "MAG") & x.passes_filter:
        return "mag"
    elif (x.genome_type == "MAG") & ~x.passes_filter:
        return "mag_fails_qc"
    elif (x.genome_type == "SPGC") & x.passes_filter:
        return "spgc"
    elif (x.genome_type == "SPGC") & x.passes_geno_positions:
        return "sfacts_only"
    elif (x.genome_type == "SPGC") & ~(x.passes_geno_positions):
        return "sfacts_fails_qc"
    else:
        raise ValueError("Genome did not match classification criteria:", x)

In [None]:
filt_stats = []
missing_species = []

_species_list = species_list
# _species_list = ["100003"]

for species in tqdm(_species_list):
    inpath = f"data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v20-v23-agg75.spgc-fit.strain_meta_spgc_and_ref.tsv"
    if not os.path.exists(inpath):
        missing_species.append(species)
        continue
    data = pd.read_table(inpath).assign(species=species, inpath=inpath)
    filt_stats.append(data)
filt_stats = (
    pd.concat(filt_stats)
    .assign(
        genome_class=lambda x: x.apply(classify_genome, axis=1),
        species_strain=lambda x: x.species + "_" + x.genome_id,
    )
    .set_index("species_strain")
)


print(
    len(missing_species),
    "out of",
    len(_species_list),
    "species are missing stats.",
)

# Analysis

## How many strains and species were detected at each filtering level.

In [None]:
# Define different subsets of the species:

# All species:
# species_list
spgc_strain_list = filt_stats[lambda x: x.genome_type.isin(["SPGC"])].index.values

# All species with enough positions
species_list0 = filt_stats[lambda x: x.passes_geno_positions].species.unique()
spgc_strain_list0 = filt_stats[
    lambda x: x.passes_geno_positions
    & x.genome_type.isin(["SPGC"])
    & x.species.isin(species_list0)
].index.values

# All species with sf strains
species_list1 = filt_stats[
    lambda x: x.passes_geno_positions & x.genome_type.isin(["SPGC"])
].species.unique()
spgc_strain_list1 = filt_stats[
    lambda x: x.passes_geno_positions
    & x.genome_type.isin(["SPGC"])
    & x.species.isin(species_list1)
].index.values

# All species with sf strains to talk about distributions (>=10)
species_list1b = idxwhere(
    filt_stats[
        lambda x: x.passes_geno_positions & x.genome_type.isin(["SPGC"])
    ].species.value_counts()
    >= 10
)
spgc_strain_list1b = filt_stats[
    lambda x: x.passes_geno_positions
    & x.genome_type.isin(["SPGC"])
    & x.species.isin(species_list1b)
].index.values

# All species with spgc strains
species_list2 = filt_stats[
    lambda x: x.passes_filter & x.genome_type.isin(["SPGC"])
].species.unique()
spgc_strain_list2 = filt_stats[
    lambda x: x.passes_filter
    & x.genome_type.isin(["SPGC"])
    & x.species.isin(species_list2)
].index.values

# All species with enough spgc strains for pangenome analysis (>=10)
species_list3 = idxwhere(
    filt_stats[
        lambda x: x.passes_filter & x.genome_type.isin(["SPGC"])
    ].species.value_counts()
    >= 10
)
spgc_strain_list3 = filt_stats[
    lambda x: x.passes_filter
    & x.genome_type.isin(["SPGC"])
    & x.species.isin(species_list3)
].index.values

# Species with large numbers of strains (>=20)
species_list4 = idxwhere(
    filt_stats[
        lambda x: x.passes_filter & x.genome_type.isin(["SPGC"])
    ].species.value_counts()
    >= 20
)
spgc_strain_list4 = filt_stats[
    lambda x: x.passes_filter
    & x.genome_type.isin(["SPGC"])
    & x.species.isin(species_list4)
].index.values

_species_list_map = {
    "All considered species": (species_list, spgc_strain_list),
    "0: Species with enough genotyped positions": (species_list0, spgc_strain_list0),
    "1: With sfacts strains": (species_list1, spgc_strain_list1),
    "1b: With (>=10) sfacts strains": (species_list1b, spgc_strain_list1b),
    "2: With SPGC inferences": (species_list2, spgc_strain_list2),
    "3: With >=10 inferences": (species_list3, spgc_strain_list3),
    "4: With >=20 inferences": (species_list4, spgc_strain_list4),
}
for _species_list_name, (_species_list, _strain_list) in _species_list_map.items():
    print(_species_list_name, len(_species_list), len(_strain_list))
    print(species_taxonomy.loc[_species_list].p__.value_counts())
    print()

In [None]:
# Supplementary Table 1
d = filt_stats.loc[spgc_strain_list2][
    [
        "species",
        "num_gene",
        "num_strain_sample",
        "sum_strain_depth",
        "species_gene_frac",
        "log_selected_gene_depth_ratio_std",
        "num_geno_positions",
    ]
].astype(dict(num_gene=int, num_strain_sample=int, num_geno_positions=int))
d.to_csv("fig/hmp2_inferred_strains_supplementary_table1.tsv", sep="\t")
d

In [None]:
d = filt_stats.loc[spgc_strain_list2][
    [
        "species",
        "num_gene",
        "num_strain_sample",
        "sum_strain_depth",
        "species_gene_frac",
        "log_selected_gene_depth_ratio_std",
        "num_geno_positions",
    ]
].astype(dict(num_gene=int, num_strain_sample=int, num_geno_positions=int))

bins = [0] + list(np.logspace(0, 8, base=2, num=9).astype(int))

print(list(zip(*np.histogram(d.num_strain_sample, bins=bins))))
print(d.num_strain_sample.quantile([0.25, 0.5, 0.75]))
plt.hist(d.num_strain_sample, bins=bins, histtype='bar')
plt.xlim(0, 256)
plt.xscale('symlog', linthresh=1, linscale=0.2)
plt.xticks(ticks=bins, labels=bins)
plt.xticks(ticks=[], labels=[], minor=True)
plt.xlabel('Num. Strain-Pure Samples')
plt.ylabel('Strains (count)')

### Figure 3E

In [None]:
# Annotations for Figure 3 species tree:
filt_stats.loc[spgc_strain_list2].species.value_counts().to_csv(
    "fig/hmp2_spgc_strain_counts.tsv", sep="\t", header=True
)

In [None]:
diss_stats = []
missing_species = []

_species_list = species_list2
# _species_list = ["100003"]

for species in tqdm(_species_list):
    inpath = f"data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v20-v23-agg75.spgc-fit.strain_diss_spgc_and_ref.tsv"
    if not os.path.exists(inpath):
        missing_species.append(species)
        continue
    data = pd.read_table(inpath).assign(species=species, inpath=inpath)
    diss_stats.append(data)
diss_stats = (
    pd.concat(diss_stats)
    .set_index(["species", "genome_id"])
    .join(filt_stats.set_index(["species", "genome_id"]), rsuffix="_b")
    .reset_index()
    .assign(
        species_strain=lambda x: x.species + "_" + x.genome_id,
    )
    .set_index("species_strain")
)

print(
    len(missing_species),
    "out of",
    len(_species_list),
    "species are missing.",
)

## How many Archaea (+Bacteria)

In [None]:
filt_stats.loc[spgc_strain_list2].join(species_taxonomy, on="species")[
    lambda x: (x.d__ == "d__Archaea")
]

In [None]:
filt_stats.loc[spgc_strain_list2].join(species_taxonomy, on="species")[
    lambda x: (x.d__ == "d__Bacteria")
]

## Novelty of SPGC/SFacts strains

### Novel Branch Length Fraction

In [None]:
branch_lengths2 = []
missing_species2 = []

for species in tqdm(species_list):
    dmat_path = f"data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.spgc_ss-all.geno_uhgg-v20_pdist-mask10-pseudo10.pkl"
    meta_path = f"data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v20-v23-agg75.spgc-fit.strain_meta_spgc_and_ref.tsv"
    if not (os.path.exists(dmat_path) and os.path.exists(meta_path)):
        missing_species2.append(species)
        continue
    _meta = pd.read_table(meta_path, index_col="genome_id")
    dmat_all = lib.dissimilarity.load_dmat_as_pickle(dmat_path)
    without_spgc_list = idxwhere(
        _meta.passes_geno_positions & _meta.genome_type.isin(["Isolate", "MAG"])
    )
    with_spgc_list = idxwhere(
        _meta.passes_geno_positions & _meta.genome_type.isin(["Isolate", "MAG", "SPGC"])
    )
    dmat_without_spgc = dmat_all.loc[without_spgc_list, without_spgc_list]
    dmat_with_spgc = dmat_all.loc[with_spgc_list, with_spgc_list]
    if dmat_without_spgc.empty:
        missing_species2.append(species)
        continue
    _tree_without_spgc = sp.cluster.hierarchy.linkage(
        squareform(dmat_without_spgc), method="average"
    )
    _tree_with_spgc = sp.cluster.hierarchy.linkage(
        squareform(dmat_with_spgc), method="average"
    )
    branch_lengths2.append(
        dict(
            species=species,
            length_with_spgc=_tree_with_spgc[:, 2].sum(),
            length_without_spgc=_tree_without_spgc[:, 2].sum(),
            num_ref=len(without_spgc_list),
            num_spgc=len(with_spgc_list) - len(without_spgc_list),
            mean_diss_with_spgc=squareform(dmat_with_spgc).mean(),
            mean_diss_without_spgc=squareform(dmat_without_spgc).mean(),
        )
    )

branch_lengths2 = pd.DataFrame(branch_lengths2).set_index("species")

In [None]:
len(missing_species2)

In [None]:
branch_lengths2.assign(
    mean_diss_diff=lambda x: x.mean_diss_with_spgc - x.mean_diss_without_spgc,
    mean_diss_ratio=lambda x: x.mean_diss_diff / x.mean_diss_without_spgc,
).sort_values("mean_diss_ratio").head(50)

#### Figure 3D

In [None]:
d = (
    # Take all genomes that have enough geno positions for relatedness estimation
    filt_stats[lambda x: x.passes_geno_positions]
    # Count the number of genomes of each type.
    [["species", "genome_id", "genome_type"]]
    .value_counts()
    .unstack(fill_value=0)
    # Tag each cluster by it's "best type"
    .assign(best_genome_type=lambda x: x[["Isolate", "MAG", "SPGC"]].idxmax(1))
    # Count for each species the number of clusters with each tag.
    .groupby("species")
    .best_genome_type.value_counts()
    .unstack(fill_value=0)
    .join(
        branch_lengths2.assign(
            branch_length_diff=lambda x: x.length_with_spgc - x.length_without_spgc,
            branch_length_ratio=lambda x: x.branch_length_diff / x.length_without_spgc,
        )
    )
    .join(
        species_prevalence.to_frame("prevalence").assign(
            mean_rabund=species_mean_rabund_when_found,
            median_rabund=species_median_rabund_when_found,
        )
    )
    .join(species_taxonomy)
    .loc[species_list1]
)

p__meta = pd.DataFrame([], index=phylum_order[::-1]).assign(
    pos=lambda x: np.arange(len(x))
)

s_func = lambda x: 4 * x**0.5

np.random.seed(42)
fig, ax = plt.subplots(figsize=(3.8, 5), facecolor="none")
ax.scatter(
    "branch_length_ratio",
    "pos_jitter",
    data=d.join(p__meta, on="p__").assign(
        pos_jitter=lambda x: x.pos + np.random.uniform(-0.35, 0.35, size=len(x)),
        c=lambda x: x.p__.map(phylum_palette),
        s=lambda x: s_func(x["SPGC"]),  # Total SPGC genomes
    ),
    # c='c',
    marker="o",
    s="s",
    facecolors="c",
    edgecolors="k",
    lw=0.25,
    alpha=1.0,
    label="__nolegend__",
)
ax.set_yticks(p__meta.pos.unique())
ax.set_yticklabels(p__meta.index.to_series().str.replace("p__", ""))
# for ytick, p__, c in zip(ax.get_yticklabels(), p__meta.index.to_series(), p__meta.index.to_series().map(phylum_palette)):
#     print(ytick, p__, c)
#     ytick.set_color(c)


ax.set_xlabel("Branch Length Increase")
ax.set_xscale("symlog", linthresh=1e-2, linscale=0.2)
ax.set_xticks([0, 0.01, 0.1, 1.0, 3.0])
ax.set_xticklabels(["0%", "1%", "10%", "100%", "300%"])
ax.set_ylabel("Phylum")
ax.set_xlim(-0.01, 6.0)  # NOTE: This is expl

# # sns.stripplot(x='p__', y='Ref_branch_frac', data=d)

for count in [1, 10, 100]:
    ax.scatter(
        [],
        [],
        edgecolor="black",
        lw=0.25,
        facecolor="grey",
        s=s_func(count),
        label=count,
    )
plt.legend(title="Num.\ninferred\nstrains", fontsize=8, bbox_to_anchor=(0.755, 0.55))

plt.savefig("fig/fig3d_branch_length.pdf", bbox_inches="tight")

In [None]:
d = (
    # Take all genomes that have enough geno positions for relatedness estimation
    filt_stats[lambda x: x.passes_geno_positions]
    # Count the number of genomes of each type.
    [["species", "genome_id", "genome_type"]]
    .value_counts()
    .unstack(fill_value=0)
    # Tag each cluster by it's "best type"
    .assign(best_genome_type=lambda x: x[["Isolate", "MAG", "SPGC"]].idxmax(1))
    # Count for each species the number of clusters with each tag.
    .groupby("species")
    .best_genome_type.value_counts()
    .unstack(fill_value=0)
    .join(
        branch_lengths2.assign(
            branch_length_diff=lambda x: x.length_with_spgc - x.length_without_spgc,
            branch_length_ratio=lambda x: x.branch_length_diff / x.length_without_spgc,
        )
    )
    .join(
        species_prevalence.to_frame("prevalence").assign(
            mean_rabund=species_mean_rabund_when_found,
            median_rabund=species_median_rabund_when_found,
        )
    )
    .join(species_taxonomy)
    .loc[species_list1]
)

plt.scatter(
    d.num_ref, d.branch_length_ratio, c=d.num_spgc, norm=mpl.colors.PowerNorm(1 / 3)
)
plt.colorbar(label="Num. inferred strains")
plt.yscale("symlog", linthresh=1e-2)
plt.xscale("log")
plt.xlabel("Num. reference strains")
plt.ylabel("Branch Length Increase")

d[
    [
        "Isolate",
        "MAG",
        "SPGC",
        "branch_length_ratio",
        "prevalence",
        "mean_rabund",
        "median_rabund",
        "d__",
        "p__",
        "c__",
        "o__",
        "f__",
        "g__",
        "s__",
    ]
].sort_values("branch_length_ratio", ascending=False).head(40)

In [None]:
d = branch_lengths2.assign(
    branch_length_diff=lambda x: x.length_with_spgc - x.length_without_spgc,
    branch_length_ratio=lambda x: x.branch_length_diff / x.length_without_spgc,
).branch_length_ratio
for thresh in [0.1, 0.2, 0.5]:
    print(thresh, (d > thresh).sum())

In [None]:
species = "101493"

spgc_meta = pd.read_table(
    f"data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v20-v23-agg75.spgc-fit.strain_meta_spgc_and_ref.tsv",
    index_col="genome_id",
)
spgc_data = xr.load_dataset(
    f"data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v20-v23-agg75.spgc-fit.nc"
)
mgtp_diss_all = lib.dissimilarity.load_dmat_as_pickle(
    f"data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.spgc_ss-all.geno_uhgg-v20_pdist-mask10-pseudo10.pkl"
)

mgtp_diss = mgtp_diss_all.loc[
    spgc_meta.passes_geno_positions, spgc_meta.passes_geno_positions
]
mgtp_linkage = sp.cluster.hierarchy.linkage(
    squareform(mgtp_diss), method="average", optimal_ordering=False
)

_colors = pd.DataFrame(
    dict(
        type=spgc_meta.genome_type.map(genome_type_palette),
        filt=spgc_meta.passes_filter.map({True: "black", False: "grey"}),
    )
)
sns.clustermap(
    mgtp_diss,
    row_colors=_colors,
    col_colors=_colors,
    row_linkage=mgtp_linkage,
    col_linkage=mgtp_linkage,
    figsize=(10, 10),
    xticklabels=0,
    yticklabels=0,
)

In [None]:
w = sf.data.World.load(
    f"data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.world.nc"
)
position_ss = w.random_sample(position=1500).position
sf.plot.plot_metagenotype(
    w.sel(position=position_ss), col_linkage_func=lambda w: w.metagenotype.linkage()
)
sf.plot.plot_community(
    w.sel(position=position_ss), col_linkage_func=lambda w: w.metagenotype.linkage()
)

### Genotype Dissimilarity to Closest Reference

In [None]:
# This inlined calculation of genotype dissimilarity is necessary because the diss_stats are incomplete where no strains pass filtering.
min_geno_diss = {}
missing_species = []

_species_list = species_list1  # NOT Species list #2, here we're looking at the overall diversity of genotypes identified.
# _species_list = ["100003"]

for species in tqdm(_species_list):
    _spgc_list = filt_stats[
        lambda x: (x.species == species)
        & x.genome_type.isin(["SPGC"])
        & x.passes_geno_positions
    ].genome_id
    _ref_list = filt_stats[
        lambda x: (x.species == species)
        & x.genome_type.isin(["Isolate", "MAG"])
        & x.passes_geno_positions
    ].genome_id
    inpath = f"data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.spgc_ss-all.geno_uhgg-v20_pdist-mask10-pseudo10.pkl"
    if not os.path.exists(inpath):
        missing_species.append(species)
        continue
    min_geno_diss[species] = (
        load_dmat_as_pickle(inpath)
        .loc[_ref_list, _spgc_list]
        .min()
        .to_frame()
        .assign(species=species)
    )

min_geno_diss = (
    pd.concat(min_geno_diss.values())
    .rename(columns={0: "min_geno_diss"})
    .rename_axis(index="strain")
    .reset_index()
    .set_index(["species", "strain"])
    .min_geno_diss
)

In [None]:
bins = np.logspace(-5, 0, num=50)
plt.hist(min_geno_diss, bins=bins)
plt.axvline(0.05, lw=1, linestyle="--", color="k")
plt.axvline(min_geno_diss.median(), lw=1, linestyle="-", color="k")

plt.xscale("log")

print(min_geno_diss.median(), (min_geno_diss > 0.05).mean())

### Gene Content Dissimilarity to Closest Reference

In [None]:
diss_stats.ref_nn_uhgg_diss.reindex(spgc_strain_list2).dropna().median()

## Relationship between Genotype and Gene Content Dissimilarity

### Overall

In [None]:
d0 = (
    diss_stats.reindex(spgc_strain_list3)  # >= 10 strains
    .dropna(subset=["min_ref_geno_diss", "ref_nn_uhgg_diss"])
    .assign(log_min_ref_geno_diss=lambda x: np.log10(x.min_ref_geno_diss))
)

diss_corr = []
for species in d0.species.unique():
    d1 = d0[lambda x: x.species == species]
    diss_corr.append(
        (
            species,
            *sp.stats.spearmanr(
                d1.min_ref_geno_diss, d1.min_ref_eggnog_diss, alternative="greater"
            ),
            d1.min_ref_geno_diss.std(),
            len(d1),
        )
    )

diss_corr = pd.DataFrame(
    diss_corr,
    columns=["species", "spearmanr", "pvalue", "geno_diss_stdev", "num_strains"],
)
diss_corr.spearmanr.quantile([0.05, 0.25, 0.5, 0.75, 0.95])

In [None]:
plt.scatter(
    "geno_diss_stdev",
    "spearmanr",
    data=diss_corr,
    c="pvalue",
    norm=mpl.colors.LogNorm(),
    cmap="viridis_r",
    s="num_strains",
)
plt.colorbar()

sp.stats.spearmanr(diss_corr.geno_diss_stdev, diss_corr.spearmanr)

In [None]:
(diss_corr.pvalue < 0.05).sum()

### E. coli

In [None]:
d0 = (
    diss_stats.reindex(spgc_strain_list2)
    .dropna(subset=["min_ref_geno_diss", "ref_nn_uhgg_diss"])
    .assign(log_min_ref_geno_diss=lambda x: np.log10(x.min_ref_geno_diss))
)
d1 = d0[lambda x: x.species == "102506"]
print(d1.shape[0])
plt.scatter(d1["log_min_ref_geno_diss"], d1["ref_nn_eggnog_diss"])
sp.stats.spearmanr(d1["min_ref_geno_diss"], d1["ref_nn_eggnog_diss"])

### Figure 3F

In [None]:
import matplotlib.ticker as mtick

d = diss_stats[lambda x: x.passes_filter].assign(
    log_min_ref_geno_diss=lambda x: np.log10(x.min_ref_geno_diss)
)
d_ref = d[lambda x: x.genome_type.isin(["Isolate", "MAG"])]
d_spgc = d[lambda x: x.genome_type.isin(["SPGC"])]
x_bins = np.linspace(
    -6,
    0,
    num=30,
)
y_bins = np.linspace(0, 1)

fig, axs = plt.subplots(
    2,
    2,
    gridspec_kw=dict(
        width_ratios=(3, 0.5), height_ratios=(0.5, 3), hspace=0.15, wspace=0.15
    ),
    sharex="col",
    sharey="row",
    figsize=(4, 4),
)

axs[1, 0].hist2d(
    "log_min_ref_geno_diss",
    "ref_nn_eggnog_diss",
    data=d_ref,
    bins=(x_bins, y_bins),
    norm=mpl.colors.PowerNorm(1 / 2, vmax=3000),
    cmap="binary",
)
# axs[1, 0].scatter('log_min_ref_geno_diss', 'ref_nn_eggnog_diss', data=d_spgc, s=0.5, alpha=1.0, color=genome_type_palette['SPGC'])
# sns.kdeplot(x='log_min_ref_geno_diss', y='ref_nn_eggnog_diss', data=d_ref.sample(n=100), color=genome_type_palette['Ref'], ax=axs[1, 0], linewidths=1, levels=np.linspace(0, 1, num=11))
sns.kdeplot(
    x="log_min_ref_geno_diss",
    y="ref_nn_eggnog_diss",
    data=d_spgc,
    color=genome_type_palette["SPGC"],
    ax=axs[1, 0],
    linewidths=1,
    levels=np.linspace(0, 1, num=11),
    alpha=0.7,
)
focal_species, focal_species_label = "102506", "E. coli"
axs[1, 0].scatter(
    "log_min_ref_geno_diss",
    "ref_nn_eggnog_diss",
    data=d_spgc[lambda x: x.species == focal_species],
    color="darkgreen",
    s=20,
    edgecolor="white",
    linewidth=0.5,
    label="__nolegend__",
)
axs[1, 0].set_xlabel("Genotype Dissimilarity")
axs[1, 0].set_ylabel("Gene Content Dissimilarity")
axs[1, 0].set_ylim(0, 0.39)

# Legend
with mpl.rc_context({"path.sketch": (5, 50, 1)}):
    axs[1, 0].plot([], [], lw=1, label="All Novel", color=genome_type_palette["SPGC"])
axs[1, 0].scatter(
    [],
    [],
    color="darkgreen",
    s=20,
    edgecolor="silver",
    linewidth=0.5,
    label=focal_species_label,
)
axs[1, 0].legend(loc="upper left")

axs[1, 1].hist(
    "ref_nn_eggnog_diss",
    data=d_ref,
    orientation="horizontal",
    bins=y_bins,
    histtype="step",
    alpha=0.8,
    lw=1,
    color=genome_type_palette["Ref"],
    density=True,
    label="__nolegend__",
)
axs[1, 1].hist(
    "ref_nn_eggnog_diss",
    data=d_ref,
    orientation="horizontal",
    bins=y_bins,
    histtype="stepfilled",
    alpha=0.5,
    color=genome_type_palette["Ref"],
    density=True,
    label="Reference",
)
axs[1, 1].hist(
    "ref_nn_eggnog_diss",
    data=d_spgc,
    orientation="horizontal",
    bins=y_bins,
    histtype="step",
    alpha=0.8,
    lw=1,
    color=genome_type_palette["SPGC"],
    density=True,
    label="__nolegend__",
)
axs[1, 1].hist(
    "ref_nn_eggnog_diss",
    data=d_spgc,
    orientation="horizontal",
    bins=y_bins,
    histtype="stepfilled",
    alpha=0.5,
    color=genome_type_palette["SPGC"],
    density=True,
    label="Novel",
)
# axs[1, 1].set_xticks([0])
axs[1, 1].set_xlabel("Strains (density)")
# axs[1, 1].legend(bbox_to_anchor=(1, 1))


axs[0, 0].hist(
    "log_min_ref_geno_diss",
    data=d_ref,
    orientation="vertical",
    bins=x_bins,
    histtype="step",
    alpha=0.8,
    lw=1,
    color=genome_type_palette["Ref"],
    density=True,
    label="__nolegend__",
)
axs[0, 0].hist(
    "log_min_ref_geno_diss",
    data=d_ref,
    orientation="vertical",
    bins=x_bins,
    histtype="stepfilled",
    alpha=0.5,
    color=genome_type_palette["Ref"],
    density=True,
    label="Reference",
)
axs[0, 0].hist(
    "log_min_ref_geno_diss",
    data=d_spgc,
    orientation="vertical",
    bins=x_bins,
    histtype="step",
    alpha=0.8,
    lw=1,
    color=genome_type_palette["SPGC"],
    density=True,
    label="__nolegend__",
)
axs[0, 0].hist(
    "log_min_ref_geno_diss",
    data=d_spgc,
    orientation="vertical",
    bins=x_bins,
    histtype="stepfilled",
    alpha=0.5,
    color=genome_type_palette["SPGC"],
    density=True,
    label="Novel",
)
# axs[0, 0].set_yticks([0])
# axs[0, 0].set_ylabel('Strains (density)')
# axs[0, 0].legend(bbox_to_anchor=(1, 1))

# axs[0, 0].set_ylabel('dens.')

axs[1, 0].xaxis.set_major_formatter(mtick.FormatStrFormatter("$10^{%d}$"))


lib.plot.hide_axes_and_spines(axs[0, 1])

# fig.savefig('fig/spgc_genotype_gene_content_joint_distributions.svg', bbox_inches='tight')
sp.stats.spearmanr(
    d_spgc[lambda x: x.species == focal_species].min_ref_geno_diss,
    d_spgc[lambda x: x.species == focal_species].ref_nn_eggnog_diss,
)

plt.savefig("fig/fig3f_closest_reference.pdf", bbox_inches="tight")

In [None]:
fig, ax1 = plt.subplots(figsize=(3, 3), facecolor="none")

ax2 = ax1.twinx()

# Legend
with mpl.rc_context({"path.sketch": (5, 50, 1)}):
    ax1.plot([], [], lw=2, label="All Novel", color=genome_type_palette["SPGC"])
ax1.scatter(
    [],
    [],
    color="darkgreen",
    s=150,
    edgecolor="silver",
    linewidth=2,
    label=focal_species_label,
)
ax1.legend(loc="upper left", facecolor="none")
lib.plot.hide_axes_and_spines(ax=ax1)

ax2.hist(
    [],
    histtype="stepfilled",
    alpha=0.5,
    color=genome_type_palette["Ref"],
    label="Reference",
)
ax2.hist(
    [],
    histtype="stepfilled",
    alpha=0.5,
    color=genome_type_palette["SPGC"],
    label="Novel",
)
ax2.legend(loc="lower left", facecolor="none")
lib.plot.hide_axes_and_spines(ax=ax2)

In [None]:
fig = plt.figure(figsize=(4, 2), facecolor="none")
x, _, _, _ = plt.hist2d(
    [], [], norm=mpl.colors.PowerNorm(1 / 2, vmin=0, vmax=3000), cmap="binary"
)
ax2 = plt.colorbar(label="Reference Strains (count)", orientation="horizontal")
lib.plot.hide_axes_and_spines()
lib.plot.rotate_xticklabels(ax2.ax, rotation=30)
# plt.tight_layout()

In [None]:
fig = plt.figure(figsize=(2, 4), facecolor="none")
x, _, _, _ = plt.hist2d(
    [], [], norm=mpl.colors.PowerNorm(1 / 2, vmin=0, vmax=3000), cmap="binary"
)
ax2 = plt.colorbar(label="Reference Strains (count)", ticks=[0, 500, 2500])
lib.plot.hide_axes_and_spines()
lib.plot.rotate_xticklabels(ax2.ax, rotation=30)
# plt.tight_layout()

fig.savefig(
    "fig/spgc_genotype_gene_content_joint_distributions_colorbar.svg",
    bbox_inches="tight",
)