In [None]:
%load_ext autoreload

In [None]:
import os as _os
_os.chdir(_os.environ['PROJECT_ROOT'])
_os.path.realpath(_os.path.curdir)

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from lib.pandas_util import idxwhere, align_indexes, invert_mapping
import matplotlib as mpl
import lib.plot
import statsmodels as sm
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm
import subprocess
from tempfile import mkstemp
import time
import subprocess
from itertools import chain
import os

In [None]:
import sfacts as sf

In [None]:
import lib.thisproject.data

In [None]:
sns.set_context('talk')
plt.rcParams['figure.dpi'] = 50

In [None]:
group = 'hmp2'
species = '102492'
stemA = 'r.proc'
stemB = 'filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts37-s85-seed0'
spgc_params = 'e100'
centroid = 75

path = dict(
    flag=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.midas_gene{centroid}.spgc-{spgc_params}.strain_files.flag",
    fit=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.world.nc",
    strain_correlation=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.midas_gene{centroid}.spgc-{spgc_params}.strain_correlation.tsv",
    strain_depth_ratio=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.midas_gene{centroid}.spgc-{spgc_params}.strain_depth_ratio.tsv",
    strain_fraction=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.comm.tsv",
    species_gene_mean_depth=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.midas_gene{centroid}.spgc-{spgc_params}.species_depth.tsv",
    species_gtpro_depth=f"data/group/{group}/{stemA}.gtpro.species_depth.tsv",
    species_correlation=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.midas_gene{centroid}.spgc-{spgc_params}.species_correlation.tsv",
    strain_thresholds=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.midas_gene{centroid}.spgc-{spgc_params}.strain_gene_threshold.tsv",
    gene_annotations=f"ref/midasdb_uhgg_gene_annotations/sp-{species}.gene{centroid}_annotations.tsv",
    raw_gene_depth=f"data/group/{group}/species/sp-{species}/{stemA}.midas_gene{centroid}.depth.nc",
    reference_copy_number=f"ref/midasdb_uhgg_pangenomes/{species}/midas_gene{centroid}.reference_copy_number.nc",
    cluster_info=f"ref/midasdb_uhgg/pangenomes/{species}/cluster_info.txt",
    species_taxonomy="ref/gtpro/species_taxonomy_ext.tsv",
    midasdb_genomes="ref/uhgg_genomes_all_4644.tsv",
    gtpro_reference_genotype=f"data/species/sp-{species}/gtpro_ref.mgtp.nc"
)

path_exists = {}
for p in path:
    path_exists[path[p]] = os.path.exists(path[p])

assert all(path_exists.values()), [p for p in path_exists if not path_exists[p]]

In [None]:
species_taxonomy = lib.thisproject.data.load_species_taxonomy(path["species_taxonomy"])
species_taxonomy.loc[species]

In [None]:
all_species_gtpro_depth = lib.thisproject.data.load_species_depth(path["species_gtpro_depth"])
all_species_gtpro_rabund = all_species_gtpro_depth.divide(all_species_gtpro_depth.sum(1), axis=0) 

plt.hist(all_species_gtpro_rabund[species], bins=np.linspace(0, 1, num=101))
plt.yscale('log')
None

In [None]:
plt.hist(all_species_gtpro_depth[species], bins=np.logspace(0, 3, num=51))
plt.xscale('log')
None

In [None]:
focal_species_core_depth = lib.thisproject.data.load_single_species_depth(path["species_gene_mean_depth"])

d = pd.DataFrame(dict(
    gtpro=all_species_gtpro_depth[species],
    gene=focal_species_core_depth,
))

plt.scatter('gtpro', 'gene', data=d, alpha=0.1)
plt.yscale('symlog', linthresh=1e-4)
plt.xscale('symlog', linthresh=1e-4)
None

In [None]:
gene_depth = xr.load_dataarray(path["raw_gene_depth"])

plt.hist(np.log10(gene_depth.isel(sample=0) + 1e-5), bins=50)
plt.yscale('log')
None

In [None]:
species_corr = pd.read_table(path["species_correlation"], names=['sample', 'correlation'], index_col='sample').squeeze()
plt.hist(species_corr, bins=np.linspace(0, 1, num=101))
plt.yscale('log')
None

