# Preamble

In [None]:
%load_ext autoreload

In [None]:
import os
os.chdir('..')
os.path.realpath(os.path.curdir)

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from lib.pandas_util import idxwhere, repeated
import matplotlib as mpl
import lib.plot
import statsmodels as sm
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

In [None]:
import sfacts as sf

## Set Parameters

In [None]:
species_id = '102506'

# Load Data

## Sample Metadata

In [None]:
mgen = pd.read_table('meta/hmp2/mgen.tsv', index_col='library_id')
preparation = pd.read_table('meta/hmp2/preparation.tsv', index_col='preparation_id')
stool = pd.read_table('meta/hmp2/stool.tsv', index_col='stool_id')
visit = pd.read_table('meta/hmp2/visit.tsv', index_col='visit_id')
subject = pd.read_table('meta/hmp2/subject.tsv', index_col='subject_id')

mgen_meta = (
    mgen
    .join(preparation.drop(columns='library_type'), on='preparation_id')
    .join(stool, on='stool_id')
    .join(visit, on='visit_id', rsuffix='_')
    .join(subject, on='subject_id')
)

assert not any(mgen_meta.subject_id.isna())

# meta.columns

In [None]:
_subject_week = (
    visit
    .join(subject, on='subject_id')
    .reset_index()
    .dropna(subset=['subject_id', 'week_number'])
    .groupby(['subject_id', 'week_number'])
    .apply(lambda d: d.loc[d.notna().sum(1).sort_values().index[-1]])
    .assign(subject_week_id=lambda x: x.subject_id + '_' + x.week_number.astype(int).astype(str))
    .set_index('subject_week_id')
    .join(stool.groupby('visit_id').fecal_calprotectin.mean(), on='visit_id')
    .sort_values(['subject_id', 'week_number'])
)

mgen_to_subject_week = mgen_meta.dropna(subset=['week_number']).apply(lambda x: x.subject_id + '_' + str(int(x.week_number)), axis=1).rename('subject_week_id')
mgen_to_subject_week
#.groupby(['subject_id', 'week_number']).visit_id.count().sort_values(ascending=False)

## Species

In [None]:
species_depth = (
    pd.read_table('data/hmp2.a.r.proc.gtpro.species_depth.tsv', index_col=['sample', 'species_id'])
    .squeeze()
    .unstack('species_id', fill_value=0)
    .groupby(mgen_to_subject_week)
    .sum()
)
species_depth.columns = species_depth.columns.astype(str)

In [None]:
subject_week = _subject_week.assign(has_mgen=lambda x: x.index.isin(species_depth.index))

## Strains

In [None]:
strain_depth = pd.read_table(
    'data_temp/hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts18-s75-seed0.strain_depth.tsv',
    # names=['library_id', 'species_strain_id', 'depth'],
    index_col=['sample', 'strain'],
).squeeze().unstack('strain', fill_value=0).groupby(mgen_to_subject_week).sum()

plt.hist(strain_depth.sum(1) - species_depth.sum(1), bins=50)
None

In [None]:
species_rabund = species_depth.divide(species_depth.sum(1), axis=0)
strain_rabund = strain_depth.divide(strain_depth.sum(1), axis=0)

## Taxonomy

In [None]:
species_taxonomy = pd.read_table('ref/gtpro/species_taxonomy_ext.tsv', names=['genome_id', 'species_id', 'taxonomy_string']).assign(species_id=lambda x: x.species_id.astype(str)).set_index('species_id')[['taxonomy_string']].assign(taxonomy_split=lambda x: x.taxonomy_string.str.split(';'))

for level_name, level_number in [('p__', 1), ('c__', 2), ('o__', 3), ('f__', 4), ('g__', 5), ('s__', 6)]:
    species_taxonomy = species_taxonomy.assign(**{level_name: species_taxonomy.taxonomy_split.apply(lambda x: x[level_number])}) 
species_taxonomy = species_taxonomy.drop(columns=['taxonomy_split'])
    
strain_taxonomy = strain_depth.columns.to_series().str.split('-').str[0].to_frame(name='species_id').join(species_taxonomy, on='species_id')

species_taxonomy = strain_taxonomy.drop_duplicates(subset=['species_id']).set_index('species_id')

