In [None]:
# Base imports
import os
import pickle
import re

# Compute imports
import numpy as np
import pandas as pd
import scipy
from tqdm.notebook import tqdm, trange

# Plotting imports
import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
from plotly import express as px
import matplotlib.patches as mpatches

# ML import
from sklearn.decomposition import NMF
from sklearn.metrics import mean_squared_error, median_absolute_error
from sklearn.metrics.pairwise import cosine_similarity

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rcParams['svg.fonttype'] = 'none'
matplotlib.rcParams['font.sans-serif'] = 'Arial'
matplotlib.rcParams['font.family'] = 'sans-serif'
sns.set_style('ticks')
matplotlib.rcParams['text.color'] = '#000000'
matplotlib.rcParams['axes.labelcolor'] = '#000000'
matplotlib.rcParams['xtick.color'] = '#000000'
matplotlib.rcParams['ytick.color'] = '#000000'

In [None]:
DF_GENES = '../../data/processed/cd-hit-results/sim80/Ebacter_strain_by_gene.pickle.gz'
ENRICHED_METADATA = '../../data/metadata/enriched_metadata.csv'
DF_EGGNOG = '../../data/processed/df_eggnog.csv'

DF_CORE_COMPLETE = '../../data/processed/CAR_genomes/df_core_complete.pickle'
DF_ACC_COMPLETE = '../../data/processed/CAR_genomes/df_acc_complete.pickle'
DF_RARE_COMPLETE = '../../data/processed/CAR_genomes/df_rare_complete.pickle'

L_BINARIZED = '../../data/processed/nmf-outputs/L_binarized.csv'
A_BINARIZED = '../../data/processed/nmf-outputs/A_binarized.csv'
L_MATRIX = '../../data/processed/nmf-outputs/L.csv'
A_MATRIX = '../../data/processed/nmf-outputs/A.csv'
BAKTA_ANNOTATIONS = '../../data/processed/bakta_gene_annotations.csv'

In [None]:
bakta_annotations = pd.read_csv(BAKTA_ANNOTATIONS, index_col=0)

In [None]:
gene_locs_acc = pd.read_csv('../acc_gene_location.csv', index_col=0)
gene_locs = pd.read_csv('../complete_gene_location.csv', index_col=0)

In [None]:
df_rare = pd.read_pickle(DF_RARE_COMPLETE)
df_acc = pd.read_pickle(DF_ACC_COMPLETE)
df_core = pd.read_pickle(DF_CORE_COMPLETE)

In [None]:
metadata = pd.read_csv(ENRICHED_METADATA, index_col=0, dtype='object')

display( metadata.shape, metadata.head())

In [None]:
# Load in (full) P matrix
df_genes = pd.read_pickle(DF_GENES)

# Filter metadata for Complete sequences only
metadata_complete = metadata[metadata.genome_status == 'Complete'] # filter for only Complete sequences

# Filter P matrix for Complete sequences only
df_genes_complete = df_genes[metadata_complete.genome_id].copy()
df_genes_complete.fillna(0, inplace=True) # replace N/A with 0
df_genes_complete = df_genes_complete.sparse.to_dense().astype('int8') # densify & typecast to int8 for space and compute reasons
inCompleteseqs = df_genes_complete.sum(axis=1) > 0 # filter for genes found in complete sequences
df_genes_complete = df_genes_complete[inCompleteseqs]

df_genes_complete.shape

In [None]:
# Load in eggNOG annotations
df_eggnog = pd.read_csv(DF_EGGNOG, index_col=0)
df_eggnog.fillna('-', inplace=True)

display(
    df_eggnog.shape,
    df_eggnog.head()
)

In [None]:
# Load in A_binarized matrix
A_binarized = pd.read_csv(A_BINARIZED, index_col=0)
A_binarized

In [None]:
# Load in L_binarized matrix
L_binarized = pd.read_csv(L_BINARIZED, index_col=0)
L_binarized

