In [None]:
from src.data_paths import MERGED, PROTEIN_CLASSES
import pandas as pd
import numpy as np
import re
import matplotlib.pyplot as plt
import seaborn as sns

class Matheo:
    def __init__(self):
        self.merged_df, self.protein_classes = self.load_dataset()
        self.chemical_param_list = ["ki", "kd", "ec50", "ic50"]

    def load_dataset(self):
        merged_df = pd.read_pickle(MERGED)
        protein_classes = pd.read_pickle(PROTEIN_CLASSES)

        return merged_df, protein_classes
    
    def initialize(self):
        highly_studied_proteins = [
            'Cytochrome P450 3A4', 
            'Epidermal growth factor receptor', 
            'Proto-oncogene tyrosine-protein kinase Src', 
            'Vascular endothelial growth factor receptor 2', 
            'Adenosine receptor A2a', 'Cytochrome P450 2C9', 
            'Cytochrome P450 1A2', 'Cytochrome P450 2C19', 
            'Cytochrome P450 2D6', 
            'Prostaglandin G/H synthase 1', 
            'Prostaglandin G/H synthase 2',
        ]
        # create lists of the relevant columns to identify cancerous proteins
        cancer_proteins = list(self.protein_classes["Gene"].dropna().values)
        genes_uniprot = list(self.protein_classes["Uniprot"].astype(str).dropna().values)

        # joins all targeted proteins 
        pattern_protein_names = '|'.join(rf"\b{re.escape(term)}\b" for term in cancer_proteins)
        pattern_gene_names =  '|'.join(rf"\b{re.escape(term)}\b" for term in genes_uniprot)
        pattern_highly_studied =  '|'.join(rf"\b{re.escape(term)}\b" for term in highly_studied_proteins)

        #test
        #pattern_protein_names

        ### extract columns that contain cancer_keywords 
        filtered_df = self.merged_df[
            self.merged_df['target_name'].isin(highly_studied_proteins)
        ]
        filtered_df.reset_index(inplace=True)

        ### Create a new column to distinguish rows with and without "drugbank_drug_name"
        filtered_df.loc[:,'drugbank_drug_name_present'] = filtered_df['drugbank_drug_name'].notna().map({True: 'With DrugBank Drug', False: 'Without DrugBank Drug'})
    
    def plot_binding_affinity_corr_mtx(self):
        ### Get only the cancer related proteins
        chemical_param_for_corr = self.filtered_df[self.chemical_param_list]
        chemical_param_for_corr

        # Get all proteins
        merged_df_chemical_param = self.merged_df[self.chemical_param_list]

        # Spearman's correlaiton coefficient 
        correlation_matrix = merged_df_chemical_param.corr(method='pearson')

        # Plot the heatmap of the correlation matrix
        plt.figure(figsize=(15, 15))  # Set the size of the plot
        sns.heatmap(correlation_matrix, annot=False, cmap='coolwarm', fmt='.2f', cbar=True, square=True)
        plt.title('Correlation Matrix Heatmap of affinity parameters', size=20)
        plt.show()
    
    def plot_binding_measurements_to_drug_presence(self):
        custom_palette = {
            "Without DrugBank Drug": "royalblue",    # Assign "blue" to "Present"
            "With DrugBank Drug": "orange",   # Assign "orange" to "Absent"
        }

        BOXPLOT_MINIMUM_COUNT = 5

        fig, axes = plt.subplots(2, 2, figsize=(20, 15))

        axes = axes.flatten()

        for i, col in enumerate(self.chemical_param_list):
            data = self.filtered_df[(self.filtered_df[col].notna())].drop_duplicates()

            counts = data.groupby(["target_name", "drugbank_drug_name_present"]).size().unstack(fill_value=0)
            valid_targets = counts[
                (counts["Without DrugBank Drug"] >= BOXPLOT_MINIMUM_COUNT) & (counts["With DrugBank Drug"] >= BOXPLOT_MINIMUM_COUNT)
            ].index
            filtered_data = data[data["target_name"].isin(valid_targets)]


            if not filtered_data.empty:

                sns.boxplot(
                    data=filtered_data,
                    x="target_name",
                    y=col,
                    hue="drugbank_drug_name_present",
                    ax=axes[i], 
                    palette=custom_palette,
                )
                
                axes[i].set_yscale("log")
                axes[i].tick_params(axis='x', rotation=30)
                for tick in axes[i].get_xticklabels():
                    tick.set_ha('right')

                axes[i].set_title(f"{col} by Protein and DrugBank Drug Presence", fontsize=16)
                axes[i].set_ylabel(f"{col} (log scale)", fontsize=10)
                axes[i].set_xlabel("")
                
                axes[i].legend(title="", fontsize=10, loc='upper right')
                axes[i].grid(axis="y")

        plt.tight_layout()
        plt.show()


