In [None]:
%load_ext autoreload

In [None]:
import os as _os
_os.chdir(_os.environ['PROJECT_ROOT'])
_os.path.realpath(_os.path.curdir)

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from lib.pandas_util import idxwhere, align_indexes, invert_mapping
import matplotlib as mpl
import lib.plot
import statsmodels as sm
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm
import subprocess
from tempfile import mkstemp
import time
import subprocess
from itertools import chain
import os

In [None]:
import sfacts as sf

In [None]:
import lib.thisproject.data

In [None]:
sns.set_context('talk')
plt.rcParams['figure.dpi'] = 50

In [None]:
# Default file path forming for interactive use.

group = 'xjin_hmp2'
species = '101346'
stemA = 'r.proc'
stemB = 'filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts43-s85-seed0'
stemC = 'sfacts42-seed0'
spgc_params = 'e100'
centroid = 95

path = dict(
    flag=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.refit-{stemC}.midas_gene{centroid}.spgc-{spgc_params}.strain_files.flag",
    fit=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.world.nc",
    refit=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.refit-{stemC}.world.nc",
    strain_correlation=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.midas_gene{centroid}.spgc-{spgc_params}.strain_correlation.tsv",
    strain_depth_ratio=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.midas_gene{centroid}.spgc-{spgc_params}.strain_depth_ratio.tsv",
    strain_fraction=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.comm.tsv",
    species_gene_mean_depth=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.midas_gene{centroid}.spgc-{spgc_params}.species_depth.tsv",
    species_gtpro_depth=f"data/group/{group}/{stemA}.gtpro.species_depth.tsv",
    species_correlation=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.midas_gene{centroid}.spgc-{spgc_params}.species_correlation.tsv",
    strain_thresholds=f"data/group/{group}/species/sp-{species}/{stemA}.gtpro.{stemB}.midas_gene{centroid}.spgc-{spgc_params}.strain_gene_threshold.tsv",
    gene_annotations=f"ref/midasdb_uhgg_gene_annotations/sp-{species}.gene{centroid}_annotations.tsv",
    raw_gene_depth=f"data/group/{group}/species/sp-{species}/{stemA}.midas_gene{centroid}.depth.nc",
    reference_copy_number=f"ref/midasdb_uhgg_pangenomes/{species}/midas_gene{centroid}.reference_copy_number.nc",
    cluster_info=f"ref/midasdb_uhgg/pangenomes/{species}/cluster_info.txt",
    species_taxonomy="ref/gtpro/species_taxonomy_ext.tsv",
    midasdb_genomes="ref/uhgg_genomes_all_4644.tsv",
    gtpro_reference_genotype=f"data/species/sp-{species}/gtpro_ref.mgtp.nc",
)

path_exists = {}
for p in path:
    path_exists[path[p]] = os.path.exists(path[p])

assert all(path_exists.values()), '\n'.join(["Missing files:"] + [p for p in path_exists if not path_exists[p]])

In [None]:
path['flag']

In [None]:
species_taxonomy = lib.thisproject.data.load_species_taxonomy(path["species_taxonomy"])
species_taxonomy.loc[species]

In [None]:
all_species_gtpro_depth = lib.thisproject.data.load_species_depth(path["species_gtpro_depth"])
all_species_gtpro_rabund = all_species_gtpro_depth.divide(all_species_gtpro_depth.sum(1), axis=0) 

plt.hist(all_species_gtpro_rabund[species], bins=[0] + list(np.logspace(-7, 1, num=101)))
plt.xscale('symlog', linthresh=1e-7)
plt.yscale('log')
None

In [None]:
plt.hist(all_species_gtpro_depth[species], bins=np.logspace(-4, 4, num=51))
plt.xscale('log')
None

In [None]:
focal_species_core_depth = lib.thisproject.data.load_single_species_depth(path["species_gene_mean_depth"])

d = pd.DataFrame(dict(
    gtpro=all_species_gtpro_depth[species],
    gene=focal_species_core_depth,
))

plt.scatter('gtpro', 'gene', data=d, alpha=0.1)
plt.yscale('symlog', linthresh=1e-4)
plt.xscale('symlog', linthresh=1e-4)
plt.plot([1e-4, 1e2], [1e-4, 1e2])
plt.xlabel('GT-Pro depth')
plt.ylabel('Core gene depth')
None

In [None]:
focal_species_core_depth = lib.thisproject.data.load_single_species_depth(path["species_gene_mean_depth"])

d = pd.DataFrame(dict(
    gtpro=all_species_gtpro_depth[species],
    gene=focal_species_core_depth,
))

plt.scatter('gtpro', 'gene', data=d, alpha=0.1)
# plt.yscale('symlog', linthresh=1e-4)
# plt.xscale('symlog', linthresh=1e-4)
plt.plot([1e-4, 1e2], [1e-4, 1e2])
plt.xlabel('GT-Pro depth')
plt.ylabel('Core gene depth')
None

In [None]:
gene_depth = xr.load_dataarray(path["raw_gene_depth"])

plt.hist(np.log10(gene_depth.isel(sample=0) + 1e-5), bins=50)
plt.yscale('log')
None

In [None]:
species_corr = pd.read_table(path["species_correlation"], names=['sample', 'correlation'], index_col='sample').squeeze()
fig, ax = plt.subplots()
ax.hist(1 - species_corr, bins=np.logspace(-3, 0, num=101))
ax.set_yscale('log')
ax.set_xscale('log')
ax.invert_xaxis()


species_threshold = species_corr.sort_values(ascending=False).head(700).min()
species_marker_gene = idxwhere(species_corr > species_threshold)
print(species_threshold)
ax.axvline(1 - species_threshold, lw=1, linestyle='--', color='k')
None

In [None]:
species_depth = pd.read_table(path["species_gene_mean_depth"], names=['sample', 'depth'], index_col='sample').squeeze()

In [None]:
fit = sf.World.load(
    path["fit"]
).drop_low_abundance_strains(0.05)

refit = sf.World.load(
    path["refit"]
)

print(fit.sizes)

np.random.seed(0)
position_ss = fit.random_sample(position=min(fit.sizes['position'], 1000)).position