In [None]:
phylon_order = ['hormaechei-xiangfangensis',
 'hormaechei-oharae',
 'hormaechei-steigerwaltii-2',
 'hormaechei-steigerwaltii-1',
 'hormaechei-steigerwaltii-3',
 'hormaechei-hormaechei',
 'hormaechei-hoffmannii-1',
 'hormaechei-hoffmannii-2',
 'unchar-1',
 'unchar-2',
 'unchar-3',
 'unchar-4',
 'roggenkampii',
 'asburiae',
 'kobei',
 'bugandensis',
 'cancerogenous',
 'ludwigii',
 'cloacae']

characterized_order = ['hormaechei-xiangfangensis',
 'hormaechei-oharae',
 'hormaechei-steigerwaltii-2',
 'hormaechei-steigerwaltii-1',
 'hormaechei-steigerwaltii-3',
 'hormaechei-hormaechei',
 'hormaechei-hoffmannii-1',
 'hormaechei-hoffmannii-2',
 'roggenkampii',
 'asburiae',
 'kobei',
 'bugandensis',
 'cancerogenous',
 'ludwigii',
 'cloacae']

In [None]:
gene_order = []

# Add in zero-phylon genes
zero_cond = L_binarized.sum(axis=1) == 0
gene_order.extend(L_binarized[zero_cond].index)

# Add in single-phylon genes
for phylon in phylon_order:
    single_cond = L_binarized.sum(axis=1) == 1
    inPhylon = L_binarized[phylon] == 1
    gene_order.extend(L_binarized[inPhylon & single_cond].index)

# Add in poly-phylon genes
for num_active_phylons in trange(2, int(L_binarized.sum(axis=1).max())+1):
    num_cond = L_binarized.sum(axis=1) == num_active_phylons
    gg = sns.clustermap(L_binarized[num_cond], method='ward', metric='euclidean', col_cluster=False, yticklabels=False);
    gene_order.extend(gg.data2d.index)

In [None]:
# Main sorted clustermap

g = sns.clustermap(
    L_binarized.loc[gene_order],
    method='ward',
    metric='euclidean',
    row_cluster=False,
    yticklabels=False,
    cmap='Greys'
);

In [None]:
strain_order = []
unchar_strain_order = []


# zero-phylon strains
noPhylon = A_binarized.sum() == 0
strain_order.extend(A_binarized.sum()[noPhylon].index.tolist())

# strain lists
single_phylon_strains = A_binarized.sum()[A_binarized.sum() == 1].index
multi_phylon_strains = A_binarized.sum()[A_binarized.sum() > 1].index

for phylon in phylon_order:
    if 'unchar' in phylon:
        continue
    else:
        phylon_aff_binarized_single = A_binarized.loc[phylon, single_phylon_strains]
        phylon_aff_binarized_multi = A_binarized.loc[phylon, multi_phylon_strains]
    
        inPhylon_single = phylon_aff_binarized_single == 1
        inPhylon_multi = phylon_aff_binarized_multi == 1
    
        list1 = phylon_aff_binarized_single[inPhylon_single].index.tolist()
        list2 = phylon_aff_binarized_multi[inPhylon_multi].index.tolist()
        new_list2 = list(set(list2) - set(strain_order)) # ensures no double-counting
        
        strain_order.extend(list1)
        strain_order.extend(new_list2)

for phylon in phylon_order: # must be done after the first loop
    if 'unchar' in phylon:
        phylon_aff_binarized_single = A_binarized.loc[phylon, single_phylon_strains]
        phylon_aff_binarized_multi = A_binarized.loc[phylon, multi_phylon_strains]
    
        inPhylon_single = phylon_aff_binarized_single == 1
        inPhylon_multi = phylon_aff_binarized_multi == 1
    
        list1 = phylon_aff_binarized_single[inPhylon_single].index.tolist()
        list2 = phylon_aff_binarized_multi[inPhylon_multi].index.tolist()
        new_list1 = list(set(list1) - set(strain_order)) # ensures no double-counting
        new_list2 = list(set(list2) - set(strain_order)) # ensures no double-counting
        
        strain_order.extend(new_list1)
        strain_order.extend(new_list2)

strain_order += unchar_strain_order

# A-binarized
sns.clustermap(A_binarized.loc[phylon_order, strain_order], cmap='Greys', xticklabels=False, row_cluster=False, col_cluster=False)

# Number of rare genes per strain per phylon

