In [None]:
import logging
import re
import urllib
from io import StringIO
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import gzip
import pickle
from tqdm.notebook import tqdm, trange
import multiprocessing
from IPython.display import display, HTML
import itertools

from collections import Counter
import plotly.graph_objects as go
import networkx as nx
import seaborn as sns
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

In [None]:
DF_GENES = '../../data/processed/cd-hit-results/sim80/Ebacter_strain_by_gene.pickle.gz'
ENRICHED_METADATA = '../../data/metadata/enriched_metadata.csv'
DF_EGGNOG = '../../data/processed/df_eggnog.csv'

DF_CORE_COMPLETE = '../../data/processed/CAR_genomes/df_core_complete.pickle'
DF_ACC_COMPLETE = '../../data/processed/CAR_genomes/df_acc_complete.pickle'
DF_RARE_COMPLETE = '../../data/processed/CAR_genomes/df_rare_complete.pickle'

L_BINARIZED = '../../data/processed/nmf-outputs/L_binarized.csv'
A_BINARIZED = '../../data/processed/nmf-outputs/A_binarized.csv'
L_MATRIX = '../../data/processed/nmf-outputs/L.csv'
A_MATRIX = '../../data/processed/nmf-outputs/A.csv'
BAKTA_ANNOTATIONS = '../../data/processed/bakta_gene_annotations.csv'

In [None]:
bakta_annotations = pd.read_csv(BAKTA_ANNOTATIONS, index_col=0)

In [None]:
gene_locs_acc = pd.read_csv('acc_gene_location.csv', index_col=0)
gene_locs = pd.read_csv('complete_gene_location.csv', index_col=0)

In [None]:
df_rare = pd.read_pickle(DF_RARE_COMPLETE)
df_acc = pd.read_pickle(DF_ACC_COMPLETE)
df_core = pd.read_pickle(DF_CORE_COMPLETE)

In [None]:
metadata = pd.read_csv(ENRICHED_METADATA, index_col=0, dtype='object')

display( metadata.shape, metadata.head())

In [None]:
# Load in (full) P matrix
df_genes = pd.read_pickle(DF_GENES)

# Filter metadata for Complete sequences only
metadata_complete = metadata[metadata.genome_status == 'Complete'] # filter for only Complete sequences

# Filter P matrix for Complete sequences only
df_genes_complete = df_genes[metadata_complete.genome_id].copy()
df_genes_complete.fillna(0, inplace=True) # replace N/A with 0
df_genes_complete = df_genes_complete.sparse.to_dense().astype('int8') # densify & typecast to int8 for space and compute reasons
inCompleteseqs = df_genes_complete.sum(axis=1) > 0 # filter for genes found in complete sequences
df_genes_complete = df_genes_complete[inCompleteseqs]

df_genes_complete.shape

In [None]:
# Load in eggNOG annotations
df_eggnog = pd.read_csv(DF_EGGNOG, index_col=0)
df_eggnog.fillna('-', inplace=True)

display(
    df_eggnog.shape,
    df_eggnog.head()
)

In [None]:
# Load in A_binarized matrix
A_binarized = pd.read_csv(A_BINARIZED, index_col=0)
A_binarized

In [None]:
# Load in L_binarized matrix
L_binarized = pd.read_csv(L_BINARIZED, index_col=0)
L_binarized

In [None]:
characterized_order = ['hormaechei-xiangfangensis',
 'hormaechei-oharae',
 'hormaechei-steigerwaltii-2',
 'hormaechei-steigerwaltii-1',
 'hormaechei-steigerwaltii-3',
 'hormaechei-hormaechei',
 'hormaechei-hoffmannii-1',
 'hormaechei-hoffmannii-2',
 'roggenkampii',
 'asburiae',
 'kobei',
 'bugandensis',
 'cancerogenous',
 'ludwigii',
 'cloacae']

# Analyze the overlap between phylons and mash clusters

In [None]:
def get_strains(phylon, A_binarized = A_binarized):
    phylon_membership = A_binarized.loc[phylon]
    return (phylon_membership[phylon_membership == 1]).index