In [None]:
sf.plot.plot_metagenotype(
    fit.sel(position=position_ss),
    # scaley=0.2, scalex=0.3,
    row_linkage_func=lambda w: w.metagenotype.linkage("position"),
    col_linkage_func=lambda w: w.community.linkage(),
)
sf.plot.plot_depth(
    fit.sel(position=position_ss),
    # scaley=0.2, scalex=0.3,
    row_linkage_func=lambda w: w.metagenotype.linkage("position"),
    col_linkage_func=lambda w: w.community.linkage(),
)
sf.plot.plot_dominance(
    fit.sel(position=position_ss),
    # scaley=0.2, scalex=0.3,
    row_linkage_func=lambda w: w.metagenotype.linkage("position"),
    col_linkage_func=lambda w: w.community.linkage(),
)
sf.plot.plot_community(
    fit.sel(position=position_ss),
    # scaley=0.2, scalex=0.3,
    col_linkage_func=lambda w: w.community.linkage(),
    row_linkage_func=lambda w: w.genotype.linkage("strain"),
)
sf.plot.plot_genotype(
    fit.sel(position=position_ss),
    # scaley=0.2, scalex=0.3,
    col_linkage_func=lambda w: w.metagenotype.linkage("position"),
    row_linkage_func=lambda w: w.genotype.linkage("strain"),
)
sf.plot.plot_genotype(
    refit.sel(position=position_ss),
    # scaley=0.2, scalex=0.3,
    col_linkage_func=lambda w: fit.sel(position=position_ss).metagenotype.linkage("position"),
)

In [None]:
ref_geno = sf.Genotype(
    sf.data.Metagenotype.load(path["gtpro_reference_genotype"])
    .to_estimated_genotype()
    .to_series()
    .unstack()
    .rename(lambda s: 'UHGG' + s[len('GUT_GENOME'):])
    .stack()
    .to_xarray()
    .sel(position=fit.position)
)
print(ref_geno.sizes)

In [None]:
sample_to_strain = (
    (fit.community.data > 0.95)
    .to_series()
    .unstack()
    .apply(idxwhere, axis=1)
    [lambda x: x.apply(bool)]
    .str[0]
    .rename('strain')
)
    
strain_to_sample_list = (
    sample_to_strain
    .rename('strain_id')
    .reset_index()
    .groupby('strain_id')
    .apply(lambda x: x['sample'].to_list())
)
strain_to_sample_list.apply(len).sort_values(ascending=False).head()

In [None]:
strain_thresholds = (
    pd.read_table(path["strain_thresholds"], index_col='strain')
    .rename(columns=dict(
        correlation_strict='corr_threshold_strict',
        correlation_moderate='corr_threshold_moderate',
        correlation_lenient='corr_threshold_lenient',
        depth_high='depth_thresh_high',
        depth_low='depth_thresh_low',
    ))
)

_strain_meta = (
    strain_thresholds
    .join(fit.genotype.entropy().to_series().rename('genotype_entropy'))
    .join(refit.genotype.entropy().to_series().rename('genotype_refit_entropy'))
    .join(fit.metagenotype.entropy().to_series().rename('metagenotype_entropy').groupby(sample_to_strain).mean().rename(int))
    .join(strain_to_sample_list.apply(len).rename('num_samples'))
    .join(species_depth.apply(lambda x: x**(1)).groupby(sample_to_strain).std().rename('depth_stdev').rename(int))
    .join(species_depth.apply(lambda x: x**(1)).groupby(sample_to_strain).max().rename('depth_max').rename(int))
    .join(species_depth.apply(lambda x: x**(1)).groupby(sample_to_strain).sum().rename('depth_sum').rename(int))
    .assign(power_index=lambda x: (x.depth_stdev * np.sqrt(x.num_samples)).fillna(0))
)
strain_meta = _strain_meta


power_index_thresh = 5
genotype_entropy_thresh = 0.2
genotype_refit_entropy_thresh = 1.0

high_power_strain_list = idxwhere(
    (strain_meta.power_index > power_index_thresh)
    & (strain_meta.genotype_entropy < genotype_entropy_thresh)
    & (strain_meta.genotype_refit_entropy < genotype_refit_entropy_thresh)
)
print(len(high_power_strain_list))
highest_power_strain_list = strain_meta.sort_values('power_index', ascending=False).head(3).index

plt.scatter(strain_meta.power_index, strain_meta.corr_threshold_moderate, c=strain_meta.genotype_refit_entropy, alpha=0.5)
plt.axvline(power_index_thresh, lw=1, linestyle='--', color='k')
plt.colorbar()
plt.xscale('log')

In [None]:
plt.scatter('metagenotype_entropy', 'genotype_entropy', data=strain_meta, c='genotype_refit_entropy')

In [None]:
strain_meta.assign(high_power=lambda x: x.index.isin(high_power_strain_list)).sort_values('power_index', ascending=False).head(20)

In [None]:
strain_meta.loc[high_power_strain_list].sort_values('corr_threshold_moderate').head(5)

In [None]:
high_power_strain_palette = lib.plot.construct_ordered_palette(high_power_strain_list, mpl.cm.Spectral)

In [None]:
strain_corr = pd.read_table(path["strain_correlation"], index_col=['gene_id', 'strain']).squeeze().unstack('strain', fill_value=0)
strain_depth = pd.read_table(
    path["strain_depth_ratio"],
    index_col=['gene_id', 'strain']
).squeeze().unstack()
strain_corr, strain_depth = align_indexes(strain_corr, strain_depth)
strain_corr = strain_corr[strain_meta.index]
strain_depth = strain_depth[strain_meta.index]

In [None]:
_strain_list = high_power_strain_list

fig, axs = lib.plot.subplots_grid(ncols=3, naxes=len(_strain_list), ax_width=5, ax_height=4, sharex=True)