In [None]:
custom_colors = [
    # Shades of red/orange/yellow
    "Red",
    "IndianRed",
    "DarkRed",
    "FireBrick",
    "Tomato",
    "Gold",
    "DarkGoldenrod",
    "Goldenrod",
    # Other species
    "Green",
    "Blue",
    "Purple",
    "Cyan",
    "Magenta",
    "Lime",
    "Pink",
]

In [None]:
def get_strains(phylon, A_binarized = A_binarized):
    phylon_membership = A_binarized.loc[phylon]
    return (phylon_membership[phylon_membership == 1]).index

In [None]:
df_rare

In [None]:
data = []
for phylon in characterized_order:
    strains = get_strains(phylon)  # Your function to get strains for the phylon
    rare_gene_counts = df_rare[strains].sum()  # Summing rare gene counts for each strain
    for strain, count in rare_gene_counts.items():
        data.append({'Phylon': phylon, 'Strain': strain, 'RareGeneCount': count})

# Convert to a dataframe
plot_data = pd.DataFrame(data)

# Create the boxenplot
plt.figure(figsize=(10, 6))
sns.boxenplot(data=plot_data, x='Phylon', y='RareGeneCount', palette=custom_colors)
plt.xticks(rotation=45, ha='right')
plt.xlabel('Phylon')
plt.ylabel('Rare Genes per Strain')
plt.title('Distribution of Rare Genes per Strain in Each Phylon')
plt.tight_layout()
plt.show()

# AMR rare genes across phylons

In [None]:
amr = pd.read_csv('../../data/processed/amrfinder/output', sep = '\t')
amr['Protein identifier'] = amr['Protein identifier'].apply(lambda x: x.split('A')[0])
amr = amr.sort_values('% Coverage of reference sequence')
amr = amr.drop_duplicates(subset='Protein identifier', keep="last")

In [None]:
rare_amr_genes = [x for x in amr['Protein identifier'] if x in df_rare.index]

amr_presence = pd.DataFrame(np.zeros((len(rare_amr_genes), len(characterized_order))), index=rare_amr_genes, columns=characterized_order)
amr_counts = pd.DataFrame(np.zeros((len(rare_amr_genes), len(characterized_order))), index=rare_amr_genes, columns=characterized_order)
for phylon in characterized_order:
    strains = get_strains(phylon)
    amr_presence[phylon] = df_rare.loc[rare_amr_genes, strains].sum(axis=1)/len(strains)
    amr_counts[phylon] = df_rare.loc[rare_amr_genes, strains].sum(axis=1)

data = []
for phylon in characterized_order:
    strains = get_strains(phylon)  # Your function to get strains for the phylon
    rare_gene_counts = df_rare.loc[rare_amr_genes, strains].sum()  # Summing rare gene counts for each strain
    for strain, count in rare_gene_counts.items():
        data.append({'Phylon': phylon, 'Strain': strain, 'RareGeneCount': count})
plot_data = pd.DataFrame(data)


In [None]:
df_genes_complete.loc[([x for x in amr['Protein identifier'] if x in df_genes_complete.index])].sum().sum()

In [None]:
df_genes_complete.loc[(rare_amr_genes)].sum().sum()

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import gridspec

# Assuming `amr_presence` is already computed
row_sums = amr_counts.sum(axis=1)  # Sum per row (values per rare AMR gene)
col_sums = amr_presence.sum(axis=0)  # Sum per column (values per phylon)

# Retrieve Subclass values for the rare AMR genes
subclass_values = amr.set_index('Protein identifier').loc[rare_amr_genes].Class

# Sort the rare AMR genes by their subclass
sorted_rare_amr_genes = subclass_values.sort_values().index

# Recompute row sums for the sorted rare AMR genes
sorted_row_sums = row_sums[sorted_rare_amr_genes]

# Set up the figure and gridspec
fig = plt.figure(figsize=(12, 12))
gs = gridspec.GridSpec(2, 2, width_ratios=[1, 0.2], height_ratios=[0.2, 1], wspace=0.05, hspace=0.05)

