# Preamble

In [None]:
%load_ext autoreload

In [None]:
import os
os.chdir('..')
os.path.realpath(os.path.curdir)

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from lib.pandas_util import idxwhere, repeated
import matplotlib as mpl
import lib.plot
import statsmodels as sm
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm
import subprocess
from tempfile import mkstemp
import time

def align_indexes(*args):
    a0, *aa = args
    idx = set(a0.index)
    for a in aa:
        idx &= set(a.index)

    assert idx
    return [a.reindex(idx) for a in args]

# Set Parameters

In [None]:
species_id = '102506'
assumed_read_length = 150

# Load Data

## Metadata

### Samples

In [None]:
mgen = pd.read_table('meta/hmp2/mgen.tsv', index_col='library_id')
preparation = pd.read_table('meta/hmp2/preparation.tsv', index_col='preparation_id')
stool = pd.read_table('meta/hmp2/stool.tsv', index_col='stool_id')
visit = pd.read_table('meta/hmp2/visit.tsv', index_col='visit_id')
subject = pd.read_table('meta/hmp2/subject.tsv', index_col='subject_id')

mgen_meta = (
    mgen
    .join(preparation.drop(columns='library_type'), on='preparation_id')
    .join(stool, on='stool_id')
    .join(visit, on='visit_id', rsuffix='_')
    .join(subject, on='subject_id')
)

assert not any(mgen_meta.subject_id.isna())

# meta.columns

In [None]:
_subject_week = (
    visit
    .join(subject, on='subject_id')
    .reset_index()
    .dropna(subset=['subject_id', 'week_number'])
    .groupby(['subject_id', 'week_number'])
    .apply(lambda d: d.loc[d.notna().sum(1).sort_values().index[-1]])
    .assign(subject_week_id=lambda x: x.subject_id + '_' + x.week_number.astype(int).astype(str))
    .set_index('subject_week_id')
    .join(stool.groupby('visit_id').fecal_calprotectin.mean(), on='visit_id')
    .sort_values(['subject_id', 'week_number'])
)

mgen_to_subject_week = mgen_meta.dropna(subset=['week_number']).apply(lambda x: x.subject_id + '_' + str(int(x.week_number)), axis=1).rename('subject_week_id')
mgen_to_subject_week
#.groupby(['subject_id', 'week_number']).visit_id.count().sort_values(ascending=False)

### Genes

In [None]:
gene_clusters = pd.read_table(f'ref_temp/midasdb_uhgg/pangenomes/{species_id}/cluster_info.txt', index_col='centroid_99')

## Raw data needing alignment

### Species

In [None]:
_species_depth = (
    pd.read_table('data/hmp2.a.r.proc.gtpro.species_depth.tsv', index_col=['sample', 'species_id'])
    .squeeze()
    .unstack('species_id', fill_value=0)
    .groupby(mgen_to_subject_week)
    .sum()
)
_species_depth.columns = _species_depth.columns.astype(str)

### Strains

In [None]:
_strain_depth = pd.read_table(
    'data_temp/hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts18-s75-seed0.strain_depth.tsv',
    # names=['library_id', 'species_strain_id', 'depth'],
    index_col=['sample', 'strain'],
).squeeze().unstack('strain', fill_value=0).groupby(mgen_to_subject_week).sum()

plt.hist(_strain_depth.sum(1) - _species_depth.sum(1), bins=50)
None

### Genes

In [None]:
_, tmp_path = mkstemp()

with open(tmp_path, 'w') as f:
    _inpath = f'data_temp/sp-{species_id}.hmp2.a.r.proc.midas_merge/genes/{species_id}/{species_id}.genes_reads.tsv.lz4'
    subprocess.Popen(('lz4cat', _inpath), stdout=f)
    f.flush()
print(_inpath)
print(tmp_path)

time.sleep(1)

_genes_reads = pd.read_table(tmp_path, index_col='gene_id').groupby(mgen_to_subject_week, axis='columns').sum().T

In [None]:
_genes_depth = _genes_reads * assumed_read_length / gene_clusters.centroid_99_length.loc[_genes_reads.columns]

### Final data alignment

In [None]:
species_depth, strain_depth, genes_depth = align_indexes(_species_depth, _strain_depth, _genes_depth)

In [None]:
subject_week = _subject_week.assign(has_mgen=lambda x: x.index.isin(species_depth.index))

## Other raw data

### Taxonomy

In [None]:
species_taxonomy = pd.read_table('ref/gtpro/species_taxonomy_ext.tsv', names=['genome_id', 'species_id', 'taxonomy_string']).assign(species_id=lambda x: x.species_id.astype(str)).set_index('species_id')[['taxonomy_string']].assign(taxonomy_split=lambda x: x.taxonomy_string.str.split(';'))

for level_name, level_number in [('p__', 1), ('c__', 2), ('o__', 3), ('f__', 4), ('g__', 5), ('s__', 6)]:
    species_taxonomy = species_taxonomy.assign(**{level_name: species_taxonomy.taxonomy_split.apply(lambda x: x[level_number])}) 
species_taxonomy = species_taxonomy.drop(columns=['taxonomy_split'])
    
strain_taxonomy = strain_depth.columns.to_series().str.split('-').str[0].to_frame(name='species_id').join(species_taxonomy, on='species_id')

species_taxonomy = strain_taxonomy.drop_duplicates(subset=['species_id']).set_index('species_id')

## Finalize data prep

### Species