for strain, ax in zip(_strain_list, axs.flatten()):
    d = pd.DataFrame(dict(
        corr=1 - strain_corr[strain],
        depth=strain_depth[strain],
        marker=species_corr > species_threshold,
        species_corr=1 - species_corr,
    ))
    ax.scatter('corr', 'depth', data=d, s=1, alpha=0.1, c='species_corr', cmap='winter_r', norm=mpl.colors.LogNorm())
    # ax.scatter('corr', 'depth', data=d[d.marker], s=1, alpha=0.05, c='tab:orange')
    ax.axvline(1 - strain_meta['corr_threshold_moderate'][strain], color='k', linestyle='--', lw=0.5)
    ax.axhline(strain_meta['depth_thresh_low'][strain], color='k', linestyle='--', lw=0.5)
    ax.set_xscale('symlog', linthresh=1e-3)
    ax.set_xlim(left=-1e-4, right=1)
    ax.invert_xaxis()
    ax.set_yscale('log')
    ax.set_title(strain)
fig.tight_layout()

In [None]:
strict_corr_hit = strain_corr > strain_meta.corr_threshold_strict
lenient_corr_hit = strain_corr > strain_meta.corr_threshold_lenient
moderate_corr_hit = strain_corr > strain_meta.corr_threshold_moderate
low_corr =  strain_corr < strain_meta.corr_threshold_lenient

low_depth = (strain_depth < strain_meta.depth_thresh_low)
depth_hit = ~low_depth
high_depth = (strain_depth > strain_meta.depth_thresh_high)
high_confidence_hit = depth_hit & strict_corr_hit
moderate_hit = depth_hit & moderate_corr_hit
maybe_hit = depth_hit & lenient_corr_hit
low_depth_hit = low_depth & strict_corr_hit
high_depth_hit = high_depth & strict_corr_hit
ambiguous_hit = depth_hit ^ strict_corr_hit
high_confidence_not_hit = low_depth & low_corr

In [None]:
reference_copy_number = xr.load_dataarray(path["reference_copy_number"])
reference_hit = pd.DataFrame(
    reference_copy_number.T > 0,
    columns=reference_copy_number.genome_id,
    index=reference_copy_number.gene_id,
)

print(reference_copy_number.sizes)

In [None]:
d = pd.DataFrame(dict(
    tally_moderate_hits=moderate_hit.sum(),
    sum_gene_ratio=strain_depth[moderate_hit].sum(),
)).assign(ratio=lambda x: x.sum_gene_ratio / x.tally_moderate_hits)

fig, ax = plt.subplots(figsize=(15, 5))
bins = np.linspace(0, 10000, num=21)
# sns.kdeplot(reference_hit.sum())
# sns.kdeplot(d.tally_moderate_hits)
# sns.kdeplot(d.tally_moderate_hits.loc[high_power_strain_list])


plt.hist(reference_hit.sum(), bins=bins, density=True, alpha=0.5)
# plt.hist(d.tally_moderate_hits, bins=bins, density=True, alpha=0.5)
plt.hist(d.tally_moderate_hits.loc[high_power_strain_list], bins=bins, density=True, alpha=0.5)

d.loc[high_power_strain_list].sort_values('ratio')

In [None]:
d = pd.DataFrame(dict(
    tally_maybe_hits=maybe_hit.sum(),
    sum_gene_ratio=strain_depth[maybe_hit].sum(),
)).assign(ratio=lambda x: x.sum_gene_ratio / x.tally_maybe_hits)

fig, ax = plt.subplots(figsize=(15, 5))
bins = np.linspace(0, 10000, num=21)
# sns.kdeplot(reference_hit.sum())
# sns.kdeplot(d.tally_moderate_hits)
# sns.kdeplot(d.tally_moderate_hits.loc[high_power_strain_list])


plt.hist(reference_hit.sum(), bins=bins, density=True, alpha=0.5)
# plt.hist(d.tally_moderate_hits, bins=bins, density=True, alpha=0.5)
plt.hist(d.tally_maybe_hits.loc[high_power_strain_list], bins=bins, density=True, alpha=0.5)

d.loc[high_power_strain_list].sort_values('ratio')

In [None]:
d = pd.DataFrame(dict(
    tally_moderate_hits=moderate_hit.sum(),
    sum_gene_ratio=strain_depth[moderate_hit].sum(),
)).assign(ratio=lambda x: x.sum_gene_ratio / x.tally_moderate_hits)

fig, ax = plt.subplots(figsize=(15, 5))
bins = np.linspace(0, 10000, num=21)
# sns.kdeplot(reference_hit.sum())
# sns.kdeplot(d.tally_moderate_hits)
# sns.kdeplot(d.tally_moderate_hits.loc[high_power_strain_list])


plt.hist(reference_copy_number.sum("gene_id"), bins=bins, density=True, alpha=0.5)
# plt.hist(d.sum_gene_ratio, bins=bins, density=True, alpha=0.5)
plt.hist(d.sum_gene_ratio.loc[high_power_strain_list], bins=bins, density=True, alpha=0.5)
None

In [None]:
samples_with_high_power_strains = idxwhere(fit.community.data.sel(strain=high_power_strain_list).sum("strain").to_series() > 0.5)
samples_without_high_power_strains = idxwhere(fit.community.data.sel(strain=high_power_strain_list).sum("strain").to_series() < 0.5)
len(samples_with_high_power_strains), len(samples_without_high_power_strains)

In [None]:
w0 = fit.sel(strain=high_power_strain_list, position=position_ss)
w1 = refit.sel(strain=high_power_strain_list, position=position_ss)


sf.plot.plot_metagenotype(
    w0,
    row_linkage_func=lambda w: w0.metagenotype.linkage("position"),
    col_linkage_func=lambda w: w0.community.linkage("sample"),
    scaley=0.0095,
    scalex=0.004,
    transpose=True,
    xticklabels=0,
)

sf.plot.plot_genotype(
    w0,
    col_linkage_func=lambda w: w0.metagenotype.linkage("position"),
    row_linkage_func=lambda w: w1.genotype.linkage("strain"),
)

sf.plot.plot_genotype(
    w1,
    col_linkage_func=lambda w: w0.metagenotype.linkage("position"),
    row_linkage_func=lambda w: w1.genotype.linkage("strain"),
)

In [None]:
sf.plot.plot_metagenotype(
    fit.sel(sample=samples_with_high_power_strains, position=position_ss).drop_low_abundance_strains(0.05),
    col_linkage_func=lambda w: w.community.linkage(),
    row_linkage_func=lambda w: w.metagenotype.linkage("position"),
    col_colors=fit.sel(sample=samples_with_high_power_strains, position=position_ss).sample.to_series().map(sample_to_strain).map(high_power_strain_palette),
)