In [None]:
mgen = pd.read_table('meta/hmp2/mgen.tsv', index_col='library_id')
preparation = pd.read_table('meta/hmp2/preparation.tsv', index_col='preparation_id')
stool = pd.read_table('meta/hmp2/stool.tsv', index_col='stool_id')
subject = pd.read_table('meta/hmp2/subject.tsv', index_col='subject_id')

sample_meta = mgen.join(preparation, on='preparation_id', rsuffix='_').join(stool, on='stool_id').join(subject, on='subject_id').loc[all_species_gtpro_depth.index]

In [None]:
species_depth = pd.read_table(path["species_gene_mean_depth"], names=['sample', 'depth'], index_col='sample').squeeze()

In [None]:
fit = sf.World.load(
    path["fit"]
).drop_low_abundance_strains(0.05)
print(fit.sizes)

np.random.seed(0)
position_ss = fit.random_sample(position=min(fit.sizes['position'], 1000)).position

In [None]:
ref_geno = sf.Genotype(
    sf.data.Metagenotype.load(path["gtpro_reference_genotype"])
    .to_estimated_genotype()
    .to_series()
    .unstack()
    .rename(lambda s: 'UHGG' + s[len('GUT_GENOME'):])
    .stack()
    .to_xarray()
    .sel(position=fit.position)
)

In [None]:
sample_to_strain = (
    (fit.community.data > 0.95)
    .to_series()
    .unstack()
    .apply(idxwhere, axis=1)
    [lambda x: x.apply(bool)]
    .str[0]
    .rename('strain')
)
    
strain_to_sample_list = (
    sample_to_strain
    .rename('strain_id')
    .reset_index()
    .groupby('strain_id')
    .apply(lambda x: x['sample'].to_list())
)
strain_to_sample_list.apply(len).sort_values(ascending=False).head()

In [None]:
strain_thresholds = (
    pd.read_table(path["strain_thresholds"], index_col='strain')
    .rename(columns=dict(
        correlation_strict='corr_threshold_strict',
        correlation_moderate='corr_threshold_moderate',
        correlation_lenient='corr_threshold_lenient',
        depth_high='depth_thresh_high',
        depth_low='depth_thresh_low',
    ))
)

_strain_meta = (
    strain_thresholds
    .join(fit.genotype.entropy().to_series().rename('genotype_entropy'))
    .join(fit.metagenotype.entropy().to_series().rename('metagenotype_entropy').groupby(sample_to_strain).mean().rename(int))
    .join(strain_to_sample_list.apply(len).rename('num_samples'))
    .join(species_depth.apply(lambda x: x**(3)).groupby(sample_to_strain).std().rename('depth_stdev').rename(int))
    .join(species_depth.apply(lambda x: x**(3)).groupby(sample_to_strain).max().rename('depth_max').rename(int))
    .join(species_depth.apply(lambda x: x**(3)).groupby(sample_to_strain).sum().rename('depth_sum').rename(int))
    .assign(power_index=lambda x: (x.depth_stdev * np.sqrt(x.num_samples)).fillna(0))
)
strain_meta = _strain_meta


power_index_thresh = 200
genotype_entropy_thresh = 0.2

high_power_strain_list = idxwhere((strain_meta.power_index > power_index_thresh) & (strain_meta.genotype_entropy < genotype_entropy_thresh))
print(len(high_power_strain_list))
highest_power_strain_list = strain_meta.sort_values('power_index', ascending=False).head(3).index

plt.scatter(strain_meta.power_index, strain_meta.corr_threshold_moderate, c=strain_meta.genotype_entropy, alpha=0.5)
plt.axvline(power_index_thresh, lw=1, linestyle='--', color='k')
plt.colorbar()
plt.xscale('log')

In [None]:
strain_meta.assign(high_power=lambda x: x.index.isin(high_power_strain_list)).sort_values('power_index', ascending=False).head(20)

In [None]:
high_power_strain_palette = lib.plot.construct_ordered_palette(high_power_strain_list, mpl.cm.Spectral)