## Genes

In [None]:
gene_clusters = pd.read_table(f'ref_temp/midasdb_uhgg/pangenomes/{species_id}/cluster_info.txt', index_col='centroid_99')

In [None]:
# FIXME: Install python-lz4 into the sfacts module so I can open this file as 'data_temp/sp-102506.hmp2.a.r.proc.midas_merge/genes/102506/102506.genes_depth.tsv.lz4' instead?
genes_depth = pd.read_table(f'data_temp/sp-{species_id}.hmp2.a.r.proc.midas_genes.depth.tsv', index_col='gene_id').groupby(mgen_to_subject_week, axis='columns').sum().T

In [None]:
genes_depth_75 = genes_depth.groupby(gene_clusters.centroid_75, axis='columns').sum()

In [None]:
mean_depth_present_genes = genes_depth_75.T.apply(lambda x: x[x > 0].median())
mean_marker_depth = genes_depth.groupby(gene_clusters.marker_id, axis='columns').sum().mean(1)

In [None]:
species_depth[species_id].sort_values()

In [None]:
plt.hist(np.log10(genes_depth_75.loc['C3009_10'] + 1e-3), bins=100)
plt.axvline(np.log10(species_depth.loc['C3009_10', species_id]))
plt.yscale('log')

In [None]:
def trim_gmean_nonzero(x, proportiontocut, axis=0):
    return np.exp(sp.stats.trim_mean(np.log(x[lambda x: x > 0]), proportiontocut, axis=axis))

In [None]:
def trim_mean_top_n(x, n, proportiontocut, axis=0):
    return sp.stats.trim_mean(np.sort(x)[-n:], proportiontocut, axis=axis)

In [None]:
mean_depth_present_genesB = genes_depth_75.T.apply(lambda x: x[x > 0].median())
mean_depth_present_genesA = genes_depth_75.T.apply(trim_gmean_nonzero, proportiontocut=0.1)
mean_depth_present_genes = genes_depth_75.T.apply(trim_mean_top_n, n=2000, proportiontocut=0.3)

# Focal Species

In [None]:
focal_species_strain_depth = strain_depth.loc[:, strain_taxonomy.species_id == species_id]
focal_species_strain_rabund = focal_species_strain_depth.divide(focal_species_strain_depth.sum(1), axis=0)

In [None]:
sns.clustermap(focal_species_strain_depth, norm=mpl.colors.PowerNorm(1/10))

In [None]:
strain_rabund_thresh = 0.95
focal_sample_species_depth_thresh_pres = 0.5
focal_sample_species_depth_thresh_abs = 1e-3

d0 = (
    (focal_species_strain_depth > focal_sample_species_depth_thresh_pres)
    & (focal_species_strain_rabund > strain_rabund_thresh)
)
num_samples = d0.sum().sort_values(ascending=False)
num_subjects = d0.groupby(subject_week.subject_id).any().sum().sort_values(ascending=False).head()

strain_sample_counts = pd.DataFrame(dict(
    num_samples=d0.sum(),
    num_subjects=d0.groupby(subject_week.subject_id).any().sum(),
)).sort_values('num_samples', ascending=False)

top_strains = strain_sample_counts.head(20).index

strain_sample_counts.head(8)

## Species Genes

### Select high-coverage and no-coverage samples

In [None]:
species_samples = idxwhere(species_depth[species_id] > focal_sample_species_depth_thresh_pres)
no_species_samples = idxwhere(species_depth[species_id] < focal_sample_species_depth_thresh_abs)
len(species_samples), len(no_species_samples)

### Comparison of GT-Pro and MIDAS depth estimates

In [None]:
d = pd.DataFrame(dict(species=species_depth[species_id], genes=mean_marker_depth))
plt.scatter('species', 'genes', data=d, alpha=0.4)
plt.scatter('species', 'genes', data=d.loc[species_samples], s=2)
plt.scatter('species', 'genes', data=d.loc[no_species_samples], s=2)
plt.plot([0, 1e2], [0, 1e2])
plt.yscale('symlog', linthresh=focal_sample_species_depth_thresh_abs)
plt.xscale('symlog', linthresh=focal_sample_species_depth_thresh_abs)

