# Setup

In [None]:
import os
import pickle

import numpy as np
import pandas as pd
import scipy.cluster.hierarchy as hc
import scipy.spatial as sp

import matplotlib
import matplotlib.patches as patches
from matplotlib import pyplot as plt
import seaborn as sns
import plotly.express as px
from tqdm.notebook import tqdm

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rcParams['svg.fonttype'] = 'none'
matplotlib.rcParams['font.sans-serif'] = 'Arial'
matplotlib.rcParams['font.family'] = 'sans-serif'
sns.set_style('ticks')
matplotlib.rcParams['text.color'] = '#000000'
matplotlib.rcParams['axes.labelcolor'] = '#000000'
matplotlib.rcParams['xtick.color'] = '#000000'
matplotlib.rcParams['ytick.color'] = '#000000'

In [None]:
from kneebow.rotor import Rotor

In [None]:
RAW_GENOMES = '../../data/raw/genomes'
MASH_GENOMES = '../../data/raw/mash_genomes'

In [None]:
SCRUBBED_SUMMARY = '../../data/metadata/scrubbed_species_summary.csv'
SCRUBBED_METADATA = '../../data/metadata/scrubbed_species_metadata.csv'

In [None]:
scrubbed_summary = pd.read_csv(SCRUBBED_SUMMARY, index_col=0, dtype='object')
scrubbed_metadata = pd.read_csv(SCRUBBED_METADATA, index_col=0, dtype='object')

# Ensure genome id is a str
scrubbed_summary['genome_id'] = scrubbed_summary['genome_id'].astype('str')
scrubbed_metadata['genome_id'] = scrubbed_metadata['genome_id'].astype('str')

# fix naming format issues
scrubbed_metadata['genome_name'] = scrubbed_metadata['genome_name'].apply(lambda x: x.replace('"[Enterobacter]', "Enterobacter"))
scrubbed_metadata['genome_name'] = scrubbed_metadata['genome_name'].apply(lambda x: x.replace('"', ""))

display(
    scrubbed_metadata.shape,
    scrubbed_metadata.head()
)

In [None]:
scrubbed_metadata.genbank_accessions

## Copy all filtered genomes into a single directory

In [None]:
# # Files have already been moved into the mash directory, no need to move them
# items = []
# item_paths = []

# for item in os.listdir(RAW_GENOMES):
#     curr_path = os.path.join(RAW_GENOMES, item)
#     if os.path.isdir(curr_path):
#         curr_fna = os.path.join(curr_path, f'{item}.fna')
#         items.append(item)
#         item_paths.append(curr_fna)


# display(
#     items[:5],
#     item_paths[:5]
# )

In [None]:
# # make sure every fna file exists as a file on disk
# assert len(items) == np.sum([os.path.isfile(item) for item in item_paths])

In [None]:
# # Already run once, no need to run again
# for item, item_path in tqdm(dict(zip(items, item_paths)).items()):
#    new_path = os.path.join(MASH_GENOMES, f'{item}.fna')
#    cmd = f'cp {item_path} {new_path}'
#    os.system(cmd)

## Make a combined mash sketch file

The following was run on a Linux terminal tmux session:

`mash sketch -o combined_sketch /media/pekar2/pan_phylon/Enterobacter/raw/mash_genomes/*.fna`

## Generate pairwise distance matrix (square)

Linux tmux session:

`mash dist combined_sketch.msh combined_sketch.msh > mash_distances.txt`

# Mash filtration and clustering

In [None]:
names = [
    'genome1',
    'genome2',
    'mash_distance',
    'p_value',
    'matching_hashes'
]

df_mash = pd.read_csv('../../data/raw/mash_distances.txt', sep='\t', names=names)
df_mash['genome1'] = df_mash['genome1'].apply(lambda x: x.split('/')[-1].split('.fna')[0])
df_mash['genome2'] = df_mash['genome2'].apply(lambda x: x.split('/')[-1].split('.fna')[0])

df_mash

In [None]:
df_mash_square = df_mash.pivot(index='genome1', columns='genome2', values='mash_distance')

display(
    df_mash_square.shape,
    df_mash_square.head()
)