In [None]:
sf.plot.plot_community(
    fit.sel(sample=samples_with_high_power_strains, position=position_ss).drop_low_abundance_strains(0.05),
    col_linkage_func=lambda w: w.community.linkage(),
    row_linkage_func=lambda w: w.genotype.linkage("strain"),
    col_colors=fit.sel(sample=samples_with_high_power_strains, position=position_ss).sample.to_series().map(sample_to_strain).map(high_power_strain_palette),
)

In [None]:
gene_cluster = pd.read_table(
    path["cluster_info"]
).set_index('centroid_99', drop=False).rename_axis(index='gene_id')
gene_annotation = pd.read_table(
    path["gene_annotations"],
    names=['locus_tag', 'ftype', 'length_bp', 'gene', 'EC_number', 'COG', 'product'],
    index_col='locus_tag',
).rename(columns=str.lower)

gene_meta = gene_cluster.loc[gene_cluster[f'centroid_{centroid}'].unique()].join(gene_annotation)

In [None]:
_cog_meta = pd.read_table(
    'ref/cog-20.meta.tsv',
    names=['cog', 'categories', 'description', 'gene', 'pathway', '_1', '_2'],
    index_col=['cog']
)
cog_meta = _cog_meta.drop(columns=['categories', '_1', '_2'])
cog_x_category = _cog_meta.categories.apply(tuple).apply(pd.Series).unstack().to_frame(name='category').reset_index()[['cog', 'category']].dropna()

In [None]:
cog_category = pd.read_table('ref/cog-20.categories.tsv', names=['category', 'description'], index_col='category')

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere(moderate_hit[strain_list].sum(1) > 0)

x = strain_depth.loc[gene_list, strain_list]

# _gene_linkage = sp.cluster.hierarchy.linkage(x, metric='cosine')

if len(gene_list) < 2e4:
    sns.clustermap(
        x,
        metric='cosine',
        norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
        yticklabels=0,
        xticklabels=0,
        # row_linkage=_gene_linkage,
        col_linkage=fit.genotype.discretized().sel(strain=strain_list).linkage("strain"),
    )
else:
    print("Too many genes for clustermap:", len(gene_list))

print(len(gene_list), len(gene_list) - gene_annotation.loc[gene_list]['product'].value_counts()['hypothetical protein'])
print()
print(
    gene_annotation
    .loc[gene_list]
    .cog.to_frame()
    .join(cog_meta, on='cog')
    .pathway
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
)
print()
print(
    gene_meta
    .loc[gene_list]
    ['product']
    .value_counts()
    .head(10)
)
print()
print(pd.merge(
    gene_annotation.loc[gene_list].cog.dropna().to_frame(),
    cog_x_category,
    on='cog',
).category.value_counts().to_frame().join(cog_category).head(10))

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere(moderate_hit[strain_list].mean(1) > 0.8)

x = strain_depth.loc[gene_list, strain_list]

_gene_linkage = sp.cluster.hierarchy.linkage(x, metric='cosine')

if len(gene_list) < 2e4:
    try:
        sns.clustermap(
            x,
            metric='cosine',
            norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
            yticklabels=0,
            xticklabels=0,
            row_linkage=_gene_linkage,
            col_linkage=fit.genotype.discretized().sel(strain=strain_list).linkage("strain"),
        )
    except RecursionError as err:
        print("Problem with row_linkage?", err)
else:
    print("Too many genes for clustermap:", len(gene_list))

print(len(gene_list), len(gene_list) - gene_annotation.loc[gene_list]['product'].value_counts()['hypothetical protein'])
print()
print(
    gene_annotation
    .loc[gene_list]
    .cog.to_frame()
    .join(cog_meta, on='cog')
    .pathway
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
)
print()
print(
    gene_meta
    .loc[gene_list]
    ['product']
    .value_counts()
    .head(10)
)
print()
print(pd.merge(
    gene_annotation.loc[gene_list].cog.dropna().to_frame(),
    cog_x_category,
    on='cog',
).category.value_counts().to_frame().join(cog_category).head(10))

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere((moderate_hit[strain_list].mean(1) > 0.05) & (high_confidence_not_hit[strain_list].mean(1) > 0.2))

x = strain_depth.loc[gene_list, strain_list]

_gene_linkage = sp.cluster.hierarchy.linkage(x, metric='euclidean')

if len(gene_list) < 2e4:
    try:
        sns.clustermap(
            x,
            metric='cosine',
            norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
            yticklabels=0,
            xticklabels=0,
            row_linkage=_gene_linkage,  # FIXME: Why does this fail sometimes?
            col_linkage=fit.genotype.discretized().sel(strain=strain_list).linkage("strain"),
        )
    except RecursionError as err:
        print("Problem with row_linkage?", err)
else:
    print("Too many genes for clustermap:", len(gene_list))

print(len(gene_list), len(gene_list) - gene_annotation.loc[gene_list]['product'].value_counts()['hypothetical protein'])
print()
print(
    gene_annotation
    .loc[gene_list]
    .cog.to_frame()
    .join(cog_meta, on='cog')
    .pathway
    .value_counts()
    .sort_values(ascending=False)
    .head(10)
)
print()
print(
    gene_meta
    .loc[gene_list]
    ['product']
    .value_counts()
    .head(10)
)
print()
print(pd.merge(
    gene_annotation.loc[gene_list].cog.dropna().to_frame(),
    cog_x_category,
    on='cog',
).category.value_counts().to_frame().join(cog_category).head(10))

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere((moderate_hit[strain_list].mean(1) > 0.05) & (high_confidence_not_hit[strain_list].mean(1) > 0.2))

x = (
    moderate_hit.loc[gene_list, strain_list].astype(float) * 3
    + (1 - high_confidence_not_hit.loc[gene_list, strain_list].astype(float)) * 2
)

if len(gene_list) < 2e4:
    try:
        sns.clustermap(
            x,
            metric='cosine',
            # norm=mpl.colors.SymLogNorm(linthresh=1e-4, vmin=0.1, vmax=10),
            yticklabels=0,
            xticklabels=0,
            row_linkage=_gene_linkage,  # FIXME: Why does this fail sometimes?
            col_linkage=fit.genotype.discretized().sel(strain=strain_list).linkage("strain"),
            cmap='gray',
        )
    except RecursionError as err:
        print("Problem with row_linkage?", err)