# Heatmap
ax_heatmap = plt.subplot(gs[1, 0])
sns.heatmap(
    amr_presence.loc[sorted_rare_amr_genes],
    cmap='Greys',
    cbar=False,  # Disable default colorbar
    ax=ax_heatmap
)
ax_heatmap.set_xlabel('Phylon')
ax_heatmap.set_ylabel('Rare AMR Genes Percentage Presence')
ax_heatmap.set_yticks([])  # Remove x-ticks
# ax_heatmap.set_xticks([])
# # Add horizontal lines to separate subclasses
# subclass_boundaries = subclass_values.loc[sorted_rare_amr_genes].values
# for i in range(1, len(subclass_boundaries)):
#     if subclass_boundaries[i] != subclass_boundaries[i - 1]:
#         ax_heatmap.axhline(i, color='black', linewidth=.5, alpha =.5, linestyle='-')

# Top boxplot (column sums) aligned with heatmap
ax_top_barplot = plt.subplot(gs[0, 0])
sns.boxenplot(data=plot_data, x='Phylon', y='RareGeneCount', palette=custom_colors, ax=ax_top_barplot)
ax_top_barplot.set_xticks([])  # Remove x-ticks
ax_top_barplot.set_ylabel('# AMR Genes')
ax_top_barplot.set_title('Distribution of rare AMR Genes per strain in Phylon')
ax_top_barplot.tick_params(axis='x', which='both', bottom=False)

# Align the top barplot width with the heatmap
pos_heatmap = ax_heatmap.get_position()
pos_top_bar = ax_top_barplot.get_position()
ax_top_barplot.set_position([pos_heatmap.x0, pos_top_bar.y0, pos_heatmap.width, pos_top_bar.height])

# Side barplot (row sums)
unique_subclasses = subclass_values.sort_values().unique()
subclass_palette = sns.color_palette("tab20", len(unique_subclasses))
subclass_colors = subclass_values.map(dict(zip(unique_subclasses, subclass_palette)))

ax_side_barplot = plt.subplot(gs[1, 1])
sns.barplot(
    y=amr_presence.loc[sorted_rare_amr_genes].index,
    x=row_sums.loc[sorted_rare_amr_genes],
    ax=ax_side_barplot,
    orient='h',
    dodge=False,  # Prevent spacing between bars
    palette=subclass_colors.loc[sorted_rare_amr_genes].values,
    width=1
)
# ax_side_barplot.set_xticks([])  # Remove x-ticks
ax_side_barplot.set_yticks([])  # Remove y-ticks
ax_side_barplot.set_xlabel('Number of Strains')
ax_side_barplot.set_ylabel('Genes')
ax_side_barplot.tick_params(axis='y', which='both', left=False)

# Add the colorbar manually
cbar_ax = fig.add_axes([pos_heatmap.x1 - 0.73, pos_heatmap.y0, 0.02, pos_heatmap.height])
sm = plt.cm.ScalarMappable(cmap='Greys', norm=plt.Normalize(vmin=amr_presence.min().min(), vmax=amr_presence.max().max()))
cbar = fig.colorbar(sm, cax=cbar_ax)

# Create the legend for the subclass colors
handles = []
for subclass, color in zip(unique_subclasses, subclass_palette):
    handles.append(plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, label=subclass))

ax_side_barplot.legend(handles=handles, title='Antibiotic Type', bbox_to_anchor=(1.1, 1), loc='upper left', fontsize=12, ncols=1)

phylon_handles = []
for phylon, color in zip(characterized_order, custom_colors):
    phylon_handles.append(
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, label=phylon)
    )

# Add the phylon legend to the top boxplot
ax_top_barplot.legend(
    handles=phylon_handles, 
    title='Phylon', 
    bbox_to_anchor=(1.05, 1),  # Position legend to the right of the top boxplot
    loc='upper left', 
    fontsize=10, 
    ncol=3
)

plt.savefig('amr_presence_plot.svg', format='svg')
plt.show()


In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import gridspec

# Assuming `amr_presence` is already computed
row_sums = amr_counts.sum(axis=1)  # Sum per row (values per rare AMR gene)
col_sums = amr_presence.sum(axis=0)  # Sum per column (values per phylon)

# Retrieve Subclass values for the rare AMR genes
subclass_values = amr.set_index('Protein identifier').loc[rare_amr_genes].Class

# Sort the rare AMR genes by their subclass
sorted_rare_amr_genes = subclass_values.sort_values().index