## Generate corressponding pearson-correlation matrix (& distance matrix)

In [None]:
# # This may take HOURS to run
# # Once finished it will IMMEDIATELY save all 3 matrices
# # so you don't have to re-compute this over and over again

# df_mash_corr = df_mash_square.corr()
# df_mash_corr_dist = 1 - df_mash_corr
# df_mash_corr_dist

# # Save matrix so the next time, only the following cell needs to be run
# # This cell should be commented out after being run once
# df_mash_corr_dist.to_csv('../../data/processed/df_mash_corr_dist.csv')
# df_mash_square.to_csv('../../data/processed/df_mash_square.csv')
# df_mash_corr.to_csv('../../data/processed/df_mash_corr.csv')

# display(
#     df_mash_corr_dist.shape,
#     df_mash_corr_dist.head()
# )

In [None]:
df_mash_corr_dist = pd.read_csv('../../data/processed/df_mash_corr_dist.csv', dtype='object').set_index('genome2').astype(float)
df_mash_square = pd.read_csv('../../data/processed/df_mash_square.csv', dtype='object').set_index('genome1').astype(float)
df_mash_corr = pd.read_csv('../../data/processed/df_mash_corr.csv', dtype='object').set_index('genome2').astype(float)

df_mash_corr_dist.index = df_mash_corr_dist.index.astype(str)
df_mash_corr_dist.columns = df_mash_corr_dist.columns.astype(str)

df_mash_square.index = df_mash_square.index.astype(str)
df_mash_square.columns = df_mash_square.columns.astype(str)

## Filter by scrubbed genomes

Based on any cleaning that may have been done in `2a`

In [None]:
scrubbed_strains = scrubbed_metadata.genome_id.astype('str')


# scrubbed_new = []
# for x in scrubbed_strains:
#     if x not in ['1296536.1', '1296536.17', '1296536.2', '158836.1', '158836.137', '158836.168', '158836.169', '158836.2', '158836.36', '158836.4', '158836.6', '1812934.4', '1812935.1', '1812935.2', '1812935.44', '1812935.8', '2027919.8', '2071710.1', '2071710.2', '208224.2', '208224.3', '2494701.1', '2494701.2', '2831890.1', '2831891.1', '2831892.1', '2870346.1', '299766.1', '299766.2', '299767.2', '301102.1', '301102.2', '301105.1', '548.1', '548.11', '548.12', '548.13', '548.14', '548.15', '548.16', '548.17', '548.6', '548.7', '550.112', '550.114', '550.122', '550.123', '550.13', '550.14', '550.15', '550.157', '550.158', '550.166', '550.167', '550.245', '550.25', '550.26', '550.27', '550.29', '550.3', '550.31', '550.32', '550.33', '550.34', '550.35', '550.36', '550.37', '550.375', '550.376', '550.377', '550.379', '550.38', '550.381', '550.383', '550.385', '550.39', '550.4', '550.42', '550.43', '550.46', '550.48', '550.49', '550.51', '550.52', '550.53', '550.54', '550.55', '550.56', '550.65', '550.66', '550.67', '550.68', '550.69', '550.7', '550.71', '550.72', '550.73', '550.74', '550.76', '550.77', '550.78', '550.79', '550.8', '550.81', '550.82', '550.83', '550.84', '550.85', '550.86', '550.87', '550.88', '550.89', '550.9', '550.91', '550.92', '550.93', '550.94', '550.95', '550.96', '61645.1', '61645.5', '61645.58', '61645.59', '61645.6', '61645.8']:
#         scrubbed_new.append(x)
# scrubbed_strains = scrubbed_new    

df_mash_square = df_mash_square.loc[scrubbed_strains, scrubbed_strains]
df_mash_corr = df_mash_corr.loc[scrubbed_strains, scrubbed_strains]
df_mash_corr_dist = df_mash_square.loc[scrubbed_strains, scrubbed_strains]

## Filter strains by Mash distance

- __Criteria 1:__ Mash value of 0.05 (soft-limit on bacterial species delineation)
- __Criteria 2:__ Any clear outliers


In [None]:
sns.histplot(df_mash_square.values.flatten())