### Identify species core genes

In [None]:
focal_samples = list(set(species_samples) | set(no_species_samples))

In [None]:
x = species_depth.loc[focal_samples]
y = genes_depth_75.loc[focal_samples]

_transf = np.cbrt

gene_species_cos_dist = pd.DataFrame(sp.spatial.distance.cdist(_transf(x.T), _transf(y.T), metric='cosine'), index=x.columns, columns=y.columns)

In [None]:
species_prefilt_thresh = 0.1
species_gene_hits = idxwhere(gene_species_cos_dist.loc[species_id] < species_prefilt_thresh)
len(species_gene_hits)

In [None]:
pd.DataFrame(dict(
    closest_species=gene_species_cos_dist.idxmin().value_counts(),
    hit_species=(gene_species_cos_dist < species_prefilt_thresh).sum(1),
)).join(species_taxonomy[['f__', 'g__']]).sort_values('hit_species', ascending=False).head(5)

In [None]:
bins = np.linspace(0, 1, num=101)
plt.hist(gene_species_cos_dist.loc[species_id], bins=bins)
plt.hist(gene_species_cos_dist.loc[species_id, gene_species_cos_dist.idxmin() == species_id], bins=bins)
plt.hist(gene_species_cos_dist.loc[species_id].reindex(gene_clusters.marker_id.dropna().index, axis='columns').dropna(), bins=bins)
plt.axvline(species_prefilt_thresh, lw=1, linestyle='--', color='k')

plt.yscale('log')
None

## Why do many genes in the focal species pangenome correlated with other species as well?

In [None]:
plt.scatter(gene_species_cos_dist.loc['102506'], gene_species_cos_dist.loc['102538'], s=1)

Is it because the two species are themselves correlated.

In [None]:
x, y = species_depth.loc[focal_samples, '102506'], species_depth.loc[focal_samples, '102538']
plt.scatter(x, y, alpha=0.2)
plt.xscale('symlog', linthresh=1e-3)
plt.yscale('symlog', linthresh=1e-3)
sp.stats.pearsonr(x, y)

This suggests that it is instead because of pangenome cross-mapping.

### How does the distribution of depth-correlation compare to a null distribution?

In [None]:
x = species_depth.loc[focal_samples]
y = genes_depth_75.loc[focal_samples]

x_permute = x.copy()
x_permute.index = np.random.choice(x.index, size=len(x.index), replace=False)
x_permute = x_permute.loc[x.index]

_transf = np.cbrt

gene_species_cos_dist_permute = pd.DataFrame(sp.spatial.distance.cdist(_transf(x_permute.T), _transf(y.T), metric='cosine'), index=x.columns, columns=y.columns)

In [None]:
bins = np.linspace(0, 1, num=101)
plt.hist(gene_species_cos_dist_permute.loc[species_id], bins=bins)
plt.hist(gene_species_cos_dist_permute.loc[species_id, gene_species_cos_dist_permute.idxmin() == species_id], bins=bins)
plt.hist(gene_species_cos_dist_permute.loc[species_id].reindex(gene_clusters.marker_id.dropna().index, axis='columns').dropna(), bins=bins)

plt.yscale('log')
None

### Assess core gene hits

In [None]:
d = gene_clusters.loc[species_gene_hits].marker_id.dropna().value_counts()
print(len(d))
d

In [None]:
sns.clustermap(genes_depth_75[species_gene_hits].loc[species_samples], norm=mpl.colors.SymLogNorm(linthresh=0.5), metric='cosine')

In [None]:
mean_depth_species_genes = genes_depth_75[species_gene_hits].apply(sp.stats.trim_mean, proportiontocut=0.1, axis=1)
mean_depth_species_genes.sort_values(ascending=False).head(10)

In [None]:
d = pd.DataFrame(dict(species=species_depth[species_id], genes=mean_depth_species_genes))
plt.scatter('species', 'genes', data=d)
plt.plot([0, 1e2], [0, 1e2])
plt.yscale('symlog')
plt.xscale('symlog')

#### Why are species depths lower when estimated from genes

I believe this may be happening due to read stealing by similar genes.