else:
    print("Too many genes for clustermap:", len(gene_list))

In [None]:
x = reference_hit.mean(1)
y = moderate_hit[high_power_strain_list].mean(1)

_all_genes = list(set(x.index) | set(y.index))

x = x.reindex(_all_genes, fill_value=False)
y = y.reindex(_all_genes, fill_value=False)

plt.scatter(x, y, alpha=0.2, s=1)
sp.stats.pearsonr(x, y)

In [None]:
# The subset with ref_geno look similar to the full set of genomes.

x = reference_hit[ref_geno.strain].mean(1)
y = reference_hit.mean(1)

plt.scatter(x, y, alpha=0.2, s=1)
sp.stats.pearsonr(x, y)

In [None]:
x = reference_hit[ref_geno.strain].mean(1)
y = moderate_hit[high_power_strain_list].mean(1)

_all_genes = list(set(x.index) | set(y.index))

x = x.reindex(_all_genes, fill_value=False)
y = y.reindex(_all_genes, fill_value=False)

plt.scatter(x, y, alpha=0.2, s=1)
sp.stats.pearsonr(x, y)

# Broad strokes characterization of gene sets

## Phylogenetic conservation

In [None]:
strain_list = high_power_strain_list
gene_list = moderate_hit.index

m = gene_meta.join(cog_meta, on='cog', rsuffix='_cog')
x = moderate_hit.loc[gene_list, strain_list]

fdist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(x.T, metric='jaccard')), index=x.columns, columns=x.columns)
gdist = fit.genotype.discretized().sel(strain=strain_list).pdist()

d = pd.DataFrame(dict(
    genotype_distance=sp.spatial.distance.squareform(gdist),
    gene_content_distance=sp.spatial.distance.squareform(fdist)
))
plt.scatter('genotype_distance', 'gene_content_distance', data=d, s=5)
sp.stats.spearmanr(d.genotype_distance, d.gene_content_distance)

In [None]:
strain_list = high_power_strain_list
gene_list = idxwhere((moderate_hit[strain_list].mean(1) > 0.05) & (high_confidence_not_hit[strain_list].mean(1) > 0.2))

m = gene_meta.join(cog_meta, on='cog', rsuffix='_cog')
x = moderate_hit.loc[gene_list, strain_list]

fdist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(x.T, metric='jaccard')), index=x.columns, columns=x.columns)
gdist = fit.genotype.discretized().sel(strain=strain_list).pdist()

d = pd.DataFrame(dict(
    genotype_distance=sp.spatial.distance.squareform(gdist),
    gene_content_distance=sp.spatial.distance.squareform(fdist)
))
plt.scatter('genotype_distance', 'gene_content_distance', data=d, s=5)
sp.stats.spearmanr(d.genotype_distance, d.gene_content_distance)

In [None]:
tally_cog_category_reps = pd.merge(
    gene_meta.loc[idxwhere(moderate_hit[strain_list].any(1))].cog.value_counts().reset_index().rename(columns=dict(index='cog', cog='tally')),
    cog_x_category,
    on='cog'
).groupby('category').tally.sum().sort_values(ascending=False)
tally_cog_category_reps

In [None]:
fig, axs = plt.subplots(5, 3, figsize=(15, 19), sharex=True, sharey=True)

for this_cog_category, ax in zip(tally_cog_category_reps.index, axs.flatten()):
    ax.set_title(this_cog_category)
    strain_list = high_power_strain_list
    cog_list = cog_x_category[cog_x_category.category == this_cog_category].cog.unique()
    gene_list = gene_meta.cog.isin(cog_list)

    x = moderate_hit.loc[gene_list, strain_list]
    fdist = pd.DataFrame(sp.spatial.distance.squareform(sp.spatial.distance.pdist(x.T, metric='jaccard')), index=x.columns, columns=x.columns)

    gdist = fit.genotype.discretized().sel(strain=strain_list).pdist()

    d = pd.DataFrame(dict(
        genotype_distance=sp.spatial.distance.squareform(gdist),
        gene_content_distance=sp.spatial.distance.squareform(fdist)
    ))
    ax.scatter('genotype_distance', 'gene_content_distance', data=d, s=5)
    ax.annotate(np.round(sp.stats.spearmanr(d.genotype_distance, d.gene_content_distance)[0], 2), xy=(0.8, 0.9), xycoords='axes fraction')
    ax.annotate(int(x.mean(1).sum()), xy=(0.8, 0.7), xycoords='axes fraction')
    ax.annotate(cog_category.loc[this_cog_category].description, xy=(0.0, 0.8), xycoords='axes fraction')

# Enrichment Analysis in Reference Genomes

In [None]:
x = reference_hit
y = moderate_hit

_all_genes = list(set(x.index) | set(y.index))

x = x.reindex(_all_genes, fill_value=False)
y = y.reindex(_all_genes, fill_value=False)

jac_cdist_inf_moderate = pd.DataFrame(sp.spatial.distance.cdist(x.T, y.T, metric='jaccard'), index=x.columns, columns=y.columns)

In [None]:
moderate_confidence_diff = pd.DataFrame(x[jac_cdist_inf_moderate.idxmin()].values * 2 - y[jac_cdist_inf_moderate.columns].values * 3, index=x.index, columns=y.columns).replace({-3: 'only_inf', -1: 'shared_genes', 0: 'both_lacking', 2: 'only_ref'})
moderate_confidence_diff.apply(lambda x: x.value_counts())[high_power_strain_list]

In [None]:
x = reference_hit
y = high_confidence_hit

_all_genes = list(set(x.index) | set(y.index))

x = x.reindex(_all_genes, fill_value=False)
y = y.reindex(_all_genes, fill_value=False)

jac_cdist_inf_confident = pd.DataFrame(sp.spatial.distance.cdist(x.T, y.T, metric='jaccard'), index=x.columns, columns=y.columns)