# Recompute row sums for the sorted rare AMR genes
sorted_row_sums = row_sums[sorted_rare_amr_genes]

# Set up the figure and gridspec
fig = plt.figure(figsize=(12, 12))
gs = gridspec.GridSpec(2, 2, width_ratios=[1, 0.2], height_ratios=[0.075, 1], wspace=0.05, hspace=0.05)

# Heatmap
ax_heatmap = plt.subplot(gs[1, 0])
sns.heatmap(
    amr_presence.loc[sorted_rare_amr_genes],
    cmap='Greys',
    cbar=False,  # Disable default colorbar
    ax=ax_heatmap
)
ax_heatmap.set_xlabel('Phylon')
ax_heatmap.set_ylabel('Rare AMR Genes Percentage Presence')
ax_heatmap.set_yticks([])  # Remove y-ticks
ax_heatmap.set_xticks([])  # Remove x-ticks

# Add blocks of color representing the phylons (adjust height to 1/3)
phylon_colors = dict(zip(characterized_order, custom_colors))

# Plot colored blocks instead of the top barplot (adjust height to 1/3)
ax_top_color_blocks = plt.subplot(gs[0, 0])

block_height = 1 / 3  # Set block height to 1/3 of the original size

# Loop through phylon list to create color blocks
for i, phylon in enumerate(characterized_order):
    ax_top_color_blocks.add_patch(plt.Rectangle(
        (i, 0), 1, block_height, color=phylon_colors[phylon], lw=0))  # Add blocks with phylon colors

# Adjust the limits of the plot to fit the smaller blocks
ax_top_color_blocks.set_xlim(0, len(characterized_order))  # Adjust x limits to fit blocks
ax_top_color_blocks.set_ylim(0, block_height)  # Set y-limits to fit smaller block height
ax_top_color_blocks.set_xticks([])  # Remove x-ticks
ax_top_color_blocks.set_yticks([])  # Remove y-ticks
ax_top_color_blocks.set_title('Phylon Representation')

# Side barplot (row sums)
unique_subclasses = subclass_values.sort_values().unique()
subclass_palette = sns.color_palette("tab20", len(unique_subclasses))
subclass_colors = subclass_values.map(dict(zip(unique_subclasses, subclass_palette)))

ax_side_barplot = plt.subplot(gs[1, 1])
sns.barplot(
    y=amr_presence.loc[sorted_rare_amr_genes].index,
    x=row_sums.loc[sorted_rare_amr_genes],
    ax=ax_side_barplot,
    orient='h',
    dodge=False,  # Prevent spacing between bars
    palette=subclass_colors.loc[sorted_rare_amr_genes].values,
    width=1
)
ax_side_barplot.set_yticks([])  # Remove y-ticks
ax_side_barplot.set_xlabel('Number of Strains')
ax_side_barplot.set_ylabel('Genes')
ax_side_barplot.tick_params(axis='y', which='both', left=False)

# Add the colorbar manually
cbar_ax = fig.add_axes([ax_heatmap.get_position().x1 - 0.73, ax_heatmap.get_position().y0, 0.02, ax_heatmap.get_position().height])
sm = plt.cm.ScalarMappable(cmap='Greys', norm=plt.Normalize(vmin=amr_presence.min().min(), vmax=amr_presence.max().max()))
cbar = fig.colorbar(sm, cax=cbar_ax)

# Create the legend for the subclass colors
handles = []
for subclass, color in zip(unique_subclasses, subclass_palette):
    handles.append(plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, label=subclass))

ax_side_barplot.legend(handles=handles, title='Antibiotic Type', bbox_to_anchor=(1.1, 1), loc='upper left', fontsize=12, ncols=1)

# Create the legend for the phylons
phylon_handles = []
for phylon, color in zip(characterized_order, custom_colors):
    phylon_handles.append(
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, label=phylon)
    )

# Add the phylon legend to the top color blocks
ax_top_color_blocks.legend(
    handles=phylon_handles, 
    title='Phylon', 
    bbox_to_anchor=(1.05, 1.9),  # Position legend to the right of the top color blocks
    loc='upper left', 
    fontsize=10, 
    ncol=3
)

plt.savefig('amr_presence_plot_with_smaller_color_blocks_height.svg', format='svg')
plt.show()