In [None]:
species_rabund = species_depth.divide(species_depth.sum(1), axis=0)
strain_rabund = strain_depth.divide(strain_depth.sum(1), axis=0)

In [None]:
species_depth[species_id].sort_values()

### Genes

In [None]:
genes_depth_75 = genes_depth.groupby(gene_clusters.centroid_75, axis='columns').sum()

In [None]:
mean_marker_depth = genes_depth.groupby(gene_clusters.marker_id, axis='columns').sum().mean(1)

In [None]:
(species_depth[species_id] == 0).sum()

In [None]:
plt.scatter(species_depth[species_id], mean_marker_depth)
plt.yscale('log')
plt.xscale('log')
plt.plot([0, 1e2], [0, 1e2])

# Focal Species

In [None]:
focal_species_strain_depth = strain_depth.loc[:, strain_taxonomy.species_id == species_id]
focal_species_strain_rabund = focal_species_strain_depth.divide(focal_species_strain_depth.sum(1), axis=0)

In [None]:
sns.clustermap(focal_species_strain_depth, norm=mpl.colors.PowerNorm(1/10))

In [None]:
strain_rabund_thresh = 0.95
focal_sample_species_depth_thresh_pres = 0.5
focal_sample_species_depth_thresh_abs = 1e-3
pure_samples = idxwhere((focal_species_strain_rabund > strain_rabund_thresh).any(1))

d0 = (
    (focal_species_strain_depth > focal_sample_species_depth_thresh_pres)
    & (focal_species_strain_rabund > strain_rabund_thresh)
)
num_samples = d0.sum().sort_values(ascending=False)
num_subjects = d0.groupby(subject_week.subject_id).any().sum().sort_values(ascending=False).head()

strain_sample_counts = pd.DataFrame(dict(
    num_samples=d0.sum(),
    num_subjects=d0.groupby(subject_week.subject_id).any().sum(),
)).sort_values('num_samples', ascending=False)

top_strains = strain_sample_counts.head(20).index

strain_sample_counts.head(8)

## Species Genes

### Select high-coverage and no-coverage samples

In [None]:
species_samples = idxwhere(species_depth[species_id] > focal_sample_species_depth_thresh_pres)
no_species_samples = idxwhere(species_depth[species_id] < focal_sample_species_depth_thresh_abs)
len(species_samples), len(no_species_samples)

### Comparison of GT-Pro and MIDAS depth estimates

In [None]:
genes_depth.loc[no_species_samples]

In [None]:
d = pd.DataFrame(dict(species=species_depth[species_id], genes=mean_marker_depth))
plt.scatter('species', 'genes', data=d, alpha=0.4)
plt.scatter('species', 'genes', data=d.loc[species_samples], s=2)
plt.scatter('species', 'genes', data=d.loc[no_species_samples], s=2)
plt.plot([0, 1e2], [0, 1e2])
plt.yscale('symlog', linthresh=focal_sample_species_depth_thresh_abs)
plt.xscale('symlog', linthresh=focal_sample_species_depth_thresh_abs)

### Identify species core genes

In [None]:
gene_depth_strain_total = genes_depth.groupby(focal_species_strain_rabund.loc[pure_samples].idxmax(1)).sum()
gene_depth_75_strain_total = genes_depth_75.groupby(focal_species_strain_rabund.loc[pure_samples].idxmax(1)).sum()
species_depth_strain_total = species_depth.groupby(focal_species_strain_rabund.loc[pure_samples].idxmax(1)).sum()

In [None]:
mean_marker_depth_strain_total = gene_depth_strain_total.groupby(gene_clusters.marker_id, axis='columns').sum().mean(1)

In [None]:
d = pd.DataFrame(dict(species=species_depth_strain_total[species_id], genes=mean_marker_depth_strain_total))
plt.scatter('species', 'genes', data=d, alpha=0.4)
plt.plot([1e-2, 1e2], [1e-2, 1e2])
plt.yscale('symlog', linthresh=focal_sample_species_depth_thresh_abs)
plt.xscale('symlog', linthresh=focal_sample_species_depth_thresh_abs)

In [None]:
# TODO: Decide whether I want to work with the strain-dereplicated values

x = species_depth[[species_id]]
y = genes_depth_75

_transf = np.cbrt

gene_species_corr = pd.DataFrame(1 - sp.spatial.distance.cdist(_transf(x.T), _transf(y.T), metric='cosine'), index=x.columns, columns=y.columns)

In [None]:
species_corr_prefilt_thresh = 0.9
species_gene_hits = idxwhere(gene_species_corr.loc[species_id] > species_corr_prefilt_thresh)
len(species_gene_hits)

In [None]:
bins = np.linspace(0, 1, num=101)
plt.hist(gene_species_corr.loc[species_id], bins=bins)
plt.hist(gene_species_corr.loc[species_id, gene_species_corr.idxmax() == species_id], bins=bins)
plt.hist(gene_species_corr.loc[species_id].reindex(gene_clusters.marker_id.dropna().index, axis='columns').dropna(), bins=bins)
plt.axvline(species_corr_prefilt_thresh, lw=1, linestyle='--', color='k')

plt.yscale('log')
None

## Why do many genes in the focal species pangenome correlated with other species as well?

Is it because the two species are themselves correlated?

Yeah, maybe...

### How does the distribution of depth-correlation compare to a null distribution?

In [None]:
# TODO: Decide whether I want to work with the strain-dereplicated values
x = species_depth[[species_id]]
y = genes_depth_75