### Find your Reference/Representative Strain ID (for filtration)


#### note from Josh: need to determine what the representative strain for the genus is, as well as looking at the above plot in order to determine what is occuring there and if any changes to the data need to be made, reference strains below pulled from BV-BRC list and excludes those references not found in the mash matrix

In [None]:
repr_strains = ["550.3788","158836.1174", "1812935.7", "2478464.3", "299767.18", "539813.36", "2494702.15", "69218.53", "881260.71", "1400147.3", '550.2510', '2494701.30']

In [None]:
# This cutoff is dependent on the data you see above
# Past studies have gone down as low as 98.5th percentile
# but 99th or 99.9th percentiles are also acceptable
cutoffs = []

for strain in repr_strains:
    cutoffs.append(np.quantile(df_mash_square.loc[strain], 0.99))

cutoff = sum(cutoffs)/len(cutoffs)

# # alternative cutoff using max of possible values
# cutoff = max(cutoffs)


cutoff

In [None]:
for repr_strain in repr_strains:
    cond = df_mash_square.loc[repr_strain] < cutoff
    good_strains = df_mash_square.loc[repr_strain][cond].index
    
    df_mash_square = df_mash_square.loc[good_strains, good_strains]
    df_mash_corr = df_mash_corr.loc[good_strains, good_strains]
    df_mash_corr_dist = df_mash_square.loc[good_strains, good_strains]
    
df_mash_corr_dist.shape

In [None]:
mash_scrubbed_summary = scrubbed_metadata.set_index('genome_id').loc[sorted(df_mash_square.index)].reset_index()
mash_scrubbed_metadata = scrubbed_metadata.set_index('genome_id').loc[sorted(df_mash_square.index)].reset_index()


display(
    mash_scrubbed_metadata.shape,
    mash_scrubbed_metadata.head()
)

## Useful functions for later analysis

In [None]:
def cluster_corr_dist(df_mash_corr_dist, thresh=0.1, method='ward', metric='euclidean'):
    '''
    Hierarchically Mash-based pairwise-pearson-distance matrix
    '''
    link = hc.linkage(sp.distance.squareform(df_mash_corr_dist), method=method, metric=metric)
    dist = sp.distance.squareform(df_mash_corr_dist)
    
    clst = pd.DataFrame(index=df_mash_corr_dist.index)
    clst['cluster'] = hc.fcluster(link, thresh * dist.max(), 'distance')
    
    return link, dist, clst


def remove_bad_strains(df_mash_scd, bad_strains_list):
    good_strains_list = sorted(set(df_mash_scd.index) - set(bad_strains_list))
    
    return df_mash_scd.loc[good_strains_list, good_strains_list]


# Sensitivity analysis to pick the threshold (for E. coli we use 0.1)
# We pick the threshold where the curve just starts to bottom out
def sensitivity_analysis(df_mash_corr_dist_complete):
    x = list(np.logspace(-3, -1, 10)) + list(np.linspace(0.1, 1, 19))
    
    def num_uniq_clusters(thresh):
        link = hc.linkage(sp.distance.squareform(df_mash_corr_dist_complete), method='ward', metric='euclidean')
        dist = sp.distance.squareform(df_mash_corr_dist_complete)
        
        clst = pd.DataFrame(index=df_mash_corr_dist_complete.index)
        clst['cluster'] = hc.fcluster(link, thresh * dist.max(), 'distance')
        
        return len(clst.cluster.unique())
    
    tmp = pd.DataFrame()
    tmp['threshold'] = pd.Series(x)
    tmp['num_clusters'] = pd.Series(x).apply(num_uniq_clusters)
    
    # Find which value the elbow corresponds to
    df_temp = tmp.sort_values(by='num_clusters', ascending=True).reset_index(drop=True)
    
    # transform input into form necessary for package
    results_itr = zip(list(df_temp.index), list(df_temp.num_clusters))
    data = list(results_itr)
    
    rotor = Rotor()
    rotor.fit_rotate(data)
    elbow_idx = rotor.get_elbow_index()
    df_temp['num_clusters'][elbow_idx]
    contamination_cutoff = df_temp['num_clusters'][elbow_idx]
    
    # Grab elbow threshold
    cond = tmp['num_clusters'] == df_temp['num_clusters'][elbow_idx]
    elbow_threshold = tmp[cond]['threshold'].iloc[0]
    
    return tmp, df_temp, elbow_idx, elbow_threshold