In [None]:
strain_corr = pd.read_table(path["strain_correlation"], index_col=['gene_id', 'strain']).squeeze().unstack('strain', fill_value=0)
strain_depth = pd.read_table(
    path["strain_depth_ratio"],
    index_col=['gene_id', 'strain']
).squeeze().unstack()
strain_corr, strain_depth = align_indexes(*align_indexes(strain_corr, strain_depth), axis="columns")

In [None]:
reference_copy_number = xr.load_dataarray(path["reference_copy_number"])
reference_hit = pd.DataFrame(
    reference_copy_number.T > 0,
    columns=reference_copy_number.genome_id,
    index=reference_copy_number.gene_id,
)

In [None]:
_strain_list = high_power_strain_list

fig, axs = lib.plot.subplots_grid(ncols=3, naxes=len(_strain_list), ax_width=5, ax_height=4, sharex=True)

for strain, ax in zip(_strain_list, axs.flatten()):
    ax.scatter(1 - strain_corr[strain], strain_depth[strain], s=1, alpha=0.05)
    ax.axvline(1 - strain_meta['corr_threshold_moderate'][strain], color='k', linestyle='--')
    ax.axhline(strain_meta['depth_thresh_low'][strain], color='k', linestyle='--')
    ax.set_xscale('symlog', linthresh=1e-4)
    ax.set_xlim(left=-1e-5, right=1)
    ax.invert_xaxis()
    ax.set_yscale('log')
    ax.set_title(strain)
fig.tight_layout()

In [None]:
strict_corr_hit = strain_corr > strain_meta.corr_threshold_strict
lenient_corr_hit = strain_corr > strain_meta.corr_threshold_lenient
moderate_corr_hit = strain_corr > strain_meta.corr_threshold_moderate
low_corr =  strain_corr < strain_meta.corr_threshold_lenient

low_depth = (strain_depth < strain_meta.depth_thresh_low)
depth_hit = ~low_depth
high_depth = (strain_depth > strain_meta.depth_thresh_high)
high_confidence_hit = depth_hit & strict_corr_hit
moderate_hit = depth_hit & moderate_corr_hit
maybe_hit = depth_hit & lenient_corr_hit
low_depth_hit = low_depth & strict_corr_hit
high_depth_hit = high_depth & strict_corr_hit
ambiguous_hit = depth_hit ^ strict_corr_hit
high_confidence_not_hit = low_depth & low_corr

In [None]:
d = pd.DataFrame(dict(
    tally_moderate_hits=moderate_hit.sum(),
    sum_gene_ratio=strain_depth[moderate_hit].sum(),
)).assign(ratio=lambda x: x.sum_gene_ratio / x.tally_moderate_hits)

fig, ax = plt.subplots(figsize=(15, 5))
bins = np.linspace(0, 10000, num=21)
# sns.kdeplot(reference_hit.sum())
# sns.kdeplot(d.tally_moderate_hits)
# sns.kdeplot(d.tally_moderate_hits.loc[high_power_strain_list])


plt.hist(reference_hit.sum(), bins=bins, density=True, alpha=0.5)
# plt.hist(d.tally_moderate_hits, bins=bins, density=True, alpha=0.5)
plt.hist(d.tally_moderate_hits.loc[high_power_strain_list], bins=bins, density=True, alpha=0.5)

d.loc[high_power_strain_list].sort_values('ratio')

In [None]:
d = pd.DataFrame(dict(
    tally_moderate_hits=moderate_hit.sum(),
    sum_gene_ratio=strain_depth[moderate_hit].sum(),
)).assign(ratio=lambda x: x.sum_gene_ratio / x.tally_moderate_hits)

fig, ax = plt.subplots(figsize=(15, 5))
bins = np.linspace(0, 10000, num=21)
# sns.kdeplot(reference_hit.sum())
# sns.kdeplot(d.tally_moderate_hits)
# sns.kdeplot(d.tally_moderate_hits.loc[high_power_strain_list])