In [None]:
d = pd.DataFrame(dict(species=species_depth[species_id], genes=mean_marker_depth))
plt.scatter('species', 'genes', data=d)
plt.plot([0, 1e2], [0, 1e2])
plt.yscale('symlog')
plt.xscale('symlog')

In [None]:
d = pd.DataFrame(dict(species=species_depth[species_id], genes=mean_depth_species_genes))
plt.scatter('species', 'genes', data=d)
plt.plot([0, 1e2], [0, 1e2])
# plt.yscale('symlog')
# plt.xscale('symlog')

In [None]:
depth_ratio_75 = genes_depth_75.divide(mean_depth_species_genes, axis=0)
depth_ratio = genes_depth.divide(mean_depth_species_genes, axis=0)

In [None]:
d = np.log2(depth_ratio_75.loc[species_samples] + 1e-2).dropna().T
sns.jointplot(x='C3009_10', y='M2034_42', data=d.assign(hits=lambda x: x.index.isin(species_gene_hits)), hue='hits',  alpha=0.5, s=2)
None

## Strain Genes

In [None]:
strain_rank = 6
strain_id = top_strains[strain_rank]
strain_id

### Select focal samples

In [None]:
pure_samples = idxwhere((focal_species_strain_depth.apply(lambda x: x / x.sum(), axis=1) > strain_rabund_thresh).any(1))
pure_samples_with_strain = list(set(species_samples) & set(idxwhere((focal_species_strain_depth.idxmax(1) == strain_id))) & set(pure_samples))
focal_samples = list(set(pure_samples_with_strain) | set(no_species_samples))
len(pure_samples_with_strain), len(focal_samples)

### Depth estimate comparison

In [None]:
d = pd.DataFrame(dict(species=species_depth[species_id], genes=mean_depth_species_genes))
plt.scatter('species', 'genes', data=d.loc[focal_samples])
plt.plot([0, 1e2], [0, 1e2])
plt.yscale('symlog')
plt.xscale('symlog')

### Identify depth correlated genes in focal samples

In [None]:
x = species_depth.loc[focal_samples]
y = genes_depth.loc[focal_samples]

_transf = np.cbrt

gene_strain_cos_dist = pd.DataFrame(sp.spatial.distance.cdist(_transf(x.T), _transf(y.T), metric='cosine'), index=x.columns, columns=y.columns)

In [None]:
strain_prefilt_thresh = 0.3
strain_gene_maybe_hits = idxwhere(gene_strain_cos_dist.loc[species_id] < strain_prefilt_thresh)
len(strain_gene_maybe_hits)

In [None]:
pd.DataFrame(dict(
    closest_species=gene_strain_cos_dist.idxmin().value_counts(),
    hit_species=(gene_strain_cos_dist < strain_prefilt_thresh).sum(1),
)).join(species_taxonomy[['f__', 'g__']]).sort_values('hit_species', ascending=False).head(5)

In [None]:
bins = np.linspace(0, 1, num=101)
plt.hist(gene_strain_cos_dist.loc[species_id], bins=bins)
plt.hist(gene_strain_cos_dist.loc[species_id, gene_strain_cos_dist.idxmin() == species_id], bins=bins)
plt.hist(gene_strain_cos_dist.loc[species_id].reindex(gene_clusters.marker_id.dropna().index, axis='columns').dropna(), bins=bins)

plt.yscale('log')
None

### How does the distribution of depth-correlation compare to a null distribution?

In [None]:
x = species_depth.loc[focal_samples]
y = genes_depth.loc[focal_samples]

x_permute = x.copy()
x_permute.index = np.random.choice(x.index, size=len(x.index), replace=False)
x_permute = x_permute.loc[x.index]

_transf = np.cbrt

gene_strain_cos_dist_permute = pd.DataFrame(sp.spatial.distance.cdist(_transf(x_permute.T), _transf(y.T), metric='cosine'), index=x.columns, columns=y.columns)

In [None]:
bins = np.linspace(0, 1, num=101)
plt.hist(gene_strain_cos_dist_permute.loc[species_id], bins=bins)
plt.hist(gene_strain_cos_dist_permute.loc[species_id, gene_strain_cos_dist_permute.idxmin() == species_id], bins=bins)
plt.hist(gene_strain_cos_dist_permute.loc[species_id].reindex(gene_clusters.marker_id.dropna().index, axis='columns').dropna(), bins=bins)

