In [None]:
import logging
import re
import urllib
from io import StringIO
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import gzip
import pickle
from tqdm.notebook import tqdm, trange
import multiprocessing
from IPython.display import display, HTML
import itertools

import plotly.graph_objects as go
from Bio import SeqIO
import os

import seaborn as sns

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

In [None]:
import os
import pickle

import plotly.express as px
import numpy as np
import pandas as pd
import scipy.cluster.hierarchy as hc
import scipy.spatial as sp
from kneebow.rotor import Rotor

In [None]:
L_BINARIZED = '../../data/processed/nmf-outputs/L_binarized.csv'
L_bin = pd.read_csv(L_BINARIZED, index_col=0)
A_BINARIZED = '../../data/processed/nmf-outputs/A_binarized.csv'
A_bin = pd.read_csv(A_BINARIZED, index_col=0)

In [None]:
PATH_TO_DATA = '../../data/'
ENRICHED_METADATA = f'{PATH_TO_DATA}/metadata/enriched_metadata.csv'

In [None]:
metadata = pd.read_csv(ENRICHED_METADATA, index_col = 0, dtype='object')
complete_metadata = metadata[metadata.genome_status == 'Complete']

In [None]:
for genome in complete_metadata.genome_id:
    input_file = '../../data/processed/bakta/' + genome + '/' + genome + '.fna'
    
    contigs = list(SeqIO.parse(input_file, "fasta"))
    largest_contig = max(contigs, key=lambda x: len(x.seq))
    filtered_contigs = [contig for contig in contigs if contig.id != largest_contig.id]
    if len(filtered_contigs) > 0:
        for i in range(0,len(filtered_contigs)):
                outfile = '../../data/processed/plasmid_fna_files/' + genome + '_' + filtered_contigs[i].id + '.fna'
                SeqIO.write(filtered_contigs[i], outfile, "fasta")


## Run Mash

Command:

mash sketch -o combined_sketch ../../data/processed/plasmid_fna_files/*.fna

mash dist combined_sketch.msh combined_sketch.msh > mash_distances.txt

## Calculations

In [None]:
names = [
    'genome1',
    'genome2',
    'mash_distance',
    'p_value',
    'matching_hashes'
]

df_mash = pd.read_csv('../../data/processed/plasmid_data/mash_distances.txt', sep='\t', names=names)
df_mash['genome1'] = df_mash['genome1'].apply(lambda x: x.split('/')[-1].split('.fna')[0])
df_mash['genome2'] = df_mash['genome2'].apply(lambda x: x.split('/')[-1].split('.fna')[0])

df_mash

In [None]:
df_mash_square = df_mash.pivot(index='genome1', columns='genome2', values='mash_distance')

display(
    df_mash_square.shape,
    df_mash_square.head()
)

In [None]:
# This may take HOURS to run
# Once finished it will IMMEDIATELY save all 3 matrices
# so you don't have to re-compute this over and over again

df_mash_corr = df_mash_square.corr()
df_mash_corr_dist = 1 - df_mash_corr
df_mash_corr_dist

display(
    df_mash_corr_dist.shape,
    df_mash_corr_dist.head()
)

In [None]:
sns.histplot(df_mash_square.values.flatten())

In [None]:
def cluster_corr_dist(df_mash_corr_dist, maxclust=4, method='ward', metric='euclidean'):
    '''
    Hierarchically Mash-based pairwise-pearson-distance matrix
    '''
    link = hc.linkage(sp.distance.squareform(df_mash_corr_dist), method=method, metric=metric)
    dist = sp.distance.squareform(df_mash_corr_dist)
    
    clst = pd.DataFrame(index=df_mash_corr_dist.index)
    clst['cluster'] = hc.fcluster(link, maxclust, 'maxclust')
    
    return link, dist, clst


def remove_bad_strains(df_mash_scd, bad_strains_list):
    good_strains_list = sorted(set(df_mash_scd.index) - set(bad_strains_list))
    
    return df_mash_scd.loc[good_strains_list, good_strains_list]


# Sensitivity analysis to pick the threshold (for E. coli we use 0.1)
# We pick the threshold where the curve just starts to bottom out
def sensitivity_analysis(df_mash_corr_dist_complete):
    x = list(np.logspace(-3, -1, 10)) + list(np.linspace(0.1, 1, 19))
    
    def num_uniq_clusters(thresh):
        link = hc.linkage(sp.distance.squareform(df_mash_corr_dist_complete), method='ward', metric='euclidean')
        dist = sp.distance.squareform(df_mash_corr_dist_complete)
        
        clst = pd.DataFrame(index=df_mash_corr_dist_complete.index)
        clst['cluster'] = hc.fcluster(link, thresh * dist.max(), 'distance')
        
        return len(clst.cluster.unique())
    
    tmp = pd.DataFrame()
    tmp['threshold'] = pd.Series(x)
    tmp['num_clusters'] = pd.Series(x).apply(num_uniq_clusters)
    
    # Find which value the elbow corresponds to
    df_temp = tmp.sort_values(by='num_clusters', ascending=True).reset_index(drop=True)
    
    # transform input into form necessary for package
    results_itr = zip(list(df_temp.index), list(df_temp.num_clusters))
    data = list(results_itr)
    
    rotor = Rotor()
    rotor.fit_rotate(data)
    elbow_idx = rotor.get_elbow_index()
    df_temp['num_clusters'][elbow_idx]
    contamination_cutoff = df_temp['num_clusters'][elbow_idx]
    
    # Grab elbow threshold
    cond = tmp['num_clusters'] == df_temp['num_clusters'][elbow_idx]
    elbow_threshold = tmp[cond]['threshold'].iloc[0]
    
    return tmp, df_temp, elbow_idx, elbow_threshold



In [None]:
df_mash_square_complete = df_mash_square
df_mash_corr_complete = df_mash_square
df_mash_corr_dist_complete = df_mash_square

df_mash_corr_dist_complete.shape

In [None]:
# Initial sensitivity analysis (gives min val to consider)
tmp, df_temp, elbow_idx, elbow_threshold = sensitivity_analysis(df_mash_corr_dist_complete)

# Plot (tells us to pick something > 0.25)
plt.rcParams["figure.dpi"] = 200
fig, axs = plt.subplots(figsize=(4,3),)
axs.plot(tmp['threshold'], tmp['num_clusters'])
plt.axhline(y=df_temp['num_clusters'][elbow_idx], c="#ff00ff", linestyle='--')
axs.set_ylabel('num_clusters')
axs.set_xlabel('index')
fig.suptitle(
    f"Num clusters decelerates \nafter a value of {df_temp['num_clusters'][elbow_idx]} (threshold: {elbow_threshold})",
    y=1
)
plt.show()

In [None]:
px.line(tmp, x='threshold', y='num_clusters')

## Plot initial clustermap of Mash values

In [None]:
elbow_threshold = elbow_threshold+0.1 # "round" up

link, dist, clst = cluster_corr_dist(df_mash_corr_dist_complete)

# Color each cluster
cm = matplotlib.colormaps.get_cmap('tab20')
clr = dict(zip(sorted(clst.cluster.unique()), cm.colors+cm.colors))
clst['color'] = clst.cluster.map(clr)

print('Number of colors: ', len(clr))
print('Number of clusters', len(clst.cluster.unique()))

In [None]:
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import squareform

import matplotlib.pyplot as plt


mat = np.array([[0.0, 2.0, 0.1], [2.0, 0.0, 2.0], [0.1, 2.0, 0.0]])
dists = squareform(mat)
linkage_matrix = linkage(dists, "single")
dendrogram(link)
plt.title("test")
plt.show()

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import linkage, leaves_list
import seaborn as sns

# Load the distance matrix from a file
distance_matrix = df_mash_corr_dist_complete.values

# Ensure the matrix is in condensed form
if distance_matrix.shape[0] == distance_matrix.shape[1]:
    condensed_distance_matrix = squareform(distance_matrix)
else:
    condensed_distance_matrix = distance_matrix

# Perform hierarchical clustering
link, dist, clst = cluster_corr_dist(df_mash_corr_dist_complete, maxclust=5)
# Get the order of the leaves
ordered_leaves = leaves_list(link)

# Reorder the distance matrix
ordered_distance_matrix = distance_matrix[ordered_leaves, :][:, ordered_leaves]

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.patches as patches

# Assuming ordered_distance_matrix, clst, and ordered_leaves are already defined

fig, ax = plt.subplots(figsize=(12, 12))
cax = ax.matshow(ordered_distance_matrix, cmap='Greens_r')

# Add a colorbar with adjusted size and position
cbar = fig.colorbar(cax, ax=ax, orientation='vertical', fraction=0.046, pad=0.04)
cbar.set_label('Distance')  # Add a label to the colorbar (optional)

# Customize the ticks on the colorbar
cbar_ticks = np.linspace(np.min(ordered_distance_matrix), np.max(ordered_distance_matrix), num=5)
cbar.set_ticks(cbar_ticks)
cbar.ax.set_yticklabels(['{:.2f}'.format(tick) for tick in cbar_ticks])

# Create a colormap for the clusters
cluster_labels = list(clst.cluster[ordered_leaves].astype(int).values)
unique_clusters = np.unique(cluster_labels)
cluster_colors = plt.cm.tab20(np.linspace(0, 1, len(unique_clusters)))
color_map = dict(zip(unique_clusters, cluster_colors))

# Add rectangles for clusters
for cluster in unique_clusters:
    indices = np.where(cluster_labels == cluster)[0]
    if len(indices) > 0:
        min_idx, max_idx = np.min(indices), np.max(indices)
        rect = patches.Rectangle((min_idx, min_idx), max_idx - min_idx + 1, max_idx - min_idx + 1,
                                 linewidth=5, edgecolor='red', facecolor='none', label=f'Cluster {cluster}')
        ax.add_patch(rect)

# Optionally, add a legend
# handles, labels = ax.get_legend_handles_labels()
# ax.legend(handles, labels, loc='upper right', bbox_to_anchor=(.95, 1), ncols=2)

# Set the tick parameters to make the labels more readable
ax.tick_params(axis='x', which='major', labelsize=10)
ax.tick_params(axis='y', which='major', labelsize=10)

plt.savefig('../images/supplemental/plasmid_mash.jpg', dpi = 400)
plt.show()


In [None]:
cluster_strains = {}
for cluster in clst.cluster.unique():
    ind = clst[clst.cluster == cluster].index
    cluster_strains[cluster] = list(set([x.split('_')[0] for x in ind]))

In [None]:
information = pd.DataFrame(columns = A_bin.index)

for cluster in cluster_strains.keys():
    strains = cluster_strains[cluster]
    information.loc[cluster] = A_bin.loc[:,strains].sum(axis=1) / A_bin.sum(axis=1)    
# sns.heatmap(information.astype(int).loc[:,['unchar-1','unchar-2','unchar-3', 'unchar-4']])
fig, ax = plt.subplots(figsize=(8,8))
sns.heatmap(information.astype(float).loc[:,[x for x in A_bin.index if 'unchar' in x]], ax=ax, cmap = 'coolwarm', annot=True)
plt.title('Percentage of Strains for each Plasmid Cluster Associated with Unchar Phylons')