x_permute = x.copy()
x_permute.index = np.random.choice(x.index, size=len(x.index), replace=False)
x_permute = x_permute.loc[x.index]

_transf = np.cbrt

gene_species_corr_permute = pd.DataFrame(1 - sp.spatial.distance.cdist(_transf(x_permute.T), _transf(y.T), metric='cosine'), index=x.columns, columns=y.columns)

In [None]:
d = pd.DataFrame(dict(
    closest_species=gene_species_corr_permute.idxmax().value_counts(),
    hit_species=(gene_species_corr_permute > species_corr_prefilt_thresh).sum(1),
)).join(species_taxonomy[['f__', 'g__']]).sort_values('hit_species', ascending=False).head(5)

d

In [None]:
bins = np.linspace(0, 1, num=101)
plt.hist(gene_species_corr_permute.loc[species_id], bins=bins)
plt.hist(gene_species_corr_permute.loc[species_id, gene_species_corr_permute.idxmax() == species_id], bins=bins)
plt.hist(gene_species_corr_permute.loc[species_id].reindex(gene_clusters.marker_id.dropna().index, axis='columns').dropna(), bins=bins)
plt.axvline(species_corr_prefilt_thresh, lw=1, linestyle='--', color='k')

plt.yscale('log')
None

### Assess core gene hits

In [None]:
d = gene_clusters.loc[species_gene_hits].marker_id.dropna().value_counts()
print(len(d))
d

In [None]:
sns.clustermap(genes_depth_75[species_gene_hits].loc[species_samples], norm=mpl.colors.SymLogNorm(linthresh=0.5), metric='cosine')

In [None]:
mean_depth_species_genes = genes_depth_75[species_gene_hits].apply(sp.stats.trim_mean, proportiontocut=0.2, axis=1)
mean_depth_species_genes.sort_values(ascending=False).head(10)

In [None]:
d = pd.DataFrame(dict(species=species_depth[species_id], genes=mean_depth_species_genes))
plt.scatter('species', 'genes', data=d)
plt.plot([0, 1e2], [0, 1e2])
plt.yscale('symlog', linthresh=1e-2)
plt.xscale('symlog', linthresh=1e-2)

#### Why are species depths lower when estimated from genes

I believe this may be happening due to read stealing by similar genes.

In [None]:
d = pd.DataFrame(dict(species=species_depth[species_id], genes=mean_marker_depth))
plt.scatter('species', 'genes', data=d)
plt.plot([0, 1e2], [0, 1e2])
plt.yscale('symlog')
plt.xscale('symlog')

In [None]:
d = pd.DataFrame(dict(species=species_depth[species_id], genes=mean_depth_species_genes))
plt.scatter('species', 'genes', data=d)
plt.plot([0, 1e2], [0, 1e2])
# plt.yscale('symlog')
# plt.xscale('symlog')

In [None]:
depth_ratio_75 = genes_depth_75.divide(mean_depth_species_genes, axis=0)
depth_ratio = genes_depth.divide(mean_depth_species_genes, axis=0)

## Strain Genes

In [None]:
strain_rank = 0
strain_id = top_strains[strain_rank]
strain_id

### Select focal samples

In [None]:
pure_samples_with_strain = list(set(species_samples) & set(idxwhere((focal_species_strain_depth.idxmax(1) == strain_id))) & set(pure_samples))
focal_samples = list(set(pure_samples_with_strain) | set(no_species_samples))
len(pure_samples_with_strain), len(focal_samples)

### Depth estimate comparison

In [None]:
d = pd.DataFrame(dict(species=species_depth[species_id], genes=mean_depth_species_genes))
plt.scatter('species', 'genes', data=d.loc[focal_samples])
plt.plot([0, 1e2], [0, 1e2])
plt.yscale('symlog', linthresh=1e-2)
plt.xscale('symlog', linthresh=1e-2)

### Identify depth correlated genes in focal samples

In [None]:
# FIXME: Decide if I want to use species depth or species_gene depth.
x = mean_depth_species_genes.loc[focal_samples].to_frame(species_id)
# x = species_depth.loc[focal_samples, [species_id]]
y = genes_depth.loc[focal_samples]

_transf = np.cbrt

gene_strain_corr = pd.DataFrame(1 - sp.spatial.distance.cdist(_transf(x.T), _transf(y.T), metric='correlation'), index=x.columns, columns=y.columns)

In [None]:
strain_corr_prefilt_thresh = 0.7
strain_gene_maybe_hits = idxwhere(gene_strain_corr.loc[species_id] > strain_corr_prefilt_thresh)
len(strain_gene_maybe_hits)

In [None]:
pd.DataFrame(dict(
    closest_species=gene_strain_corr.idxmin().value_counts(),
    hit_species=(gene_strain_corr > strain_corr_prefilt_thresh).sum(1),
)).join(species_taxonomy[['f__', 'g__']]).sort_values('hit_species', ascending=False).head(5)

In [None]:
bins = np.linspace(0, 1, num=101)
plt.hist(gene_strain_corr.loc[species_id], bins=bins)
plt.hist(gene_strain_corr.loc[species_id, gene_strain_corr.idxmax() == species_id], bins=bins)
plt.hist(gene_strain_corr.loc[species_id].reindex(gene_clusters.marker_id.dropna().index, axis='columns').dropna(), bins=bins)
plt.axvline(strain_corr_prefilt_thresh, lw=1, linestyle='--', color='k')