plt.yscale('log')
None

### Assess depth-correlated genes

In [None]:
d = gene_clusters.loc[strain_gene_maybe_hits].marker_id.dropna().value_counts()
print(len(d))
d

In [None]:
d = genes_depth.loc[species_samples, strain_gene_maybe_hits]
e0 = np.log(focal_species_strain_depth[strain_id] + 1e-3)
e1 = ((e0 - e0.min()) / (e0.max() - e0.min()))

sns.clustermap(
    d,
    norm=mpl.colors.SymLogNorm(linthresh=1.0),
    metric='cosine',
    row_colors=pd.DataFrame(dict(
        focal=d.index.to_series().isin(pure_samples_with_strain).astype(float).map(mpl.cm.viridis),
        depth=e1.map(mpl.cm.viridis),
        rabund=focal_species_strain_rabund[strain_id].map(mpl.cm.viridis),
    )),
)

In [None]:
d = genes_depth.loc[pure_samples_with_strain, strain_gene_maybe_hits]

sns.clustermap(
    d,
    norm=mpl.colors.SymLogNorm(linthresh=1.0),
    metric='cosine'
)

### Compare gene depth-ratio estimates

In [None]:
total_strain_gene_depth = genes_depth.loc[pure_samples_with_strain].sum()
strain_mean_depth_ratio = total_strain_gene_depth / total_strain_gene_depth[strain_gene_maybe_hits].mean()

In [None]:
d0 = np.log2(strain_mean_depth_ratio)
d1 = np.log2(depth_ratio.loc[pure_samples_with_strain].mean())

bins = np.linspace(-10, 5, num=51)

d2 = pd.DataFrame(dict(mean_depth_ratio=d1, wmean_depth_ratio=d0)).assign(hit=lambda x: x.index.isin(strain_gene_maybe_hits))

g = sns.jointplot(data=d2, x='wmean_depth_ratio', y='mean_depth_ratio', hue='hit')
g.ax_joint.plot([-6, 6], [-6, 6], color='k')
# plt.hist(d0, bins=bins)
# plt.hist(d0.reindex(idxwhere(gene_clusters.marker_id.notna())).dropna(), bins=bins)

# plt.hist(d1, bins=bins)
# plt.hist(d1.reindex(idxwhere(gene_clusters.marker_id.notna())).dropna(), bins=bins)
# plt.yscale('log')
None

#### Why does a weighted mean result in a higher estimate?

I believe there is a bias due to truncating the coverage when few reads map to a gene.

## Pick strain genes

In [None]:
best_strain_gene_match = gene_strain_cos_dist.loc[species_id].fillna(1.0).groupby(gene_clusters.centroid_75).idxmin()
assert best_strain_gene_match.is_unique and best_strain_gene_match.index.is_unique
strain_to_species_gene = best_strain_gene_match.to_frame('centroid_99').reset_index().set_index('centroid_99', drop=False)['centroid_75']
strain_to_species_gene

In [None]:
depth_ratio_thresh = 0.2
strain_cos_thresh = 0.3
species_cos_thresh = 0.3

strain_gene_info = (
    strain_mean_depth_ratio
    .to_frame('depth_ratio')
    .assign(
        marker_id=gene_clusters.marker_id,
        strain_cos=gene_strain_cos_dist.loc[species_id].fillna(1.0),
        # species_cos=gene_species_cos_dist.loc[species_id].fillna(1.0),
        species_gene_id=strain_to_species_gene,
    )
    .join(gene_species_cos_dist.loc[species_id].to_frame('species_cos'), on='species_gene_id')
    .assign(species_cos=lambda x: x.species_cos.fillna(1.0))
    .assign(hit=lambda x: (x.depth_ratio > depth_ratio_thresh) & ((x.strain_cos < strain_cos_thresh) | (x.species_cos < species_cos_thresh)))
    .join(gene_clusters.drop(columns=['marker_id']))
)

strain_gene_info[strain_gene_info.hit]

## Assess strain genes

In [None]:
d = strain_gene_info[strain_gene_info.hit].sort_values('depth_ratio')