In [None]:
high_confidence_diff = pd.DataFrame(x[jac_cdist_inf_confident.idxmin()].values * 2 - y[jac_cdist_inf_confident.columns].values * 3, index=x.index, columns=y.columns).replace({-3: 'only_inf', -1: 'shared_genes', 0: 'both_lacking', 2: 'only_ref'})
high_confidence_diff.apply(lambda x: x.value_counts())[high_power_strain_list]

In [None]:
ref_geno_pdist = ref_geno.pdist()
ref_hits_pdist = pd.DataFrame(
    sp.spatial.distance.squareform(sp.spatial.distance.pdist(reference_hit[ref_geno.strain].T, metric='jaccard')),
    index=ref_geno_pdist.index,
    columns=ref_geno_pdist.columns
)

In [None]:
x = sp.spatial.distance.squareform(ref_geno_pdist)
y = sp.spatial.distance.squareform(ref_hits_pdist)
plt.scatter(x, y, alpha=0.1, s=1)
# plt.xscale('symlog', linthresh=1e-3, linscale=0.1)

In [None]:
geno_cdist_to_ref = pd.DataFrame(sf.math.genotype_cdist(fit.genotype.discretized().values, ref_geno.values), index=fit.strain, columns=ref_geno.strain)
geno_cdist_to_ref.shape

In [None]:
hits_cdist_to_ref = pd.DataFrame(sp.spatial.distance.cdist(moderate_hit.reindex(reference_hit.index, fill_value=False).T, reference_hit[ref_geno.strain].T, metric='jaccard'), index=moderate_hit.columns, columns=ref_geno.strain)
hits_cdist_to_ref.shape

In [None]:
depth_only_hits_cdist_to_ref = pd.DataFrame(sp.spatial.distance.cdist(depth_hit.reindex(reference_hit.index, fill_value=False).T, reference_hit[ref_geno.strain].T, metric='jaccard'), index=moderate_hit.columns, columns=ref_geno.strain)
depth_only_hits_cdist_to_ref.shape

In [None]:
_geno_cdist, _hits_cdist, _depth_hits_cdist = align_indexes(geno_cdist_to_ref, hits_cdist_to_ref, depth_only_hits_cdist_to_ref)

best_match_inf_geno = _geno_cdist.idxmin(1)
best_match_inf_geno_diss = {}
best_match_inf_hits_diss = {}
best_match_depth_inf_hits_diss = {}
for s in _geno_cdist.index:
    best_match_inf_geno_diss[s] = _geno_cdist.loc[s, best_match_inf_geno[s]]
    best_match_inf_hits_diss[s] = _hits_cdist.loc[s, best_match_inf_geno[s]]
    best_match_depth_inf_hits_diss[s] = _depth_hits_cdist.loc[s, best_match_inf_geno[s]]

In [None]:
_geno_cdist, _hits_cdist = (ref_geno_pdist + np.eye(ref_geno_pdist.shape[0])), (ref_hits_pdist + np.eye(ref_hits_pdist.shape[0]))

best_match_ref_geno = _geno_cdist.idxmin(1)
# By adding 1 to the diagonal, we should make the best match NOT itself.
assert not (best_match_ref_geno.index.to_series() == best_match_ref_geno).any()

best_match_ref_geno_diss = {}
best_match_ref_hits_diss = {}
for s in _geno_cdist.index:
    best_match_ref_geno_diss[s] = _geno_cdist.loc[s, best_match_ref_geno[s]]
    best_match_ref_hits_diss[s] = _hits_cdist.loc[s, best_match_ref_geno[s]]

In [None]:
d_inf_to_ref = pd.DataFrame(dict(
    geno=best_match_inf_geno_diss,
    hits=best_match_inf_hits_diss,
    depth_hits=best_match_depth_inf_hits_diss,
    power_index=strain_meta.power_index,
))

d_ref_to_ref = pd.DataFrame(dict(
    geno=best_match_ref_geno_diss,
    hits=best_match_ref_hits_diss,
    power_index=strain_meta.power_index,
))

d_all_by_all = pd.DataFrame(dict(
    geno=sp.spatial.distance.squareform(ref_geno_pdist),
    hits=sp.spatial.distance.squareform(ref_hits_pdist),
))


fig = plt.figure(figsize=(15, 10))
xbins = np.logspace(-3, 0, num=101)
ybins = np.linspace(0, 1.0, num=101)
plt.hist2d('geno', 'hits', data=d_all_by_all, bins=(xbins, ybins), norm=mpl.colors.SymLogNorm(linthresh=1), cmap=mpl.cm.Greys)
plt.colorbar()
plt.xscale('symlog', linthresh=1e-3, linscale=0.2)
# plt.scatter('geno', 'hits', data=d_ref_to_ref, s=20)
# plt.scatter('geno', 'hits', data=d_inf_to_ref.loc[high_power_strain_list], s=20)


None

In [None]:
d_inf_to_ref = pd.DataFrame(dict(
    geno=best_match_inf_geno_diss,
    hits=best_match_inf_hits_diss,
    depth_hits=best_match_depth_inf_hits_diss,
    power_index=strain_meta.power_index,
    genotype_entropy=strain_meta.genotype_entropy,
    genotype_refit_entropy=strain_meta.genotype_refit_entropy,
))

d_ref_to_ref = pd.DataFrame(dict(
    geno=best_match_ref_geno_diss,
    hits=best_match_ref_hits_diss,
    power_index=strain_meta.power_index,
))

d_all_by_all = pd.DataFrame(dict(
    geno=sp.spatial.distance.squareform(ref_geno_pdist),
    hits=sp.spatial.distance.squareform(ref_hits_pdist),
))


fig = plt.figure(figsize=(15, 10))
xbins = np.logspace(-3, 0, num=101)
ybins = np.linspace(0, 1.0, num=101)
plt.hist2d('geno', 'hits', data=d_all_by_all, bins=(xbins, ybins), norm=mpl.colors.SymLogNorm(linthresh=1), cmap=mpl.cm.Greys)
plt.colorbar()
plt.xscale('symlog', linthresh=1e-3, linscale=0.2)
plt.scatter('geno', 'hits', data=d_ref_to_ref, s=20, color='tab:blue')
# plt.scatter('geno', 'depth_hits', data=d_inf_to_ref.loc[high_power_strain_list], s=40, color='tab:purple')
plt.scatter('geno', 'hits', data=d_inf_to_ref.loc[high_power_strain_list], s=40, c='genotype_refit_entropy', cmap=mpl.cm.YlOrRd_r)


