In [None]:
# Base imports
import os
import pickle
import re

# Compute imports
import numpy as np
import pandas as pd
import scipy
from tqdm.notebook import tqdm, trange

# Plotting imports
import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
from plotly import express as px
import matplotlib.patches as mpatches

# ML import
from sklearn.decomposition import NMF
from sklearn.metrics import mean_squared_error, median_absolute_error
from sklearn.metrics.pairwise import cosine_similarity

from sklearn.cluster import KMeans


matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rcParams['svg.fonttype'] = 'none'
matplotlib.rcParams['font.sans-serif'] = 'Arial'
matplotlib.rcParams['font.family'] = 'sans-serif'
sns.set_style('ticks')
matplotlib.rcParams['text.color'] = '#000000'
matplotlib.rcParams['axes.labelcolor'] = '#000000'
matplotlib.rcParams['xtick.color'] = '#000000'
matplotlib.rcParams['ytick.color'] = '#000000'

In [None]:
DF_GENES = '../../data/processed/cd-hit-results/sim80/Ebacter_strain_by_gene.pickle.gz'
ENRICHED_METADATA = '../../data/metadata/enriched_metadata.csv'
DF_EGGNOG = '../../data/processed/df_eggnog.csv'

DF_CORE_COMPLETE = '../../data/processed/CAR_genomes/df_core_complete.pickle'
DF_ACC_COMPLETE = '../../data/processed/CAR_genomes/df_acc_complete.pickle'
DF_RARE_COMPLETE = '../../data/processed/CAR_genomes/df_rare_complete.pickle'

L_BINARIZED = '../../data/processed/nmf-outputs/L_binarized.csv'
A_BINARIZED = '../../data/processed/nmf-outputs/A_binarized.csv'
L_MATRIX = '../../data/processed/nmf-outputs/L.csv'
A_MATRIX = '../../data/processed/nmf-outputs/A.csv'
BAKTA_ANNOTATIONS = '../../data/processed/bakta_gene_annotations.csv'

In [None]:
bakta_annotations = pd.read_csv(BAKTA_ANNOTATIONS, index_col=0)

In [None]:
df_rare = pd.read_pickle(DF_RARE_COMPLETE)
df_acc = pd.read_pickle(DF_ACC_COMPLETE)
df_core = pd.read_pickle(DF_CORE_COMPLETE)

In [None]:
metadata = pd.read_csv(ENRICHED_METADATA, index_col=0, dtype='object')

display( metadata.shape, metadata.head())

In [None]:
# Load in (full) P matrix
df_genes = pd.read_pickle(DF_GENES)

# Filter metadata for Complete sequences only
metadata_complete = metadata[metadata.genome_status == 'Complete'] # filter for only Complete sequences

# Filter P matrix for Complete sequences only
df_genes_complete = df_genes[metadata_complete.genome_id].copy()
df_genes_complete.fillna(0, inplace=True) # replace N/A with 0
df_genes_complete = df_genes_complete.sparse.to_dense().astype('int8') # densify & typecast to int8 for space and compute reasons
inCompleteseqs = df_genes_complete.sum(axis=1) > 0 # filter for genes found in complete sequences
df_genes_complete = df_genes_complete[inCompleteseqs]

df_genes_complete.shape

In [None]:
# Load in eggNOG annotations
df_eggnog = pd.read_csv(DF_EGGNOG, index_col=0)
df_eggnog.fillna('-', inplace=True)

display(
    df_eggnog.shape,
    df_eggnog.head()
)

In [None]:
# Load in A_binarized matrix
A_binarized = pd.read_csv(A_BINARIZED, index_col=0)
A_binarized

In [None]:
# Load in L_binarized matrix
L_binarized = pd.read_csv(L_BINARIZED, index_col=0)
L_binarized

In [None]:
phylon_order = ['hormaechei-xiangfangensis',
 'hormaechei-oharae',
 'hormaechei-steigerwaltii-2',
 'hormaechei-steigerwaltii-1',
 'hormaechei-steigerwaltii-3',
 'hormaechei-hormaechei',
 'hormaechei-hoffmannii-1',
 'hormaechei-hoffmannii-2',
 'unchar-1',
 'unchar-2',
 'unchar-3',
 'unchar-4',
 'roggenkampii',
 'asburiae',
 'kobei',
 'bugandensis',
 'cancerogenous',
 'ludwigii',
 'cloacae']

