# MASH
Summary of [MASH](https://github.com/marbl/Mash) results from project: `[{{ project().name }}]`

## Description
Fast genome and metagenome distance estimation using MinHash

In [None]:
import pandas as pd
from pathlib import Path

import warnings
warnings.filterwarnings('ignore')

#import os
import seaborn as sns
import matplotlib.pyplot as plt
import scipy.cluster.hierarchy as shc
from sklearn.cluster import AgglomerativeClustering, KMeans
from sklearn.preprocessing import MinMaxScaler
import numpy as np
import networkx as nx
import plotly.graph_objects as go
import yaml

sns.set_context("paper")


def kMeansRes(scaled_data, k, alpha_k=0.02):
    '''
    # Calculating clusters from https://medium.com/towards-data-science/an-approach-for-choosing-number-of-clusters-for-k-means-c28e614ecb2c
    Parameters 
    ----------
    scaled_data: matrix 
        scaled data. rows are samples and columns are features for clustering
    k: int
        current k for applying KMeans
    alpha_k: float
        manually tuned factor that gives penalty to the number of clusters
    Returns 
    -------
    scaled_inertia: float
        scaled inertia value for current k           
    '''
    
    inertia_o = np.square((scaled_data - scaled_data.mean(axis=0))).sum()
    # fit k-means
    kmeans = KMeans(n_clusters=k, random_state=0).fit(scaled_data)
    scaled_inertia = kmeans.inertia_ / inertia_o + alpha_k * k
    return scaled_inertia

def chooseBestKforKMeans(scaled_data, k_range):
    ans = []
    for k in k_range:
        scaled_inertia = kMeansRes(scaled_data, k)
        ans.append((k, scaled_inertia))
    results = pd.DataFrame(ans, columns = ['k','Scaled Inertia']).set_index('k')
    best_k = results.idxmin()[0]
    return best_k, results

def create_edge_trace(Graph, name, showlegend=False, color='#888', width=0.5, opacity=0.8,
                      legendgroup="edges", legendgrouptitle_text="edges"):
    edge_trace = go.Scatter(
        x=[],
        y=[],
        name=name,
        opacity=opacity,
        line=dict(width=width,color=color),
        hoverinfo='none',
        mode='lines',
        showlegend=showlegend,
        legendgroup=legendgroup,
        legendgrouptitle_text=legendgrouptitle_text,)

    edges = np.array([edge for edge in Graph.edges() if G.edges[edge]["relation_type"] == name])
    pos = np.array([Graph.nodes[e]['pos'] for e in edges.flatten()]).reshape(-1, 2)
    xs = np.insert(pos[:, 0], np.arange(2, len(pos[:, 0]), 2), None)
    ys = np.insert(pos[:, 1], np.arange(2, len(pos[:, 1]), 2), None)
    edge_trace['x'] = xs
    edge_trace['y'] = ys

    return edge_trace

def create_node_trace(G, node_trace_category, color, showtextlabel=False, nodesize=10, nodeopacity=0.8, 
                      nodesymbol="circle", linewidth=1, linecolor="black", textposition="top center", showlegend=False,
                     legendgroup="nodes", legendgrouptitle_text="nodes"):
    if showtextlabel:
        markermode = "markers+text"
    else:
        markermode = "markers"
    nodes = np.array([node for node in G.nodes() if G.nodes[node]["node_trace"] == node_trace_category])
    pos = np.array([G.nodes[node]['pos'] for node in nodes.flatten()]).reshape(-1, 2)
    xs, ys = pos[:, 0], pos[:, 1]
    texts = np.array([G.nodes[node]['text'] for node in nodes])
    node_trace = go.Scatter(
        x=xs.tolist(),
        y=ys.tolist(),
        text=texts.tolist(),
        textposition=textposition,
        mode=markermode,
        hoverinfo='text',
        name=node_trace_category,
        showlegend=showlegend,
        legendgroup=legendgroup,
        legendgrouptitle_text=legendgrouptitle_text,
        marker=dict(
            symbol=nodesymbol,
            opacity=nodeopacity,
            showscale=False,
            color=color,
            size=nodesize,
            line=dict(width=linewidth, color=linecolor)))
    return node_trace

## File Configurations

In [None]:
with open("config.yaml", "r") as f:
    notebook_configuration = yaml.safe_load(f)
notebook_configuration

In [None]:
assert "project_name" in notebook_configuration, "Please specify a project name in the config.yaml file"

if "report_dir" in notebook_configuration:
    report_dir = Path(notebook_configuration["report_dir"])
elif "bgcflow_dir" in notebook_configuration:
    report_dir = Path(notebook_configuration["bgcflow_dir"]) / f"data/processed/{project_name}"

if "antismash_version" in notebook_configuration:
    antismash_version = notebook_configuration["antismash_version"]
else:
    antismash_version = "7.0.0"

FIGURE = "Figure_3"
FIGURE2 = "Figure_S7"

In [None]:
df_mash = pd.read_csv(report_dir / 'mash/df_mash.csv', index_col=0)
df_gtdb = pd.read_csv(report_dir / 'tables' / 'df_gtdb_meta.csv', index_col='genome_id')

## Hierarchical Clustering based on MASH values

In [None]:
df_mash_corr = df_mash.corr()

plt.figure(figsize=(30, 7))
plt.title("MASH Distances")

selected_data = df_mash_corr.copy()
clusters = shc.linkage(selected_data, 
            method='ward', 
            metric="euclidean",
            optimal_ordering=True,)
shc.dendrogram(Z=clusters, labels=df_mash_corr.index, leaf_rotation=90)
plt.show()

## Estimate Number of Clusters

In [None]:
# choose features
data_for_clustering = df_mash.copy()
data_for_clustering.fillna(0,inplace=True)

# create data matrix
data_matrix = np.matrix(data_for_clustering).astype(float)
data_matrix

# scale the data
mms = MinMaxScaler()
scaled_data = mms.fit_transform(np.asarray(data_matrix))

# choose k range
if len(df_mash) <= 21:
    max_range = len(df_mash) - 1
else:
    max_range = 20

k_range=range(2, max_range)
# compute adjusted intertia
best_k, results = chooseBestKforKMeans(scaled_data, k_range)

# plot the results
plt.figure(figsize=(7,4))
plt.plot(results,'o')
#plt.title('Adjusted Inertia for each K')
plt.xlabel('K-means clusters')
plt.ylabel('Adjusted Inertia')
plt.xticks(range(2,max_range,1))
print(f"Estimated number of clusters: {best_k}")

image_format = 'svg'
image_name = Path(f'assets/figures/{FIGURE2}/{FIGURE2}_a.svg')
image_name.parent.mkdir(parents=True, exist_ok=True)

plt.savefig(image_name, format=image_format, dpi=1200)

## MASH Clustermap

In [None]:
n_clusters = best_k

# max color 12
if best_k < 12:
    top_clusters = best_k
else:
    top_clusters = 12
# create output folder
fig_folder = Path(f"assets/figures/{FIGURE}")
fig_folder.mkdir(parents=True, exist_ok=True)

Agg_hc = AgglomerativeClustering(n_clusters = n_clusters, affinity = 'euclidean', linkage = 'ward')
y_hc = Agg_hc.fit_predict(df_mash_corr)
df_hclusts = pd.DataFrame(index=df_mash_corr.index, columns=['hcluster', 'color_code'])
df_hclusts['hcluster'] = y_hc

In [None]:
top_clusters_new = df_hclusts.hcluster.value_counts().index.tolist()[:top_clusters]

color_set = ['#264653','#2a9d8f','#e9c46a','#f4a261','#e76f51',  "#808080", "#808080", "#808080"]

colors = []
while len(colors) < len(top_clusters_new):
    if len(colors) <= len(color_set):
        colors.append(color_set[len(colors) - 1])
    else:
        colors.append('#808080')

dict_top_colors = dict(zip(top_clusters_new, colors[:len(top_clusters_new)]))
dict_top_colors

In [None]:
for genome_id in df_hclusts.index:
    cluster_id = df_hclusts.loc[genome_id, 'hcluster']
    if cluster_id in top_clusters_new:
        df_hclusts.loc[genome_id, 'color_code'] = dict_top_colors[cluster_id]
    else:
        df_hclusts.loc[genome_id, 'color_code'] = "#000000"
        
comm_colors = df_hclusts['color_code']
plt.figure()

# sns.set_theme(color_codes=True)
g = sns.clustermap(1 - df_mash ,
                  figsize=(8,8), row_linkage=clusters, col_linkage=clusters,
                  row_colors=comm_colors, col_colors=comm_colors,
                  yticklabels=False, xticklabels=False, cmap="rocket_r")
# g.ax_cbar.set_position((1, .2, .03, .4))
g.cax.set_visible(True)
# g.ax_row_dendrogram.set_visible(False)
# g.ax_col_dendrogram.set_visible(False)
# g.ax_heatmap.set_xlabel('Genomes')
# g.ax_heatmap.set_ylabel('Genomes')
#Draw the legend bar for the classes                 
for col in dict_top_colors.keys():
    g.ax_col_dendrogram.bar(-10, 10, color=dict_top_colors[col],
                             label=col, linewidth=4)
    g.ax_col_dendrogram.legend(ncol=2)

plt.savefig(f"assets/figures/{FIGURE}/b.svg")
plt.show()

### Define phylogroups in the order

In [None]:
phylogroup_mapping = {}
for num, k in enumerate(dict_top_colors.keys()):
    phylogroup_mapping[k] = f"P{num}"

In [None]:
df_hclusts["phylogroup"] = [phylogroup_mapping[hclust] for hclust in df_hclusts.hcluster.tolist()]
df_hclusts = df_hclusts.reindex(index=g.data2d.index)

In [None]:
outdir = Path("assets/tables")
outdir.mkdir(parents=True, exist_ok=True)
df_hclusts.to_csv(f"assets/tables/{FIGURE}b_mash_hcluster.csv")

### Draw MASH network

In [None]:
tax_mapping = {}
for k, v in df_gtdb.Organism.to_dict().items():
    v = v.split()
    if len(v) == 2:
        genus, epithet = v
        species = f"S. {epithet}"
    elif v == 's':
        species = "Saccharopolyspora sp."
    tax_mapping[k] = species

In [None]:
node_annotation_map = {}
for i in df_hclusts.index:
    phylogroup = df_hclusts.loc[i, "phylogroup"]
    color = df_hclusts.loc[i, "color_code"]
    symbol = "circle"
    node_annotation_map[phylogroup] = {'color' : color,
                                      'node_symbol' : symbol}

In [None]:
edge_annotation_map = {'mash' : {'color':'black',
                                 'width':0.5}}

In [None]:
traces = []
cutoff = 0.85
G = nx.from_pandas_adjacency(df_mash)
edge_to_remove = [e for e in G.edges if G.edges[e]['weight'] > 1-cutoff]
G.remove_edges_from(edge_to_remove)

# define layout options
options = {
    'prog': 'neato',
}
pos = nx.nx_agraph.graphviz_layout(G, **options)#, args='-Goverlap=false -Elen=weight')
for n, p in pos.items():
    G.nodes[n]['pos'] = p
    G.nodes[n]['node_trace'] = df_hclusts.loc[n, "phylogroup"]
    G.nodes[n]['text'] = f'{tax_mapping[n]}'

weights = []
for e in G.edges:
    weight = G.edges[e]['weight']
    weight = f"{1-weight:.2f}"
    weights.append(weight)
    G.edges[e]['relation_type'] = f'{float(weight):.0%}'

weights = sorted(set(weights))

x_max, x_min = 1, 0.85
y_max, y_min = 3, 0.2
x = (y_max - y_min) / (x_max - x_min)
c = y_max - (x_max*x)

for w in weights:
    width = float(w)*x + c
    edge_trace = create_edge_trace(G, f'{float(w):.0%}', color='black', width=width, showlegend=True, opacity=0.5,
                                   legendgroup="MASH distances", legendgrouptitle_text="Similarity")
    traces.append(edge_trace)

for trace in df_hclusts["phylogroup"].unique():
    nodeopacity = 0.8
    showtextlabel = True
    linecolor = None
    linewidth = 0.5
    textposition="middle center"
    node_size = 28
    color = node_annotation_map[trace]['color']
    node_trace = create_node_trace(G, trace, color, showtextlabel=showtextlabel, 
                                   nodesymbol=node_annotation_map[trace]['node_symbol'], nodeopacity=nodeopacity, 
                                   showlegend=True, linecolor=linecolor, linewidth=linewidth, nodesize=node_size,
                                   textposition=textposition, legendgroup="genomes", legendgrouptitle_text="Species phylogroup")
    traces.append(node_trace)

In [None]:
fig = go.Figure(data=traces,
                layout=go.Layout(
                    paper_bgcolor='rgba(0,0,0,0)',
                    plot_bgcolor='white',
                    showlegend=True,
                    hovermode='closest',
                    margin=dict(b=20,l=5,r=5,t=40),
                    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, linecolor='black', mirror=True, linewidth=1),
                    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, linecolor='black', mirror=True, linewidth=1),
                    width=800, height=700)
                )