plt.scatter('strain_cos', 'species_cos', data=d, c='depth_ratio', s=5, norm=mpl.colors.LogNorm())
plt.yscale('log')
plt.xscale('log')
plt.colorbar()

print(len(d), len(d.marker_id.value_counts()))
d.marker_id.value_counts()

### Why multiple hits to the same marker genes? ("contamination" / "redundancy")

In [None]:
strain_gene_info[strain_gene_info.marker_id == 'B000082'].sort_values('strain_cos')

Two hits to B000082 may be read stealing. Together they only have 1.45x depth and the correlation
is on the lower side for one of them (cos diss of 0.18). They also have the same centroid_90 (although not centroid_95).

In [None]:
strain_gene_info[strain_gene_info.marker_id == 'B000086'].sort_values('strain_cos')

Two hits to B000086 seems likely to be read stealing. Together they have 1.15x depth and the correlation
is on the lower side for one of them (cos diss of 0.24).
They also have the same centroid_90 (although not centroid_95).

In [None]:
strain_gene_info[strain_gene_info.marker_id == 'B000081'].sort_values('strain_cos')

Two hits to B000081 seems less likely to be read stealing. Together they only have >2x depth and the correlation
is relatively high for both. They also do not share a centroid_75.

What does it mean that one of them is way longer than all of the others?

### How often would we miss "species genes" in the strain-specific analysis?

In [None]:
d = strain_gene_info.sort_values('depth_ratio').assign(
    log_strain_cos=lambda x: np.log(x.strain_cos),
    log_species_cos=lambda x: np.log(x.species_cos),
)

plt.scatter('strain_cos', 'species_cos', data=d[d.hit], c='depth_ratio', s=5, norm=mpl.colors.LogNorm())
sns.kdeplot(x='strain_cos', y='species_cos', log_scale=True, data=d.sample(n=10000), color='black', alpha=1.0, linewidths=1, zorder=1)


plt.axvline(0.3, lw=1, linestyle='--', color='k')
plt.axhline(0.3, lw=1, linestyle='--', color='k')
# plt.yscale('log')
# plt.xscale('log')
plt.colorbar()
print('Fraction of hits uncorrelated with strain:', (d[d.hit].strain_cos > 0.3).mean())
print('Fraction of hits uncorrelated with species:', (d[d.hit].species_cos > 0.3).mean())
print('Fraction of hits correlated with strain AND speces:', ((d[d.hit].species_cos < 0.3) & (d[d.hit].strain_cos < 0.3)).mean())

As much as 40% of the time, apparently.

### Are the strains really the same across subjects?

In [None]:
mean_depth_species_genes.loc[pure_samples_with_strain].sort_values(ascending=False).head(20)

In [None]:
x, y = 'M2064_53', 'H4040_22'
d = np.log2(depth_ratio.loc[[x, y]] + 1e-2).dropna().T.join(strain_gene_info)
sns.jointplot(x=x, y=y, data=d.sort_values('hit'), hue='hit', s=3, marginal_kws=dict(common_norm=False))

In [None]:
d = genes_depth.loc[pure_samples_with_strain, strain_gene_info.hit]
# d = (depth_ratio / strain_gene_info.depth_ratio).loc[species_samples, strain_gene_info.hit].dropna()

sns.clustermap(
    d,
    norm=mpl.colors.SymLogNorm(linthresh=1.0),
    metric='cosine',
)

In [None]:
d = genes_depth.loc[species_samples, strain_gene_info.hit]
e0 = np.log(focal_species_strain_depth[strain_id] + 1e-3)
e1 = ((e0 - e0.min()) / (e0.max() - e0.min()))

sns.clustermap(
    d,
    norm=mpl.colors.SymLogNorm(linthresh=1.0),
    metric='cosine',
    row_colors=pd.DataFrame(dict(
        focal=d.index.to_series().isin(pure_samples_with_strain).astype(float).map(mpl.cm.viridis),
        depth=e1.map(mpl.cm.viridis),
        rabund=focal_species_strain_rabund[strain_id].map(mpl.cm.viridis),
    )),
)

### Do multiple hits to the same gene cluster make sense?

#### How often do we get multiple hits to clust_75