plt.hist(reference_copy_number.sum("gene_id"), bins=bins, density=True, alpha=0.5)
# plt.hist(d.sum_gene_ratio, bins=bins, density=True, alpha=0.5)
plt.hist(d.sum_gene_ratio.loc[high_power_strain_list], bins=bins, density=True, alpha=0.5)
None

In [None]:
samples_with_high_power_strains = idxwhere(fit.community.data.sel(strain=high_power_strain_list).sum("strain").to_series() > 0.5)
samples_without_high_power_strains = idxwhere(fit.community.data.sel(strain=high_power_strain_list).sum("strain").to_series() < 0.5)
len(samples_with_high_power_strains), len(samples_without_high_power_strains)

In [None]:
sf.plot.plot_genotype(
    fit.sel(strain=high_power_strain_list, position=position_ss),
    col_linkage_func=lambda w: w.metagenotype.linkage("position"),
    row_linkage_func=lambda w: w.genotype.linkage("strain"),
)

In [None]:
sf.plot.plot_metagenotype(
    fit.sel(sample=samples_with_high_power_strains, position=position_ss),
    col_linkage_func=lambda w: w.metagenotype.linkage(),
    row_linkage_func=lambda w: w.metagenotype.linkage("position"),
    col_colors=fit.sel(sample=samples_with_high_power_strains, position=position_ss).sample.to_series().map(sample_to_strain).map(high_power_strain_palette),
)

In [None]:
sf.plot.plot_community(
    fit.sel(sample=samples_with_high_power_strains, position=position_ss).drop_low_abundance_strains(0.05),
    col_linkage_func=lambda w: w.metagenotype.linkage(),
    row_linkage_func=lambda w: w.genotype.linkage("strain"),
    col_colors=fit.sel(sample=samples_with_high_power_strains, position=position_ss).sample.to_series().map(sample_to_strain).map(high_power_strain_palette),
)

In [None]:
gene_cluster = pd.read_table(
    path["cluster_info"]
).set_index('centroid_99', drop=False).rename_axis(index='gene_id')
gene_annotation = pd.read_table(
    path["gene_annotations"],
    names=['locus_tag', 'ftype', 'length_bp', 'gene', 'EC_number', 'COG', 'product'],
    index_col='locus_tag',
).rename(columns=str.lower)

gene_meta = gene_cluster.loc[gene_cluster[f'centroid_{centroid}'].unique()].join(gene_annotation)

In [None]:
_cog_meta = pd.read_table(
    'ref/cog-20.meta.tsv',
    names=['cog', 'categories', 'description', 'gene', 'pathway', '_1', '_2'],
    index_col=['cog']
)
cog_meta = _cog_meta.drop(columns=['categories', '_1', '_2'])
cog_x_category = _cog_meta.categories.apply(tuple).apply(pd.Series).unstack().to_frame(name='category').reset_index()[['cog', 'category']].dropna()

In [None]:
cog_category = pd.read_table('ref/cog-20.categories.tsv', names=['category', 'description'], index_col='category')

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere(moderate_hit[strain_list].sum(1) > 0)

x = strain_depth.loc[gene_list, strain_list]

# _gene_linkage = sp.cluster.hierarchy.linkage(x, metric='cosine')

if len(gene_list) < 2e4:
    sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=0,
        xticklabels=0,
        # row_linkage=_gene_linkage,
        col_linkage=fit.genotype.discretized().sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

print(len(gene_list), len(gene_list) - gene_annotation.loc[gene_list]['product'].value_counts()['hypothetical protein'])
print()
print(
    gene_annotation
    .loc[gene_list]
    .cog.to_frame()
    .join(cog_meta, on='cog')
    .pathway
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
)
print()
print(
    gene_meta
    .loc[gene_list]
    ['product']
    .value_counts()
    .head(10)
)
print()
print(pd.merge(
    gene_annotation.loc[gene_list].cog.dropna().to_frame(),
    cog_x_category,
    on='cog',
).category.value_counts().to_frame().join(cog_category).head(10))

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere(moderate_hit[strain_list].mean(1) > 0.8)

x = strain_depth.loc[gene_list, strain_list]

_gene_linkage = sp.cluster.hierarchy.linkage(x, metric='cosine')

