In [36]:
import pandas as pd
import os
from plotnine import *
import seaborn as sns
from plotnine.themes import theme_seaborn
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import linkage, leaves_list
import matplotlib.patches as mpatches
import viridis

In [None]:

#define fuction to determine gene type (BLEE, CPO, Other)
def determine_gene_type(gene):
    if gene in list_BLEE:
        return "BLEE"
    elif gene in list_CPO:
        return "CPO"
    else:
        return "Other"
    
os.chdir("/data/bioinfo_doc/research/20240409_TFM-CMARTINEZ_IC-SM_T/20240711_RESULTS/reports_modified")


#Import reference lists.
data_BLEE = pd.read_csv("20240805eq_BLEE.tsv", sep = "\t")
data_CPO = pd.read_csv("20240805eq_CPO.tsv", sep = "\t")

list_BLEE = data_BLEE['blee_modified'].to_list()
list_CPO = data_CPO['cpo_modified'].to_list()

#import output data
ariba_CARD =  pd.read_csv("20240805ariba_CARD_modified.tsv", sep = "\t")
ariba_NCBI = pd.read_csv("20240805ariba_NCBI_modified.tsv", sep = "\t")
ariba_RESFINDER = pd.read_csv("20240805ariba_RESFINDER_modified.tsv", sep = "\t")
abricate_CARD = pd.read_csv("20240805abricate_CARD_modified.tsv", sep = "\t")
abricate_NCBI = pd.read_csv("20240805abricate_NCBI_modified.tsv", sep = "\t")
abricate_RESFINDER = pd.read_csv("20240805abricate_RESFINDER_modified.tsv", sep = "\t")
amrfinderplus_NCBI = pd.read_csv("20240805amrfinderplus_NCBI_modified.tsv", sep = "\t")
rgi_CARD = pd.read_csv("20240805rgi_CARD_modified.tsv", sep = "\t")

os.chdir("../")

#define a dictionary to ease iteration through reports.
all_reports = {
    "ariba_CARD": ariba_CARD,
    "ariba_NCBI": ariba_NCBI,
    "ariba_RESFINDER": ariba_RESFINDER,
    "abricate_CARD": abricate_CARD,
    "abricate_NCBI": abricate_NCBI,
    "abricate_RESFINDER": abricate_RESFINDER,
    "amrfinderplus_NCBI": amrfinderplus_NCBI,
    "rgi_CARD": rgi_CARD
}

build df for median seqID

In [None]:
#filter to see if the found gene is a BLEE or CPO.
for key, value in all_reports.items():
   all_reports[key]['gene_type'] = all_reports[key]['gene_symbol_modified'].apply(determine_gene_type)
   
#filter for BLEEs and CPOs
all_reports_blees_cpos = {}
for key in all_reports.keys():
    all_reports_blees_cpos[key] = all_reports[key][(all_reports[key]['gene_type'] == "BLEE") | (all_reports[key]['gene_type'] == "CPO")]
# Initialize an empty set to store unique gene_symbol_modified values
unique_genes = set()

# Iterate through each DataFrame in the all_reports dictionary
for key, df in all_reports_blees_cpos.items():
    # Filter the DataFrame for rows where sequence_identity is 99%
    filtered_df = df[df['sequence_identity'] >= 99]
    # Add the unique gene_symbol_modified values to the set
    unique_genes.update(filtered_df['gene_symbol_modified'].unique())

# Convert the set to a list if needed
unique_genes_list = list(unique_genes)

# Initialize an empty DataFrame with unique genes as rows
result_df = pd.DataFrame(unique_genes_list, columns=['gene_symbol_modified'])

# Iterate through each DataFrame in the all_reports dictionary again to populate the result DataFrame
for key, df in all_reports.items():
    # Filter the DataFrame for rows where gene_symbol_modified is in the unique_genes_list
    filtered_df = df[df['gene_symbol_modified'].isin(unique_genes_list)]
    # Group by gene_symbol_modified and get the maximum sequence_identity for each gene
    median_seq_identity = filtered_df.groupby('gene_symbol_modified')['sequence_identity'].median()
    # Convert the Series to a DataFrame
    median_seq_identity_df = median_seq_identity.reset_index()
    # Rename the sequence_identity column to the tool name
    median_seq_identity_df = median_seq_identity_df.rename(columns={'sequence_identity': key})
    # Merge the median_seq_identity_df with the result DataFrame
    result_df = result_df.merge(median_seq_identity_df, how='left', on='gene_symbol_modified')

for row in result_df.iterrows():
    result_df['gene_type'] = result_df['gene_symbol_modified'].apply(determine_gene_type)