plt.yscale('log')
None

### How does the distribution of depth-correlation compare to a null distribution?

In [None]:
# FIXME: Decide if I want to use species depth or species_gene depth.
x = mean_depth_species_genes.loc[focal_samples].to_frame(species_id)
# x = species_depth.loc[focal_samples, [species_id]]
y = genes_depth.loc[focal_samples]

x_permute = x.copy()
x_permute.index = np.random.choice(x.index, size=len(x.index), replace=False)
x_permute = x_permute.loc[x.index]

_transf = lambda x: x**(1/3)

gene_strain_corr_permute = pd.DataFrame(1 - sp.spatial.distance.cdist(_transf(x_permute.T), _transf(y.T), metric='cosine'), index=x.columns, columns=y.columns)

In [None]:
bins = np.linspace(0, 1, num=101)
plt.hist(gene_strain_corr_permute.loc[species_id], bins=bins)
plt.hist(gene_strain_corr_permute.loc[species_id, gene_strain_corr_permute.idxmax() == species_id], bins=bins)
plt.hist(gene_strain_corr_permute.loc[species_id].reindex(gene_clusters.marker_id.dropna().index, axis='columns').dropna(), bins=bins)
plt.axvline(strain_corr_prefilt_thresh, lw=1, linestyle='--', color='k')


plt.yscale('log')
None

### Assess depth-correlated genes

In [None]:
d = gene_clusters.loc[strain_gene_maybe_hits].marker_id.dropna().value_counts()
print(len(d))
d

In [None]:
d = genes_depth.loc[pure_samples_with_strain, strain_gene_maybe_hits]

sns.clustermap(
    d,
    norm=mpl.colors.SymLogNorm(linthresh=1.0),
    metric='cosine'
)

In [None]:
d = genes_depth.loc[species_samples, strain_gene_maybe_hits]
e0 = np.log(focal_species_strain_depth[strain_id] + 1e-3)
e1 = ((e0 - e0.min()) / (e0.max() - e0.min()))

sns.clustermap(
    d,
    norm=mpl.colors.SymLogNorm(linthresh=1.0),
    metric='cosine',
    row_colors=pd.DataFrame(dict(
        focal=d.index.to_series().isin(pure_samples_with_strain).astype(float).map(mpl.cm.viridis),
        depth=e1.map(mpl.cm.viridis),
        rabund=focal_species_strain_rabund[strain_id].map(mpl.cm.viridis),
    )),
)

### Compare gene depth-ratio estimates

In [None]:
total_strain_gene_depth = genes_depth.loc[pure_samples_with_strain].sum()
strain_mean_depth_ratio = total_strain_gene_depth / total_strain_gene_depth[strain_gene_maybe_hits].groupby(gene_clusters.centroid_75).sum().mean()

In [None]:
d0 = np.log2(strain_mean_depth_ratio)
d1 = np.log2(depth_ratio.loc[pure_samples_with_strain].mean())

bins = np.linspace(-10, 5, num=51)

d2 = pd.DataFrame(dict(mean_depth_ratio=d1, wmean_depth_ratio=d0)).assign(hit=lambda x: x.index.isin(strain_gene_maybe_hits)).sort_values('hit')

g = sns.jointplot(data=d2, x='wmean_depth_ratio', y='mean_depth_ratio', hue='hit')
g.ax_joint.plot([-6, 6], [-6, 6], color='k')
# plt.hist(d0, bins=bins)
# plt.hist(d0.reindex(idxwhere(gene_clusters.marker_id.notna())).dropna(), bins=bins)

# plt.hist(d1, bins=bins)
# plt.hist(d1.reindex(idxwhere(gene_clusters.marker_id.notna())).dropna(), bins=bins)
# plt.yscale('log')
None

#### Why does a weighted mean result in a higher estimate?

I believe there is a bias due to truncating the coverage when few reads map to a gene.

## Pick strain genes

In [None]:
best_strain_gene_match = gene_strain_corr.loc[species_id].fillna(0).groupby(gene_clusters.centroid_75).idxmax()
assert best_strain_gene_match.is_unique and best_strain_gene_match.index.is_unique
strain_to_species_gene = best_strain_gene_match.to_frame('centroid_99').reset_index().set_index('centroid_99', drop=False)['centroid_75']
strain_to_species_gene

In [None]:
depth_ratio_thresh = 0.05
strain_corr_thresh = 0.7
species_corr_thresh = 0.7

strain_gene_info = (
    strain_mean_depth_ratio
    .to_frame('depth_ratio')
    .assign(
        marker_id=gene_clusters.marker_id,
        strain_corr=gene_strain_corr.loc[species_id].fillna(0),
        # species_corr=gene_species_corr.loc[species_id].fillna(0),
        species_gene_id=strain_to_species_gene,
    )
    .join(gene_species_corr.loc[species_id].to_frame('species_corr'), on='species_gene_id')
    .assign(species_corr=lambda x: x.species_corr.fillna(0))
    .assign(hit=lambda x: (x.depth_ratio > depth_ratio_thresh) & ((x.strain_corr > strain_corr_thresh) | (x.species_corr > species_corr_thresh)))
    .join(gene_clusters.drop(columns=['marker_id']))
    .join(gene_species_corr.loc[species_id].to_frame('_species_corr'), on='centroid_75')
    .assign(strain_species_log_ratio=lambda x: np.log2((1 - x._species_corr) / (1 - x.strain_corr)))
)