## Find threshold for Mash clustering

In [None]:
# Only looking at Complete sequences
cond = scrubbed_summary.genome_status == 'Complete'
complete_seqs = set(scrubbed_summary[cond].genome_id)
complete_seqs = sorted(
    complete_seqs.intersection(set(df_mash_square.index))
)


df_mash_square_complete = df_mash_square.loc[complete_seqs, complete_seqs]
df_mash_corr_complete = df_mash_square.loc[complete_seqs, complete_seqs]
df_mash_corr_dist_complete = df_mash_square.loc[complete_seqs, complete_seqs]

df_mash_corr_dist_complete.shape

In [None]:
# Initial sensitivity analysis (gives min val to consider)
tmp, df_temp, elbow_idx, elbow_threshold = sensitivity_analysis(df_mash_corr_dist_complete)

# Plot (tells us to pick something > 0.25)
plt.rcParams["figure.dpi"] = 200
fig, axs = plt.subplots(figsize=(4,3),)
axs.plot(tmp['threshold'], tmp['num_clusters'])
plt.axhline(y=df_temp['num_clusters'][elbow_idx], c="#ff00ff", linestyle='--')
axs.set_ylabel('num_clusters')
axs.set_xlabel('index')
fig.suptitle(
    f"Num clusters decelerates \nafter a value of {df_temp['num_clusters'][elbow_idx]} (threshold: {elbow_threshold})",
    y=1
)
plt.show()

In [None]:
px.line(tmp, x='threshold', y='num_clusters')

## Plot initial clustermap of Mash values

In [None]:
elbow_threshold = elbow_threshold+0.1 # "round" up

link, dist, clst = cluster_corr_dist(df_mash_corr_dist_complete, thresh=elbow_threshold)

# Color each cluster
cm = matplotlib.colormaps.get_cmap('tab20')
clr = dict(zip(sorted(clst.cluster.unique()), cm.colors+cm.colors))
clst['color'] = clst.cluster.map(clr)

print('Number of colors: ', len(clr))
print('Number of clusters', len(clst.cluster.unique()))

In [None]:
size = 6

legend_TN = [patches.Patch(color=c, label=l) for l,c in clr.items()]

sns.set(rc={'figure.facecolor':'white'})
g = sns.clustermap(
    df_mash_square_complete,
    figsize=(size,size),
    row_linkage=link,
    col_linkage=link,
    col_colors=clst.color,
    yticklabels=False,
    xticklabels=False,
    cmap='BrBG_r',
    robust=True,
    center = .05
)

l2=g.ax_heatmap.legend(loc='upper left', bbox_to_anchor=(1.01,0.85), handles=legend_TN,frameon=True)
l2.set_title(title='Clusters',prop={'size':10})

## Filter out small clusters (typically with < 5 members)

In [None]:
px.histogram(pd.DataFrame(clst.cluster.value_counts()), nbins=100)

In [None]:
bad_clusters = clst.cluster.value_counts()[clst.cluster.value_counts() < 5]
bad_clusters

In [None]:
bad_genomes_list = []

for genome in df_mash_square_complete.index:
    cluster = clst.loc[genome, 'cluster']
    if cluster in bad_clusters:
        bad_genomes_list.append(genome)

# Update filtration
good_genome_ids = list(set(mash_scrubbed_metadata.set_index('genome_id').index) - set(bad_genomes_list))
mash_scrubbed_metadata = mash_scrubbed_metadata.set_index('genome_id').loc[good_genome_ids].reset_index()
mash_scrubbed_summary = mash_scrubbed_summary.set_index('genome_id').loc[good_genome_ids].reset_index()

df_mash_square_complete = remove_bad_strains(df_mash_square_complete, bad_genomes_list)
df_mash_corr_complete = remove_bad_strains(df_mash_square_complete, bad_genomes_list)
df_mash_corr_dist_complete = remove_bad_strains(df_mash_corr_dist_complete, bad_genomes_list)