# Print the result DataFrame
result_df.head()

# Save the result DataFrame to a TSV file
result_df.to_csv('output_dir/20240809_result_df_median_seqID_tool.tsv', sep='\t', index=True)

Build heatmap by median seqID

In [None]:

def label_color_mapping(df):
    color_mapping = {
        "BLEE": "brown",
        "CPO": "blue",
    }
    
    label_colors = {}
    for _, row in result_df.iterrows():
        label_colors[row['gene_symbol_modified']] = color_mapping.get(row['gene_type'], "black")
    return label_colors

label_colors = label_color_mapping(result_df)

# split the X-axis labels
def split_label(label):
    label = label.split('_')
    if len(label) == 2:
        return f"{label[0].upper()}\n{label[1]}"
    return label.replace('_', '\n')


# Step 1: Melt the DataFrame
melted_df = result_df.melt(id_vars=['gene_symbol_modified'], var_name='Tool', value_name='sequence_identity')

# Step 2: Pivot the melted DataFrame
heatmap_data = melted_df.pivot(index='gene_symbol_modified', columns='Tool', values='sequence_identity')
heatmap_data = heatmap_data.drop(columns=['gene_type'])

# Ensure all values are numeric
heatmap_data = heatmap_data.apply(pd.to_numeric, errors='coerce').fillna(0)

# Specify tool order in X-axis
tool_order = ['ariba_CARD', 'ariba_NCBI', 'ariba_RESFINDER', 'abricate_CARD', 'abricate_NCBI', 'abricate_RESFINDER', 'amrfinderplus_NCBI', 'rgi_CARD']

# Reorder the columns based on the tool order
heatmap_data = heatmap_data[tool_order]

# custom annotation array: remove 0 values
annot = heatmap_data.applymap(lambda x: f'{x:.2f}')

# Create the clustermap
clustermap = sns.clustermap(
    heatmap_data,
    cmap='coolwarm',
    annot=annot,
    fmt="",
    cbar_kws={'label': 'Mediana SeqID%'},
    figsize=(16, 18),
    method='average'
)

# Customize the plot
clustermap.ax_heatmap.set_title('Mediana del SeqID% de las BLEEs/CPs detectados por, al menos, un método ', fontsize=16)
clustermap.ax_heatmap.set_xlabel('Método', fontsize=14)
clustermap.ax_heatmap.set_ylabel('ARG', fontsize=14)

clustermap.ax_heatmap.set_xticklabels([split_label(label.get_text()) for label in clustermap.ax_heatmap.get_xticklabels()], rotation=0, ha='center', fontsize=12)
clustermap.ax_heatmap.set_yticklabels(clustermap.ax_heatmap.get_yticklabels(), fontsize=12)

for label in clustermap.ax_heatmap.get_yticklabels():
    if label.get_text() in list_BLEE:
        label.set_bbox(dict(facecolor='#38e6c7', edgecolor='none', pad=2))
        label.set_color('black')
    else:
        label.set_bbox(dict(facecolor='#f6d2f3', edgecolor='none', pad=2))
        label.set_color('black')

# Increase the size of the color bar label and ticks
cbar = clustermap.ax_heatmap.collections[0].colorbar
cbar.ax.tick_params(labelsize=12)  # Increase tick size
cbar.set_label('Sequence Identity', size=14)  # Increase label size

for label in clustermap.ax_heatmap.get_yticklabels():
    if label.get_text() in list_BLEE:
        label.set_bbox(dict(facecolor='#38e6c7', edgecolor='none', pad=2))
        label.set_color('black')
    else:
        label.set_bbox(dict(facecolor='#f6d2f3', edgecolor='none', pad=2))
        label.set_color('black')

cbar = clustermap.ax_heatmap.collections[0].colorbar
cbar.ax.tick_params(labelsize=12)
cbar.set_label('Median SeqID', size=12)

legend_patches = [
    mpatches.Patch(color='#38e6c7', label='BLEE'),
    mpatches.Patch(color='#f6d2f3', label='CP')
]

clustermap.ax_heatmap.legend(handles=legend_patches, title='Gene Type', bbox_to_anchor=(1.15, 1), loc='upper left', fontsize=12, title_fontsize=14)


# Save the plot to a file
# clustermap.savefig("output_dir/20240809_clustered_heatmap_median_seqID_seaborn.png", dpi = 300, bbox_inches='tight')
# clustermap.savefig("output_dir/20240809_clustered_heatmap_median_seqID_seaborn.svg", dpi = 300, bbox_inches='tight')


clustermap.show()