if len(gene_list) < 2e4:
    sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=0,
        xticklabels=0,
        row_linkage=_gene_linkage,
        col_linkage=fit.genotype.discretized().sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

print(len(gene_list), len(gene_list) - gene_annotation.loc[gene_list]['product'].value_counts()['hypothetical protein'])
print()
print(
    gene_annotation
    .loc[gene_list]
    .cog.to_frame()
    .join(cog_meta, on='cog')
    .pathway
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
)
print()
print(
    gene_meta
    .loc[gene_list]
    ['product']
    .value_counts()
    .head(10)
)
print()
print(pd.merge(
    gene_annotation.loc[gene_list].cog.dropna().to_frame(),
    cog_x_category,
    on='cog',
).category.value_counts().to_frame().join(cog_category).head(10))

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere((moderate_hit[strain_list].mean(1) > 0.05) & (high_confidence_not_hit[strain_list].mean(1) > 0.2))

x = strain_depth.loc[gene_list, strain_list]

_gene_linkage = sp.cluster.hierarchy.linkage(x, metric='cosine')

if len(gene_list) < 2e4:
    sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=0,
        xticklabels=0,
        row_linkage=_gene_linkage,
        col_linkage=fit.genotype.discretized().sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

print(len(gene_list), len(gene_list) - gene_annotation.loc[gene_list]['product'].value_counts()['hypothetical protein'])
print()
print(
    gene_annotation
    .loc[gene_list]
    .cog.to_frame()
    .join(cog_meta, on='cog')
    .pathway
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
)
print()
print(
    gene_meta
    .loc[gene_list]
    ['product']
    .value_counts()
    .head(10)
)
print()
print(pd.merge(
    gene_annotation.loc[gene_list].cog.dropna().to_frame(),
    cog_x_category,
    on='cog',
).category.value_counts().to_frame().join(cog_category).head(10))

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere((moderate_hit[strain_list].mean(1) > 0.05) & (high_confidence_not_hit[strain_list].mean(1) > 0.2))

x = moderate_hit.loc[gene_list, strain_list].astype(float)

if len(gene_list) < 2e4:
    sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=0,
        xticklabels=0,
        row_linkage=_gene_linkage,
        col_linkage=fit.genotype.discretized().sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere((moderate_hit[strain_list].mean(1) > 0.05) & (high_confidence_not_hit[strain_list].mean(1) > 0.2))

x = (
    moderate_hit.loc[gene_list, strain_list].astype(float) * 3
    + (1 - high_confidence_not_hit.loc[gene_list, strain_list].astype(float)) * 2
)

if len(gene_list) < 2e4:
    sns.clustermap(
        x,
        metric='cosine',
        # norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=0,
        xticklabels=0,
        row_linkage=_gene_linkage,
        col_linkage=fit.genotype.discretized().sel(strain=strain_list).linkage("strain"),
        cmap='gray_r',
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

# Broad strokes characterization of gene sets

## Phylogenetic conservation

In [None]:
strain_list = high_power_strain_list
gene_list = moderate_hit.index

m = gene_meta.join(cog_meta, on='cog', rsuffix='_cog')
x = moderate_hit.loc[gene_list, strain_list]

fdist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(x.T, metric='jaccard')), index=x.columns, columns=x.columns)
gdist = fit.genotype.discretized().sel(strain=strain_list).pdist()

d = pd.DataFrame(dict(
    genotype_distance=sp.spatial.distance.squareform(gdist),
    gene_content_distance=sp.spatial.distance.squareform(fdist)
))
plt.scatter('genotype_distance', 'gene_content_distance', data=d, s=5)
sp.stats.spearmanr(d.genotype_distance, d.gene_content_distance)

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere((moderate_hit[strain_list].mean(1) > 0.05) & (high_confidence_not_hit[strain_list].mean(1) > 0.2))

m = gene_meta.join(cog_meta, on='cog', rsuffix='_cog')
x = moderate_hit.loc[gene_list, strain_list]

fdist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(x.T, metric='jaccard')), index=x.columns, columns=x.columns)
gdist = fit.genotype.discretized().sel(strain=strain_list).pdist()

