# 2b Mash filtration & clustering

In this notebook, we run Mash to generate genome-wise pairwise similarity scores (which correspond to Average Nucleotide Identify (ANI) & DNA-DNA reassociation value).

Mash will be used as a final filtration metric to filter out strains which are too dissimilar from the rest of the genome collection.

In [None]:
import subprocess
import pickle

import numpy as np
import pandas as pd
import scipy.cluster.hierarchy as hc
import scipy.spatial as sp

import matplotlib
import matplotlib.patches as patches
from matplotlib import pyplot as plt
import seaborn as sns
import plotly.express as px
import os

from tqdm.notebook import tqdm

from kneebow.rotor import Rotor

# pyphyon import
import pyphylon.mash as mash
from pyphylon.util import load_config

In [None]:
CONFIG = load_config("config.yml")
WORKDIR = CONFIG["WORKDIR"]
output_folder = os.path.join("../output/")

In [None]:
RAW = os.path.join(WORKDIR, "1b_protected/raw")
RAW_GENOMES = os.path.join(RAW, "genomes")
FNA_GENOMES = os.path.join(RAW, "genomes/fna")

In [None]:
scrubbed_species_summary = pd.read_csv(os.path.join(output_folder, '2a_genome_summary.csv'), index_col=0, dtype={'genome_id': str})
scrubbed_species_metadata = pd.read_csv(os.path.join(output_folder, '2a_genome_metadata.csv'), index_col=0, dtype={'genome_id': str})


display(
    scrubbed_species_summary.shape,
    scrubbed_species_summary.head(),
    scrubbed_species_metadata.shape,
    scrubbed_species_metadata.head()
)

## Run Mash

- Use Snakemake to run MASH.  See example/readme.md

WSL workaround

```bash
  MSYS_NO_PATHCONV=1 wsl -d pangenome bash -c "
    source ~/miniforge3/etc/profile.d/conda.sh && \
    conda activate pangenome && \
    cd /mnt/f/lab_projects/pangenomics/pyphylon && \
    mkdir -p temp/2b_mash && \
    mash sketch -o temp/2b_mash/combined_sketch temp/1b_protected/raw/genomes/fna/*.fna && \
    mash dist temp/2b_mash/combined_sketch.msh temp/2b_mash/combined_sketch.msh > temp/2b_mash/mash_distances.txt
  "
```

# Mash filtration and clustering

In [None]:
names = [
    'genome1',
    'genome2',
    'mash_distance',
    'p_value',
    'matching_hashes'
]

TEMP_DIR = CONFIG.get("REUSE_TEMP_DIR", "../temp/")
df_mash = pd.read_csv(os.path.join(TEMP_DIR, '2b_mash/mash_distances.txt'), sep='\t', names=names)
df_mash['genome1'] = df_mash['genome1'].apply(lambda x: x.split('/')[-1].split('.fna')[0])
df_mash['genome2'] = df_mash['genome2'].apply(lambda x: x.split('/')[-1].split('.fna')[0])

df_mash

In [None]:
df_mash_square = df_mash.pivot(index='genome1', columns='genome2', values='mash_distance')

display(
    df_mash_square.shape,
    df_mash_square.head()
)

In [None]:
sns.heatmap(df_mash_square, cmap='viridis')

# Generate corressponding pearson-correlation matrix (& distance matrix)

In [None]:
# This may take HOURS to run
# Once finished it will IMMEDIATELY save all 3 matrices
# so you don't have to re-compute this over and over again

df_mash_corr = df_mash_square.corr()
df_mash_corr_dist = 1 - df_mash_corr
df_mash_corr_dist

# Save matrix so the next time, only the following cell needs to be run
# This cell should be commented out after being run once
df_mash_corr_dist.to_csv(os.path.join(output_folder, '2b_mash_corr_dist.csv'))

display(
    df_mash_corr_dist.shape,
    df_mash_corr_dist.head()
)

## Filter by scrubbed genomes

Based on any cleaning that may have been done in `2a`

In [None]:
# TEMP filter scrubbed strains to onle include those in the mash matrix
scrubbed_strains = scrubbed_species_metadata.genome_id.astype('str')
scrubbed_strains = scrubbed_strains[scrubbed_strains.isin(df_mash_corr_dist.index.astype(str))]
scrubbed_strains

In [None]:
# scrubbed_strains = scrubbed_species_metadata.genome_id.astype('str')

df_mash_square = df_mash_square.loc[scrubbed_strains, scrubbed_strains]
df_mash_corr = df_mash_corr.loc[scrubbed_strains, scrubbed_strains]
df_mash_corr_dist = df_mash_square.loc[scrubbed_strains, scrubbed_strains]

## Filter strains by Mash distance