None

In [None]:
d_inf_to_ref = pd.DataFrame(dict(
    geno=best_match_inf_geno_diss,
    hits=best_match_inf_hits_diss,
    depth_hits=best_match_depth_inf_hits_diss,
    power_index=strain_meta.power_index,
    genotype_entropy=strain_meta.genotype_entropy,
    genotype_refit_entropy=strain_meta.genotype_refit_entropy,
))

d_ref_to_ref = pd.DataFrame(dict(
    geno=best_match_ref_geno_diss,
    hits=best_match_ref_hits_diss,
    power_index=strain_meta.power_index,
))

d_all_by_all = pd.DataFrame(dict(
    geno=sp.spatial.distance.squareform(ref_geno_pdist),
    hits=sp.spatial.distance.squareform(ref_hits_pdist),
))


fig = plt.figure(figsize=(15, 10))
xbins = np.linspace(0, 1.0, num=101)
ybins = np.linspace(0, 1.0, num=101)
plt.hist2d('geno', 'hits', data=d_all_by_all, bins=(xbins, ybins), norm=mpl.colors.SymLogNorm(linthresh=1), cmap=mpl.cm.Greys)
plt.colorbar()
plt.scatter('geno', 'hits', data=d_ref_to_ref, s=20, color='tab:blue')
# plt.scatter('geno', 'depth_hits', data=d_inf_to_ref.loc[high_power_strain_list], s=40, color='tab:purple')
plt.scatter('geno', 'hits', data=d_inf_to_ref.loc[high_power_strain_list], s=40, c='genotype_refit_entropy', cmap=mpl.cm.YlOrRd_r)


None

In [None]:
d_inf_to_ref = pd.DataFrame(dict(
    geno=best_match_inf_geno_diss,
    hits=best_match_inf_hits_diss,
    depth_hits=best_match_depth_inf_hits_diss,
    power_index=strain_meta.power_index,
    genotype_entropy=strain_meta.genotype_entropy,
    genotype_refit_entropy=strain_meta.genotype_refit_entropy,
))

plt.scatter('hits', 'depth_hits', data=d_inf_to_ref, c='genotype_refit_entropy', cmap=mpl.cm.YlOrRd_r)
plt.colorbar()
plt.plot([0, 1], [0, 1], lw=1, linestyle='--', color='k')

In [None]:
g = sf.Genotype.concat(dict(
    ref=ref_geno.sel(strain=best_match_inf_geno.unique()),
    fit=refit.genotype.sel(strain=high_power_strain_list),
), dim='strain')
sf.plot_genotype(g.sel(position=position_ss), row_linkage_func=lambda w: w.genotype.discretized().linkage("strain"))

In [None]:
def pairwise_similarity_scores(pdist_matrix, binary_matrix, pdist_transformation=np.sqrt):
    x = binary_matrix.astype(float) * 2 - 1  # Shift and scale the binary matrix so it spans -1 to +1.
    reciprocal_transformed_pdist = np.nan_to_num(1 / pdist_transformation(pdist_matrix), posinf=0, neginf=0)
    return np.sum(
        (
            np.einsum('ab,bc->acb', x.T, x)
            * np.expand_dims(reciprocal_transformed_pdist, -1)
        ),
        (0, 1),
    )

def null_pairwise_similarity_scores(pdist_matrix, n_permutations=10_000, pdist_transformation=np.sqrt):
    nstrains = pdist_matrix.shape[0]    
    # Construct a matrix, where gene i (1-indexed) has i strains with that binary
    # feature (a lower triangular matrix works for this)
    unpermuted_x = np.tri(nstrains + 1, M=nstrains, k=-1)
    
    # Use permutations of this matrix to constuct the nulls, parameterized
    # by the number of strains.
    null_scores = []
    for i in range(n_permutations):
        x = unpermuted_x[:, np.random.choice(np.arange(nstrains), size=nstrains, replace=False)]
        null_scores.append(pairwise_similarity_scores(pdist_matrix, x, pdist_transformation=pdist_transformation))
    return pd.DataFrame(null_scores)

In [None]:
plt.plot(null_pairwise_similarity_scores(gdist, n_permutations=100).mean())
plt.plot(null_pairwise_similarity_scores(gdist, n_permutations=100, pdist_transformation=lambda x: x).mean())
plt.plot(null_pairwise_similarity_scores(gdist, n_permutations=100, pdist_transformation=lambda x: x**2).mean())

In [None]:
strain_list = high_power_strain_list
nstrains = len(strain_list)
gdist = refit.genotype.sel(strain=strain_list).pdist().values
n_permutations = 1_000
# Slow if n_permutations is much higher than 10_000
permutation_scores = null_pairwise_similarity_scores(gdist, n_permutations=n_permutations, pdist_transformation=np.sqrt)

In [None]:
fig, axs = plt.subplots(2)
for i in permutation_scores.columns:
    if i / nstrains <= 0.5:
        ax = axs[0]
    else:
        ax = axs[1]
    sns.kdeplot(permutation_scores[i], c=mpl.cm.coolwarm(i / nstrains), ax=ax)
    
# for i in permutation_scores.columns:
#     plt.plot([], [], color=mpl.cm.coolwarm(i / nstrains), label=i)
# plt.colorbar()

In [None]:
gene_list = idxwhere((moderate_hit[strain_list].sum(1) > 0) & (moderate_hit[strain_list].sum(1) < (nstrains)))
hits = moderate_hit.loc[gene_list, strain_list]
scores = pd.Series(pairwise_similarity_scores(gdist, hits, pdist_transformation=np.sqrt), index=hits.index)