## Keep filtering until robust clusters show up

In [None]:
iteration = 1
prev = 0
curr = len(clst.cluster.unique())

while(np.abs(prev - curr) > 0 ):
    print(f'iteration {iteration}...{curr}')
    
    # Cluster
    link, dist, clst = cluster_corr_dist(df_mash_corr_dist_complete, thresh=elbow_threshold)
    
    # Color each cluster
    cm = matplotlib.colormaps.get_cmap('tab20')
    clr = dict(zip(sorted(clst.cluster.unique()), cm.colors))
    clst['color'] = clst.cluster.map(clr)
    
    # Increment
    prev = curr
    curr = len(clst.cluster.unique())
    
    # Define bad clusters
    bad_clusters = clst.cluster.value_counts()[clst.cluster.value_counts() < 5]
    
    # Remove bad genomes
    bad_genomes_list = []
    for genome in df_mash_square_complete.index:
        cluster = clst.loc[genome, 'cluster']
        if cluster in bad_clusters:
            bad_genomes_list.append(genome)

    # Update filtration
    good_genome_ids = list(set(mash_scrubbed_metadata.set_index('genome_id').index) - set(bad_genomes_list))
    mash_scrubbed_metadata = mash_scrubbed_metadata.set_index('genome_id').loc[good_genome_ids].reset_index()
    mash_scrubbed_summary = mash_scrubbed_summary.set_index('genome_id').loc[good_genome_ids].reset_index()
    
    # Update filtration
    df_mash_square_complete = remove_bad_strains(df_mash_square_complete, bad_genomes_list)
    df_mash_corr_complete = remove_bad_strains(df_mash_square_complete, bad_genomes_list)
    df_mash_corr_dist_complete = remove_bad_strains(df_mash_corr_dist_complete, bad_genomes_list)
    
    # Increment
    iteration +=1

In [None]:
df_mash_square_complete.shape # Current shape after filtration

In [None]:
link, dist, clst = cluster_corr_dist(df_mash_corr_dist_complete, thresh=elbow_threshold)

# Color each cluster
cm = matplotlib.colormaps.get_cmap('tab20')
clr = dict(zip(sorted(clst.cluster.unique()), cm.colors+cm.colors))
clst['color'] = clst.cluster.map(clr)

print('Number of colors: ', len(clr))
print('Number of clusters', len(clst.cluster.unique()))

In [None]:
assert clst.cluster.value_counts().min() >= 5

In [None]:
px.histogram(clst.cluster.value_counts(), nbins=50)

# Plot filtered Mash clustermap

__From this it looks like our final rank for NMF decomposition will be 16 for Enterobacter__

In [None]:
size = 6

legend_TN = [patches.Patch(color=c, label=l) for l,c in clr.items()]

sns.set(rc={'figure.facecolor':'white'})
g = sns.clustermap(
    df_mash_square_complete,
    figsize=(size,size),
    row_linkage=link,
    col_linkage=link,
    col_colors=clst.color,
    yticklabels=False,
    xticklabels=False,
    cmap='BrBG_r',
    robust=True,
    center = .05
)

l2=g.ax_heatmap.legend(loc='upper left', bbox_to_anchor=(1.05,0.85), handles=legend_TN,frameon=True)
l2.set_title(title='Clusters',prop={'size':10})

## Labeled with species information

In [None]:
display(
    mash_scrubbed_metadata.shape,
    mash_scrubbed_metadata.head()
)

In [None]:
df_species = mash_scrubbed_metadata[mash_scrubbed_metadata.genome_status == 'Complete'].loc[:,["genome_id", "genome_name"]]
df_species["species"] = df_species["genome_name"].apply(lambda x: x.split()[0]+" " +x.split()[1])
df_species.set_index('genome_id', inplace=True)
df_species.head(20)

In [None]:
df_species.loc[df_species[df_species.species == "uncultured Enterobacter"].index, 'species'] = "Enterobacter sp."