strain_gene_info[strain_gene_info.hit]

In [None]:
bins = np.linspace(-4, 4, num=101)
plt.hist(strain_gene_info.strain_species_log_ratio, bins=bins)
plt.hist(strain_gene_info[strain_gene_info.hit].strain_species_log_ratio, bins=bins)
plt.yscale('log')
None

## Assess strain genes

### How often would we miss "species genes" in the strain-specific analysis?

In [None]:
d = strain_gene_info.sort_values('depth_ratio')

fig = plt.figure(figsize=(15, 10))
plt.scatter('strain_corr', 'species_corr', data=d[d.hit], c='depth_ratio', s=5, norm=mpl.colors.LogNorm())
sns.kdeplot(x='strain_corr', y='species_corr', data=d.sample(n=1000), color='black', alpha=1.0, linewidths=1, zorder=1)


plt.axvline(strain_corr_thresh, lw=1, linestyle='--', color='k')
plt.axhline(species_corr_thresh, lw=1, linestyle='--', color='k')
# plt.yscale('log')
# plt.xscale('log')
plt.colorbar()
print('Fraction of hits uncorrelated with strain:', (d[d.hit].strain_corr < strain_corr_thresh).mean())
print('Fraction of hits uncorrelated with species:', (d[d.hit].species_corr < species_corr_thresh).mean())
print('Fraction of hits correlated with strain AND species:', ((d[d.hit].species_corr > species_corr_thresh) & (d[d.hit].strain_corr > strain_corr_thresh)).mean())


d1 = d[d.hit].groupby(gene_clusters.marker_id)[['depth_ratio', 'hit']].sum()
print(d[d.hit].depth_ratio.sum().round(), d.hit.sum(), d1.depth_ratio.sum().round(), d1.shape[0])
print(d1.sort_values('depth_ratio', ascending=False))

In [None]:
d = (
    strain_gene_info
    [strain_gene_info.hit]
    .groupby(gene_clusters.marker_id)
    [['depth_ratio', 'hit']]
    .sum()
    .reindex(gene_clusters.marker_id.dropna().unique())
    .fillna(0)
    .join(gene_clusters.groupby('marker_id').centroid_99_length.median())
    .sort_values('depth_ratio')
    .reset_index()
)
d
plt.plot(d.index, d.depth_ratio)
plt.scatter(d.index, d.depth_ratio, c=d.hit, s=d.centroid_99_length / 10)
plt.axhline(1, lw=1, linestyle='--', color='k')
plt.xticks(range(len(d.index)), d.marker_id, rotation=45, ha='right')
None

### Why multiple hits to the same marker genes? ("contamination" / "redundancy")

TODO: Update this now that I'm better correcting for read-stealing.

In [None]:
marker_id = 'B000062'

d0 = strain_gene_info[strain_gene_info.marker_id == marker_id].sort_values('strain_corr', ascending=False)
gene_hits = idxwhere(d0.hit)
d1 = genes_depth[gene_hits].assign(total=lambda x: x.sum(1)).assign(species=mean_depth_species_genes)

_transf = np.cbrt
_ncol = 5
npanels = len(gene_hits) + 1
nrow = int(np.ceil(npanels / _ncol))
ncol = min(_ncol, npanels)

linthresh = 1e-1
xx = np.logspace(-3, 2, num=1000)

fig, axs = plt.subplots(nrow, ncol, figsize=(3 * ncol, 3 * nrow), sharex=True, sharey=True)
for gene_id, ax in zip(gene_hits + ['total'], axs.flatten()):
    ax.scatter('species', gene_id, data=d1, s=2)
    corr = 1 - sp.spatial.distance.cdist(
        _transf(species_depth.loc[focal_samples, [species_id]].T),
        _transf(d1.loc[focal_samples, [gene_id]].T),
        metric='cosine'
    )[0, 0]
    if gene_id != 'total':
        _depth_ratio = d0.loc[gene_id].depth_ratio
    else:
        _depth_ratio = 1.0
    ax.plot(xx, xx, lw=1, linestyle='--', color='k')
    ax.plot(xx, xx * _depth_ratio, lw=1, linestyle='-', color='k')
    ax.set_title(f'{gene_id} ({corr:0.2f})', fontdict=dict(fontsize=8))
    ax.set_aspect('equal')
ax.set_yscale('symlog', linthresh=linthresh)
ax.set_xscale('symlog', linthresh=linthresh)
d0.head(len(gene_hits) + 3)

In [None]:
marker_id = 'B000081'

d0 = strain_gene_info[strain_gene_info.marker_id == marker_id].sort_values('strain_corr', ascending=False)
gene_hits = idxwhere(d0.hit)
d1 = genes_depth[gene_hits].assign(total=lambda x: x.sum(1)).assign(species=mean_depth_species_genes)

_transf = np.cbrt
_ncol = 5
npanels = len(gene_hits) + 1
nrow = int(np.ceil(npanels / _ncol))
ncol = min(_ncol, npanels)

linthresh = 1e-1
xx = np.logspace(-3, 2, num=1000)