In [None]:
strain_maps = pd.DataFrame(index = metadata_complete.genome_id.values, columns = ['phylon', 'mash_cluster'])
for strain, mash in zip(metadata_complete.genome_id.values, metadata_complete.complete_mash_cluster.values):
    if A_binarized.loc[characterized_order, strain].max() > 0:
        strain_maps.loc[strain] = [A_binarized.loc[characterized_order, strain].idxmax(), int(float(mash))]
    else:
        strain_maps.loc[strain] = ['None', int(float(mash))]

In [None]:
mash_nodes = list(range(1, len(metadata_complete.complete_mash_cluster.unique())+1))
phylon_nodes = characterized_order
mash_strain_counts = metadata_complete.complete_mash_cluster.astype(float).astype(int).value_counts()

edge_counter = Counter()
for strain in strain_maps.index:
    phylon, mash = strain_maps.loc[strain]
    edge_counter[(int(float(mash)), phylon)] += 1

sankey_diagram = pd.DataFrame(columns = ['mash', 'phylon', 'strain_percentage'])
for item in edge_counter.items():
    sankey_diagram.loc[len(sankey_diagram)] = [item[0][0], item[0][1], item[1]]

sankey_diagram

In [None]:
import plotly.graph_objects as go
import matplotlib

# Ensure sankey_diagram is sorted
sankey_diagram = sankey_diagram.sort_values(['mash', 'phylon'])

# Extract source and target nodes in order
mash_nodes = sankey_diagram['mash'].unique()
phylon_nodes = sankey_diagram['phylon'].unique()

# Combine mash and phylon nodes into one ordered list
all_nodes = list(mash_nodes) + list(phylon_nodes)

# Map node names to indices
node_indices = {node: idx for idx, node in enumerate(all_nodes)}

# Map 'mash' and 'phylon' columns to their indices
sankey_diagram['source'] = sankey_diagram['mash'].map(node_indices)
sankey_diagram['target'] = sankey_diagram['phylon'].map(node_indices)

# Mash cluster colors
cm = matplotlib.colormaps.get_cmap('tab20')
mash_clr = dict(zip(sorted(strain_maps.dropna().mash_cluster.unique()), cm.colors + cm.colors))

# Phylon colors
custom_colors = [
    '#FFFFFF',
    "Red", "IndianRed", "DarkRed", "FireBrick", "Tomato",
    "Gold", "DarkGoldenrod", "Goldenrod", "Green",
    "Blue", "Purple", "Cyan", "Magenta", "Lime", "Pink",
]
phylon_colors = {phylon: color for phylon, color in zip((['None'] + characterized_order), custom_colors)}

# Assign colors to nodes
node_colors = []
for node in all_nodes:
    if node in mash_clr:  # Mash cluster node
        node_colors.append(mash_clr[node])
    elif node in phylon_colors:  # Phylon node
        node_colors.append(phylon_colors[node])
    else:  # Default color
        node_colors.append("lightblue")

# Convert RGB tuples to hex if applicable
node_colors = [
    '#%02x%02x%02x' % (int(c[0]*255), int(c[1]*255), int(c[2]*255)) if isinstance(c, tuple) else c
    for c in node_colors
]

# Create the Sankey diagram
fig = go.Figure(go.Sankey(
    node=dict(
        pad=5,  # Adjust padding between nodes (decrease to bring source/target closer)
        thickness=50,  # Increase thickness to make nodes larger
        line=dict(color="black", width=0.5),
        label=all_nodes,
        color=node_colors
    ),
    link=dict(
        source=sankey_diagram['source'].astype(int).tolist(),
        target=sankey_diagram['target'].astype(int).tolist(),
        value=sankey_diagram['strain_percentage'].astype(float).tolist()
    )
))

# Adjust layout for a narrower and slightly taller diagram
fig.update_layout(
    title_text="Sankey Diagram of Mash to Phylon Nodes",
    font_size=10,
    margin=dict(l=10, r=10, t=40, b=10),  # Adjust margins for tighter spacing
    width=500,  # Set width to make the diagram narrower
    height=600  # Set height to make the diagram slightly taller
)
fig.write_image("../images/supplemental/sankey_diagram.svg")
fig.show()