In [None]:
cloacae = ["Enterobacter cloacae", 'Enterobacter asburiae', 'Enterobacter hormaechei',
           "Enterobacter kobei", "Enterobacter ludwigii", "Enterobacter nimipressuralis"] # check which species are officially a part of the complex
df_species["group"] = "Other"
df_species.loc[df_species.species.apply(lambda x: x in cloacae), "group"] = "Cloacae Complex"

In [None]:
custom_colors = {'Enterobacter hormaechei': 'FireBrick',
 'Enterobacter cloacae': 'Pink',
 'Enterobacter sp.': 'SlateGray',
 'Enterobacter roggenkampii': 'Green',
 'Enterobacter kobei': 'Purple',
 'Enterobacter cancerogenus': 'Magenta',
 'Enterobacter bugandensis': 'Cyan',
 'Enterobacter asburiae': 'Blue',
 'Enterobacter ludwigii': 'Lime',
 'Enterobacter mori': 'Seashell',
 'Enterobacter xiangfangensis': 'Red'}

df_species['color'] = df_species.species.map(custom_colors)

In [None]:
size = 6

legend_TN = [patches.Patch(color=c, label=l) for l,c in custom_colors.items()]

sns.set(rc={'figure.facecolor':'white'})
g = sns.clustermap(
    df_mash_square_complete,
    figsize=(size,size),
    row_linkage=link,
    col_linkage=link,
    col_colors=df_species.color,
    yticklabels=False,
    xticklabels=False,
    cmap='BrBG_r',
    robust=True,
    center = .05
)

l2=g.ax_heatmap.legend(loc='upper left', bbox_to_anchor=(1.05,0.85), handles=legend_TN,frameon=True)
l2.set_title(title='Clusters',prop={'size':10})

plt.savefig("../images/mash_clustermap.png", format='png', bbox_inches='tight')

In [None]:
size = 6

legend_TN = [patches.Patch(color=c, label=l) for l,c in clr.items()]

sns.set(rc={'figure.facecolor':'white'})
g = sns.clustermap(
    df_mash_square_complete,
    figsize=(size,size),
    row_linkage=link,
    col_linkage=link,
    col_colors=df_species.color,
    yticklabels=False,
    xticklabels=False,
    cmap='BrBG_r',
    center = .05,
    robust=True
)


l2.set_title(title='Clusters',prop={'size':10})
plt.savefig('../images/mash_clustermap.jpg', dpi = 400)

In [None]:
## NOTE:
# consider how to resolve the sp. distinctions, as well as uncultured Enterobacter

# KNN Classification

In [None]:
x = mash_scrubbed_metadata.set_index('genome_id')
x['complete_mash_cluster'] = clst.cluster.astype(str)
x = x.reset_index()

In [None]:
df = df_mash_square.loc[mash_scrubbed_metadata.genome_id]

In [None]:
import numpy as np
import pandas as pd
from sklearn.neighbors import KNeighborsClassifier

labels = x.complete_mash_cluster.dropna().astype(str)

# Ensure df_mash is a numpy array
if isinstance(df, pd.DataFrame):
    df = df.values

n_points = df.shape[0]
labeled_indices = list(labels.keys())
unlabeled_indices = list(set(range(n_points)) - set(labeled_indices))

labels_array = np.full(n_points, fill_value=None, dtype=object)
for idx, label in labels.items():
    labels_array[idx] = label

# Use only labeled data for training
X_train = df[np.ix_(labeled_indices, labeled_indices)]
y_train = labels_array[labeled_indices]

knn = KNeighborsClassifier(n_neighbors=5, metric='precomputed')
knn.fit(X_train, y_train)

X_unlabeled = df[np.ix_(unlabeled_indices, labeled_indices)]
predicted_labels = knn.predict(X_unlabeled)

for idx, label in zip(unlabeled_indices, predicted_labels):
    labels_array[idx] = label

# The array labels_array now contains the original and predicted labels


In [None]:
x['KNN'] = labels_array

In [None]:
x.set_index('genome_id')[['KNN', 'complete_mash_cluster']].apply(lambda x: x.KNN == x.complete_mash_cluster, axis=1).value_counts()

In [None]:
sns.barplot(x.set_index('genome_id')[['KNN', 'complete_mash_cluster']]['KNN'].value_counts())