- __Criteria 1:__ Mash value of 0.05 (soft-limit on bacterial species delineation)
- __Criteria 2:__ Any clear outliers


In [None]:
sns.histplot(df_mash_square.values.flatten())

### Find your Reference/Representative Strain ID (for filtration)

In [None]:
# Auto-detect reference/representative strains from metadata
_ref_mask = scrubbed_species_metadata['reference_genome'].notna()
_ref_ids = scrubbed_species_metadata.loc[_ref_mask, 'genome_id'].astype(str)
repr_strains = sorted(_ref_ids[_ref_ids.isin(df_mash_square.index)].tolist())

if not repr_strains:
    # Fallback: use the medoid (genome with smallest mean mash distance)
    repr_strains = [df_mash_square.mean(axis=1).idxmin()]

print(f"Representative strains: {repr_strains}")

In [None]:
# This cutoff is dependent on the data you see above
# Past studies have gone down as low as 98.5th percentile
# but 99th or 99.9th percentiles are also acceptable
cutoffs = []

for strain in repr_strains:
    cutoffs.append(np.quantile(df_mash_square.loc[strain], 0.99))

cutoff = sum(cutoffs)/len(cutoffs)
cutoff

In [None]:
for repr_strain in repr_strains:
    cond = df_mash_square.loc[repr_strain] < cutoff
    good_strains = df_mash_square.loc[repr_strain][cond].index
    
    df_mash_square = df_mash_square.loc[good_strains, good_strains]
    df_mash_corr = df_mash_corr.loc[good_strains, good_strains]
    df_mash_corr_dist = df_mash_square.loc[good_strains, good_strains]
    
df_mash_corr_dist.shape

In [None]:
mash_scrubbed_summary = scrubbed_species_metadata.set_index('genome_id').loc[sorted(df_mash_square.index)].reset_index()
mash_scrubbed_metadata = scrubbed_species_metadata.set_index('genome_id').loc[sorted(df_mash_square.index)].reset_index()


display(
    mash_scrubbed_metadata.shape,
    mash_scrubbed_metadata.head()
)

## Find threshold for Mash clustering

In [None]:
cond = scrubbed_species_summary.genome_status == 'Complete'
complete_seqs = set(scrubbed_species_summary[cond].genome_id)
complete_seqs = sorted(
    complete_seqs.intersection(set(df_mash_square.index))
)


df_mash_square_complete = df_mash_square.loc[complete_seqs, complete_seqs]
df_mash_corr_complete = df_mash_square.loc[complete_seqs, complete_seqs]
df_mash_corr_dist_complete = df_mash_square.loc[complete_seqs, complete_seqs]

df_mash_corr_dist_complete.shape

In [None]:
# Initial sensitivity analysis (gives min val to consider)
from pyphylon.mash import sensitivity_analysis, cluster_corr_dist, remove_bad_strains
tmp, df_temp, elbow_idx, elbow_threshold = sensitivity_analysis(df_mash_corr_dist_complete)

# Plot (tells us to pick something > 0.25)
plt.rcParams["figure.dpi"] = 200
fig, axs = plt.subplots(figsize=(4,3),)
axs.plot(tmp['threshold'], tmp['num_clusters'])
plt.axhline(y=df_temp['num_clusters'][elbow_idx], c="#ff00ff", linestyle='--')
axs.set_ylabel('num_clusters')
axs.set_xlabel('index')
fig.suptitle(
    f"Num clusters decelerates \nafter a value of {df_temp['num_clusters'][elbow_idx]} (threshold: {elbow_threshold})",
    y=1
)
plt.show()

## Plot initial clustermap of Mash values

In [None]:
elbow_threshold = elbow_threshold+0.1 # "round" up

link, dist, clst = cluster_corr_dist(df_mash_corr_dist_complete, thresh=elbow_threshold)

# Color each cluster
cm = matplotlib.colormaps.get_cmap('tab20')
clr = dict(zip(sorted(clst.cluster.unique()), cm.colors))
clst['color'] = clst.cluster.map(clr)

print('Number of colors: ', len(clr))
print('Number of clusters', len(clst.cluster.unique()))

In [None]:
size = 6

legend_TN = [patches.Patch(color=c, label=l) for l,c in clr.items()]

sns.set(rc={'figure.facecolor':'white'})
g = sns.clustermap(
    df_mash_square_complete,
    figsize=(size,size),
    row_linkage=link,
    col_linkage=link,
    col_colors=clst.color,
    yticklabels=False,
    xticklabels=False,
    cmap='BrBG_r',
    robust=True
)

l2=g.ax_heatmap.legend(loc='upper left', bbox_to_anchor=(1.01,0.85), handles=legend_TN,frameon=True)
l2.set_title(title='Clusters',prop={'size':10})