In [None]:
def permutations_to_pvalues(empirical_scores, key, permutation_scores, decimals=100, test='two-sided'):
    n_permutations = permutation_scores.shape[0]
    
    # NOTE: I may need to round to some number of decimals for numerical reasons.
    empirical_scores = np.round(empirical_scores, decimals=decimals)
    permutation_scores = np.round(permutation_scores, decimals=decimals)

    pvalues = []
    for k, score_empir in zip(key, empirical_scores):
        scores_perm = permutation_scores[k]
        # Got some tips on how to calculate p-values from <https://stats.stackexchange.com/a/25929>
        if test == 'two-sided':
            p = ((np.abs(scores_perm) >= abs(score_empir)).sum() + 1) / (n_permutations + 1)
        elif test == 'lower':
            p = ((np.abs(scores_perm) <= min(score_empir, -score_empir)).sum() + 1) / (n_permutations + 1)
        elif test == 'higher':
            p = ((np.abs(scores_perm) >= max(score_empir, -score_empir)).sum() + 1) / (n_permutations + 1)
        else:
            raise ValueError(f"`test` parameter must be one of ['two-sided', 'left', 'right'], not '{test}'.")
        pvalues.append(p)
    return pd.Series(pvalues, empirical_scores.index)
    
pvalues = permutations_to_pvalues(scores, hits.sum(1), permutation_scores, decimals=4, test='two-sided')

# pvalues_old = []
# decimals = 4
# for n, s in tqdm(list(zip(hits.sum(1), scores))):
#     # NOTE: I need to round to some number of decimals for numerical reasons.
#     scores_perm = np.round(permutation_scores[n], decimals=decimals)
#     scores_empir = np.round(s, decimals=decimals)
#     # Got some tips on how to calculate p-values from <https://stats.stackexchange.com/a/25929>
#     # Choose one:
#     # p_left = ((permutation_scores[n] <= min(s, -s)).sum() + 1) / (n_permutations + 1)
#     # p_right = ((permutation_scores[n] >= max(s, -s)).sum() + 1) / (n_permutations + 1)
#     p_ts = ((np.abs(scores_perm) >= abs(scores_empir)).sum() + 1) / (n_permutations + 1)
#     pvalues_old.append(p_ts)
# pvalues_old = pd.Series(pvalues_old, scores.index)
                             


In [None]:
plt.plot(pvalues.sort_values().values)
plt.yscale('logit')

In [None]:
expected_scores = hits.sum(1).map(permutation_scores.mean())

In [None]:
adjusted_scores = scores - expected_scores

In [None]:
plt.hist(pvalues, bins=100)
plt.yscale('log')

In [None]:
plt.scatter(
    adjusted_scores,
    -np.log10(pvalues),
    s=5,
    c=hits.sum(1),
    cmap=mpl.cm.viridis,
    # alpha=0.1
)
plt.colorbar()
plt.axvline(0, linestyle='--', color='k')

In [None]:
_gene_list = idxwhere((pvalues < 1e-2))
print(len(_gene_list))
x = hits.loc[_gene_list]
s = adjusted_scores.loc[_gene_list]
vrange = max(s.max(), -s.min())
c = (s + vrange) / (vrange * 2)

sns.clustermap(x, col_linkage=refit.genotype.sel(strain=strain_list).linkage(), row_colors=mpl.cm.coolwarm(c))

In [None]:
plt.plot(np.linspace(0, 1, num=len(pvalues)), pvalues.sort_values().values)
plt.plot([0, 1], [0, 1])

In [None]:
_gene_list = idxwhere((pvalues > 1e-1))
print(len(_gene_list))
x = hits.loc[_gene_list]
s = adjusted_scores.loc[_gene_list]
vrange = max(s.max(), -s.min())
c = (s + vrange) / (vrange * 2)

sns.clustermap(x, col_linkage=refit.genotype.sel(strain=strain_list).linkage(), row_colors=mpl.cm.coolwarm(c))

In [None]:
ref_strain_list = list(best_match_inf_geno.unique())
nstrains = len(ref_strain_list)
ref_gdist = ref_geno_pdist.loc[ref_strain_list, ref_strain_list]
n_permutations = 10_000
# Slow if n_permutations is much higher than 10_000
ref_permutation_scores = null_pairwise_similarity_scores(ref_gdist, n_permutations=n_permutations, pdist_transformation=np.sqrt)

In [None]:
fig, axs = plt.subplots(2, sharex=True, sharey=True)
for i in ref_permutation_scores.columns:
    if i / nstrains <= 0.5:
        ax = axs[0]
    else:
        ax = axs[1]
    sns.kdeplot(ref_permutation_scores[i], c=mpl.cm.coolwarm(i / nstrains), ax=ax)
    
axs[0].set_yscale('log')
axs[1].set_yscale('log')
# for i in permutation_scores.columns:
#     plt.plot([], [], color=mpl.cm.coolwarm(i / nstrains), label=i)
# plt.colorbar()

In [None]:
hits = reference_hit.loc[gene_list, ref_strain_list]
ref_scores = pd.Series(pairwise_similarity_scores(ref_gdist, hits, pdist_transformation=np.sqrt), index=hits.index)
ref_pvalues = permutations_to_pvalues(ref_scores, hits.sum(1), ref_permutation_scores, decimals=4, test='two-sided')

ref_expected_scores = hits.sum(1).map(ref_permutation_scores.mean())
ref_adjusted_scores = ref_scores - ref_expected_scores

In [None]:
plt.scatter(
    ref_adjusted_scores,
    -np.log10(ref_pvalues),
    s=5,
    c=hits.sum(1),
    cmap=mpl.cm.viridis,
    # alpha=0.1
)
plt.colorbar()
plt.axvline(0, linestyle='--', color='k')

In [None]:
plt.scatter(ref_adjusted_scores, adjusted_scores, s=2, alpha=0.5)
sp.stats.spearmanr(ref_adjusted_scores, adjusted_scores)

In [None]:
_gene_list = idxwhere((ref_pvalues < 1e-2))
print(len(_gene_list))
x = hits.loc[_gene_list]
s = ref_adjusted_scores.loc[_gene_list]
vrange = max(s.max(), -s.min())
c = (s + vrange) / (vrange * 2)

sns.clustermap(x, col_linkage=ref_geno.sel(strain=ref_strain_list).genotype.linkage(), row_colors=mpl.cm.coolwarm(c))

In [None]:
hits = set(idxwhere(pvalues < 1e-2))
ref_hits = set(idxwhere(ref_pvalues < 1e-2))

len(hits), len(ref_hits), len(hits & ref_hits)

In [None]:
gene_annotation

In [None]:
gene_annotation.loc[list(hits & ref_hits)].head(50)