In [None]:
characterized_order = ['hormaechei-xiangfangensis',
 'hormaechei-oharae',
 'hormaechei-steigerwaltii-2',
 'hormaechei-steigerwaltii-1',
 'hormaechei-steigerwaltii-3',
 'hormaechei-hormaechei',
 'hormaechei-hoffmannii-1',
 'hormaechei-hoffmannii-2',
 'roggenkampii',
 'asburiae',
 'kobei',
 'bugandensis',
 'cancerogenous',
 'ludwigii',
 'cloacae']

In [None]:
L_NORM = '../../data/processed/nmf-outputs/L_norm.csv'
A_NORM = '../../data/processed/nmf-outputs/A_norm.csv'
L_MATRIX = '../../data/processed/nmf-outputs/L.csv'
A_MATRIX = '../../data/processed/nmf-outputs/A.csv'

L = pd.read_csv(L_MATRIX, index_col=0)
L.columns = L_binarized.columns
L_norm = pd.read_csv(L_NORM, index_col=0)

A = pd.read_csv(A_MATRIX, index_col=0)
A_norm = pd.read_csv(A_NORM, index_col=0)

In [None]:
def recommended_threshold(A_norm, i):
    column_data_reshaped = A_norm.loc[i].values.reshape(-1, 1)
    
    # 3-means clustering
    kmeans = KMeans(n_clusters=3, random_state=0, n_init='auto')
    kmeans.fit(column_data_reshaped)
    labels = kmeans.labels_
    centers = kmeans.cluster_centers_
    
    # Find the cluster with the highest mean
    highest_mean_cluster = np.argmax(centers)
    
    # Binarize the row based on the cluster with the highest mean
    binarized_row = (labels == highest_mean_cluster).astype(int)
    
    # Find k-means-recommended threshold using min value that still binarizes to 1
    x = pd.Series(dict(zip(A_norm.columns, binarized_row)))
    threshold = A_norm.loc[i, x[x==1].index].min()
    
    return threshold

# Test functionality on original P matrix

In [None]:
import infer_affinities

In [None]:
A_new = infer_affinities.infer_affinities(L_norm.to_numpy(),  df_genes_complete.loc[L_binarized.index].to_numpy(), n_jobs=40)
A_new = pd.DataFrame(A_new, index = L_binarized.columns, columns = df_genes_complete.loc[L_binarized.index].columns)
A_complete_new = A_new.copy()

In [None]:
A_binarized_new = pd.DataFrame(np.zeros_like(A_new.values), index=A_norm.index, columns=A_new.columns)
A_binarized_temp =  pd.DataFrame(np.zeros_like(A_new.values), index=A_norm.index, columns=A_new.columns)


for idx in A_norm.index: # same as A_norm.index
    cond = A_new.loc[idx] >= recommended_threshold(A_norm, idx)
    A_binarized_temp.loc[idx,cond] = 1

if A_binarized_temp.loc[characterized_order].sum().max() > 1:
    print("Binarization results in multiple assingment to charactarized phylons")
else:
    print("Number of strains with an assigned charactarized phylon:", int(A_binarized_temp.loc[characterized_order].sum().sum()))
    A_binarized_new = A_binarized_temp.copy()
    print("Original A matrix strains with assigned charactarized phylon:", int(A_binarized.loc[characterized_order].sum().sum()))

In [None]:
P_new = df_genes.loc[L_binarized.index, A_binarized_new.columns].fillna(0)
P_new_recon  = L_binarized @ A_binarized_new

1 - (P_new - P_new_recon).abs().sum().sum() / P_new.shape[0] / P_new.shape[1]

# Infer Affinities for non-complete strains

In [None]:
df_genes_new = df_genes.loc[L_binarized.index, [x for x in df_genes.columns if x not in df_genes_complete.columns]].fillna(0)
df_genes_new.head()

In [None]:
A_new = infer_affinities.infer_affinities(L_norm.to_numpy(),  df_genes_new.to_numpy(), n_jobs=40)
A_new = pd.DataFrame(A_new, index = L_binarized.columns, columns = df_genes_new.columns)

## Binarize the new A matrix
NOTE: this is ongoing and there are a few issues with making this work correctly, need to decide from below methods how to binarize new strains