# Save Mash-scrubbed `summary` and `metadata`

In [None]:
x = mash_scrubbed_metadata.set_index('genome_id')
x['complete_mash_cluster'] = clst.cluster.astype('int')
mash_scrubbed_metadata = x.reset_index()
mash_scrubbed_metadata.head()

In [None]:
filepath = SCRUBBED_SUMMARY.split('scrubbed_species_summary.csv')[0]
filepath = os.path.join(filepath, 'mash_scrubbed_species_summary.csv')
filepath

In [None]:
mash_scrubbed_summary.to_csv(filepath)

In [None]:
filepath = SCRUBBED_METADATA.split('scrubbed_species_metadata.csv')[0]
filepath = os.path.join(filepath, 'mash_scrubbed_species_metadata.csv')
filepath

In [None]:
mash_scrubbed_metadata.to_csv(filepath)

# Save Mash results

In [None]:
filepath = filepath.split('mash_scrubbed_species_metadata.csv')[0]
filepath = os.path.join(filepath, 'df_mash_square.csv')
filepath

In [None]:
df_mash_square.to_csv(filepath)

In [None]:
filepath = filepath.split('df_mash_square.csv')[0]
filepath = os.path.join(filepath, 'df_mash_corr_dist.csv')
filepath

In [None]:
df_mash_corr_dist.to_csv(filepath)

# Characterize Each Cluster

In [None]:
mash_scrubbed_metadata.complete_mash_cluster.value_counts()

In [None]:
df_mash_square

In [None]:
# BV-BRC Strains with same strain as from https://f1000research.com/articles/7-521
type_strains = {
    'asburiae': '1646339.4',
    'bugandensis': '881260.3',
    'cancerogenous': '69218.16',
    'cloacae clade K': '550.420',
    'asburiae clade L': '61645.63',
    'cloacae clade N': '550.1227',
    'cloacae clade s': '550.979',
    'cancerogenous clade t': '69218.15',
    'cloacae cloacae': '716541.4',
    'cloacae dissolvens': '1104326.3',
    'hormaechei hoffmannii': '1812934.3',
    'hormaechei hormaechei': '888063.8',
    'hormaechei oharae': '301102.37',
    'hormaechei sterigerwaltii': '299766.39',
    'hormaechei xiangfangensis': '1296536.16',
    'kobei': '208224.12',
    'ludwigii': '299767.18',
    'mori': '539813.36', # from BV-BRC
    'rogenkampii': '1812935.7',
}

id_to_type ={strain:key for key, strain in type_strains.items()}

In [None]:
cluster_to_type = {}
for cluster in sorted(mash_scrubbed_metadata[mash_scrubbed_metadata.genome_id.isin(df_mash_corr_complete)].complete_mash_cluster.astype(float).astype(int).unique()):
    strains = list(mash_scrubbed_metadata[mash_scrubbed_metadata.complete_mash_cluster == cluster].genome_id.values)
    best_match = id_to_type[df_mash_square.loc[type_strains.values(), strains].mean(axis=1).idxmin()]
    cluster_to_type[cluster] =  best_match + ' - ' + type_strains[best_match]
    print(int(float(cluster)), id_to_type[df_mash_square.loc[type_strains.values(), strains].mean(axis=1).idxmin()], type_strains[best_match])

In [None]:
size = 6

legend_TN = [patches.Patch(color=c, label=str(l) + ' - ' + cluster_to_type[l]) for l,c in clr.items()]

sns.set(rc={'figure.facecolor':'white'})
g = sns.clustermap(
    df_mash_square_complete,
    figsize=(size,size),
    row_linkage=link,
    col_linkage=link,
    col_colors=clst.color,
    yticklabels=False,
    xticklabels=False,
    cmap='BrBG_r',
    robust=True,
    center = .05
)

l2=g.ax_heatmap.legend(loc='upper left', bbox_to_anchor=(1.05,1), handles=legend_TN,frameon=True)
l2.set_title(title='Clusters',prop={'size':10})
plt.savefig("../images/supplemental/mash_clustermap_mash_labels.png", format='png', bbox_inches='tight')