fig, axs = plt.subplots(nrow, ncol, figsize=(3 * ncol, 3 * nrow), sharex=True, sharey=True)
for gene_id, ax in zip(gene_hits + ['total'], axs.flatten()):
    ax.scatter('species', gene_id, data=d1, s=2)
    corr = 1 - sp.spatial.distance.cdist(
        _transf(species_depth.loc[focal_samples, [species_id]].T),
        _transf(d1.loc[focal_samples, [gene_id]].T),
        metric='cosine'
    )[0, 0]
    if gene_id != 'total':
        _depth_ratio = d0.loc[gene_id].depth_ratio
    else:
        _depth_ratio = d0[d0.hit].depth_ratio.sum()
    ax.plot(xx, xx, lw=1, linestyle='--', color='k')
    ax.plot(xx, xx * _depth_ratio, lw=1, linestyle='-', color='k')
    ax.set_title(f'{gene_id} ({corr:0.2f})', fontdict=dict(fontsize=8))
    ax.set_aspect('equal')
    
ax.set_yscale('symlog', linthresh=linthresh)
ax.set_xscale('symlog', linthresh=linthresh)
d0.head(len(gene_hits) + 3)

In [None]:
marker_id = 'B000103'

d0 = strain_gene_info[strain_gene_info.marker_id == marker_id].sort_values('strain_corr', ascending=False)
gene_hits = idxwhere(d0.hit)
d1 = genes_depth[gene_hits].assign(total=lambda x: x.sum(1)).assign(species=mean_depth_species_genes)

_transf = np.cbrt
_ncol = 5
npanels = len(gene_hits) + 1
nrow = int(np.ceil(npanels / _ncol))
ncol = min(_ncol, npanels)

linthresh = 1e-1
xx = np.logspace(-3, 2, num=1000)

fig, axs = plt.subplots(nrow, ncol, figsize=(3 * ncol, 3 * nrow), sharex=True, sharey=True)
for gene_id, ax in zip(gene_hits + ['total'], axs.flatten()):
    ax.scatter('species', gene_id, data=d1, s=2)
    corr = 1 - sp.spatial.distance.cdist(
        _transf(species_depth.loc[focal_samples, [species_id]].T),
        _transf(d1.loc[focal_samples, [gene_id]].T),
        metric='cosine'
    )[0, 0]
    if gene_id != 'total':
        _depth_ratio = d0.loc[gene_id].depth_ratio
    else:
        _depth_ratio = d0[d0.hit].depth_ratio.sum()
    ax.plot(xx, xx, lw=1, linestyle='--', color='k')
    ax.plot(xx, xx * _depth_ratio, lw=1, linestyle='-', color='k')
    ax.set_title(f'{gene_id} ({corr:0.2f})', fontdict=dict(fontsize=8))
    ax.set_aspect('equal')
    
ax.set_yscale('symlog', linthresh=linthresh)
ax.set_xscale('symlog', linthresh=linthresh)
d0.head(len(gene_hits) + 3)

### Are the strains really the same across subjects?

In [None]:
mean_depth_species_genes.loc[pure_samples_with_strain].sort_values(ascending=False).head(20)

In [None]:
x, y = 'M2064_53', 'H4040_22'
d = np.log2(depth_ratio.loc[[x, y]] + 1e-2).dropna().T.join(strain_gene_info)
jp = sns.jointplot(x=x, y=y, data=d.sort_values('hit'), hue='hit', s=3, marginal_kws=dict(common_norm=False))
jp.ax_joint.plot([-6, 4], [-6, 4], lw=1, linestyle='--', color='k')

In [None]:
d = genes_depth.loc[pure_samples_with_strain, strain_gene_info.hit]

sns.clustermap(
    d,
    norm=mpl.colors.SymLogNorm(linthresh=1.0),
    metric='cosine',
)
d.shape

In [None]:
d = genes_depth.loc[pure_samples_with_strain, strain_gene_info.hit].groupby(gene_clusters.centroid_75, axis=1).sum()

sns.clustermap(
    d,
    norm=mpl.colors.SymLogNorm(linthresh=1.0),
    metric='cosine',
)
d.shape

In [None]:
d = genes_depth.loc[species_samples, strain_gene_info.hit]
e0 = np.log(focal_species_strain_depth[strain_id] + 1e-3)
e1 = ((e0 - e0.min()) / (e0.max() - e0.min()))

sns.clustermap(
    d,
    norm=mpl.colors.SymLogNorm(linthresh=1.0),
    metric='cosine',
    row_colors=pd.DataFrame(dict(
        focal=d.index.to_series().isin(pure_samples_with_strain).astype(float).map(mpl.cm.viridis),
        depth=e1.map(mpl.cm.viridis),
        rabund=focal_species_strain_rabund[strain_id].map(mpl.cm.viridis),
    )),
)
d.shape

### Do multiple hits to the same gene cluster make sense?

#### How often do we get multiple hits to clust_75

Much less frequently than in the database.
And, the fact that most genes have a summed depth-ratio of ~1x suggests that these can be explained by read-stealing

In [None]:
d = (
    strain_gene_info
    .groupby(['centroid_75', 'hit'])
    .size()
    .unstack(fill_value=0)
    .join(
        strain_gene_info
        [strain_gene_info.hit]
        .groupby('centroid_75')
        .depth_ratio
        .sum()
    )
    .join(gene_clusters.centroid_99_length)
)

bins = np.linspace(0, 20, 21)
plt.hist(d[False], bins=bins)
plt.hist(d[True], bins=bins, alpha=0.5)

plt.yscale('log')
# plt.xlim(0, 10)
d.sort_values(True, ascending=False).head()