d = pd.DataFrame(dict(
    genotype_distance=sp.spatial.distance.squareform(gdist),
    gene_content_distance=sp.spatial.distance.squareform(fdist)
))
plt.scatter('genotype_distance', 'gene_content_distance', data=d, s=5)
sp.stats.spearmanr(d.genotype_distance, d.gene_content_distance)

In [None]:
tally_cog_category_reps = pd.merge(
    gene_meta.loc[idxwhere(moderate_hit[strain_list].any(1))].cog.value_counts().reset_index().rename(columns=dict(index='cog', cog='tally')),
    cog_x_category,
    on='cog'
).groupby('category').tally.sum().sort_values(ascending=False)
tally_cog_category_reps

In [None]:
fig, axs = plt.subplots(5, 3, figsize=(15, 19), sharex=True, sharey=True)

for this_cog_category, ax in zip(tally_cog_category_reps.index, axs.flatten()):
    ax.set_title(this_cog_category)
    strain_list = high_power_strain_list
    cog_list = cog_x_category[cog_x_category.category == this_cog_category].cog.unique()
    gene_list = gene_meta.cog.isin(cog_list)

    x = moderate_hit.loc[gene_list, strain_list]
    fdist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(x.T, metric='jaccard')), index=x.columns, columns=x.columns)

    gdist = fit.genotype.discretized().sel(strain=strain_list).pdist()

    d = pd.DataFrame(dict(
        genotype_distance=sp.spatial.distance.squareform(gdist),
        gene_content_distance=sp.spatial.distance.squareform(fdist)
    ))
    ax.scatter('genotype_distance', 'gene_content_distance', data=d, s=5)
    ax.annotate(np.round(sp.stats.spearmanr(d.genotype_distance, d.gene_content_distance)[0], 2), xy=(0.8, 0.9), xycoords='axes fraction')
    ax.annotate(int(x.mean(1).sum()), xy=(0.8, 0.7), xycoords='axes fraction')
    ax.annotate(cog_category.loc[this_cog_category].description, xy=(0.0, 0.8), xycoords='axes fraction')

# Enrichment Analysis in Reference Genomes

In [None]:
x = reference_hit
y = moderate_hit

_all_genes = list(set(x.index) | set(y.index))

x = x.reindex(_all_genes, fill_value=False)
y = y.reindex(_all_genes, fill_value=False)

jac_cdist_inf_moderate = pd.DataFrame(sp.spatial.distance.cdist(x.T, y.T, metric='jaccard'), index=x.columns, columns=y.columns)

In [None]:
moderate_confidence_diff = pd.DataFrame(x[jac_cdist_inf_moderate.idxmin()].values * 2 - y[jac_cdist_inf_moderate.columns].values * 3, index=x.index, columns=y.columns).replace({-3: 'only_inf', -1: 'shared_genes', 0: 'both_lacking', 2: 'only_ref'})
moderate_confidence_diff.apply(lambda x: x.value_counts())[high_power_strain_list]

In [None]:
x = reference_hit
y = high_confidence_hit

_all_genes = list(set(x.index) | set(y.index))

x = x.reindex(_all_genes, fill_value=False)
y = y.reindex(_all_genes, fill_value=False)

jac_cdist_inf_confident = pd.DataFrame(sp.spatial.distance.cdist(x.T, y.T, metric='jaccard'), index=x.columns, columns=y.columns)

In [None]:
high_confidence_diff = pd.DataFrame(x[jac_cdist_inf_confident.idxmin()].values * 2 - y[jac_cdist_inf_confident.columns].values * 3, index=x.index, columns=y.columns).replace({-3: 'only_inf', -1: 'shared_genes', 0: 'both_lacking', 2: 'only_ref'})
high_confidence_diff.apply(lambda x: x.value_counts())[high_power_strain_list]

In [None]:
x = reference_hit.mean(1)
y = moderate_hit[high_power_strain_list].mean(1)

_all_genes = list(set(x.index) | set(y.index))

x = x.reindex(_all_genes, fill_value=False)
y = y.reindex(_all_genes, fill_value=False)