In [None]:
# Basic binarization with thresholds from original matrix:
A_binarized_new = pd.DataFrame(np.zeros_like(A_new.values), index=A_norm.index, columns=A_new.columns)
A_binarized_temp =  pd.DataFrame(np.zeros_like(A_new.values), index=A_norm.index, columns=A_new.columns)


for idx in A_norm.index: # same as A_norm.index
    cond = A_new.loc[idx] >= recommended_threshold(A_norm, idx)
    A_binarized_temp.loc[idx,cond] = 1

if A_binarized_temp.loc[characterized_order].sum().max() > 1:
    print("Binarization results in multiple assingment to charactarized phylons")
else:
    print("Number of strains with an assigned charactarized phylon:", int(A_binarized_temp.loc[characterized_order].sum().sum()))
    A_binarized_new = A_binarized_temp.copy()

In [None]:
P_new = df_genes.loc[L_binarized.index, A_binarized_new.columns].fillna(0)
P_new_recon  = L_binarized @ A_binarized_new

1 - (P_new - P_new_recon).abs().sum().sum() / P_new.shape[0] / P_new.shape[1]

In [None]:
import numpy as np

# Example data: Replace with your actual data
# Assume `A_binarized_new.loc[phylon_order]` contains strain counts for each category.
strain_counts = A_binarized_new.loc[phylon_order].sum(axis=1)

# Example proportions (replace with actual proportions for unchars)
unchar_proportions = {}
for unchar_phylon  in ['unchar-1', 'unchar-2', 'unchar-3', 'unchar-4']:
    unchar_proportions[unchar_phylon] = A_binarized_new.loc[characterized_order,A_binarized_new.T[A_binarized_new.T[unchar_phylon] == 1].index].idxmax().value_counts() / A_binarized_new.loc[unchar_phylon].sum()

# Initialize the figure
fig, ax = plt.subplots(figsize=(12, 8))

# Track bottom positions for stacking
bottoms = np.zeros(len(phylon_order))

# Define colors for primary categories
primary_colors = {
    'hormaechei-xiangfangensis': "Red",
    'hormaechei-oharae': "IndianRed",
    'hormaechei-steigerwaltii-2': "DarkRed",
    'hormaechei-steigerwaltii-1': "FireBrick",
    'hormaechei-steigerwaltii-3': "Tomato",
    'hormaechei-hormaechei': "Gold",
    'hormaechei-hoffmannii-1': "DarkGoldenrod",
    'hormaechei-hoffmannii-2': "Goldenrod",
    'roggenkampii': "Green",
    'asburiae': "Blue",
    'kobei': "Purple",
    'bugandensis': "Cyan",
    'cancerogenous': "Magenta",
    'ludwigii': "Lime",
    'cloacae': "Pink",
    None: "grey"  # Default for unclassified
}

# Plot each bar
for i, phylon in enumerate(phylon_order):
    if 'unchar' in phylon:  # Handle "unchar" phylons
        proportions = unchar_proportions.get(phylon, {})
        for category, proportion in proportions.items():
            height = strain_counts[i] * proportion
            ax.bar(phylon, height, bottom=bottoms[i], color=primary_colors.get(category, "grey"))
            bottoms[i] += height
    else:  # Primary categories
        ax.bar(phylon, strain_counts[i], color=primary_colors.get(phylon, "grey"))

# Customize plot
plt.xticks(ticks=range(len(phylon_order)), labels=[x.replace('unchar', 'mobile') for x in phylon_order], rotation=90)
plt.title("Assignment of WGS Strains to Phylons")
plt.xlabel("Phylons")
plt.ylabel("Number of Assigned Strains")

# Define the legend labels and their corresponding colors
legend_labels = list(primary_colors.keys())
legend_colors = list(primary_colors.values())

# Add the legend to the plot
handles = [plt.Rectangle((0,0),1,1, color=color) for color in legend_colors]
ax.legend(handles, legend_labels, title="Phylons", loc="upper right", bbox_to_anchor=(1, 1), ncols=2)


# Display
plt.savefig('../images/supplemental/inferred_affinities.svg', format='svg', dpi=600)
plt.show()


In [None]:
A_binarized_new.loc[characterized_order].sum().sum()