## Filter out small clusters (typically with < 5 members)

In [None]:
small_clst_limit = 0

In [None]:
px.histogram(pd.DataFrame(clst.cluster.value_counts()), nbins=100)

In [None]:
bad_clusters = clst.cluster.value_counts()[clst.cluster.value_counts() < small_clst_limit]
bad_clusters

In [None]:
bad_genomes_list = []

for genome in df_mash_square_complete.index:
    cluster = clst.loc[genome, 'cluster']
    if cluster in bad_clusters:
        bad_genomes_list.append(genome)

# Update filtration
df_mash_square_complete = remove_bad_strains(df_mash_square_complete, bad_genomes_list)
df_mash_corr_complete = remove_bad_strains(df_mash_square_complete, bad_genomes_list)
df_mash_corr_dist_complete = remove_bad_strains(df_mash_corr_dist_complete, bad_genomes_list)

## Keep filtering until robust clusters show up

In [None]:
iteration = 1
prev = 0
curr = len(clst.cluster.unique())

while(np.abs(prev - curr) > 0 ):
    print(f'iteration {iteration}...{curr}')
    
    # Cluster
    link, dist, clst = cluster_corr_dist(df_mash_corr_dist_complete, thresh=elbow_threshold)
    
    # Color each cluster
    cm = matplotlib.colormaps.get_cmap('tab20')
    clr = dict(zip(sorted(clst.cluster.unique()), cm.colors))
    clst['color'] = clst.cluster.map(clr)
    
    # Increment
    prev = curr
    curr = len(clst.cluster.unique())
    
    # Define bad clusters
    bad_clusters = clst.cluster.value_counts()[clst.cluster.value_counts() < small_clst_limit]
    
    # Remove bad genomes
    bad_genomes_list = []
    for genome in df_mash_square_complete.index:
        cluster = clst.loc[genome, 'cluster']
        if cluster in bad_clusters:
            bad_genomes_list.append(genome)
    
    # Update filtration
    df_mash_square_complete = remove_bad_strains(df_mash_square_complete, bad_genomes_list)
    df_mash_corr_complete = remove_bad_strains(df_mash_square_complete, bad_genomes_list)
    df_mash_corr_dist_complete = remove_bad_strains(df_mash_corr_dist_complete, bad_genomes_list)
    
    # Increment
    iteration +=1

In [None]:
df_mash_square_complete.shape # Current shape after filtration

In [None]:
link, dist, clst = cluster_corr_dist(df_mash_corr_dist_complete, thresh=elbow_threshold)

# Color each cluster
cm = matplotlib.colormaps.get_cmap('tab20')
clr = dict(zip(sorted(clst.cluster.unique()), cm.colors))
clst['color'] = clst.cluster.map(clr)

print('Number of colors: ', len(clr))
print('Number of clusters', len(clst.cluster.unique()))

In [None]:
assert clst.cluster.value_counts().min() >= small_clst_limit

In [None]:
px.histogram(clst.cluster.value_counts(), nbins=50)

# Plot filtered Mash clustermap

__From this it looks like our final rank for NMF decomposition will be 16 for Enterobacter__

In [None]:
size = 6

legend_TN = [patches.Patch(color=c, label=l) for l,c in clr.items()]

sns.set(rc={'figure.facecolor':'white'})
g = sns.clustermap(
    df_mash_square_complete,
    figsize=(size,size),
    row_linkage=link,
    col_linkage=link,
    col_colors=clst.color,
    yticklabels=False,
    xticklabels=False,
    cmap='BrBG_r',
    robust=True
)

l2=g.ax_heatmap.legend(loc='upper left', bbox_to_anchor=(1.05,0.85), handles=legend_TN,frameon=True)
l2.set_title(title='Clusters',prop={'size':10})

# Save Mash-scrubbed `summary` and `metadata`

In [None]:
mash_scrubbed_metadata = pd.concat([mash_scrubbed_metadata, clst.loc[mash_scrubbed_metadata.genome_id].reset_index().cluster], axis=1)
mash_scrubbed_metadata.rename({'cluster':'mash_cluster'}, axis=1, inplace=True)

In [None]:
filepath = os.path.join(output_folder, '2b_genome_summary.csv')
filepath
mash_scrubbed_summary.to_csv(filepath)

In [None]:
filepath = os.path.join(output_folder, '2b_genome_metadata.csv')
mash_scrubbed_metadata.to_csv(filepath)

## Save Mash results

In [None]:
filepath = os.path.join(output_folder, '2b_mash_square.csv')
df_mash_square.to_csv(filepath)
filepath = os.path.join(output_folder, '2b_mash_corr_dist.csv')
df_mash_corr_dist.to_csv(filepath)