plt.scatter(x, y, alpha=0.2, s=1)
sp.stats.pearsonr(x, y)

In [None]:
ref_geno_pdist = ref_geno.pdist()
ref_hits_pdist = pd.DataFrame(
    sp.spatial.distance.squareform(sp.spatial.distance.pdist(reference_hit[ref_geno.strain].T, metric='jaccard')),
    index=ref_geno_pdist.index,
    columns=ref_geno_pdist.columns
)

In [None]:
x = sp.spatial.distance.squareform(ref_geno_pdist)
y = sp.spatial.distance.squareform(ref_hits_pdist)
plt.scatter(x, y, alpha=0.01, s=1)
# plt.xscale('symlog', linthresh=1e-3, linscale=0.1)

In [None]:
geno_cdist_to_ref = pd.DataFrame(sf.math.genotype_cdist(fit.genotype.discretized().values, ref_geno.values), index=fit.strain, columns=ref_geno.strain)
geno_cdist_to_ref.shape

In [None]:
hits_cdist_to_ref = pd.DataFrame(sp.spatial.distance.cdist(moderate_hit.reindex(reference_hit.index, fill_value=False).T, reference_hit[ref_geno.strain].T, metric='jaccard'), index=moderate_hit.columns, columns=ref_geno.strain)
hits_cdist_to_ref.shape

In [None]:
depth_only_hits_cdist_to_ref = pd.DataFrame(sp.spatial.distance.cdist(depth_hit.reindex(reference_hit.index, fill_value=False).T, reference_hit[ref_geno.strain].T, metric='jaccard'), index=moderate_hit.columns, columns=ref_geno.strain)
depth_only_hits_cdist_to_ref.shape

In [None]:
_geno_cdist, _hits_cdist, _depth_hits_cdist = align_indexes(geno_cdist_to_ref, hits_cdist_to_ref, depth_only_hits_cdist_to_ref)

best_match_inf_geno = _geno_cdist.idxmin(1)
best_match_inf_geno_diss = {}
best_match_inf_hits_diss = {}
best_match_depth_inf_hits_diss = {}
for s in _geno_cdist.index:
    best_match_inf_geno_diss[s] = _geno_cdist.loc[s, best_match_inf_geno[s]]
    best_match_inf_hits_diss[s] = _hits_cdist.loc[s, best_match_inf_geno[s]]
    best_match_depth_inf_hits_diss[s] = _depth_hits_cdist.loc[s, best_match_inf_geno[s]]

In [None]:
_geno_cdist, _hits_cdist = (ref_geno_pdist + np.eye(ref_geno_pdist.shape[0])), (ref_hits_pdist + np.eye(ref_hits_pdist.shape[0]))

best_match_ref_geno = _geno_cdist.idxmin(1)
# By adding 1 to the diagonal, we should make the best match NOT itself.
assert not (best_match_ref_geno.index.to_series() == best_match_ref_geno).any()

best_match_ref_geno_diss = {}
best_match_ref_hits_diss = {}
for s in _geno_cdist.index:
    best_match_ref_geno_diss[s] = _geno_cdist.loc[s, best_match_ref_geno[s]]
    best_match_ref_hits_diss[s] = _hits_cdist.loc[s, best_match_ref_geno[s]]

In [None]:
d_inf_to_ref = pd.DataFrame(dict(
    geno=best_match_inf_geno_diss,
    hits=best_match_inf_hits_diss,
    depth_hits=best_match_depth_inf_hits_diss,
    power_index=strain_meta.power_index,
))

d_ref_to_ref = pd.DataFrame(dict(
    geno=best_match_ref_geno_diss,
    hits=best_match_ref_hits_diss,
    power_index=strain_meta.power_index,
))

d_all_by_all = pd.DataFrame(dict(
    geno=sp.spatial.distance.squareform(ref_geno_pdist),
    hits=sp.spatial.distance.squareform(ref_hits_pdist),
))