plt.figure()
bins = np.linspace(-5, 5, 100)
plt.hist(np.log2(d['depth_ratio'].dropna()), bins=bins)
# plt.yscale('log')
plt.axvline(np.log2(1.0), lw=1, linestyle='-', color='k')
plt.axvline(np.log2(0.2), lw=1, linestyle='--', color='k')


d.sort_values('depth_ratio', ascending=False).head(10)

#### Are these groups of homologous genes sensible?

In [None]:
centroid_75 = 'UHGG141962_04670'

d = (
    strain_gene_info
    .groupby(['centroid_75', 'hit'])
    .size()
    .unstack(fill_value=0)
    .join(
        strain_gene_info
        [strain_gene_info.hit]
        .groupby('centroid_75')
        .depth_ratio
        .sum()
    )
    .join(gene_clusters.centroid_99_length)
)


d1 = strain_gene_info[lambda x: (x.centroid_75 == centroid_75) & x.hit]
nhits = d1.shape[0]
print(d1.depth_ratio.sum())
strain_gene_info[strain_gene_info.centroid_75 == centroid_75].sort_values(['strain_corr'], ascending=False).head(nhits + 2)

In [None]:
!seqtk subseq ref_temp/midasdb_uhgg/pangenomes/{species_id}/centroids.ffn <(echo UHGG141962_04670)