Much less frequently than in the database.

In [None]:
d = strain_gene_info.groupby(['centroid_75', 'hit']).size().unstack(fill_value=0)

bins = np.linspace(0, 20, 21)
plt.hist(d[False], bins=bins)
plt.hist(d[True], bins=bins, alpha=0.5)

plt.yscale('log')
# plt.xlim(0, 10)
d.sort_values(True, ascending=False).head(5)

#### Are these groups of homologous genes sensible?

In [None]:
strain_gene_info[strain_gene_info.centroid_75 == 'UHGG144776_03235'].sort_values(['strain_cos'])

Maybe! For instance, clust_75 'UHGG144776_03235' has 8 copies,
but they _all_ have different centroid_95
(suggesting that they're not doing too too much read-stealing),
and they all have depth_ratios of > 0.25, which suggests that
the gene family gets a ton of coverage across samples
(maybe 8x across all of the homologues).

Interestingly, it's also a Type IV secretion system (or maybe conjugative machinery) according
to the nr annotation:

> ```
UHGG146925_03529 (WP_113394767.1): type IV conjugative transfer system protein TraL [Shigella sonnei]
Function assigned by HMM accession: TIGR02762.1
```

In [None]:
strain_gene_info[strain_gene_info.centroid_75 == 'UHGG079207_00971'].sort_values(['strain_cos']).head(10)

Another for instance: clust_75 'UHGG144776_03235' has 7 copies and
many have different centroid_95 (suggesting that they're not doing too too much read-stealing).
They also all have depth_ratios of > 0.25, and
the clust_75 may have a total of >8x across all of the homologues.

In the nr, this sequence is annotated as a transposase.

> ```
UHGG140415_04958 (ATB16992.1): IS3 family transposase (plasmid) [Escherichia coli]
Function assigned by sequence similarity to: WP_076612086.1
```

## What can we say about these genes?

### What functions do some species genes have?

In [None]:
strain_gene_info[strain_gene_info.hit].sort_values('species_cos').head(5)

- UHGG146766_03472 is assigned "regulatory protein mokC [Escherichia coli B088]" by blastx
- UHGG146766_00719 is assigned "hypothetical protein [Escherichia coli]" by blastx
- UHGG143923_02826 is assigned "alkyl hydroperoxide reductase subunit F [Escherichia coli]" by blastx

### What functions do some strain genes have?

In [None]:
strain_gene_info[strain_gene_info.hit].sort_values('strain_cos').head(20)

- UHGG001882_03784 is assigned "hypothetical protein [Escherichia coli]" by blastx
- UHGG051562_01340 is assigned "unnamed protein product [Klebsiella pneumoniae]" by blastx
- UHGG212210_02494 is assigned "glutamine-hydrolyzing GMP synthase [Escherichia coli]" by blastx
- UHGG000026_01858 is assigned "type I-F CRISPR-associated protein Csy3 [Escherichia coli]"

### What functions do some species-only genes have?

In [None]:
strain_gene_info[lambda x: x.hit & (x.strain_cos > 0.3)].sort_values('species_cos').head(5)

- UHGG144539_05148 = tRNA dihydrouridine synthase DusB [Escherichia coli]
- UHGG034380_01358 = heme exporter protein B [Escherichia coli]
- UHGG140694_05370 = hypothetical protein C9900_19320 [Salmonella enterica subsp. enterica serovar Kentucky]
- UHGG222834_02353 = 50S ribosomal protein L5 [Enterobacteriaceae]
- UHGG146130_04125 = conserved hypothetical protein [Escherichia coli 042]

## Various Characterizations of strains

In [None]:
# Moderate positive relationship between gene length and correlation.
# (Among hits, highly correlated genes tend to be shorter.)

d = strain_gene_info[strain_gene_info.hit].sort_values('species_cos')
# d1 = d0.assign(rank=range(1, len(d0) + 1))
plt.scatter(x='species_cos', y='centroid_99_length', data=d, s=1)
sns.kdeplot(x='species_cos', y='centroid_99_length', data=d, log_scale=True, linewidths=1, color='k')

sp.stats.spearmanr(d['centroid_99_length'], d['species_cos'])

In [None]:
# Slight negative relationship between gene length and estimated depth.
d = strain_gene_info[strain_gene_info.hit]

plt.scatter('centroid_99_length', 'depth_ratio', data=d, s=1)
plt.yscale('log')
plt.xscale('log')

sp.stats.spearmanr(d['centroid_99_length'], d['depth_ratio'])

In [None]:
# Moderate negative relationship between "strain-ness" and estimated depth.
d = strain_gene_info[strain_gene_info.hit].assign(strain_species_ratio = lambda x: (1 - x.strain_cos) / (1 - x.species_cos))

plt.scatter('strain_species_ratio', 'depth_ratio', data=d, s=1)
plt.yscale('log')
plt.xscale('log')

sp.stats.spearmanr(d['strain_species_ratio'], d['depth_ratio'])

# Cross-strain Comparison

## Write gene table

In [None]:
path = f'{strain_id}.tsv'
print(path)
strain_gene_info.to_csv(path, sep='\t')

## Load tables and compare

In [None]:
strain_list = top_strains[:7]  # ['102506-2', '102506-5', '102506-6', '102506-11', '102506-28', '102506-24', '102506-4']
all_strains_gene_info = {}
for _strain_id in strain_list:
    all_strains_gene_info[_strain_id] = pd.read_table(f'{_strain_id}.tsv', index_col='gene_id')

In [None]:
all_strains_gene_hits = pd.DataFrame({k: all_strains_gene_info[k].hit for k in all_strains_gene_info})
all_strains_gene_hits.value_counts().sort_index()

In [None]:
all_strains_gene_hits_75 = pd.DataFrame({k: all_strains_gene_info[k].groupby(gene_clusters.centroid_75).hit.any() for k in all_strains_gene_info})
all_strains_gene_hits_75.value_counts().sort_index()

In [None]:
pure_samples_with_strains = list(
    set(species_samples)
    & set(idxwhere((focal_species_strain_depth.idxmax(1).isin(strain_list))))
    & set(pure_samples)
)

strain_hit = idxwhere(
    all_strains_gene_hits.any(1)
)

sample_dominant_strain = focal_species_strain_rabund.idxmax(1)
strain_sample_palette = lib.plot.construct_ordered_palette(strain_list, cm='tab20')
subject_palette = lib.plot.construct_ordered_palette(subject_week.sort_values(['site', 'subject_id']).subject_id, cm='Spectral')


d = genes_depth.loc[pure_samples_with_strains, strain_hit]

sns.clustermap(
    d,
    norm=mpl.colors.SymLogNorm(linthresh=1.0),
    metric='cosine',
    # col_colors=pd.DataFrame(dict(
    #     a=strain_gene_infoA.hit.astype(float).map(mpl.cm.viridis),
    #     b=strain_gene_infoB.hit.astype(float).map(mpl.cm.viridis),
    # )),
    row_colors=pd.DataFrame(dict(
        subject=subject_week.subject_id.map(subject_palette),
        pure=sample_dominant_strain.map(strain_sample_palette),
        depth=(np.cbrt(species_depth[species_id]) / 3).map(mpl.cm.viridis),
    )),
)

In [None]:
strain_depth[_strain_id][lambda x: x > 0.5].std()
# Since we know this ahead of time, we could select our cutoff for samples based on this relationship.

In [None]:
fig = plt.figure(figsize=(10, 5))

for _strain_id in strain_list:
    num_sample = strain_sample_counts.loc[_strain_id, 'num_samples']
    stdev_depth = strain_depth[_strain_id][lambda x: x > focal_sample_species_depth_thresh_pres].std()
    quality_index = stdev_depth * np.sqrt(num_sample)
    sns.kdeplot(all_strains_gene_info[_strain_id].loc[all_strains_gene_hits.all(1), 'strain_cos'], label=f"{_strain_id}", c=strain_sample_palette[_strain_id], lw=np.cbrt(quality_index))
    sns.kdeplot(all_strains_gene_info[_strain_id].loc[lambda x: x.species_gene_id.isin(species_gene_hits), 'strain_cos'], c=strain_sample_palette[_strain_id], linestyle='--')

plt.axvline(strain_cos_thresh, linestyle='--', lw=1, color='grey')
plt.legend() 