In [None]:
%load_ext autoreload

In [None]:
%autoreload

In [None]:
import os
os.chdir('..')
os.path.realpath(os.path.curdir)

In [None]:
import xarray as xr
from glob import glob
import pandas as pd
from lib.pandas_util import idxwhere
from sklearn.cluster import AgglomerativeClustering
import sfacts as sf
import matplotlib as mpl
import matplotlib.pyplot as plt
from lib.plot import construct_ordered_palette
from tqdm import tqdm
import numpy as np

In [None]:
species_taxonomy = pd.read_table('ref/gtpro/species_taxonomy_ext.tsv', names=['genome_id', 'species_id', 'taxonomy_string']).assign(species_id=lambda x: x.species_id.astype(str)).set_index('species_id')[['taxonomy_string']].assign(taxonomy_split=lambda x: x.taxonomy_string.str.split(';'))

for level_name, level_number in [('p__', 1), ('c__', 2), ('o__', 3), ('f__', 4), ('g__', 5), ('s__', 6)]:
    species_taxonomy = species_taxonomy.assign(**{level_name: species_taxonomy.taxonomy_split.apply(lambda x: x[level_number])}) 
species_taxonomy = species_taxonomy.drop(columns=['taxonomy_split'])

In [None]:
species_list = []
for path in glob('data/sp-*.zshi.a.r.proc.gtpro.filt-poly05-cvrg10.mgen.pdist.nc'):
    species_list.append(path[len('data/sp-'):-len('.zshi.a.r.proc.gtpro.filt-poly05-cvrg10.mgen.pdist.nc')]) 

In [None]:
sorted(species_list)[-10:]

In [None]:
len(species_list)

In [None]:
species_strain_counts = {}

phyla = species_taxonomy.loc[species_list].p__.unique()
palette = construct_ordered_palette(phyla, cm='tab20')

for species_id in tqdm(species_list):
    pdmat = xr.load_dataarray(f'data/sp-{species_id}.zshi.a.r.proc.gtpro.filt-poly05-cvrg10.mgen.pdist.nc')
    agg = pd.Series(AgglomerativeClustering(n_clusters=None, distance_threshold=0.1, affinity='precomputed', linkage='complete').fit(pdmat).labels_, index=pdmat.sampleA)
    species_strain_counts[species_id] = agg.value_counts()

In [None]:
fig, ax = plt.subplots()
for species_id in species_list:
    ax.plot(species_strain_counts[species_id].values, color=palette[species_taxonomy.loc[species_id].p__])

for p__ in phyla:
    plt.scatter([], [], label=p__, color=palette[p__])
plt.legend(bbox_to_anchor=(1, 1))
    
ax.set_yscale('symlog')
ax.set_xscale('symlog')
ax.set_xlim(0)

In [None]:
d = pd.DataFrame(
    {s: (species_strain_counts[s].sum(), len(species_strain_counts[s])) for s in species_strain_counts},
    index=['num_samples', 'num_strains']
).T.assign(samples_per_strain=lambda x: x.num_samples / x.num_strains).join(species_taxonomy).sort_values('taxonomy_string')

taxa = d.p__.unique()
palette = construct_ordered_palette(taxa, cm='tab20')

plt.scatter('num_samples', 'num_strains', data=d, c=d.p__.map(palette), label='__none__')
plt.yscale('log')
plt.xscale('log')

for tax in d.p__.unique():
    plt.scatter([], [], label=tax, color=palette[tax])
plt.plot([0, 1e4], [0, 1e4], lw=1, linestyle='--', color='grey', zorder=0)
plt.legend(bbox_to_anchor=(1, 1))

In [None]:
d.sort_values(['samples_per_strain'], ascending=False).head(20)

In [None]:
d.loc['100170']

In [None]:
d[d.f__ == 'f__Rikenellaceae']

In [None]:
species_id = '100653'
mgen = sf.Metagenotype.load(f'data/sp-{species_id}.zshi.a.r.proc.gtpro.filt-poly05-cvrg10.mgen.nc')
mgen_ss = mgen.random_sample(position=1)
pdmat = xr.load_dataarray(f'data/sp-{species_id}.zshi.a.r.proc.gtpro.filt-poly05-cvrg10.mgen.pdist.nc')
clust = pd.Series(AgglomerativeClustering(n_clusters=None, distance_threshold=0.1, affinity='precomputed', linkage='complete').fit(pdmat).labels_, index=pdmat.sampleA, name='clust')
clust_size = clust.value_counts()

In [None]:
mgen.sizes, clust_size.head()

In [None]:
c = clust_size.index[0]
print(c)

in_clust = idxwhere(clust == c)
close_to_clust = idxwhere(pdmat.loc[in_clust].to_pandas().max() < 0.2)
sister_clust = clust.loc[close_to_clust].unique()
sample_list = idxwhere(clust.sort_values().isin(sister_clust))
m = mgen.sel(sample=sample_list)
position_order = m.select_variable_positions(0.02).sel(sample=idxwhere(clust == c)).sum("sample").to_series().unstack().apply(lambda x: x/x.sum(), axis=1).sort_values('alt').index.to_list()

print(len(in_clust), len(sample_list), len(position_order))

sf.plot.plot_metagenotype(
    m.sel(position=position_order),
    row_cluster=False,
    col_linkage_func=lambda w: m.metagenotype.linkage(),
    col_colors_func=lambda w: xr.Dataset(dict(
        aa=clust.rename_axis('sample').loc[w.sample].mod(20),
        bb=w.sample.isin(in_clust).astype(int),
    )),
    row_col_annotation_cmap=mpl.cm.tab20,
    # scaley=0.05,
)

In [None]:
m2 = sf.Metagenotype(
    m.to_series().unstack("sample").groupby(clust, axis='columns').sum().rename_axis(columns='sample').stack().reorder_levels(('sample', 'position', 'allele')).to_xarray()
)
sf.plot.plot_metagenotype(
    m2.sel(position=position_order),
    row_cluster=False,
)

In [None]:
plt.hist(np.abs(m2.sel(sample=[c]).to_estimated_genotype().to_series().squeeze() - 0.5) + 0.5, bins=22)
plt.yscale('log')
None

In [None]:
from scipy.spatial.distance import squareform
plt.hist(squareform(pdmat.sel(sampleA=sample_list, sampleB=sample_list).values), bins=np.linspace(0, 0.2, num=44))
plt.yscale('log')

In [None]:
_mgen_meta = pd.read_table('/pollard/home/bsmith/Projects/haplo-manuscript/raw/shi2021_s7.tsv', index_col=['Sample ID']).rename_axis(index='sample_id')
mgen_meta = pd.read_table('/pollard/home/bsmith/Projects/haplo-manuscript/raw/shi2021_s8.tsv', index_col='NCBI Accession Number').join(_mgen_meta, on='Sample ID', rsuffix='_')
mgen_meta.head()

In [None]:
mgen_meta['Country'].value_counts()

In [None]:
import lib.plot

fig, ax = plt.subplots(figsize=(10, 10))

# color_palette

_ = lib.plot.ordination_plot(
    pdmat_square,
    meta=mgen_meta,
    ordin=lib.plot.pca_ordination,
    ordin_kws={},
    colorby='Study',
    color_palette=lib.plot.construct_ordered_palette(mgen_meta.Study.sort_values()),
    ax=ax,
    edgecolor_palette={'__none__': 'none'},
    # scatter_kws={'alpha': 0.5},
)

In [None]:
pdmatx