Maybe! For instance, clust_75 'UHGG141962_04670' has 3 hits,
with 3 different centroid_95s
(suggesting that they're not doing too too much read-stealing),
and many have depth_ratios of > 0.5, which suggests that
the gene family gets a ton of coverage across samples
(total of 10.6x across all of the hits).

Interestingly, it seems to be a mobile element:

> ```
UHGG146925_03529: transposase [Escherichia coli 101-1]
```

In [None]:
centroid_75 = 'UHGG145420_04357'

d = (
    strain_gene_info
    .groupby(['centroid_75', 'hit'])
    .size()
    .unstack(fill_value=0)
    .join(
        strain_gene_info
        [strain_gene_info.hit]
        .groupby('centroid_75')
        .depth_ratio
        .sum()
    )
    .join(gene_clusters.centroid_99_length)
)


d1 = strain_gene_info[lambda x: (x.centroid_75 == centroid_75) & x.hit]
nhits = d1.shape[0]
print(d1.depth_ratio.sum())
strain_gene_info[strain_gene_info.centroid_75 == centroid_75].sort_values(['strain_corr'], ascending=False).head(nhits + 2)

In [None]:
!seqtk subseq ref_temp/midasdb_uhgg/pangenomes/{species_id}/centroids.ffn <(echo UHGG144565_01670)

In [None]:
centroid_75 = 'UHGG144776_03235'

d = (
    strain_gene_info
    .groupby(['centroid_75', 'hit'])
    .size()
    .unstack(fill_value=0)
    .join(
        strain_gene_info
        [strain_gene_info.hit]
        .groupby('centroid_75')
        .depth_ratio
        .sum()
    )
    .join(gene_clusters.centroid_99_length)
)


d1 = strain_gene_info[lambda x: (x.centroid_75 == centroid_75) & x.hit]
nhits = d1.shape[0]
print(d1.depth_ratio.sum())
strain_gene_info[strain_gene_info.centroid_75 == centroid_75].sort_values(['strain_corr'], ascending=False).head(nhits + 2)

In [None]:
!seqtk subseq ref_temp/midasdb_uhgg/pangenomes/{species_id}/centroids.ffn <(echo UHGG144776_03235)

Maybe! For instance, clust_75 'UHGG144776_03235' has 5 hits,
but many have different centroid_95
(suggesting that they're not doing too too much read-stealing),
and a total gene dose of 2.7x across all of the hits.

Interestingly, it's also a Type IV secretion system (or maybe conjugative machinery) according
to the nr annotation:

> ```
UHGG146925_03529 (WP_113394767.1): type IV conjugative transfer system protein TraL [Shigella sonnei]
Function assigned by HMM accession: TIGR02762.1
```

## What can we say about these genes?

### What functions do some species genes have?

In [None]:
d = strain_gene_info[strain_gene_info.hit].sort_values('species_corr', ascending=False).head(5)

gene_id_list_string = ' '.join(d.index)
d

In [None]:
!seqtk subseq ref_temp/midasdb_uhgg/pangenomes/{species_id}/centroids.ffn <(echo {gene_id_list_string} | tr ' ' '\n')

- UHGG143797_05500 is assigned "Nickel/cobalt homeostasis protein RcnB precursor [Escherichia coli]"
- UHGG146766_03472 is assigned "regulatory protein mokC [Escherichia coli B088]" by blastx
- UHGG146766_00719 is assigned "hypothetical protein [Escherichia coli]" by blastx
- UHGG143923_02826 is assigned "alkyl hydroperoxide reductase subunit F [Escherichia coli]" by blastx

### What functions do some strain-only genes have?

In [None]:
d = strain_gene_info[lambda x: x.hit & x.strain_corr.gt(0.95) & x.species_corr.lt(0.6)].sort_values('strain_corr', ascending=False).head(5)

gene_id_list_string = ' '.join(d.index)
d

In [None]:
!seqtk subseq ref_temp/midasdb_uhgg/pangenomes/{species_id}/centroids.ffn <(echo {gene_id_list_string} | tr ' ' '\n')

- UHGG001882_03784 is assigned "hypothetical protein [Escherichia coli]" by blastx
- UHGG051562_01340 is assigned "unnamed protein product [Klebsiella pneumoniae]" by blastx
- UHGG212210_02494 is assigned "glutamine-hydrolyzing GMP synthase [Escherichia coli]" by blastx
- UHGG000026_01858 is assigned "type I-F CRISPR-associated protein Csy3 [Escherichia coli]"
- UHGG000317_01658 = dihydrolipoyl dehydrogenase [Escherichia coli]
- UHGG000026_00542 = colibactin biosynthesis acyltransferase ClbG [Enterobacterales]
- UHGG027725_01181 = ATP-dependent RNA helicase DbpA [Escherichia coli]
- UHGG001882_03728 = hypothetical protein NMECO18_13515 [Escherichia coli]
- UHGG051562_01340 = unnamed protein product [Klebsiella pneumoniae]

## Various Characterizations of strains

In [None]:
# Moderate positive relationship between gene length and correlation.
# (Among hits, highly correlated genes tend to be longer.)

d = strain_gene_info[strain_gene_info.hit].sort_values('strain_corr', ascending=False)
# d1 = d0.assign(rank=range(1, len(d0) + 1))
plt.scatter(x='strain_corr', y='centroid_99_length', data=d, s=1)
sns.kdeplot(x='strain_corr', y='centroid_99_length', data=d, log_scale=True, linewidths=1, color='k')

sp.stats.spearmanr(d['centroid_99_length'], d['strain_corr'])

In [None]:
# Strong negative relationship between gene length and estimated depth.
d = strain_gene_info[strain_gene_info.hit]

plt.scatter('centroid_99_length', 'depth_ratio', data=d, s=1)
plt.yscale('log')
plt.xscale('log')

sp.stats.spearmanr(d['centroid_99_length'], d['depth_ratio'])

In [None]:
# Strong positive relationship between "strain-ness" and estimated depth.
d = strain_gene_info[strain_gene_info.hit]

plt.scatter('strain_species_log_ratio', 'depth_ratio', data=d, s=1)
plt.yscale('log')
# plt.xscale('log')

sp.stats.spearmanr(d['strain_species_log_ratio'], d['depth_ratio'])

# Cross-strain Comparison

## Write gene table

In [None]:
path = f'{strain_id}.tsv'
print(path)
strain_gene_info.to_csv(path, sep='\t')

## Load tables and compare

In [None]:
strain_list = top_strains[:7]  # ['102506-2', '102506-5', '102506-6', '102506-11', '102506-28', '102506-24', '102506-4']
all_strains_gene_info = {}
for _strain_id in strain_list:
    all_strains_gene_info[_strain_id] = pd.read_table(f'{_strain_id}.tsv', index_col='gene_id')

In [None]:
all_strains_gene_hits = pd.DataFrame({k: all_strains_gene_info[k].hit for k in all_strains_gene_info})
all_strains_gene_hits.value_counts().sort_index()

In [None]:
all_strains_gene_hits_75 = pd.DataFrame({k: all_strains_gene_info[k].groupby(gene_clusters.centroid_75).hit.any() for k in all_strains_gene_info})
all_strains_gene_hits_75.value_counts().sort_index()

In [None]:
pure_samples_with_strains = list(
    set(species_samples)
    & set(idxwhere((focal_species_strain_depth.idxmax(1).isin(strain_list))))
    & set(pure_samples)
)

strain_hit = idxwhere(
    all_strains_gene_hits.any(1)
)

sample_dominant_strain = focal_species_strain_rabund.idxmax(1)
strain_sample_palette = lib.plot.construct_ordered_palette(strain_list, cm='tab20')
subject_palette = lib.plot.construct_ordered_palette(subject_week.sort_values(['site', 'subject_id']).subject_id, cm='Spectral')


d = genes_depth.loc[pure_samples_with_strains, strain_hit]

sns.clustermap(
    d,
    norm=mpl.colors.SymLogNorm(linthresh=1.0),
    metric='cosine',
    # col_colors=pd.DataFrame(dict(
    #     a=strain_gene_infoA.hit.astype(float).map(mpl.cm.viridis),
    #     b=strain_gene_infoB.hit.astype(float).map(mpl.cm.viridis),
    # )),
    row_colors=pd.DataFrame(dict(
        subject=subject_week.subject_id.map(subject_palette),
        pure=sample_dominant_strain.map(strain_sample_palette),
        depth=(np.cbrt(species_depth[species_id]) / 3).map(mpl.cm.viridis),
    )),
)

In [None]:
strain_depth[_strain_id][lambda x: x > 0.5].std()
# Since we know this ahead of time, we could select our cutoff for samples based on this relationship.

In [None]:
fig = plt.figure(figsize=(10, 5))

for _strain_id in strain_list:
    num_sample = strain_sample_counts.loc[_strain_id, 'num_samples']
    stdev_depth = strain_depth[_strain_id][lambda x: x > focal_sample_species_depth_thresh_pres].std()
    quality_index = stdev_depth * np.sqrt(num_sample)
    sns.kdeplot(all_strains_gene_info[_strain_id].loc[all_strains_gene_hits.all(1), 'strain_corr'], label=f"{_strain_id}", c=strain_sample_palette[_strain_id], lw=np.cbrt(quality_index))
    sns.kdeplot(all_strains_gene_info[_strain_id].loc[lambda x: x.species_gene_id.isin(species_gene_hits), 'strain_corr'], c=strain_sample_palette[_strain_id], linestyle='--')

plt.axvline(strain_corr_thresh, linestyle='--', lw=1, color='grey')
plt.legend() 