fig = plt.figure(figsize=(15, 10))
xbins = np.logspace(-3, 0, num=101)
ybins = np.linspace(0, 1.0, num=101)
plt.hist2d('geno', 'hits', data=d_all_by_all, bins=(xbins, ybins), norm=mpl.colors.SymLogNorm(linthresh=1), cmap=mpl.cm.Greys)
plt.colorbar()
plt.xscale('symlog', linthresh=1e-3, linscale=0.2)
# plt.scatter('geno', 'hits', data=d_ref_to_ref, s=20)
# plt.scatter('geno', 'hits', data=d_inf_to_ref.loc[high_power_strain_list], s=20)


None

In [None]:
d_inf_to_ref = pd.DataFrame(dict(
    geno=best_match_inf_geno_diss,
    hits=best_match_inf_hits_diss,
    depth_hits=best_match_depth_inf_hits_diss,
    power_index=strain_meta.power_index,
))

d_ref_to_ref = pd.DataFrame(dict(
    geno=best_match_ref_geno_diss,
    hits=best_match_ref_hits_diss,
    power_index=strain_meta.power_index,
))

d_all_by_all = pd.DataFrame(dict(
    geno=sp.spatial.distance.squareform(ref_geno_pdist),
    hits=sp.spatial.distance.squareform(ref_hits_pdist),
))


fig = plt.figure(figsize=(15, 10))
xbins = np.logspace(-3, 0, num=101)
ybins = np.linspace(0, 1.0, num=101)
plt.hist2d('geno', 'hits', data=d_all_by_all, bins=(xbins, ybins), norm=mpl.colors.SymLogNorm(linthresh=1), cmap=mpl.cm.Greys)
plt.colorbar()
plt.xscale('symlog', linthresh=1e-3, linscale=0.2)
plt.scatter('geno', 'hits', data=d_ref_to_ref, s=20, color='tab:blue')
# plt.scatter('geno', 'depth_hits', data=d_inf_to_ref.loc[high_power_strain_list], s=40, color='tab:purple')
plt.scatter('geno', 'hits', data=d_inf_to_ref.loc[high_power_strain_list], s=40, color='tab:orange')


None

In [None]:
d_inf_to_ref = pd.DataFrame(dict(
    geno=best_match_inf_geno_diss,
    hits=best_match_inf_hits_diss,
    depth_hits=best_match_depth_inf_hits_diss,
    power_index=strain_meta.power_index,
))

d_ref_to_ref = pd.DataFrame(dict(
    geno=best_match_ref_geno_diss,
    hits=best_match_ref_hits_diss,
    power_index=strain_meta.power_index,
))

d_all_by_all = pd.DataFrame(dict(
    geno=sp.spatial.distance.squareform(ref_geno_pdist),
    hits=sp.spatial.distance.squareform(ref_hits_pdist),
))


fig = plt.figure(figsize=(15, 10))
xbins = np.linspace(0, 1.0, num=101)
ybins = np.linspace(0, 1.0, num=101)
plt.hist2d('geno', 'hits', data=d_all_by_all, bins=(xbins, ybins), norm=mpl.colors.SymLogNorm(linthresh=1), cmap=mpl.cm.Greys)
plt.colorbar()
plt.scatter('geno', 'hits', data=d_ref_to_ref, s=20, color='tab:blue')
# plt.scatter('geno', 'depth_hits', data=d_inf_to_ref.loc[high_power_strain_list], s=40, color='tab:purple')
plt.scatter('geno', 'hits', data=d_inf_to_ref.loc[high_power_strain_list], s=40, color='tab:orange')


None

In [None]:
d_inf_to_ref = pd.DataFrame(dict(
    geno=best_match_inf_geno_diss,
    hits=best_match_inf_hits_diss,
    depth_hits=best_match_depth_inf_hits_diss,
    power_index=strain_meta.power_index,
))

plt.scatter('hits', 'depth_hits', data=d_inf_to_ref, c='power_index', norm=mpl.colors.LogNorm())
plt.plot([0, 1], [0, 1], lw=1, linestyle='--', color='k')

In [None]:
g = sf.Genotype.concat(dict(
    ref=ref_geno.sel(strain=best_match_inf_geno.unique()),
    fit=fit.genotype.discretized().sel(strain=high_power_strain_list),
), dim='strain')
sf.plot_genotype(g.sel(position=position_ss))