fig.update_layout(legend=dict(
    orientation="v"
))

In [None]:
outfile = Path(f"assets/figures/{FIGURE2}/{FIGURE2}_b.svg")
outfile_html = Path(f"assets/figures/{FIGURE2}/{FIGURE2}_b.html")
outfile.parent.mkdir(exist_ok=True, parents=True)
fig.write_image(outfile)
fig.write_html(outfile_html)

## Annotate network image

In [None]:
from svgutils.compose import *
from svgutils.compose import Figure
from IPython.display import SVG as disp_SVG

In [None]:
width, height = 20, 50
color = "black"
# create boxes

rectangle = f"""<svg width="{width+20}" height="{height+20}">
    <rect x="20" y="20" width="{width}" height="{height}"
    style="fill:none;stroke:{color};stroke-width:1;opacity:1" />
    </svg>"""

outfile = Path(f"assets/figures/{FIGURE2}/doodles/rectangle_{width}x{height}_{color}.svg")
outfile.parent.mkdir(parents=True, exist_ok=True)
with open(outfile, "w") as f:
            f.write(rectangle)

disp_SVG(rectangle)

In [None]:
final_figure = Figure("700", "1000",
                      Panel(
                          SVG(f"assets/figures/{FIGURE2}/{FIGURE2}_a.svg").scale(1.3).move(0, 0),
                          Text("A", 0, 30, size=18, weight='bold'),
                      ),                 
                      Panel(
                          SVG(f"assets/figures/{FIGURE2}/doodles/rectangle_{width}x{height}_{color}.svg").scale(1).move(240, 290),
                      ),
                      Panel(
                          SVG(f"assets/figures/{FIGURE2}/{FIGURE2}_b.svg").scale(0.8).move(80, 380),
                          Text("B", 0, 380+30, size=18, weight='bold'),
                      ),
                     )
outfile = Path(f"assets/figures/{FIGURE2}/{FIGURE2}.svg")
outfile.parent.mkdir(parents=True, exist_ok=True)
final_figure.save(outfile)
final_figure

## References
<font size="2">
{% for i in project().rule_used['mash']['references'] %}
- *{{ i }}*
{% endfor %}
</font>