In [None]:
# Import libraries
import json
import logging
import math
import networkx as nx
import numpy as np
import os
import pickle as pkl
import random
import re
import shutil
import statistics
import subprocess
import sys
import tempfile
import time
import warnings
import pandas as pd
import esm
#import mdtraj as md
from Bio import SeqIO
from Bio.PDB import PDBParser
from datetime import datetime
from IPython.display import display, clear_output
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.axes._axes import _log as matplotlib_axes_logger
from matplotlib.gridspec import GridSpec
from matplotlib.colors import Normalize
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from scipy.stats import gmean, pearsonr
from scipy.optimize import curve_fit
import glob
from sklearn.decomposition import PCA
import umap
from sklearn.preprocessing import LabelEncoder

# Setting initial options
warnings.filterwarnings('ignore')
pd.options.mode.chained_assignment = None 
matplotlib_axes_logger.setLevel('INFO')

DESIGN_COUNT = {}

if GRID:
    USERNAME = os.getlogin()
#     FOLDER_HOME = f'{os.getcwd()}/{DESIGN_FOLDER}'
# if BLUEPEBBLE:
#     FOLDER_HOME = f'{os.getcwd()}/{DESIGN_FOLDER}'
# if BACKGROUND_JOB:
#     FOLDER_HOME = f'{os.getcwd()}/{DESIGN_FOLDER}'

FOLDER_HOME = f'{os.getcwd()}/{DESIGN_FOLDER}'

os.makedirs(FOLDER_HOME, exist_ok=True)
FOLDER_INPUT = f'{os.getcwd()}/Input'
if not os.path.isdir(FOLDER_INPUT): print("ERROR! Input folder missing!")
LOG_FILE = f'{FOLDER_HOME}.log'
ALL_SCORES_CSV = f'{FOLDER_HOME}/all_scores.csv'
VARIABLES_JSON  = f'{FOLDER_HOME}/variables.json'

# Configure logging file
log_format = '%(asctime)s - %(levelname)s - %(message)s'
date_format = '%Y-%m-%d %H:%M:%S'

# Remove all handlers associated with the root logger
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

# Basic configuration for logging to a file
logging.basicConfig(filename=LOG_FILE, level=logging.DEBUG, format=log_format, datefmt=date_format)
#logging.basicConfig(filename=LOG_FILE, level=logging.INFO, format=log_format, datefmt=date_format)

# Create a StreamHandler for console output
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(logging.Formatter(log_format, datefmt=date_format))

# Add the console handler to the root logger
logging.getLogger().addHandler(console_handler)

# main functions - running

In [None]:
def controller(RESET=False, EXPLORE=False, PROMPT=True, UNBLOCK_ALL=False, 
               PRINT_VAR=True, PLOT_DATA=True, 
               BLUEPEBBLE=False, GRID=True):
    
    # Main AI.zymes functions. Controls the whole design process

    # Startup, will only be executed once in the beginning
    setup_aizymes(RESET, EXPLORE, PROMPT) 
    
    # Check if Startup is done, if done, read in all_scores_df
    all_scores_df = startup_controller(UNBLOCK_ALL, 
                                       RESET,
                                       PRINT_VAR=PRINT_VAR, 
                                       PLOT_DATA=PLOT_DATA)
    
    while not os.path.exists(os.path.join(FOLDER_HOME, str(MAX_DESIGNS))):

        # Check how many jobs are currently running
        # **
        num_running_jobs = check_running_jobs()
        
        if num_running_jobs >= MAX_JOBS: 
            all_scores_df = update_scores(all_scores_df)
            time.sleep(20)
            
        else:
                        
            # Update scores
            all_scores_df = update_scores(all_scores_df)
            
            # Check if parent designs are done, if not, start design
            parent_done, all_scores_df = start_parent_design(all_scores_df)

            if parent_done:
                # Boltzmann Selection
                selected_index = boltzmann_selection(all_scores_df)

                # Decide Fate of selected index
                if selected_index is not None:
                    all_scores_df = start_calculation(all_scores_df, selected_index)
    
        time.sleep(1)
    
    all_scores_df = update_scores(all_scores_df)
    print(f"Stopped because {os.path.join(FOLDER_HOME, str(MAX_DESIGNS))} exists.")

def check_running_jobs():
    
    if GRID:
        jobs = subprocess.check_output(["qstat", "-u", USERNAME]).decode("utf-8").split("\n")
        jobs = [job for job in jobs if SUBMIT_PREFIX in job]
        return len(jobs)
        
    if BLUEPEBBLE:
        jobs = subprocess.check_output(["squeue","--me"]).decode("utf-8").split("\n")
        jobs = [job for job in jobs if SUBMIT_PREFIX in job]
        return len(jobs)
        
    if BACKGROUND_JOB:
        with open(f'{FOLDER_HOME}/n_running_jobs.dat', 'r'): jobs = int(f.read())
        return jobs
    
    if ABBIE_LOCAL:
        return 0

def update_potential(score_type:str, index:int, all_scores_df:pd.DataFrame):
    '''Creates a <score_type>_potential.dat file in FOLDER_HOME/<index> 
    If latest score comes from Rosetta Relax - then the score for variant <index>
    will be added to the <score_type>_potential.dat of the parent index 
    and the <score_type>_potential value of the dataframe for the parent 
    will be updated with the average of the parent and child scores. 
    Parameters:
    - score_type(str): Type of score to update, one of these options: total, interface, catalytic, efield
    - index (int): variant index to update
    - all_scores_df (pd.DataFrame): scores dataframe
    Returns: 
    - all_scores_df (pd.DataFrame): updated dataframe'''
    
    score = all_scores_df.at[index, f'{score_type}_score']

    score_taken_from = all_scores_df.at[index, 'score_taken_from']

    parent_index = all_scores_df.at[index, "parent_index"] #parent index is stored as a string because can be "Parent"

    filename = f"{FOLDER_HOME}/{index}/{score_type}_potential.dat"

    parent_filename = f"{FOLDER_HOME}/{parent_index}/{score_type}_potential.dat"

    ## overwrites contents of filename - always overwrite the current index - only append for parent
    with open(filename, "w") as f: 
            f.write(str(score))

    all_scores_df.at[index, f'{score_type}_potential'] = score

    if score_taken_from == "Relax" and parent_index != "Parent":
        #Add new index scores to parent potential file (unless parent index is "Parent")

        ## appends to parent_filename
        with open(parent_filename, "a") as f: 
            f.write(f"\n{str(score)}")
        with open(parent_filename, "r") as f:
            potentials = f.readlines()
        
        all_scores_df.at[parent_index, f'{score_type}_potential'] = np.average([float(i) for i in potentials])

    all_scores_df = all_scores_df.dropna(subset=['index'])
    
    return all_scores_df
        
def update_scores(all_scores_df):
    
    for _, row in all_scores_df.iterrows():

        index = int(row['index'])
        parent_index = row['parent_index']
          
        # do NOT update score if score was taken from a relax file. Prevents repeated scoring!
        if row['score_taken_from'] == 'Relax': continue
            
            
        # default scorefile path
        score_file_path = f"{FOLDER_HOME}/{int(index)}/score_rosetta_relax.sc"
        
        # default pdb file path
        if row["design_method"] == "ProteinMPNN":
            pdb_path = f"{FOLDER_HOME}/{int(index)}/{WT}_Rosetta_Relax_{int(index)}.pdb"
        else:
            pdb_path = f"{FOLDER_HOME}/{int(index)}/{WT}_Rosetta_Design_{int(index)}.pdb"
        
        if not os.path.exists(score_file_path):

            # change scorefile path if run is a RosettaDesign and if score_rosetta_relax.sc does not exist
            if row['design_method'] == "RosettaDesign":
                score_file_path = f"{FOLDER_HOME}/{int(index)}/score_rosetta_design.sc" 

                # do NOT update score if score was taken from a design file. Prevents repeated scoring!
                if row['score_taken_from'] == 'Design': continue

        # do NOT update score if score_file_path does not exist. Usually means job is not done.
        if not os.access(score_file_path, os.R_OK) or not os.access(pdb_path, os.R_OK): continue
  
        if score_file_path == f"{FOLDER_HOME}/{int(index)}/score_rosetta_relax.sc":
            all_scores_df.at[index, 'score_taken_from'] = 'Relax'
        if score_file_path == f"{FOLDER_HOME}/{int(index)}/score_rosetta_design.sc":
            all_scores_df.at[index, 'score_taken_from'] = 'Design'
           
        with open(score_file_path, "r") as f: scores = f.readlines()
        
        if len(scores) < 3: continue # if the timing is bad, the score file is not fully written. Check if len(scores) > 2!
        
        headers = scores[1].split()
        scores  = scores[2].split()

        catalytic_score = 0.0
        interface_score = 0.0
        for idx_headers, header in enumerate(headers):
            if header == 'total_score':                total_score      = float(scores[idx_headers])
            if header == 'interface_delta_X':          interface_score += float(scores[idx_headers])
            if header in ['if_X_angle_constraint', 
                          'if_X_atom_pair_constraint', 
                          'if_X_dihedral_constraint']: interface_score -= float(scores[idx_headers])   
            #Use 6-3-2 weighting when calculating the catalytic score
            if header in ['atom_pair_constraint']:     catalytic_score += float(scores[idx_headers])       
            if header in ['angle_constraint']:         catalytic_score += float(scores[idx_headers])       
            if header in ['dihedral_constraint']:      catalytic_score += float(scores[idx_headers])  

        efield_score, index_efields_dict = calc_efields_score(pdb_path)  

        update_efieldsdf(index, index_efields_dict)              

        # Update scores
        all_scores_df.at[index, 'total_score']     = total_score
        all_scores_df.at[index, 'interface_score'] = interface_score                
        all_scores_df.at[index, 'catalytic_score'] = catalytic_score
        all_scores_df.at[index, 'efield_score'] = efield_score
        
        # This is just for book keeping. AIzymes will always use the most up_to_date scores saved above
        if score_file_path == f"{FOLDER_HOME}/{int(index)}/score_rosetta_relax.sc":
            all_scores_df.at[index, 'relax_total_score']     = total_score
            all_scores_df.at[index, 'relax_interface_score'] = interface_score                
            all_scores_df.at[index, 'relax_catalytic_score'] = catalytic_score
            all_scores_df.at[index, 'relax_efield_score'] = efield_score
        if score_file_path == f"{FOLDER_HOME}/{int(index)}/score_rosetta_design.sc":
            all_scores_df.at[index, 'design_total_score']     = total_score
            all_scores_df.at[index, 'design_interface_score'] = interface_score                
            all_scores_df.at[index, 'design_catalytic_score'] = catalytic_score
            all_scores_df.at[index, 'design_efield_score'] = efield_score
        
        for score_type in ['total', 'interface', 'catalytic', 'efield']:     

            all_scores_df = update_potential(score_type = score_type,
                                             index= index, 
                                             all_scores_df = all_scores_df,)   

        logging.info(f"Updated scores and potentials of index {int(index)}.")
        if all_scores_df.at[index, 'score_taken_from'] == 'Relax' and all_scores_df.at[index, 'parent_index'] != "Parent":
            logging.info(f"Adjusted potentials of index parent {all_scores_df.at[index, 'parent_index']} (parent of index {int(index)}).")
        
        #unblock index if relaxed file exists
        if all_scores_df.at[int(index), "blocked"] == True:
            if f"{WT}_Rosetta_Relax_{int(index)}.pdb" in os.listdir(os.path.join(FOLDER_HOME, str(int(index)))):
                all_scores_df.at[index, "blocked"] = False
                logging.debug(f"Unblocked index {int(index)}.")

     
        # Update catalytic residues
        all_scores_df = save_cat_res_into_all_scores_df(all_scores_df, index, pdb_path)

        # Update sequence and mutations
        reference_sequence = extract_sequence_from_pdb(f"{FOLDER_INPUT}/{WT}.pdb")
        current_sequence = extract_sequence_from_pdb(pdb_path)
        mutations = sum(1 for a, b in zip(current_sequence, reference_sequence) if a != b)
        all_scores_df['sequence'] = all_scores_df['sequence'].astype('object')
        all_scores_df.at[index, 'sequence']  = current_sequence
        all_scores_df.at[index, 'mutations'] = int(mutations)


        if index % 1000 == 0:
            save_all_scores_df(all_scores_df)
    
    save_all_scores_df(all_scores_df)

    return all_scores_df

def normalize_scores(unblocked_all_scores_df, print_norm=False, norm_all=False, extension="score"):
    
    def neg_norm_array(array, score_type):

        if len(array) > 1:  ##check that it's not only one value
            
            array    = -array
            
            if norm_all:
                if print_norm:
                    print(score_type,NORM[score_type],end=" ")
                array = (array-NORM[score_type][0])/(NORM[score_type][1]-NORM[score_type][0])
                array[array < 0] = 0.0
                if np.any(array > 1.0): print("\nNORMALIZATION ERROR!",score_type,"has a value >1!") 
            else:
                if print_norm:
                    print(score_type,[np.mean(array),np.std(array)],end=" ")
                # Normalize using mean and standard deviation
                if np.std(array) == 0:
                    array = np.where(np.isnan(array), array, 0.0)  # Handle case where all values are the same
                else:
                    array = (array - np.mean(array)) / np.std(array)

            return array
        
        else:
            # do not normalize if array only contains 1 value
            return [1]
         
    catalytic_scores    = unblocked_all_scores_df[f"catalytic_{extension}"]
    catalytic_scores    = neg_norm_array(catalytic_scores, f"catalytic_{extension}")   
    
    total_scores        = unblocked_all_scores_df[f"total_{extension}"]
    total_scores        = neg_norm_array(total_scores, f"total_{extension}")   
    
    interface_scores    = unblocked_all_scores_df[f"interface_{extension}"]
    interface_scores    = neg_norm_array(interface_scores, f"interface_{extension}")  
    
    efield_scores    = unblocked_all_scores_df[f"efield_{extension}"]   ### to be worked on
    efield_scores    = neg_norm_array(-1*efield_scores, f"efield_{extension}")   ### to be worked on, with MINUS here
    
    if len(total_scores) == 0:
        combined_scores = []
    else:
        combined_scores     = np.stack((total_scores, interface_scores, efield_scores))
        combined_scores     = np.mean(combined_scores, axis=0)
        
          
    if print_norm:
        if combined_scores.size > 0:
            print("HIGHSCORE:","{:.2f}".format(np.amax(combined_scores)),end=" ")
            print("Designs:",len(combined_scores),end=" ")
            PARENTS = [i for i in os.listdir(f'{FOLDER_HOME}/{FOLDER_PARENT}') if i[-4:] == ".pdb"]
            print("Parents:",len(PARENTS))
        
    return catalytic_scores, total_scores, interface_scores, efield_scores, combined_scores
        
def boltzmann_selection(all_scores_df):

    all_scores_df = pd.read_csv(ALL_SCORES_CSV)
    parent_indices  = set(all_scores_df['parent_index'].astype(str).values)
    all_scores_df = all_scores_df[all_scores_df["blocked"] == False] # Remove blocked indices
       
    all_scores_df = all_scores_df.dropna(subset=['total_score'])     # Remove indices without score (design running)
    
    # Drop catalytic scroes > mean + 1 std
    mean_catalytic_score = all_scores_df['catalytic_score'].mean()
    std_catalytic_score = all_scores_df['catalytic_score'].std()
    if len(all_scores_df) > 10:
        all_scores_df = all_scores_df[all_scores_df['catalytic_score'] < mean_catalytic_score + std_catalytic_score]
    
    # If there are structures that ran through RosettaRelax but have never been used for design, run 1 design (exclude PMPNN as it is always relaxed)
    relaxed_indices = all_scores_df[(all_scores_df['score_taken_from'] == 'Relax') & (all_scores_df['design_method'] != 'ProteinMPNN')]
    relaxed_indices = [str(i) for i in relaxed_indices.index]
    filtered_indices = [index for index in relaxed_indices if index not in parent_indices]


    if len(filtered_indices) >= 1:
        selected_index = filtered_indices[0]
        logging.info(f"{selected_index} selected because its relaxed but nothing was designed from it.")
        return int(selected_index)
                
    # Do Boltzmann Selection if some scores exist
    _, _, _, _, combined_potentials = normalize_scores(all_scores_df, norm_all=False, \
                                                    extension="potential", print_norm = False) 
        
    if len(combined_potentials) > 0:
        
        if isinstance(KBT_BOLTZMANN, (float, int)):
            kbt_boltzmann = KBT_BOLTZMANN
        elif len(KBT_BOLTZMANN) > 2:
            logging.error(f"KBT_BOLTZMANN must either be a single value or list of two values.")
            logging.error(f"KBT_BOLTZMANN is {KBT_BOLTZMANN}")
        else:
            # Ramp down kbT_boltzmann over time (i.e., with increaseing indices)
            # datapoints = legth of all_scores_df - number of parents generated
            num_pdb_files = len([file for file in os.listdir(f'{FOLDER_HOME}/{FOLDER_PARENT}') if file.endswith('.pdb')])
            datapoints = max(all_scores_df['index'].max() +1 - num_pdb_files*N_PARENT_JOBS, 0)
            kbt_boltzmann = max(KBT_BOLTZMANN[0] * np.exp(-KBT_BOLTZMANN[1]*datapoints), 0.05)
        boltzmann_factors = np.exp(combined_potentials / (kbt_boltzmann))
        probabilities = boltzmann_factors / sum(boltzmann_factors)
        
        #selected_index = int(np.random.choice(np.array(all_scores_df["index"].tolist()), p=probabilities))
        if len(all_scores_df["index"] > 0):
            selected_index = int(np.random.choice(all_scores_df["index"].to_numpy(), p=probabilities))
        else:
            return None
        
    else:
        
        selected_index = 0

    return selected_index

def start_parent_design(all_scores_df):

    number_of_indices = len(all_scores_df)
    PARENTS = [i for i in os.listdir(f'{FOLDER_HOME}/{FOLDER_PARENT}') if i[-4:] == ".pdb"]
    
    if number_of_indices < N_PARENT_JOBS * len(PARENTS):
        
        parent_done = False
        selected_index = int(number_of_indices / N_PARENT_JOBS)
        parent = PARENTS[selected_index][:-4]
        
        new_index, all_scores_df = create_new_index(parent_index="Parent", all_scores_df=all_scores_df)
        all_scores_df['design_method'] = all_scores_df['design_method'].astype('object') #?
        all_scores_df.at[new_index, 'design_method'] = "RosettaDesign"
        all_scores_df['luca'] = all_scores_df['luca'].astype('object') #?
        all_scores_df.at[new_index, 'luca'] = parent
        
        #Add cat res to new entry
        all_scores_df = save_cat_res_into_all_scores_df(all_scores_df, new_index, 
                                                        f'{FOLDER_HOME}/{FOLDER_PARENT}/{PARENTS[selected_index]}',
                                                        from_parent_struct=True)
        
        run_RosettaDesign(parent_index=parent, new_index=new_index, all_scores_df=all_scores_df, parent_done=parent_done)                      
        
    else:
        
        parent_done = True
    
    save_all_scores_df(all_scores_df)
    return parent_done, all_scores_df

# Decides what to do with selected index
def start_calculation(all_scores_df, selected_index):
    
    logging.debug(f"Starting new calculation for index {selected_index}.")

    blocked = False
    if all_scores_df.at[selected_index, "blocked"] == True:
        blocked = True
        
    relaxed = False
    if f"{WT}_Rosetta_Relax_{selected_index}.pdb" in os.listdir(os.path.join(FOLDER_HOME, str(selected_index))):
        relaxed = True

    # Check if ESMfold_Rosetta_Relax is done
    if relaxed:

        # ESMfold_Rosetta_Relax is done, create a new index
        new_index, all_scores_df = create_new_index(parent_index=selected_index, all_scores_df=all_scores_df)

        #####
        # Here, we can add an AI to decide on the next steps
        #####

        # Run Rosetta_Design with new_index
        rand = random.random()
        if rand < ProteinMPNN_PROB:  
            all_scores_df.at[new_index, 'design_method'] = "ProteinMPNN"
            run_ProteinMPNN(parent_index=selected_index, new_index=new_index, all_scores_df=all_scores_df) 
        elif rand < ProteinMPNN_PROB + LMPNN_PROB:
            all_scores_df.at[new_index, 'design_method'] = "LigandMPNN"
            run_LigandMPNN(parent_index=selected_index, new_index=new_index, all_scores_df=all_scores_df) 
        else:                    
            all_scores_df.at[new_index, 'design_method'] = "RosettaDesign"
            run_RosettaDesign(parent_index=selected_index, new_index=new_index, all_scores_df=all_scores_df) 
        save_all_scores_df(all_scores_df)
        
    else:
        # ESMfold_Rosetta_Relax is not done, check if Index is blocked
        if blocked:
            # Blocked --> ESMfold_Rosetta_Relax is still running, do not do antyting. This shouldn't happen!
            logging.error(f"Index {selected_index} is being worked on. Skipping index.")
            logging.error(f"Note: This should not happen! Check blocking and Boltzman selection.")
        else:
            # Not blocked --> submit ESMfold_Rosetta_Relax and block index
            logging.info(f"Index {selected_index} has no relaxed structure, starting ESMfold_Rosetta_Relax.")
            submitted_status = run_ESMfold_RosettaRelax(index=selected_index, all_scores_df=all_scores_df, \
                                                        OnlyRelax=True, EXPLORE=EXPLORE)
            
            #Check if submission executed correctly before blocking
            if submitted_status:
                all_scores_df.at[selected_index, "blocked"] = True

        save_all_scores_df(all_scores_df)
        
    return all_scores_df

def create_new_index(parent_index, all_scores_df):
    
    # Create a new line with the next index and parent_index
    new_index = len(all_scores_df)
    
    # Append the new line to the DataFrame and save to  all_scores_df.csv
    if isinstance(KBT_BOLTZMANN, (float, int)):
        kbt_boltzmann = KBT_BOLTZMANN
    elif len(KBT_BOLTZMANN) == 2:
        num_pdb_files = len([file for file in os.listdir(f'{FOLDER_HOME}/{FOLDER_PARENT}') if file.endswith('.pdb')])
        datapoints = max(all_scores_df['index'].max() +1 - num_pdb_files*N_PARENT_JOBS, 0)
        kbt_boltzmann = max(KBT_BOLTZMANN[0] * np.exp(-KBT_BOLTZMANN[1]*datapoints), 0.05)
    if parent_index == 'Parent':
        generation = 0
        luca = "x"
    else:
        generation = all_scores_df['generation'][int(parent_index)]+1
        luca       = all_scores_df['luca'][int(parent_index)]
        
    # all_scores_df = all_scores_df.append({'index': new_index, 
    #                                       'parent_index': parent_index,
    #                                       'kbt_boltzmann': kbt_boltzmann,
    #                                       'generation': generation,
    #                                       'luca': luca,
    #                                       'blocked': False,
    #                                       }, ignore_index=True)
    
    # why do this in two steps?
    new_index_df = pd.DataFrame({'index': int(new_index), 
                                'parent_index': parent_index,
                                'kbt_boltzmann': kbt_boltzmann,
                                'generation': generation,
                                'luca': luca,
                                'blocked': False,
                                }, index = [0])
    all_scores_df = pd.concat([all_scores_df, new_index_df], ignore_index=True)

    save_all_scores_df(all_scores_df)

    # Create the folders for the new index
    os.makedirs(f"{FOLDER_HOME}/{new_index}/scripts", exist_ok=True)
           
    logging.debug(f"Child index {new_index} created for {parent_index}.")
    
    return new_index, all_scores_df

# main functions - design

In [None]:
def run_ESMfold_RosettaRelax(index, all_scores_df, OnlyRelax=False, ProteinMPNN=False, PreMatchRelax=False,
                             ProteinMPNN_parent_index=0, cmd="", bash=False, EXPLORE=False):
    
    # Giving the ESMfold algorihm the needed inputs
    output_file = f'{FOLDER_HOME}/{index}/ESMfold/{WT}_ESMfold_output_{index}.pdb'

    protein_mpnn_seq_file = f'{FOLDER_HOME}/{index}/ProteinMPNN/{WT}_{index}.seq'
    ligand_mpnn_seq_file = f'{FOLDER_HOME}/{index}/LigandMPNN/{WT}_{index}.seq'
    esmfold_seq_file = f'{FOLDER_HOME}/{index}/ESMfold/{WT}_{index}.seq'

    if ProteinMPNN:
        if os.path.exists(protein_mpnn_seq_file):
            sequence_file = protein_mpnn_seq_file
        elif os.path.exists(ligand_mpnn_seq_file):
            sequence_file = ligand_mpnn_seq_file
        else:
            logging.error(f"Neither {protein_mpnn_seq_file} nor {ligand_mpnn_seq_file} exists, but ProteinMPNN was True.")
            return
    else:
        sequence_file = esmfold_seq_file

        
    # Make directories
    os.makedirs(f"{FOLDER_HOME}/{index}/ESMfold", exist_ok=True)
    os.makedirs(f"{FOLDER_HOME}/{index}/scripts", exist_ok=True)

        
    # Options for EXPLORE, accelerated script for testing
    ex = "-ex1 -ex2"
    if EXPLORE: ex = ""
        
    # Get Name of parent PDB
    if OnlyRelax: 
        PDBFile = f"{FOLDER_HOME}/{index}/{WT}_Rosetta_Design_{index}.pdb"
    elif PreMatchRelax: 
        PDBFile = f"{FOLDER_INPUT}/{WT}.pdb"
    elif ProteinMPNN:
        PDBFile = f"{FOLDER_HOME}/{ProteinMPNN_parent_index}/{WT}_Rosetta_Relax_{ProteinMPNN_parent_index}.pdb"
    else:
        print("I don't know what you want me to do")
        return
    if not os.path.isfile(PDBFile):
        logging.error(f"{PDBFile} not present!")
        return False

    # Make sequence file
    if OnlyRelax or PreMatchRelax: 
        seq = extract_sequence_from_pdb(PDBFile)
        with open(f"{FOLDER_HOME}/{index}/ESMfold/{WT}_{index}.seq","w") as f: f.write(seq)
                    
    # Get the pdb file from the last step and strip away ligand and hydrogens 
    cpptraj = f'''parm    {PDBFile}
trajin  {PDBFile}
strip   :{LIGAND}
strip   !@C,N,O,CA
trajout {FOLDER_HOME}/{index}/ESMfold/{WT}_CPPTraj_Apo_{index}.pdb
'''
    with open(f'{FOLDER_HOME}/{index}/ESMfold/CPPTraj_Apo_{index}.in','w') as f: f.write(cpptraj)

    # Get the pdb file from the last step and strip away everything except the ligand
    cpptraj = f'''parm    {PDBFile}
trajin  {PDBFile}
strip   !:{LIGAND}
trajout {FOLDER_HOME}/{index}/ESMfold/{WT}_CPPTraj_Lig_{index}.pdb
'''
    with open(f'{FOLDER_HOME}/{index}/ESMfold/CPPTraj_Lig_{index}.in','w') as f: f.write(cpptraj)

    # Get the ESMfold pdb file and strip away all hydrogens
    cpptraj = f'''parm    {FOLDER_HOME}/{index}/ESMfold/{WT}_ESMfold_output_{index}.pdb
trajin  {FOLDER_HOME}/{index}/ESMfold/{WT}_ESMfold_output_{index}.pdb
strip   !@C,N,O,CA
trajout {FOLDER_HOME}/{index}/ESMfold/{WT}_ESMfold_no_hydrogens_{index}.pdb
'''
    with open(f'{FOLDER_HOME}/{index}/ESMfold/CPPTraj_no_hydrogens_{index}.in','w') as f: f.write(cpptraj)

    # Align substrate and ESM prediction of scaffold without hydrogens
    cpptraj = f'''parm    {FOLDER_HOME}/{index}/ESMfold/{WT}_ESMfold_no_hydrogens_{index}.pdb
reference {FOLDER_HOME}/{index}/ESMfold/{WT}_CPPTraj_Apo_{index}.pdb [apo]
trajin    {FOLDER_HOME}/{index}/ESMfold/{WT}_ESMfold_no_hydrogens_{index}.pdb
rmsd      @CA ref [apo]
trajout   {FOLDER_HOME}/{index}/ESMfold/{WT}_ESMfold_aligned_{index}.pdb noter
'''
    with open(f'{FOLDER_HOME}/{index}/ESMfold/CPPTraj_aligned_{index}.in','w') as f: f.write(cpptraj) 
              
    if GRID:           extension = "linuxgccrelease"
    if BLUEPEBBLE:     extension = "serialization.linuxgccrelease"
    if BACKGROUND_JOB: extension = "serialization.linuxgccrelease"
    if ABBIE_LOCAL:    
        extension = "linuxgccrelease"
        bash_args = "OMP_NUM_THREADS=1"
    else:
        bash_args = ""
 
    cmd += f"""
    
{bash_args} python {FOLDER_HOME}/ESMfold.py {output_file} {sequence_file}

sed -i '/PARENT N\/A/d' {FOLDER_HOME}/{index}/ESMfold/{WT}_ESMfold_output_{index}.pdb
cpptraj -i {FOLDER_HOME}/{index}/ESMfold/CPPTraj_Apo_{index}.in           &> \
           {FOLDER_HOME}/{index}/ESMfold/CPPTraj_Apo_{index}.out
cpptraj -i {FOLDER_HOME}/{index}/ESMfold/CPPTraj_Lig_{index}.in           &> \
           {FOLDER_HOME}/{index}/ESMfold/CPPTraj_Lig_{index}.out
cpptraj -i {FOLDER_HOME}/{index}/ESMfold/CPPTraj_no_hydrogens_{index}.in  &> \
           {FOLDER_HOME}/{index}/ESMfold/CPPTraj_no_hydrogens_{index}.out
cpptraj -i {FOLDER_HOME}/{index}/ESMfold/CPPTraj_aligned_{index}.in       &> \
           {FOLDER_HOME}/{index}/ESMfold/CPPTraj_aligned_{index}.out

# Assemble the final protein
sed -i '/END/d' {FOLDER_HOME}/{index}/ESMfold/{WT}_ESMfold_aligned_{index}.pdb
# Return HETATM to ligand output and remove TER
sed -i -e 's/^ATOM  /HETATM/' -e '/^TER/d' {FOLDER_HOME}/{index}/ESMfold/{WT}_CPPTraj_Lig_{index}.pdb
"""
    
    input_extension_relax = ""
    if PreMatchRelax:
        extension_relax = "_APO"
        ## No ligand necessary so just use the aligned pdb from ESMfold
        cmd += f"""
cp {FOLDER_HOME}/{index}/ESMfold/{WT}_ESMfold_aligned_{index}.pdb \
   {FOLDER_HOME}/{index}/{WT}_ESMfold_{index}{extension_relax}.pdb
"""  

    else:
        extension_relax = ""
        remark = generate_remark_from_all_scores_df(all_scores_df, index)
        with open(f'{FOLDER_HOME}/{index}/{WT}_ESMfold_{index}.pdb', 'w') as f: f.write(remark+"\n")
        cmd += f"""
cat {FOLDER_HOME}/{index}/ESMfold/{WT}_ESMfold_aligned_{index}.pdb >> {FOLDER_HOME}/{index}/{WT}_ESMfold_{index}.pdb
cat {FOLDER_HOME}/{index}/ESMfold/{WT}_CPPTraj_Lig_{index}.pdb     >> {FOLDER_HOME}/{index}/{WT}_ESMfold_{index}.pdb
sed -i '/TER/d' {FOLDER_HOME}/{index}/{WT}_ESMfold_{index}.pdb
"""
        
    cmd += f"""
# Run Rosetta Relax
{ROSETTA_PATH}/bin/rosetta_scripts.{extension} \
                -s                                        {FOLDER_HOME}/{index}/{WT}_ESMfold_{index}{extension_relax}.pdb \
                -extra_res_fa                             {FOLDER_INPUT}/{LIGAND}.params \
                -parser:protocol                          {FOLDER_HOME}/{index}/scripts/Rosetta_Relax_{index}.xml \
                -out:file:scorefile                       {FOLDER_HOME}/{index}/score_rosetta_relax.sc \
                -nstruct                                  1 \
                -ignore_zero_occupancy                    false \
                -corrections::beta_nov16                  true \
                -run:preserve_header                      true \
                -overwrite {ex}

# Rename the output file
mv {WT}_ESMfold_{index}{extension_relax}_0001.pdb {WT}_Rosetta_Relax_{index}{extension_relax}.pdb
sed -i '/        H  /d' {WT}_Rosetta_Relax_{index}{extension_relax}.pdb
"""
    
    if PreMatchRelax:
        extension_relax = "_APO"
        
        cmd += f"""
# Align relaxed ESM prediction of scaffold without hydrogens
cpptraj -i {FOLDER_HOME}/{index}/ESMfold/{WT}_Rosetta_Relax_aligned_{index}{extension_relax}.in           &> \
           {FOLDER_HOME}/{index}/ESMfold/{WT}_Rosetta_Relax_aligned_{index}{extension_relax}.out
sed -i '/END/d' {FOLDER_HOME}/{index}/{WT}_Rosetta_Relax_aligned_{index}{extension_relax}.pdb
"""  
        
        cpptraj = f'''parm    {FOLDER_HOME}/{index}/{WT}_Rosetta_Relax_{index}{extension_relax}.pdb [protein]
parm      {FOLDER_HOME}/{index}/ESMfold/{WT}_CPPTraj_Apo_{index}.pdb [reference]
reference {FOLDER_HOME}/{index}/ESMfold/{WT}_CPPTraj_Apo_{index}.pdb parm [reference] [apo]
trajin    {FOLDER_HOME}/{index}/{WT}_Rosetta_Relax_{index}{extension_relax}.pdb parm [protein]
rmsd      @CA ref [apo]
trajout   {FOLDER_HOME}/{index}/{WT}_Rosetta_Relax_aligned_{index}{extension_relax}.pdb noter
'''
        with open(f'{FOLDER_HOME}/{index}/ESMfold/{WT}_Rosetta_Relax_aligned_{index}{extension_relax}.in','w') as f: 
            f.write(cpptraj) 
    
        # Create the Rosetta_Relax.xml file
    repeats = "3"
    if EXPLORE: repeats = "1"
    Rosetta_Relax_xml = f"""
<ROSETTASCRIPTS>

    <SCOREFXNS>
    
        <ScoreFunction name      = "score"                   weights = "beta_nov16" >
            <Reweight scoretype  = "atom_pair_constraint"    weight  = "4" />
            <Reweight scoretype  = "angle_constraint"        weight  = "2" />
            <Reweight scoretype  = "dihedral_constraint"     weight  = "1" />
        </ScoreFunction> 
        
        <ScoreFunction name      = "score_final"             weights = "beta_nov16" >
            <Reweight scoretype  = "atom_pair_constraint"    weight  = "4" />
            <Reweight scoretype  = "angle_constraint"        weight  = "2" />
            <Reweight scoretype  = "dihedral_constraint"     weight  = "1" />
        </ScoreFunction>
        
    </SCOREFXNS>
       
    <MOVERS>
                                  
        <FastRelax  name="mv_relax" disable_design="false" repeats="{repeats}" /> 
"""
    if not PreMatchRelax: Rosetta_Relax_xml += f"""
        <AddOrRemoveMatchCsts     name="mv_add_cst" 
                                  cst_instruction="add_new" 
                                  cstfile="{FOLDER_INPUT}/{LIGAND}_enzdes.cst" />

"""
    Rosetta_Relax_xml += f"""

        <InterfaceScoreCalculator   name                   = "mv_inter" 
                                    chains                 = "X" 
                                    scorefxn               = "score_final" />
    </MOVERS>
    
    <PROTOCOLS>  

        <Add mover_name="mv_relax" />
"""
    if not PreMatchRelax: Rosetta_Relax_xml += f"""                                  
        <Add mover_name="mv_add_cst" />       
        <Add mover_name="mv_inter" />
"""
    Rosetta_Relax_xml += f"""
    </PROTOCOLS>
    
</ROSETTASCRIPTS>
"""
    # Write the Rosetta_Relax.xml to a file
    with open(f'{FOLDER_HOME}/{index}/scripts/Rosetta_Relax_{index}.xml', 'w') as f:
        f.writelines(Rosetta_Relax_xml)      
        
    if OnlyRelax or PreMatchRelax: 
        with open(f'{FOLDER_HOME}/{index}/scripts/ESMfold_Rosetta_Relax_{index}.sh', 'w') as file:
            file.write(cmd)
        logging.info(f"Run ESMfold & Rosetta_Relax for index {index}.")
        submit_job(index=index, job="ESMfold_Rosetta_Relax", bash=bash)
        
    if ProteinMPNN:
        with open(f'{FOLDER_HOME}/{index}/scripts/ProteinMPNN_ESMfold_Rosetta_Relax_{index}.sh', 'w') as file:
            file.write(cmd)
        logging.info(f"Run Ligand/ProteinMPNN for index {index} based on index {ProteinMPNN_parent_index}.")
        submit_job(index=index, job="ProteinMPNN_ESMfold_Rosetta_Relax", bash=bash)

    return True
        
def run_RosettaDesign(parent_index, new_index, all_scores_df, parent_done=True):

    # Options for EXPLORE, accelerated script for testing
    ex = "-ex1 -ex2"
    if EXPLORE: ex = ""

    if GRID:           extension = "linuxgccrelease"
    if BLUEPEBBLE:     extension = "serialization.linuxgccrelease"
    if BACKGROUND_JOB: extension = "serialization.linuxgccrelease"

    if ABBIE_LOCAL:    extension = "linuxgccrelease"

    if parent_done:
        PDB_input  = f'{FOLDER_HOME}/{parent_index}/{WT}_Rosetta_Relax_{parent_index}.pdb'
        PDB_output = f'{WT}_Rosetta_Relax_{parent_index}_0001.pdb'
    else:
        PDB_input  = f'{FOLDER_HOME}/{FOLDER_PARENT}/{parent_index}.pdb'
        PDB_output = f'{parent_index}_0001.pdb'
        
    all_scores_df = save_cat_res_into_all_scores_df(all_scores_df, new_index, PDB_input, from_parent_struct=True)
    
    cmd = f"""{ROSETTA_PATH}/bin/rosetta_scripts.{extension}\
    -s                                        {PDB_input} \
    -in:file:native                           {PDB_input} \
    -run:preserve_header                      true \
    -extra_res_fa                             {FOLDER_INPUT}/{LIGAND}.params \
    -parser:protocol                          {FOLDER_HOME}/{new_index}/scripts/Rosetta_Design_{new_index}.xml \
    -out:file:scorefile                       {FOLDER_HOME}/{new_index}/score_rosetta_design.sc \
    -nstruct                                  1  \
    -ignore_zero_occupancy                    false  \
    -corrections::beta_nov16                  true \
    -overwrite {ex}
    
mv {PDB_output} {WT}_Rosetta_Design_{new_index}.pdb 
"""
    # Write the shell command to a file
    with open(f'{FOLDER_HOME}/{new_index}/scripts/Rosetta_Design_{new_index}.sh','w') as file: file.write(cmd)
                
    # Create XML script for Rosetta Design  
    repeats = "3"
    if EXPLORE: repeats = "1"
        
    Rosetta_Design_xml = f"""
<ROSETTASCRIPTS>

    <SCOREFXNS>

        <ScoreFunction            name="score"                           weights="beta_nov16" >  
            <Reweight             scoretype="atom_pair_constraint"       weight="4" />
            <Reweight             scoretype="angle_constraint"           weight="2" />    
            <Reweight             scoretype="dihedral_constraint"        weight="1" />        
            <Reweight             scoretype="res_type_constraint"        weight="1" />              
        </ScoreFunction>
       
        <ScoreFunction            name="score_unconst"                   weights="beta_nov16" >        
            <Reweight             scoretype="atom_pair_constraint"       weight="0" />
            <Reweight             scoretype="dihedral_constraint"        weight="0" />
            <Reweight             scoretype="angle_constraint"           weight="0" />              
        </ScoreFunction>

        <ScoreFunction            name="score_final"                     weights="beta_nov16" >    
            <Reweight             scoretype="atom_pair_constraint"       weight="4" />
            <Reweight             scoretype="angle_constraint"           weight="2" />    
            <Reweight             scoretype="dihedral_constraint"        weight="1" />               
        </ScoreFunction>
   
   </SCOREFXNS>
   
    <RESIDUE_SELECTORS>
   
        <Index                    name="sel_design"
                                  resnums="{DESIGN}" />

        <Index                    name="sel_repack"
                                  resnums="{REPACK}" />
"""
    
    # Add residue number constraints from REMARK (via all_scores_df['cat_resi'])
    cat_resis = all_scores_df.at[new_index, 'cat_resi'].split(';')
    for idx, cat_resi in enumerate(cat_resis): 
        
        Rosetta_Design_xml += f"""
        <Index                    name="sel_cat_{idx}"
                                  resnums="{int(cat_resi)}" />
"""
        
    Rosetta_Design_xml += f"""
        <Or                       name="sel_desrep"
                                  selectors="sel_design,sel_repack" />

        <Not                      name="sel_nothing"
                                  selector="sel_desrep" />
    </RESIDUE_SELECTORS>
   
    <TASKOPERATIONS>
   
        <OperateOnResidueSubset   name="tsk_design"                      selector="sel_design" >
                                  <RestrictAbsentCanonicalAASRLT         aas="GPAVLIMFYWHKRQNEDST" />
        </OperateOnResidueSubset>
"""
    
    # Add residue identity constraints from constraint file
    with open(f'{FOLDER_HOME}/cst.dat', 'r') as f: cat_resns = f.read()
    cat_resns = cat_resns.split(";")
    
    for idx, cat_resn in enumerate(cat_resns): 
        Rosetta_Design_xml += f"""
        <OperateOnResidueSubset   name="tsk_cat_{idx}"                   selector="sel_cat_{idx}" >
                                  <RestrictAbsentCanonicalAASRLT         aas="{cat_resn}" />
        </OperateOnResidueSubset>
"""
    
    tsk_cat = []
    for idx, cat_res in enumerate(cat_resns): 
        tsk_cat += [f"tsk_cat_{idx}"]
    tsk_cat = ",".join(tsk_cat)
        
    Rosetta_Design_xml += f"""
       
        <OperateOnResidueSubset   name="tsk_repack"                      selector="sel_repack" >
                                  <RestrictToRepackingRLT />
        </OperateOnResidueSubset>
       
        <OperateOnResidueSubset   name="tsk_nothing"                     selector="sel_nothing" >
                                  <PreventRepackingRLT />
        </OperateOnResidueSubset>
       
    </TASKOPERATIONS>

    <FILTERS>
   
        <HbondsToResidue          name="flt_hbonds"
                                  scorefxn="score"
                                  partners="1"
                                  residue="1X"
                                  backbone="true"
                                  sidechain="true"
                                  from_other_chains="true"
                                  from_same_chain="false"
                                  confidence="0" />
    </FILTERS>
   
    <MOVERS>
       
        <FavorSequenceProfile     name="mv_native"
                                  weight="{CST_WEIGHT}"
                                  use_native="true"
                                  matrix="IDENTITY"
                                  scorefxns="score" />  
                               
        <AddOrRemoveMatchCsts     name="mv_add_cst"
                                  cst_instruction="add_new"
                                  cstfile="{FOLDER_INPUT}/{LIGAND}_enzdes.cst" />

        <FastDesign               name                   = "mv_design"
                                  disable_design         = "false"
                                  task_operations        = "tsk_design,tsk_repack,tsk_nothing,{tsk_cat}"
                                  repeats                = "{repeats}"
                                  ramp_down_constraints  = "false"
                                  scorefxn               = "score" />
        
        <FastDesign               name                   = "mv_design_no_native"
                                  disable_design         = "false"
                                  task_operations        = "tsk_design,tsk_repack,tsk_nothing,{tsk_cat}"
                                  repeats                = "1"
                                  ramp_down_constraints  = "false"
                                  scorefxn               = "score" />
                                  
        <FastRelax                name                   = "mv_relax"
                                  disable_design         = "true"
                                  task_operations        = "tsk_design,tsk_repack,tsk_nothing,{tsk_cat}"
                                  repeats                = "1"
                                  ramp_down_constraints  = "false"
                                  scorefxn               = "score_unconst" />  
                                  
        <InterfaceScoreCalculator name                   = "mv_inter"
                                  chains                 = "X"
                                  scorefxn               = "score_final" />
                                 
    </MOVERS>

    <PROTOCOLS>
        <Add mover_name="mv_add_cst" />
        <Add mover_name="mv_design_no_native" />
        <Add mover_name="mv_native" />
        <Add mover_name="mv_design" />
        <Add mover_name="mv_relax" />
        <Add mover_name="mv_inter" />
    </PROTOCOLS>
   
</ROSETTASCRIPTS>

"""
    # Write the XML script to a file
    with open(f'{FOLDER_HOME}/{new_index}/scripts/Rosetta_Design_{new_index}.xml', 'w') as f:
        f.writelines(Rosetta_Design_xml)               
        
    # Submit the job using the submit_job function
    logging.info(f"Run RosettaDesign for index {new_index} based on index {parent_index}.")
    submit_job(index=new_index, job="Rosetta_Design", ram=2)

    return True



def run_ProteinMPNN(parent_index, new_index, all_scores_df):
    """
    Executes the ProteinMPNN pipeline for a given protein structure and generates
    new protein sequences with potentially higher functional scores.

    Parameters:
    - parent_index (str): The index of the parent protein variant.
    - new_index (str): The index assigned to the new protein variant.
    - all_scores_df (DataFrame): A DataFrame containing information for protein variants.
    - --------------------------------GLOBAL variables used ----------------------------
    - FOLDER_HOME (str): The base directory where ProteinMPNN and related files are located.
    - WT (str): The wild type or reference protein identifier.
    - DESIGN (str): A string representing positions and types of amino acids to design with RosettaDesign.
    - ProteinMPNN_T (float): The sampling temperature for ProteinMPNN.
    - EXPLORE (bool): Flag to indicate whether exploration mode is enabled.
    - PMPNN_BIAS (float): The bias value for ProteinMPNN parent sequence retention.

    Returns:
    None: The function primarily works by side effects, finally producing the highes scoring sequence 
    in the specified directories.
    
    Note:
    This function assumes the ProteinMPNN toolkit is available and properly set up in the specified location.
    It involves multiple subprocess calls to Python scripts for processing protein structures and generating new sequences.
    """

    # Testing PMPNN_submit.py
    cmd = f'python {os.getcwd()}/PMPNN_submit.py --index {parent_index} --new_index {new_index} --home_folder {FOLDER_HOME}'
    with open(f'{FOLDER_HOME}/{new_index}/scripts/PMPNN_ESM_Relax_{new_index}.sh','w') as file: file.write(cmd)
    submit_job(index=new_index, job="PMPNN_ESM_Relax")
    return

    # Ensure ProteinMPNN is available
    if not os.path.exists(f'{FOLDER_HOME}/../ProteinMPNN'):
        logging.error(f"ProteinMPNN not installed in {FOLDER_HOME}/../ProteinMPNN.")
        logging.error("Install using: git clone https://github.com/dauparas/ProteinMPNN.git")
        return

    # Prepare file paths
    pdb_file = f"{FOLDER_HOME}/{parent_index}/{WT}_Rosetta_Relax_{parent_index}.pdb"
    if not os.path.isfile(pdb_file):
        logging.error(f"{pdb_file} not present!")
        return

    protein_mpnn_folder = f"{FOLDER_HOME}/{new_index}/ProteinMPNN"
    os.makedirs(protein_mpnn_folder, exist_ok=True)
    shutil.copy(pdb_file, os.path.join(protein_mpnn_folder, f"{WT}_Rosetta_Relax_{parent_index}.pdb"))

    seq = extract_sequence_from_pdb(pdb_file)
    with open(os.path.join(protein_mpnn_folder, f"Rosetta_Relax_{parent_index}.seq"), "w") as f:
        f.write(seq)
    

    # Run ProteinMPNN steps using subprocess after creating the bias file
    helper_scripts_path = f"{FOLDER_HOME}/../ProteinMPNN/helper_scripts"
    protein_mpnn_path = f"{FOLDER_HOME}/../ProteinMPNN"

    # Prepare input JSON for bias dictionary creation
    input_json = {"name": f"{WT}_Rosetta_Relax_{parent_index}", "seq_chain_A": seq}

    # Create bias dictionary
    mpnn_alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
    mpnn_alphabet_dict = {aa: idx for idx, aa in enumerate(mpnn_alphabet)}
    
    bias_dict = {}
    for chain_key, sequence in input_json.items():
        if chain_key.startswith('seq_chain_'):
            chain = chain_key[-1]
            chain_length = len(sequence)
            bias_per_residue = np.zeros([chain_length, 21])  # 21 for each amino acid in the alphabet

            # Apply a positive bias for the amino acid at each position
            for idx, aa in enumerate(sequence):
                if aa in mpnn_alphabet_dict:  # Ensure the amino acid is in the defined alphabet
                    aa_index = mpnn_alphabet_dict[aa]
                    bias_per_residue[idx, aa_index] = PMPNN_BIAS  # Use the global bias variable

            bias_dict[input_json["name"]] = {chain: bias_per_residue.tolist()}

    # Write the bias dictionary to a JSON file
    bias_json_path = os.path.join(protein_mpnn_folder, "bias_by_res.jsonl")
    with open(bias_json_path, 'w') as f:
        json.dump(bias_dict, f)
        f.write('\n')

    # Parse multiple chains
    run_command([
        "python", os.path.join(helper_scripts_path, "parse_multiple_chains.py"),
        "--input_path", protein_mpnn_folder,
        "--output_path", os.path.join(protein_mpnn_folder, "parsed_chains.jsonl")
    ])

    # Assign fixed chains
    run_command([
        "python", os.path.join(helper_scripts_path, "assign_fixed_chains.py"),
        "--input_path", os.path.join(protein_mpnn_folder, "parsed_chains.jsonl"),
        "--output_path", os.path.join(protein_mpnn_folder, "assigned_chains.jsonl"),
        "--chain_list", 'A'
    ])

    # Make fixed positions dict
    run_command([
        "python", os.path.join(helper_scripts_path, "make_fixed_positions_dict.py"),
        "--input_path", os.path.join(protein_mpnn_folder, "parsed_chains.jsonl"),
        "--output_path", os.path.join(protein_mpnn_folder, "fixed_positions.jsonl"),
        "--chain_list", 'A',
        "--position_list", " ".join(DESIGN.split(","))
    ])

    # Protein MPNN run
    run_command([
        "python", os.path.join(protein_mpnn_path, "protein_mpnn_run.py"),
        "--jsonl_path", os.path.join(protein_mpnn_folder, "parsed_chains.jsonl"),
        "--chain_id_jsonl", os.path.join(protein_mpnn_folder, "assigned_chains.jsonl"),
        "--fixed_positions_jsonl", os.path.join(protein_mpnn_folder, "fixed_positions.jsonl"),
        "--bias_by_res_jsonl", os.path.join(protein_mpnn_folder, "bias_by_res.jsonl"),
        "--out_folder", protein_mpnn_folder,
        "--num_seq_per_target", "100",
        "--sampling_temp", ProteinMPNN_T,
        "--seed", "37",
        "--batch_size", "1"
    ])
    

    # Find highest scoring sequence
    highest_scoring_sequence = find_highest_scoring_sequence(protein_mpnn_folder, parent_index, input_sequence_path=
                                                             f"{FOLDER_HOME}/input_sequence_with_X_as_wildecard.seq")

    # Save highest scoring sequence and prepare for ESMfold
    with open(os.path.join(protein_mpnn_folder, f"{WT}_{new_index}.seq"), "w") as f:
        f.write(highest_scoring_sequence)
    
    if highest_scoring_sequence:
        logging.info(f"Ran ProteinMPNN for index {parent_index} and found a new sequence with index {new_index}.")
    else:
        logging.error(f"Failed to find a new sequnce for index {parent_index} with ProteinMPNN.")
    
    all_scores_df = save_cat_res_into_all_scores_df(all_scores_df, new_index, pdb_file, from_parent_struct=False)
        
    # Run ESMfold Relax with the ProteinMPNN Flag
    run_ESMfold_RosettaRelax(index=new_index, all_scores_df=all_scores_df, OnlyRelax=False, \
                             ProteinMPNN=True, ProteinMPNN_parent_index=parent_index, EXPLORE=EXPLORE)

def find_highest_scoring_sequence(folder_path, parent_index, input_sequence_path):
    """
    Identifies the highest scoring protein sequence from a set of generated PNPNN sequences,
    excluding the parent and WT sequence (except wildcard positions specified by DESIGN).

    Parameters:
    - folder_path (str): The path to the directory containing sequence files (/ProteinMPNN).
    - parent_index (str): The index of the parent protein sequence.
    - input_sequence_path (str): The path to a file containing the input sequence pattern,
      where 'X' represents wildcard positions that can match any character.
    - -------------------------------- GLOBAL variables used ----------------------------
    - WT (str): The wild type or reference protein identifier.
      

    Returns:
    - highest_scoring_sequence (str): The protein sequence with the highest score 
      that does not match the parent and WT.
    
    Note:
    This function parses .fa files to find sequences and their scores, and applies
    a regex pattern derived from the input sequence to filter sequences.
    It assumes the presence of 'global_score' within the sequence descriptor lines
    in the .fa file for scoring.
    """
    # Construct the file path for the sequence data
    file_path = f'{folder_path}/seqs/{WT}_Rosetta_Relax_{parent_index}.fa'
    parent_seq_file = f'{folder_path}/Rosetta_Relax_{parent_index}.seq'
    
    # Read the parent sequence from its file
    with open(parent_seq_file, 'r') as file:
        parent_sequence = file.readline().strip()

    # Read the input sequence pattern and prepare it for regex matching
    with open(input_sequence_path, 'r') as file:
        input_sequence = file.readline().strip()
    pattern = re.sub('X', '.', input_sequence)  # Replace 'X' with regex wildcard '.'

    highest_score = 0
    highest_scoring_sequence = ''

    # Process the sequence file to find the highest scoring sequence
    with open(file_path, 'r') as file:
        for line in file:
            if line.startswith('>'):
                score_match = re.search('global_score=(\d+\.\d+)', line)
                if score_match:
                    score = float(score_match.group(1))
                    sequence = next(file, '').strip()  # Read the next line for the sequence
                    
                    # Check if the score is higher, the sequence is different from the parent,
                    # and does not match the input sequence pattern
                    if score > highest_score and sequence != parent_sequence and not re.match(pattern, sequence):
                        highest_score = score
                        highest_scoring_sequence = sequence

    # Return the highest scoring sequence found
    return highest_scoring_sequence

def find_highest_confidence_sequence(fa_file_path, output_seq_file_path):
    """
    Parses a .fa file to find the sequence with the highest overall confidence and writes it to a .seq file.

    Parameters:
    - fa_file_path (str): Path to the .fa file generated by LigandMPNN.
    - output_seq_file_path (str): Path where the .seq file should be saved.
    """
    highest_confidence = 0.0
    highest_confidence_sequence = None
    current_confidence = 0.0

    with open(fa_file_path, 'r') as file:
        for line in file:
            if line.startswith('>'):
                # Extract overall confidence from the header line
                match = re.search('overall_confidence=([0-9.]+)', line)
                if match:
                    current_confidence = float(match.group(1))
            else:
                # Sequence line
                current_sequence = line.strip()
                if current_confidence > highest_confidence:
                    highest_confidence = current_confidence
                    highest_confidence_sequence = current_sequence

    # Write the highest confidence sequence to a .seq file
    if highest_confidence_sequence:
        with open(output_seq_file_path, 'w') as output_file:
            output_file.write(highest_confidence_sequence)
        logging.info(f"Extracted sequence with highest confidence: {highest_confidence} to {output_seq_file_path}")
    else:
        logging.error(f"No sequence found with overall confidence for {fa_file_path}.")

def run_LigandMPNN(parent_index, new_index, all_scores_df):
    """
    Executes the LigandMPNN pipeline for a given protein-ligand structure and generates
    new protein sequences with potentially higher functional scores considering the ligand context.

    Parameters:
    - parent_index (str): The index of the parent protein variant.
    - new_index (str): The index assigned to the new protein variant.
    - all_scores_df (DataFrame): A DataFrame containing information for protein variants.
    """
    # Ensure LigandMPNN is available
    if not os.path.exists(f'{FOLDER_HOME}/../LigandMPNN'):
        logging.error(f"LigandMPNN not installed in {FOLDER_HOME}/LigandMPNN.")
        logging.error("Install using: git clone https://github.com/dauparas/LigandMPNN.git")
        return
    ligand_mpnn_path = f"{FOLDER_HOME}/../LigandMPNN"

    # Prepare file paths
    pdb_file = f"{FOLDER_HOME}/{parent_index}/{WT}_Rosetta_Relax_{parent_index}.pdb"
    if not os.path.isfile(pdb_file):
        logging.error(f"{pdb_file} not present!")
        return

    ligand_mpnn_folder = f"{FOLDER_HOME}/{new_index}/LigandMPNN"
    os.makedirs(ligand_mpnn_folder, exist_ok=True)
    shutil.copy(pdb_file, os.path.join(ligand_mpnn_folder, f"{WT}_Rosetta_Relax_{parent_index}.pdb"))

    # Extract catalytic residue information
    cat_resi = int(all_scores_df.at[parent_index, 'cat_resi'])
    fixed_residues = f"A{cat_resi}"

    # Run LigandMPNN
    run_command([
        "python", os.path.join(ligand_mpnn_path, "run.py"),
        "--model_type", "ligand_mpnn",
        "--temperature", LMPNN_T,
        "--seed", "37",
        "--pdb_path", os.path.join(ligand_mpnn_folder, f"{WT}_Rosetta_Relax_{parent_index}.pdb"),
        "--out_folder", ligand_mpnn_folder,
        #"--pack_side_chains", "1",
        "--number_of_packs_per_design", "4",
        "--fixed_residues", fixed_residues
    ], cwd=ligand_mpnn_path)

    find_highest_confidence_sequence(f"{FOLDER_HOME}/{new_index}/LigandMPNN/seqs/{WT}_Rosetta_Relax_{parent_index}.fa",
                                    f"{FOLDER_HOME}/{new_index}/LigandMPNN/{WT}_{new_index}.seq")

    # Update all_scores_df

    logging.info(f"Ran LigandMPNN for index {parent_index} and generated index {new_index}.")

    all_scores_df = save_cat_res_into_all_scores_df(all_scores_df, new_index, pdb_file, from_parent_struct=False)

    run_ESMfold_RosettaRelax(index=new_index, all_scores_df=all_scores_df, OnlyRelax=False, \
                             ProteinMPNN=True, ProteinMPNN_parent_index=parent_index, EXPLORE=False)

    # Save updates to all_scores_df
    #save_all_scores_df(all_scores_df)

# main functions - startup

In [None]:
def startup_controller(UNBLOCK_ALL, RESET, PRINT_VAR=True, PLOT_DATA=True):

    # Execute setup if variables file not found
    if not os.path.isfile(VARIABLES_JSON): setup_aizymes(RESET) #? why is this called again?
    
    # Creat all input files, not needed here but does not harm
    prepare_input_files()
    
    if PRINT_VAR:
        if os.path.isfile(VARIABLES_JSON):
            with open(VARIABLES_JSON, 'r') as f: 
                globals_dict = json.load(f)
            if globals_dict['DESIGN_FOLDER'] == DESIGN_FOLDER:
                for k, v in globals_dict.items():
                    globals()[k] = v
                    print(k.ljust(16), ':', v)
            else:
                print("WRONG DESIGN FOLDER!")
                sys.exit()
    
    time.sleep(1)
    
    if PLOT_DATA:
        plot_scores()
        
    # Read in current databases of AIzymes
    all_scores_df = pd.read_csv(ALL_SCORES_CSV)

    if UNBLOCK_ALL: 
        all_scores_df["blocked"] = False
   
    return all_scores_df 

def prepare_input_files():
                      
    # Create the ESMfold.py script
    ESMfold_python_script = """import sys
from transformers import AutoTokenizer, EsmForProteinFolding, EsmConfig
import torch
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37

# Set PyTorch to use only one thread
torch.set_num_threads(1)

output_file = sys.argv[1]
sequence_file = sys.argv[2]

with open(sequence_file) as f: sequence=f.read()

def convert_outputs_to_pdb(outputs):
    final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
    outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
    final_atom_positions = final_atom_positions.cpu().numpy()
    final_atom_mask = outputs["atom37_atom_exists"]
    pdbs = []
    for i in range(outputs["aatype"].shape[0]):
        aa = outputs["aatype"][i]
        pred_pos = final_atom_positions[i]
        mask = final_atom_mask[i]
        resid = outputs["residue_index"][i] + 1
        pred = OFProtein(
            aatype=aa,
            atom_positions=pred_pos,
            atom_mask=mask,
            residue_index=resid,
            b_factors=outputs["plddt"][i],
            chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
        )
        pdbs.append(to_pdb(pred))
    return pdbs

tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)
torch.backends.cuda.matmul.allow_tf32 = True
model.trunk.set_chunk_size(64)
tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids']
with torch.no_grad(): output = model(tokenized_input)
pdb = convert_outputs_to_pdb(output)
with open(output_file, "w") as f: f.write("".join(pdb))
"""

    # Write the ESMfold.py to a file
    with open(f"{FOLDER_HOME}/ESMfold.py", "w") as f: 
        f.write(ESMfold_python_script)

    # Save input sequence with X as wildcard
    if not os.path.isfile(f"{FOLDER_INPUT}/{WT}.pdb"):
        print(f"Error, scaffold protein structure {FOLDER_INPUT}/{WT}.pdb is missing!")
        sys.exit()        
    seq = extract_sequence_from_pdb(f"{FOLDER_INPUT}/{WT}.pdb")
    design_positions = [int(x) for x in DESIGN.split(',')]
    # Replace seq with X at design positions. Note: Subtract 1 from each position to convert to Python's 0-based indexing
    seq = ''.join('X' if (i+1) in design_positions else amino_acid for i, amino_acid in enumerate(seq))
    with open(f'{FOLDER_HOME}/input_sequence_with_X_as_wildecard.seq', 'w') as f:
        f.writelines(seq)    
    
    # Get the Constraint Residues from enzdes constraints file
    with open(f'{FOLDER_INPUT}/{LIGAND}_enzdes.cst', 'r') as f:
        cst = f.readlines()    
    cst = [i.split()[-1] for i in cst if "TEMPLATE::   ATOM_MAP: 2 res" in i]
    cst = ";".join(cst)
    with open(f'{FOLDER_HOME}/cst.dat', 'w') as f:
        f.write(cst)    
    
def setup_aizymes(RESET=False, EXPLORE=False, PROMPT=True):
    
    
    # Check if setup needs to run
    if not os.path.isfile(VARIABLES_JSON):
        if PROMPT:
            if not input(f'''Do you want to start AIzymes? [y/n]

    ''') == 'y':
                return #stops the start up. Although VARIABLES_JSON is missing, user elected not to set up AIzymes


    else:
        if RESET:
            if not input(f'''Do you really want to restart AIzymes from scratch? 
    This will delete all existing files in {FOLDER_HOME} [y/n]

    ''') == 'y':
                return #stops the start up. Although VARIABLES_JSON exists and RESET set, user canceled
        else:
            return #stop startup. VARIABLES_JSON exists and RESET not set by user

    with open(LOG_FILE, 'w'): pass  #resets logfile
    logging.info(f"Running AI.zymes setup.")
    logging.info(f"Content of {FOLDER_HOME} deleted.")
    logging.info(f"Happy AI.zymeing! :)")
   
    if os.path.exists(FOLDER_HOME):
        for item in os.listdir(FOLDER_HOME):
            if item == FOLDER_MATCH: continue
            item = f'{FOLDER_HOME}/{item}'
            if os.path.isfile(item): 
                os.remove(item)
            elif os.path.isdir(item):
                shutil.rmtree(item)
    os.makedirs(FOLDER_HOME, exist_ok=True)

    prepare_input_files()
        
    #make empyt all_scores_df
    all_scores_df = make_empty_all_scores_df()

    # create empty blocked.dat
    #np.savetxt(BLOCKED_DAT, np.array([], dtype=int), fmt='%d')

    # Save global varliables
    variables_to_save = [
        'DESIGN_FOLDER', 'FOLDER_MATCH', 'MAX_JOBS', 'N_PARENT_JOBS', 'MAX_DESIGNS', 'KBT_BOLTZMANN', 'CST_WEIGHT',
        'ProteinMPNN_PROB', 'LMPNN_PROB', 'WT', 'LIGAND', 'ROSETTA_PATH', 'REPACK', 'DESIGN', 'MATCH', 'FOLDER_PARENT',
        'ProteinMPNN_T', 'PMPNN_BIAS', 'LMPNN_T', 'SUBMIT_PREFIX', 'BLUEPEBBLE', 'GRID', 'BACKGROUND_JOB', 'ABBIE_LOCAL', 'FIELD_TOOLS'
    ]
    globals_to_save = {k: globals()[k] for k in variables_to_save}
    globals_to_save['EXPLORE'] = EXPLORE
    
    if N_PARENT_JOBS < MAX_JOBS*2:
        logging.warning(f"To ensure a smooth start, N_PARENT_JOBS should be at least 2 x MAX_JOBS.")
        logging.warning(f"N_PARENT_JOBS: {N_PARENT_JOBS}, MAX_JOBS: {MAX_JOBS}.")
            
    with open(VARIABLES_JSON, 'w') as f: json.dump(globals_to_save, f, indent=4)
    
    
def make_empty_all_scores_df():
    
    all_scores_df = pd.DataFrame(columns=['index', 'sequence', 'parent_index', \
                                          'interface_score', 'total_score', 'catalytic_score', 'efield_score', \
                                          'interface_potential', 'total_potential', 'catalytic_potential', 'efield_potential', \
                                          'relax_interface_score', 'relax_total_score', 'relax_catalytic_score', 'relax_efield_score', \
                                          'design_interface_score', 'design_total_score', 'design_catalytic_score', 'design_efield_score', \
                                          'generation', 'mutations', 'design_method', 'score_taken_from', 'blocked', \
                                          'cat_resi', 'cat_resn'])
    
    save_all_scores_df(all_scores_df)

    return all_scores_df

# RosettaMatch

In [None]:
def run_RosettaMatch(EXPLORE=False, submit=False, bash=True):
        
    prepare_input_files()
    
    os.makedirs(FOLDER_MATCH, exist_ok=True)
    if not os.path.isdir(f'{FOLDER_HOME}/{FOLDER_MATCH}/scripts'):
        run_ESMfold_RosettaRelax(FOLDER_MATCH, all_scores_df=None, PreMatchRelax=True, EXPLORE=EXPLORE) 
    elif not os.path.isfile(f'{FOLDER_HOME}/{FOLDER_MATCH}/{WT}_Rosetta_Relax_{FOLDER_MATCH}_APO.pdb'):
        print(f"ESMfold and Relax of {FOLDER_MATCH} still running.")
    elif not os.path.isdir(f'{FOLDER_HOME}/{FOLDER_MATCH}/matches'):
        run_Matcher()
    else:
        print("Matching is done")
            
def run_Matcher():
        
    cmd = f"""       
  
cd {FOLDER_HOME}/{FOLDER_MATCH}

echo C9 > {LIGAND}.central
echo {" ".join(MATCH.split(","))} > {LIGAND}.pos

{ROSETTA_PATH}/bin/gen_lig_grids.linuxgccrelease \
    -s                      {WT}_Rosetta_Relax_aligned_{FOLDER_MATCH}_APO.pdb ESMfold/{WT}_CPPTraj_Lig_{FOLDER_MATCH}.pdb \
    -extra_res_fa           {FOLDER_INPUT}/{LIGAND}.params \
    -grid_delta             0.5 \
    -grid_lig_cutoff        5.0 \
    -grid_bb_cutoff         2.25 \
    -grid_active_res_cutoff 15.0 \
    -overwrite 

mv {WT}_Rosetta_Relax_aligned_{FOLDER_MATCH}_APO.pdb_0.gridlig {WT}.gridlig
rm {WT}_Rosetta_Relax_aligned_{FOLDER_MATCH}_APO.pdb_0.pos 2>1

rm -r matches
mkdir matches
cd matches

{ROSETTA_PATH}/bin/match.linuxgccrelease \
    -s                                        ../{WT}_Rosetta_Relax_aligned_{FOLDER_MATCH}_APO.pdb \
    -match:lig_name                           {LIGAND} \
    -extra_res_fa                             {FOLDER_INPUT}/{LIGAND}.params \
    -match:geometric_constraint_file          {FOLDER_INPUT}/{LIGAND}_enzdes.cst \
    -match::scaffold_active_site_residues     ../{LIGAND}.pos \
    -match:required_active_site_atom_names    ../{LIGAND}.central \
    -match:active_site_definition_by_gridlig  ../{WT}.gridlig  \
    -match:grid_boundary                      ../{WT}.gridlig  \
    -gridligpath                              ../{WT}.gridlig  \
    -overwrite  \
    -output_format PDB  \
    -output_matches_per_group 1  \
    -consolidate_matches true 
""" 
    with open(f'{FOLDER_HOME}/{FOLDER_MATCH}/scripts/RosettaMatch_{FOLDER_MATCH}.sh', 'w') as file: file.write(cmd)
    logging.info(f"Run Rosetta_Match for index {FOLDER_MATCH}.")
    submit_job(FOLDER_MATCH, job="RosettaMatch", bash=False)

def separate_matches():
    parent_folder = f'{FOLDER_HOME}/{FOLDER_PARENT}'
    pdb_files = [f for f in os.listdir(parent_folder) if f.endswith('.pdb')]

    for pdb_file in pdb_files:
        pdb_path = os.path.join(parent_folder, pdb_file)
        with open(pdb_path, 'r') as file:
            lines = file.readlines()
            for line in lines:
                if line.startswith('REMARK 666 MATCH TEMPLATE'):
                    parts = line.split()
                    catalytic_residue_index = parts[11]  # Extract the catalytic residue index
                    break

        # Create the destination folder if it doesn't exist
        dest_folder = os.path.join(parent_folder, catalytic_residue_index)
        os.makedirs(dest_folder, exist_ok=True)

        # Copy the PDB file to the destination folder instead of moving it
        shutil.copy(pdb_path, os.path.join(dest_folder, pdb_file))
    


# helper functions

In [None]:
def submit_job(index, job, bash=False, ram=16):        
      
    if GRID:
        submission_script = f"""#!/bin/bash
#$ -V
#$ -cwd
#$ -N {SUBMIT_PREFIX}_{job}_{index}
#$ -hard -l mf={ram}G
#$ -o {FOLDER_HOME}/{index}/scripts/AI_{job}_{index}.out
#$ -e {FOLDER_HOME}/{index}/scripts/AI_{job}_{index}.err
"""
    if BLUEPEBBLE:
        submission_script = f"""#!/bin/bash
#SBATCH --account={BLUEPEBBLE_ACCOUNT}
#SBATCH --partition=short
#SBATCH --mem=40GB
#SBATCH --ntasks-per-node=1
#SBATCH --time=2:00:00    
#SBATCH --nodes=1          
#SBATCH --job-name={SUBMIT_PREFIX}_{job}_{index}
#SBATCH --output={FOLDER_HOME}/{index}/scripts/AI_{job}_{index}.out
#SBATCH --error={FOLDER_HOME}/{index}/scripts/AI_{job}_{index}.err
"""
        
    if BACKGROUND_JOB:
        if not os.path.isfile(f'{FOLDER_HOME}/n_running_jobs.dat'):
            with open(f'{FOLDER_HOME}/n_running_jobs.dat', 'w') as f: f.write('0')
        with open(f'{FOLDER_HOME}/n_running_jobs.dat', 'r'): jobs = int(f.read())
        with open(f'{FOLDER_HOME}/n_running_jobs.dat', 'w'): f.write(jobs+1)
        submission_script = ""

    if ABBIE_LOCAL:
        submission_script = ""
        
    submission_script += f"""
# Output folder
cd {FOLDER_HOME}/{index}
pwd
bash {FOLDER_HOME}/{index}/scripts/{job}_{index}.sh
""" 
    if BACKGROUND_JOB:
        submission_script = f"""
jobs=$(cat {FOLDER_HOME}/n_running_jobs.dat)
jobs=$((jobs - 1))
echo "$jobs" > {FOLDER_HOME}/n_running_jobs.dat
"""


    # Create the submission_script
    with open(f'{FOLDER_HOME}/{index}/scripts/submit_{job}_{index}.sh', 'w') as file: file.write(submission_script)
    
    if bash:
        #Bash the submission_script for testing
        subprocess.run(f'bash {FOLDER_HOME}/{index}/scripts/submit_{job}_{index}.sh', shell=True, text=True)
    else:
        #Submit the submission_script
        if GRID:
            if "ESM" in job:
                
                output = subprocess.check_output(
    (f'qsub -l h="!bs-dsvr64&!bs-dsvr58&!bs-dsvr42&'
     f'!bs-grid64&!bs-grid65&!bs-grid66&!bs-grid67&'
     f'!bs-grid68&!bs-grid69&!bs-grid70&!bs-grid71&'
     f'!bs-grid72&!bs-grid73&!bs-grid74&!bs-grid75&'
     f'!bs-grid76&!bs-grid77&!bs-grid78&!bs-headnode04&'
     f'!bs-stellcontrol05&!bs-stellsubmit05" -q regular.q '
     f'{FOLDER_HOME}/{index}/scripts/submit_{job}_{index}.sh'),
    shell=True, text=True
    )
            else:
                output = subprocess.check_output(f'qsub -q regular.q \
                                                {FOLDER_HOME}/{index}/scripts/submit_{job}_{index}.sh', \
                                                shell=True, text=True)
            logging.debug(output[:-1]) #remove newline at end of output
            
        if BLUEPEBBLE:
            output = subprocess.check_output(f'sbatch {FOLDER_HOME}/{index}/scripts/submit_{job}_{index}.sh', \
                                             shell=True, text=True)
            logging.debug(output[:-1]) #remove newline at end of output
            
        if BACKGROUND_JOB:

            stdout_log_file_path = f'{FOLDER_HOME}/{index}/scripts/submit_{job}_{index}_stdout.log'
            stderr_log_file_path = f'{FOLDER_HOME}/{index}/scripts/submit_{job}_{index}_stderr.log'

            with open(stdout_log_file_path, 'w') as stdout_log_file, open(stderr_log_file_path, 'w') as stderr_log_file:
                process = subprocess.Popen(f'bash {FOLDER_HOME}/{index}/scripts/submit_{job}_{index}.sh &', 
                                           shell=True, stdout=stdout_log_file, stderr=stderr_log_file)
        
        if ABBIE_LOCAL:

            stdout_log_file_path = f'{FOLDER_HOME}/{index}/scripts/submit_{job}_{index}_stdout.log'
            stderr_log_file_path = f'{FOLDER_HOME}/{index}/scripts/submit_{job}_{index}_stderr.log'

            with open(stdout_log_file_path, 'w') as stdout_log_file, open(stderr_log_file_path, 'w') as stderr_log_file:
                process = subprocess.Popen(f'bash {FOLDER_HOME}/{index}/scripts/submit_{job}_{index}.sh &', 
                                           shell=True, stdout=stdout_log_file, stderr=stderr_log_file)
        
def extract_sequence_from_pdb(pdb_path):
    with open(pdb_path, "r") as pdb_file:
        for record in SeqIO.parse(pdb_file, "pdb-atom"):
            seq = str(record.seq)
    return seq
    
def generate_remark_from_all_scores_df(all_scores_df, index):

    remark = ''
    cat_resns = str(all_scores_df.at[index, 'cat_resn']).split(';')
    # Making sure resi is converted to int to avoid crash in Relax
    cat_resis = [int(float(x)) for x in str(all_scores_df.at[index, 'cat_resi']).split(';')]
    
    remarks = []

    for cat_resi, cat_resn in zip(cat_resis, cat_resns):
        remarks.append(f'REMARK 666 MATCH TEMPLATE X {LIGAND}    0 MATCH MOTIF A {cat_resn}{str(cat_resi).rjust(5)}  1  1')
    return "\n".join(remarks)

def save_cat_res_into_all_scores_df(all_scores_df, index, PDB_file_path, from_parent_struct=False):
    
    '''Finds the indices and names of the catalytic residue from <PDB_file_path> 
       Saves indices and residues into <all_scores_df> in row <index> as lists.
       To make sure these are saved and loaded as list, ";".join() and .split(";") should be used
       IF information is read from an input structure for design do not save cat_resn
       Returns the updated all_scores_df'''
      
    time.sleep(0.1)
    
    with open(PDB_file_path, 'r') as f: 
        PDB = f.readlines()
    
    remarks = [i for i in PDB if i[:10] == 'REMARK 666']

    cat_resis = []
    cat_resns = []

    for remark in remarks:
        cat_resis.append(str(int(remark[55:59])))

    for cat_resi in cat_resis:
        for line in PDB[len(remarks)+2:]:
            atomtype = line[12:16]
            if atomtype != " CA ": continue
            resi = str(int(line[22:26]))
            resn = line[17:20]
            if resi == cat_resi:
                cat_resns.append(resn)
                break

    # all_scores_df['cat_resi'] = all_scores_df['cat_resi'].astype(object)
    # all_scores_df['cat_resn'] = all_scores_df['cat_resn'].astype(object)
    all_scores_df.at[index, 'cat_resi'] = ";".join(cat_resis)
    # Only save the cat_resn if this comes from the designed structure, not from the input structure for design
    if not from_parent_struct:
        all_scores_df.at[index, 'cat_resn'] = ";".join(cat_resns)
    
    return all_scores_df 

def reset_to_after_parent_design():
    
    folders = []
    
    for folder_name in os.listdir(FOLDER_HOME):
        if os.path.isdir(os.path.join(FOLDER_HOME, folder_name)) and folder_name.isdigit():
            folders.append(int(folder_name))
    
    all_scores_df = make_empty_all_scores_df()
        
    PARENTS = [i for i in os.listdir(f'{FOLDER_HOME}/{FOLDER_PARENT}') if i[-4:] == ".pdb"]
    
    for folder in sorted(folders):
        
        folder_path = os.path.join(FOLDER_HOME, str(folder))
        
        if folder >= N_PARENT_JOBS * len(PARENTS):
            
            #Remove non-parent designs
            shutil.rmtree(folder_path)
            
        else:
            
            #Remove Potentials
            for item in os.listdir(folder_path):
                if 'potential.dat' not in item: continue
                item_path = os.path.join(folder_path, item)
                os.remove(item_path)
                print(item_path)
                    
            #Update Scorefile
            new_index, all_scores_df = create_new_index(parent_index="Parent", all_scores_df=all_scores_df)
            all_scores_df['design_method'] = all_scores_df['design_method'].astype('object') 
            all_scores_df.at[new_index, 'design_method'] = "RosettaDesign"
            all_scores_df['luca'] = all_scores_df['luca'].astype('object') 
            score_file_path = f"{FOLDER_HOME}/{int(index)}/score_rosetta_design.sc"
            with open(score_file_path, 'r') as f: score = f.readlines()[2]
            all_scores_df.at[new_index, 'luca'] = score.split()[-1][:-5]
    
            if new_index % 100 == 0: print(folder, new_index) 

    save_all_scores_df(all_scores_df)

def reset_to_after_index(index):
    '''This function resets the run back to a chosen index. It removes all later entries from the all_scores.csv and the home dir.
    index: The last index to keep, after which everything will be deleted.'''
    
    folders = []
    
    for folder_name in os.listdir(FOLDER_HOME):
        if os.path.isdir(os.path.join(FOLDER_HOME, folder_name)) and folder_name.isdigit():
            folders.append(int(folder_name))
    
    # Load the existing all_scores_df
    all_scores_df = pd.read_csv(ALL_SCORES_CSV)
    
    # Filter out rows with index greater than the specified index
    all_scores_df = all_scores_df[all_scores_df['index'] <= index]
    
    # Save the updated all_scores_df
    save_all_scores_df(all_scores_df)
    
    for folder in sorted(folders):
        if folder > index:
            folder_path = os.path.join(FOLDER_HOME, str(folder))
            shutil.rmtree(folder_path)
    
    print(f"Reset completed. All entries and folders after index {index} have been removed.")

    
def save_all_scores_df(all_scores_df):
    
    temp_fd, temp_path = tempfile.mkstemp(dir=FOLDER_HOME) # Create a temporary file

    try:
        all_scores_df.to_csv(temp_path, index=False)       # Save DataFrame to the temporary file
        os.close(temp_fd)                                  # Close file descriptor
        os.rename(temp_path, ALL_SCORES_CSV)               # Rename temporary file to final filename
    except Exception as e:
        os.close(temp_fd)                                  # Ensure file descriptor is closed in case of error
        os.unlink(temp_path)                               # Remove the temporary file if an error occurs
        raise e

def get_best_structures(save_structures = False, seq_per_active_site = 100):
    print(save_structures)
    # Read the scores DataFrame
    all_scores_df = pd.read_csv(ALL_SCORES_CSV)
    all_scores_df_og = all_scores_df.copy()
    # Drop rows with NaN in 'total_score'
    all_scores_df = all_scores_df.dropna(subset=['total_score'])

    # Calculate the combined scores using the normalize_scores function
    _, _, _, _, combined_scores = normalize_scores(all_scores_df, print_norm=False, norm_all=True)

    all_scores_df['combined_score'] = combined_scores

    all_scores_df['replicate_sequences'] = 0  # Initialize to count duplicates
    all_scores_df['replicate_sequences_combined_score'] = 0.0  # To store the average score
    all_scores_df['replicate_sequences_combined_score_std'] = 0.0  # To store the standard deviation

    # Loop to find duplicates, calculate average score, and standard deviation
    for i, row in all_scores_df.iterrows():
        duplicates = all_scores_df[all_scores_df['sequence'] == row['sequence']]
        avg_score = duplicates['combined_score'].mean()
        std_dev = duplicates['combined_score'].std()

        all_scores_df.at[i, 'replicate_sequences'] = len(duplicates)
        all_scores_df.at[i, 'replicate_sequences_combined_score'] = avg_score
        all_scores_df.at[i, 'replicate_sequences_combined_score_std'] = std_dev

    # Remove replicates and keep only highest combined_score
    all_scores_df.sort_values(by=['combined_score'], ascending=[False], inplace=True)
    all_scores_df.drop_duplicates(subset=['sequence'], keep='first', inplace=True)

    # Define Design group
    def get_design_sequence(sequence, design_positions):
        return ''.join(sequence[pos - 1] for pos in design_positions)

    design_positions = [int(pos) for pos in DESIGN.split(',')]
    #manually define AS here
    #design_positions = [95, 84, 65, 15, 97, 116, 99, 58, 18, 54, 14, 55, 26, 114, 38, 112, 30, 36, 82, 98, 63]
    all_scores_df['design_group'] = all_scores_df['sequence'].apply(lambda seq: get_design_sequence(seq, design_positions))

    # Use the standard deviation selection for catalytic score
    mean_catalytic_score = all_scores_df['catalytic_score'].mean()
    std_catalytic_score = all_scores_df['catalytic_score'].std()
    all_scores_df = all_scores_df[all_scores_df['catalytic_score'] < mean_catalytic_score + std_catalytic_score]

    # Get the best variants while respecting the seq_per_active_site limit
    top_variants = []
    group_counts = {}

    for _, row in all_scores_df.iterrows():
        group = row['design_group']
        if group not in group_counts:
            group_counts[group] = 0
        if group_counts[group] < seq_per_active_site:
            top_variants.append(row)
            group_counts[group] += 1
        if len(top_variants) >= 100:
            break

    top100 = pd.DataFrame(top_variants)

    selected_indices = np.array(top100['index'].tolist(), dtype=int)
    # print(selected_indices)
    # print(top100)

    # Print average scores for top 100 and all data points
    score_types = ['combined_score', 'total_score', 'interface_score', 'efield_score', 'catalytic_score', 'mutations']
    print_average_scores(all_scores_df, top100, score_types)

    # Create the destination folder if it doesn't exist
    best_structures_folder = os.path.join(FOLDER_HOME, 'best_structures')
    os.makedirs(best_structures_folder, exist_ok=True)

    # Copy files based on the top100 'index'
    print(best_structures_folder)
    if save_structures:
        print("Saving...")
        for index, row in top100.iterrows():
            geom_mean = "{:.3f}".format(row['combined_score'])
            relax_file = f"{FOLDER_HOME}/{int(index)}/{WT}_Rosetta_Relax_{int(index)}.pdb"
            design_file = f"{FOLDER_HOME}/{int(index)}/{WT}_Rosetta_Design_{int(index)}.pdb"
            if os.path.isfile(relax_file):
                src_file = relax_file
            else:
                src_file = design_file
            dest_file = os.path.join(best_structures_folder, f"{geom_mean}_{WT}_Rosetta_{os.path.basename(src_file)}")
            shutil.copy(src_file, dest_file)
        print("Saved strucutures to: ", best_structures_folder)
            
    # Plot sorted total score, interface score, and efield score distributions
    def plot_elbow_curve(scores_dict, title, top_indices):
        sorted_scores = sorted(scores_dict.items(), key=lambda x: x[1], reverse=True)
        indices, scores = zip(*sorted_scores)
        colors = ['orange' if idx in top_indices else 'blue' for idx in indices]
        alphas = [1.0 if idx in top_indices else 0.05 for idx in indices]
        plt.scatter(range(len(scores)), scores, color=colors, alpha=alphas, s=1)
        plt.title(title)
        plt.xlabel('Variants')
        plt.ylabel('Score')
        plt.show()

    top100_indices = set(top100['index'])

    plot_elbow_curve(all_scores_df.set_index('index')['total_score'].to_dict(), 'Total Score Elbow Curve', top100_indices)
    plot_elbow_curve(all_scores_df.set_index('index')['interface_score'].to_dict(), 'Interface Score Elbow Curve', top100_indices)
    plot_elbow_curve(all_scores_df.set_index('index')['efield_score'].to_dict(), 'Efield Score Elbow Curve', top100_indices)
    plot_elbow_curve(all_scores_df.set_index('index')['catalytic_score'].to_dict(), 'Catalytic Score Elbow Curve', top100_indices)

    return selected_indices



def trace_mutation_tree(all_scores_df, index):
    mutations = []
    offspring_counts = []
    combined_scores = []
    total_scores = []
    interface_scores = []
    efield_scores = []
    generations = []

    all_scores_df = all_scores_df.dropna(subset=['total_score'])
    
    # Calculate combined scores using normalized scores
    _, _, _, _, combined_scores_normalized = normalize_scores(all_scores_df, print_norm=True, norm_all=True)
    
    # Add combined scores to the DataFrame
    all_scores_df['combined_score'] = combined_scores_normalized

    # Cast index column to int
    all_scores_df['index'] = all_scores_df['index'].astype(int)
    all_scores_df['parent_index'] = all_scores_df['parent_index'].apply(lambda x: int(x) if x != "Parent" else x)

    def get_mutations(parent_seq, child_seq):
        return [f"{p}{i+1}{c}" for i, (p, c) in enumerate(zip(parent_seq, child_seq)) if p != c]

    def count_offspring(all_scores_df, parent_index):
        children = all_scores_df[all_scores_df['parent_index'] == parent_index]
        count = len(children)
        for child_index in children['index']:
            count += count_offspring(all_scores_df, child_index)
        return count

    total_variants = len(all_scores_df)
    total_mutations = int(all_scores_df.loc[all_scores_df['index'] == index, 'mutations'].values[0])
    current_index = index
    accumulated_mutations = 0

    while current_index in all_scores_df['index'].values:
        current_row = all_scores_df[all_scores_df['index'] == current_index].iloc[0]
        parent_index = current_row['parent_index']
        
        if parent_index in all_scores_df['index'].values:
            parent_row = all_scores_df[all_scores_df['index'] == parent_index].iloc[0]
            parent_seq = parent_row['sequence']
            child_seq = current_row['sequence']
            mutation = get_mutations(parent_seq, child_seq)
            offspring_count = count_offspring(all_scores_df, parent_index)
            
            mutations.append(mutation)
            offspring_counts.append(offspring_count)
            generations.append(current_row['generation'])
            
            # Store actual scores
            combined_scores.append(current_row['combined_score'])
            total_scores.append(current_row['total_score'])
            interface_scores.append(current_row['interface_score'])
            efield_scores.append(current_row['efield_score'])
        
        current_index = parent_index

    # Plot the actual scores
    fig, axs = plt.subplots(2, 2, figsize=(15, 10))

    def plot_scores(ax, scores, title):
        ax.plot(generations[::-1], scores[::-1], marker='o', linestyle='-', color='b')
        ax.set_title(title)
        ax.set_xlabel('Generation')
        ax.set_ylabel('Score')
        ax.grid(True, linestyle='--', which='major', color='grey', alpha=0.7)

    plot_scores(axs[0, 0], combined_scores, 'Combined Score vs Generations')
    plot_scores(axs[0, 1], total_scores, 'Total Score vs Generations')
    plot_scores(axs[1, 0], interface_scores, 'Interface Score vs Generations')
    plot_scores(axs[1, 1], efield_scores, 'Efield Score vs Generations')

    plt.tight_layout()
    plt.show()

    return mutations[::-1], offspring_counts[::-1], combined_scores[::-1], total_scores[::-1], interface_scores[::-1], efield_scores[::-1]

            


def print_average_scores(all_scores_df, top100, score_types):
    print("\nSummary of Average Scores:")
    print(f"{'Score Type':<20} {'Average of All':<20} {'Average of Top 100':<20}")
    print("="*60)
    for score_type in score_types:
        avg_all = all_scores_df[score_type].mean()
        avg_top100 = top100[score_type].mean()
        print(f"{score_type.replace('_', ' ').title():<20} {avg_all:<20.4f} {avg_top100:<20.4f}")
    print("\n")

def run_command(command, cwd=None, capture_output=False):
    """Wrapper to execute .py files in runtime with arguments, and print error messages if they occur.
    
    Parameters:
    - command: The command to run as a list of strings.
    - cwd: Optional; The directory to execute the command in.
    - capture_output: Optional; If True, capture stdout and stderr. Defaults to False (This is to conserve memory).
    """
    try:
        # If capture_output is True, capture stdout and stderr
        if capture_output:
            result = subprocess.run(command, capture_output=True, text=True, check=True, cwd=cwd)
        else:
            # If capture_output is False, suppress all output by redirecting to os.devnull
            with open(os.devnull, 'w') as fnull:
                result = subprocess.run(command, stdout=fnull, stderr=fnull, text=True, check=True, cwd=cwd)
        return result.stdout
    except subprocess.CalledProcessError as e:
        logging.error(f"Command '{e.cmd}' failed with return code {e.returncode}")
        logging.error(e.stderr)
        #maybe rerun command here in case of efields
        raise
    except Exception as e:
        logging.error(f"An error occurred while running command: {command}")
        raise

def wait_for_file(file_path, timeout=5):
    """Wait for a file to exist and have a non-zero size."""
    start_time = time.time()
    while time.time() - start_time < timeout:
        if os.path.exists(file_path) and os.path.getsize(file_path) > 0:
            return True
        time.sleep(0.1)  # Wait for 0.1 seconds before checking again
    return False

def hamming_distance(seq1, seq2):
    """Calculate the Hamming distance between two strings."""
    if not isinstance(seq2, str):
        return None
    if len(seq1) != len(seq2):
        raise ValueError("Sequences must be of equal length")
    return sum(ch1 != ch2 for ch1, ch2 in zip(seq1, seq2))



## Electric Fields

In [None]:
def generate_AMBER_files(structure_filename : str):
    '''Uses tleap to create a .parm7 and .rst7 file from a pdb. Requires ambertools and pdb-tools 
    (pip installable from  https://www.bonvinlab.org/pdb-tools/).
    Also requires 5TS.prepi and 5TS.frcmod in the INPUT folder
    TODO: Add script to generate these if not present.
    
    Parameters:
    - structure_filename (str): The path to the pdb file to analyse without the file extension.'''

    #delete everything after 'CONECT' - otherwise Amber tries to read it as extra residues
    clean_pdb = run_command(['sed', '-n', '/CONECT/q;p', f"{structure_filename}.pdb"], capture_output=True)
    with open(f"{structure_filename}_clean.pdb", "w") as f:
        f.write(clean_pdb)

    # with open(f"{structure_filename}_clean.pdb", "w") as f:
    #     subprocess.call(['sed', '-n', '/CONECT/q;p', f"{structure_filename}.pdb"], stdout=f)
        
    #remove hydrogens - requires pip install of https://www.bonvinlab.org/pdb-tools/
    noHclean_pdb = run_command(['pdb_delelem', '-H', f"{structure_filename}_clean.pdb"], capture_output=True)
    with open(f"{structure_filename}_noHclean.pdb", "w") as f:
        f.write(noHclean_pdb)

    # with open(f"{structure_filename}_noHclean.pdb", "w") as f:
    #     subprocess.call(['pdb_delelem', '-H', f"{structure_filename}_clean.pdb"], stdout=f)

    os.remove(f"{structure_filename}_clean.pdb")

    with open("tleap.in", "w") as f:
            f.write(f"""source leaprc.protein.ff19SB 
    source leaprc.gaff
    loadamberprep   Input/5TS.prepi
    loadamberparams Input/5TS.frcmod
    mol = loadpdb {structure_filename}_noHclean.pdb
    saveamberparm mol {structure_filename}.parm7 {structure_filename}.rst7
    quit
    """)
            
    run_command(["tleap", "-s", "-f", "tleap.in"])

def calc_efields_score(pdb_path):
    '''Executes the FieldTools.py script to calculate the electric field across the C-H bond of 5TS.
    Requires a field_target.dat in the Input folder. Currently hard-coded based on 5TS
    TODO: Make this function agnostic to contents of field_target
    
    Parameters:
    - pdb_path (str): The path to the pdb structure to analyse - either from design or relax.
    
    Returns:
    - bond_field (float): The total electric field across the 5TS@C9_:5TS@H04 bond in MV/cm. (Currently hard coded to these atoms)
    - all_fields (dict): The components of the electric field across the 5TS@C9_:5TS@H04 bond per residue.'''

    structure_filename = os.path.splitext(pdb_path)[0]

    generate_AMBER_files(structure_filename)

    run_command(["python", f"{FIELD_TOOLS}", 
                "-nc", f"{os.path.relpath(structure_filename)}.rst7", 
                "-parm", f"{os.path.relpath(structure_filename)}.parm7", 
                "-out", f"{os.path.relpath(structure_filename)}_fields.pkl", 
                "-target", "Input/field_target.dat", 
                "-solvent", "WAT"])
    

    with open(f"{structure_filename}_fields.pkl", "rb") as f:
        FIELDS = pkl.load(f)

    bond_field = FIELDS[':5TS@C9_:5TS@H04']['Total']
    all_fields = FIELDS[':5TS@C9_:5TS@H04']

    return bond_field[0], all_fields

def update_efieldsdf(index:int, index_efields_dict:dict):
    '''Adds a new row to "{FOLDER_HOME}/electric_fields.csv" containing the electric fields 
    generated by FieldTools.py for all residues in the protein'''

    no_residues = len(index_efields_dict)-4

    gen_headers = ["Total","Protein","Solvent","WAT"]
    resi_headers = [f"RESI_{idx}" for idx in range(1,no_residues+1)]
    headers = gen_headers + resi_headers

    fields_list = [field[0] for field in index_efields_dict.values()]

    if not os.path.isfile(f"{FOLDER_HOME}/electric_fields.csv"):       
        fields_df = pd.DataFrame([fields_list], columns=headers, index=[index])
        fields_df.to_csv(f"{FOLDER_HOME}/electric_fields.csv") 

    else:
        fields_df = pd.read_csv(f"{FOLDER_HOME}/electric_fields.csv", index_col=0)
        new_row_df = pd.DataFrame([fields_list], columns=headers, index=[index])
        fields_df = pd.concat([fields_df, new_row_df])
        fields_df.sort_index(inplace=True)
        fields_df.to_csv(f"{FOLDER_HOME}/electric_fields.csv") 


# plotting functions

In [None]:
def plot_scores(combined_score_min=0, combined_score_max=1, combined_score_bin=0.01,
                interface_score_min=0, interface_score_max=1, interface_score_bin=0.01,
                total_score_min=0, total_score_max=1, total_score_bin=0.01,
                catalytic_score_min=0, catalytic_score_max=1, catalytic_score_bin=0.01,
                mut_min=0,mut_max=len(DESIGN.split(","))+1):
        
    # Break because file does not exist
    if not os.path.isfile(ALL_SCORES_CSV): return
    
    all_scores_df = pd.read_csv(ALL_SCORES_CSV)
    all_scores_df['sequence'] = all_scores_df['sequence'].astype(str)
    all_scores_df['design_method'] = all_scores_df['design_method'].astype(str)
    all_scores_df['score_taken_from'] = all_scores_df['score_taken_from'].astype(str)
            
    # Break because not enough data
    if len(all_scores_df.dropna(subset=['total_score'])) < 3: return
    

    if (ProteinMPNN_PROB > 0 or LMPNN_PROB > 0):
        #first not nan sequence from all_scores_df
        mut_max = len(all_scores_df[all_scores_df['sequence'] != 'nan']['sequence'].iloc[0])   
        
    # Plot data
    fig, axs = plt.subplots(3, 4, figsize=(15, 9))
    
    all_scores_df = all_scores_df.dropna(subset=['total_score'])
    catalytic_scores, total_scores, interface_scores, efield_scores, combined_scores = normalize_scores(all_scores_df, 
                                                                                         print_norm=True,
                                                                                         norm_all=True)
            
    plot_combined_score(axs[0,0], combined_scores, \
                        combined_score_min, combined_score_max, combined_score_bin)
    plot_interface_score(axs[0,1], interface_scores, \
                         interface_score_min, interface_score_max, interface_score_bin)
    plot_total_score(axs[0,2], total_scores, \
                     total_score_min, total_score_max, total_score_bin)
    plot_catalytic_score(axs[0,3], catalytic_scores, \
                         catalytic_score_min, catalytic_score_max, catalytic_score_bin)
    plot_efield_score(axs[2,0], efield_scores, \
                        catalytic_score_min, catalytic_score_max, catalytic_score_bin)
    
    plot_boltzmann_histogram(axs[1,0], combined_scores, all_scores_df, \
                             combined_score_min, combined_score_max, combined_score_bin)
    plot_combined_score_v_index(axs[1,1], combined_scores, all_scores_df)
    plot_combined_score_v_generation_violin(axs[1,2], combined_scores, all_scores_df)
    plot_mutations_v_generation_violin(axs[1,3], all_scores_df, mut_min, mut_max)
    plot_score_v_generation_violin(axs[2,1], 'total_score', all_scores_df)
    plot_score_v_generation_violin(axs[2,2], 'interface_score', all_scores_df)
    plot_score_v_generation_violin(axs[2,3], 'efield_score', all_scores_df)
    plt.tight_layout()
    plt.show()
    
    # fig, ax = plt.subplots(1, 1, figsize=(15, 2.5))
    # plot_tree(ax, combined_scores, all_scores_df, G)
    # plt.show()

def plot_summary():
    all_scores_df = pd.read_csv(ALL_SCORES_CSV)
    
    fig, axs = plt.subplots(3, 2, figsize=(20, 23))

    # Normalize scores
    catalytic_scores, total_scores, interface_scores, efield_scores, combined_scores = normalize_scores(all_scores_df, 
                                                                                         print_norm=True,
                                                                                         norm_all=True)
    
    # Plot interface vs total score colored by generation
    plot_interface_v_total_score_generation(axs[0,0], total_scores, interface_scores, all_scores_df['generation'])

    # Plot stacked histogram of interface scores by generation
    plot_stacked_histogram_by_generation(axs[0, 1], all_scores_df)

    # Plot stacked histogram of interface scores by catalytic residue index
        # Create a consistent color map for catalytic residues
    all_scores_df['cat_resi'] = pd.to_numeric(all_scores_df['cat_resi'], errors='coerce')
    unique_cat_resi = all_scores_df['cat_resi'].dropna().unique()
    colors = plt.cm.tab20(np.linspace(0, 1, len(unique_cat_resi)))
    color_map = {resi: colors[i] for i, resi in enumerate(unique_cat_resi)}
    plot_stacked_histogram_by_cat_resi(axs[1, 0], all_scores_df, color_map=color_map)

    # Plot stacked histogram of interface scores by catalytic residue index (excluding generation 0)
    all_scores_df_cleaned = all_scores_df[all_scores_df['generation'] != 0]  # Exclude generation 0
    plot_stacked_histogram_by_cat_resi(axs[1, 1], all_scores_df_cleaned, color_map=color_map)

    # Plot interface vs total score colored by catalytic residue index
    plot_interface_v_total_score_cat_resi(axs[2,0], total_scores, interface_scores, all_scores_df['cat_resi'])

    # Plot interface vs total score colored by catalytic residue name
    legend_elements = plot_interface_v_total_score_cat_resn(axs[2, 1], total_scores, interface_scores, all_scores_df['cat_resn'])
    
    axs[0,0].set_title('Total Scores vs Interface Scores by Generation')
    axs[0,1].set_title('Stacked Histogram of Interface Scores by Generation')
    axs[1,0].set_title('Stacked Histogram of Interface Scores by Catalytic Residue Index')
    axs[1,1].set_title('Stacked Histogram of Interface Scores by Catalytic Residue Index (Excluding Generation 0)')
    axs[2,0].set_title('Total Scores vs Interface Scores by Catalytic Residue Index')
    axs[2,1].set_title('Total Scores vs Interface Scores by Catalytic Residue Name')

    # Adjust legends
    # For cat_resi
    handles_cat_resi, labels_cat_resi = axs[1,1].get_legend_handles_labels()
    fig.legend(handles_cat_resi, labels_cat_resi, loc='upper right', bbox_to_anchor=(0.95, 0.65), title="Catalytic Residue Index")

    # For generation
    handles_generation, labels_generation = axs[0,1].get_legend_handles_labels()
    fig.legend(handles_generation, labels_generation, loc='upper right', bbox_to_anchor=(0.95, 0.98), title="Generation")

    # For cat_resn
    handles_cat_resn, labels_cat_resn = axs[2,1].get_legend_handles_labels()
    fig.legend(handles = legend_elements, loc='upper right', bbox_to_anchor=(0.95, 0.31), title="Catalytic Residue")

    # Adjust layout to make space for the legends on the right
    plt.tight_layout(rect=[0, 0, 0.85, 1])

def plot_interface_v_total_score_selection(ax, total_scores, interface_scores, selected_indices):
    """
    Plots a scatter plot of total_scores vs interface_scores and highlights the points
    corresponding to the selected indices.

    Parameters:
    - ax (matplotlib.axes.Axes): The Axes object to plot on.
    - total_scores (list or np.array): The total scores of the structures.
    - interface_scores (list or np.array): The interface scores of the structures.
    - selected_indices (list of int): Indices of the points to highlight.
    """
    
    # Create a mask for selected indices
    mask = np.ones(len(total_scores), dtype=bool)  # Initialize mask to include all points
    mask[selected_indices] = False  # Exclude selected indices
    
    # Plot all points excluding the selected ones on the given Axes object
    ax.scatter(total_scores[mask], interface_scores[mask], color='gray', alpha=0.3, label='All Points', s=1)
    
    # Highlight selected points on the given Axes object
    ax.scatter(total_scores[selected_indices], interface_scores[selected_indices], color='red', alpha=0.4, label='Selected Points', s=1)
    
    ax.set_title('Total Scores vs Interface Scores')
    ax.set_xlabel('Total Score')
    ax.set_ylabel('Interface Score')
    ax.legend()
    ax.grid(True)

    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

def plot_interface_v_total_score_cat_resi(ax, total_scores, interface_scores, cat_resi):
    from matplotlib.lines import Line2D
    """
    Plots a scatter plot of total_scores vs interface_scores and colors the points
    according to the catalytic residue number (cat_resi) for all data points, using categorical coloring.
    Adds a legend to represent each unique catalytic residue number with its corresponding color.

    Parameters:
    - ax (matplotlib.axes.Axes): The Axes object to plot on.
    - total_scores (list or np.array): The total scores of the structures.
    - interface_scores (list or np.array): The interface scores of the structures.
    - cat_resi (pd.Series or np.array): Catalytic residue numbers for all data points.
    """
    # Ensure cat_resi is a pandas Series for easier handling and remove NaN values
    if not isinstance(cat_resi, pd.Series):
        cat_resi = pd.Series(cat_resi)
    cat_resi = cat_resi.dropna()  # Drop NaN values
    
    # Proceed with the rest of the function after removing NaN values
    unique_cat_resi = cat_resi.unique()
    colors = plt.cm.tab20(np.linspace(0, 1, len(unique_cat_resi)))

    color_map = {resi: colors[i] for i, resi in enumerate(unique_cat_resi)}
    cat_resi_colors = cat_resi.map(color_map).values

    scatter = ax.scatter(total_scores[cat_resi.index], interface_scores[cat_resi.index], c=cat_resi_colors, alpha=0.4, s=2)

    legend_elements = [Line2D([0], [0], marker='o', color='w', label=f'Cat Resi {resi}',
                              markerfacecolor=color_map[resi], markersize=10) for resi in unique_cat_resi]
    #ax.legend(handles=legend_elements, title="Catalytic Residue", bbox_to_anchor=(1.05, 1), loc='upper left')

    ax.set_title('Total Scores vs Interface Scores')
    ax.set_xlabel('Total Score')
    ax.set_ylabel('Interface Score')
    ax.grid(True)

    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

def plot_interface_v_total_score_cat_resn(ax, total_scores, interface_scores, cat_resn):
    from matplotlib.lines import Line2D
    """
    Plots a scatter plot of total_scores vs interface_scores and colors the points
    according to the catalytic residue name (cat_resn) for all data points, using categorical coloring.
    Adds a legend to represent each unique catalytic residue name with its corresponding color.

    Parameters:
    - ax (matplotlib.axes.Axes): The Axes object to plot on.
    - total_scores (list or np.array): The total scores of the structures.
    - interface_scores (list or np.array): The interface scores of the structures.
    - cat_resn (pd.Series or np.array): Catalytic residue names for all data points.
    """
    # Ensure cat_resn is a pandas Series for easier handling and remove NaN values
    if not isinstance(cat_resn, pd.Series):
        cat_resn = pd.Series(cat_resn)
    cat_resn = cat_resn.dropna()  # Drop NaN values
    
    # Proceed with the rest of the function after removing NaN values
    unique_cat_resn = cat_resn.unique()
    colors = plt.cm.tab20(np.linspace(0, 1, len(unique_cat_resn)))

    color_map = {resn: colors[i] for i, resn in enumerate(unique_cat_resn)}
    cat_resn_colors = cat_resn.map(color_map).values

    scatter = ax.scatter(total_scores[cat_resn.index], interface_scores[cat_resn.index], c=cat_resn_colors, alpha=0.4, s=2)

    legend_elements = [Line2D([0], [0], marker='o', color='w', label=f'Cat Resn {resn}',
                              markerfacecolor=color_map[resn], markersize=10) for resn in unique_cat_resn]
    #ax.legend(handles=legend_elements, title="Catalytic Residue", bbox_to_anchor=(1.05, 1), loc='upper left')

    ax.set_title('Total Scores vs Interface Scores')
    ax.set_xlabel('Total Score')
    ax.set_ylabel('Interface Score')
    ax.grid(True)

    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

    return legend_elements

def plot_interface_v_total_score_generation(ax, total_scores, interface_scores, generation):
    """
    Plots a scatter plot of total_scores vs interface_scores and colors the points
    according to the generation for all data points, using categorical coloring.
    Adds a legend to represent each unique generation with its corresponding color.

    Parameters:
    - ax (matplotlib.axes.Axes): The Axes object to plot on.
    - total_scores (list or np.array): The total scores of the structures.
    - interface_scores (list or np.array): The interface scores of the structures.
    - generation (pd.Series or np.array): Generation numbers for all data points.
    """
    if not isinstance(generation, pd.Series):
        generation = pd.Series(generation)
    generation = generation.dropna()  # Drop NaN values
    
    unique_generations = generation.unique()
    colors = plt.cm.viridis(np.linspace(0, 1, len(unique_generations)))

    color_map = {gen: colors[i] for i, gen in enumerate(unique_generations)}
    
    # Loop through each generation to plot, adjusting alpha for generation 0
    for gen in unique_generations:
        gen_mask = generation == gen
        alpha_value = 0.2 if gen == 0 else 0.8  # More transparent for generation 0
        ax.scatter(total_scores[generation.index][gen_mask], interface_scores[generation.index][gen_mask], 
                   c=[color_map[gen]], alpha=alpha_value, s=2, label=f'Generation {gen}' if gen == 0 else None)
    #ax.legend(handles=legend_elements, title="Generation", bbox_to_anchor=(1.05, 1), loc='upper left')

    ax.set_title('Total vs Interface Scores - Generation', fontsize=18)
    ax.set_xlabel('Total Score', fontsize=16)
    ax.set_ylabel('Interface Score', fontsize=16)
    ax.grid(True)

    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

    ax.tick_params(axis='both', which='major', labelsize=16)

def plot_stacked_histogram_by_cat_resi(ax, all_scores_df, color_map=None, show_legend=False):
    """
    Plots a stacked bar plot of interface scores colored by cat_resi on the given Axes object,
    where each bar's segments represent counts of different cat_resi values in that bin.

    Parameters:
    - ax (matplotlib.axes.Axes): The Axes object to plot on.
    - all_scores_df (pd.DataFrame): DataFrame containing 'cat_resi' and 'interface_score' columns.
    - color_map (dict): Optional; A dictionary mapping catalytic residue indices to colors.
    - show_legend (bool): Optional; Whether to show the legend. Defaults to False.
    """
    # Drop rows with NaN in 'interface_score'
    all_scores_df_cleaned = all_scores_df.dropna(subset=['interface_score'])

    # Ensure cat_resi is numeric and drop NaN values
    all_scores_df_cleaned['cat_resi'] = pd.to_numeric(all_scores_df_cleaned['cat_resi'], errors='coerce').dropna()
    unique_cat_resi = all_scores_df_cleaned['cat_resi'].unique()

    # Use the provided color_map or generate a new one
    if color_map is None:
        colors = plt.cm.tab20(np.linspace(0, 1, len(unique_cat_resi)))  # Generate colors for unique cat_resi values
        color_map = {resi: colors[i] for i, resi in enumerate(unique_cat_resi)}  # Map cat_resi to colors

    # Define bins for the histogram
    bins = np.linspace(all_scores_df_cleaned['interface_score'].min(), all_scores_df_cleaned['interface_score'].max(), 21)
    bin_centers = 0.5 * (bins[:-1] + bins[1:])

    # Calculate counts for each cat_resi in each bin
    counts = {resi: np.histogram(all_scores_df_cleaned[all_scores_df_cleaned['cat_resi'] == resi]['interface_score'], bins=bins)[0] for resi in unique_cat_resi}

    # Plot stacked bars for each bin
    bottom = np.zeros(len(bin_centers))
    for resi in unique_cat_resi:
        ax.bar(bin_centers, counts[resi], bottom=bottom, width=np.diff(bins), label=f'Cat Resi {resi}', color=color_map[resi], align='center')
        bottom += counts[resi]

    # Create a custom legend
    legend_elements = [Line2D([0], [0], marker='o', color='w', label=f'Cat Resi {resi}',
                              markerfacecolor=color_map[resi], markersize=10) for resi in unique_cat_resi]
    if show_legend:
        ax.legend(handles=legend_elements, title="Catalytic Residue")

    ax.set_title('Stacked Histogram of Interface Scores')
    ax.set_xlabel('Interface Score')
    ax.set_ylabel('Count')

    ax.set_xlim(-32.5, -13.5)

def plot_stacked_histogram_by_cat_resn(ax, all_scores_df):
    """
    Plots a stacked bar plot of interface scores colored by cat_resn on the given Axes object,
    where each bar's segments represent counts of different cat_resn values in that bin.

    Parameters:
    - ax (matplotlib.axes.Axes): The Axes object to plot on.
    - all_scores_df (pd.DataFrame): DataFrame containing 'cat_resn' and 'interface_score' columns.
    """
    # Drop rows with NaN in 'interface_score'
    all_scores_df_cleaned = all_scores_df.dropna(subset=['interface_score'])

    # Ensure cat_resn is a string and drop NaN values
    all_scores_df_cleaned['cat_resn'] = all_scores_df_cleaned['cat_resn'].astype(str).dropna()
    unique_cat_resn = all_scores_df_cleaned['cat_resn'].unique()
    colors = plt.cm.tab20(np.linspace(0, 1, len(unique_cat_resn)))  # Generate colors for unique cat_resn values

    color_map = {resn: colors[i] for i, resn in enumerate(unique_cat_resn)}  # Map cat_resn to colors

    # Define bins for the histogram
    bins = np.linspace(all_scores_df_cleaned['interface_score'].min(), all_scores_df_cleaned['interface_score'].max(), 21)
    bin_centers = 0.5 * (bins[:-1] + bins[1:])

    # Calculate counts for each cat_resn in each bin
    counts = {resn: np.histogram(all_scores_df_cleaned[all_scores_df_cleaned['cat_resn'] == resn]['interface_score'], bins=bins)[0] for resn in unique_cat_resn}

    # Plot stacked bars for each bin
    bottom = np.zeros(len(bin_centers))
    for resn in unique_cat_resn:
        ax.bar(bin_centers, counts[resn], bottom=bottom, width=np.diff(bins), label=f'Cat Resn {resn}', color=color_map[resn], align='center')
        bottom += counts[resn]
    
    ax.set_xlim(-32.5, -13.5)

    # Create a custom legend
    legend_elements = [Line2D([0], [0], marker='o', color='w', label=f'Cat Resn {resn}',
                              markerfacecolor=color_map[resn], markersize=10) for resn in unique_cat_resn]
    #ax.legend(handles=legend_elements, title="Catalytic Residue")

    ax.set_title('Stacked Histogram of Interface Scores by Catalytic Residue')
    ax.set_xlabel('Interface Score')
    ax.set_ylabel('Count')

def plot_stacked_histogram_by_generation(ax, all_scores_df):
    """
    Plots a stacked bar plot of interface scores colored by generation on the given Axes object,
    where each bar's segments represent counts of different generation values in that bin.

    Parameters:
    - ax (matplotlib.axes.Axes): The Axes object to plot on.
    - all_scores_df (pd.DataFrame): DataFrame containing 'generation' and 'interface_score' columns.
    """
    all_scores_df_cleaned = all_scores_df.dropna(subset=['interface_score', 'generation'])

    all_scores_df_cleaned['generation'] = pd.to_numeric(all_scores_df_cleaned['generation'], errors='coerce').dropna()
    unique_generations = all_scores_df_cleaned['generation'].unique()
    colors = plt.cm.viridis(np.linspace(0, 1, len(unique_generations)))

    color_map = {gen: colors[i] for i, gen in enumerate(unique_generations)}

    bins = np.linspace(all_scores_df_cleaned['interface_score'].min(), all_scores_df_cleaned['interface_score'].max(), 21)
    bin_centers = 0.5 * (bins[:-1] + bins[1:])

    counts = {gen: np.histogram(all_scores_df_cleaned[all_scores_df_cleaned['generation'] == gen]['interface_score'], bins=bins)[0] for gen in unique_generations}

    bottom = np.zeros(len(bin_centers))
    for gen in unique_generations:
        ax.bar(bin_centers, counts[gen], bottom=bottom, width=np.diff(bins), label=f'Generation {gen}', color=color_map[gen], align='center')
        bottom += counts[gen]

    legend_elements = [Line2D([0], [0], marker='o', color='w', label=f'Generation {gen}',
                              markerfacecolor=color_map[gen], markersize=10) for gen in unique_generations]
    #ax.legend(handles=legend_elements, title="Generation")

    ax.set_title('Stacked Histogram of Interface Scores by Generation')
    ax.set_xlabel('Interface Score')
    ax.set_ylabel('Count')

    ax.set_xlim(-32.5, -13.5)

def plot_combined_score(ax, combined_scores, score_min, score_max, score_bin):
    
    ax.hist(combined_scores, bins=np.arange(score_min,score_max+score_bin,score_bin))
    ax.axvline(HIGHSCORE, color='b')
    ax.axvline(NEG_BEST, color='r')
    ax.set_xlim(score_min,score_max)
    ax.set_title('Histogram of Score')
    ax.set_xlabel('Combined Score')
    ax.set_ylabel('Frequency')
    
def plot_interface_score(ax, interface_scores, interface_score_min, interface_score_max, interface_score_bin):
        
    ax.hist(interface_scores, density=True,
            bins=np.arange(interface_score_min,interface_score_max+interface_score_bin,interface_score_bin))
    ax.set_xlim(interface_score_min,interface_score_max)
    ax.set_title('Histogram of Interface Score')
    ax.set_xlabel('Interface Score')
    ax.set_ylabel('Frequency')
    
def plot_total_score(ax, total_scores, total_score_min, total_score_max, total_score_bin):

    ax.hist(total_scores, density=True,
            bins=np.arange(total_score_min,total_score_max+total_score_bin,total_score_bin))
    ax.set_xlim(total_score_min,total_score_max)
    ax.set_title('Histogram of Total Score')
    ax.set_xlabel('Total Score')
    ax.set_ylabel('Frequency')
    
def plot_catalytic_score(ax, catalytic_scores, total_score_min, total_score_max, total_score_bin):

    ax.hist(catalytic_scores, density=True,
            bins=np.arange(total_score_min,total_score_max+total_score_bin,total_score_bin))
    ax.set_xlim(total_score_min,total_score_max)
    ax.set_title('Histogram of Catalytic Score')
    ax.set_xlabel('Catalytic Score')
    ax.set_ylabel('Frequency')
    
def plot_efield_score(ax, efield_scores, total_score_min, total_score_max, total_score_bin):

    ax.hist(efield_scores, density=True,
            bins=np.arange(total_score_min,total_score_max+total_score_bin,total_score_bin))
    ax.set_xlim(total_score_min,total_score_max)
    ax.set_title('Histogram of Efield Score')
    ax.set_xlabel('Efield Score')
    ax.set_ylabel('Frequency')

def plot_boltzmann_histogram(ax, combined_scores, all_scores_df, score_min, score_max, score_bin):

    _, _, _, _, combined_potentials = normalize_scores(all_scores_df, print_norm=False, norm_all=False, extension="score")
            
    if isinstance(KBT_BOLTZMANN, (float, int)):
        kbt_boltzmann = KBT_BOLTZMANN
    else:
        if len(KBT_BOLTZMANN) == 2:
            kbt_boltzmann = max(KBT_BOLTZMANN[0] * np.exp(-KBT_BOLTZMANN[1]*all_scores_df['index'].max()), 0.05)
    boltzmann_factors = np.exp(combined_potentials / (kbt_boltzmann)) 
    print(f"Min/Max boltzmann factors: {min(boltzmann_factors)}, {max(boltzmann_factors)}")
    probabilities     = boltzmann_factors / sum(boltzmann_factors) 
    
    random_scores     = np.random.choice(combined_potentials, size=10000, replace=True)
    boltzmann_scores  = np.random.choice(combined_potentials, size=10000, replace=True, p=probabilities)


    # Plot the first histogram
    ax.hist(random_scores, density=True, alpha=0.7, label='Random Sampling', \
            bins=np.arange(score_min-2,score_max+1+score_bin,score_bin))
    ax.text(0.05, 0.95, "normalized only to \n this dataset")
    ax.set_xlabel('Potential')
    ax.set_ylabel('Density (Normal)')
    ax.set_title(f'kbT = {kbt_boltzmann:.1e}')
    # Create a twin y-axis for the second histogram
    ax_dup = ax.twinx()
    ax_dup.hist(boltzmann_scores, density=True, alpha=0.7, color='orange', label='Boltzmann Sampling', \
                bins=np.arange(score_min-2,score_max+1+score_bin,score_bin))
    ax.set_xlim(score_min-2,score_max+1)
    ax_dup.set_ylabel('Density (Boltzmann)')
    ax_dup.tick_params(axis='y', labelcolor='orange')

def plot_interface_score_v_total_score(ax, all_scores_df, 
                                       total_score_min, total_score_max, interface_score_min, interface_score_max):

    ax.scatter(all_scores_df['total_score'], all_scores_df['interface_score'],
            c=all_scores_df['index'], cmap='coolwarm_r', s=5)
    correlation,_ = pearsonr(all_scores_df['total_score'], all_scores_df['interface_score'])
    xmin = all_scores_df['total_score'].min()
    xmax = all_scores_df['total_score'].max()
    z = np.polyfit(all_scores_df['total_score'], all_scores_df['interface_score'], 1)
    p = np.poly1d(z)
    x_trendline = np.linspace(xmin, xmax, 100) 
    ax.plot(x_trendline, p(x_trendline), "k")
    ax.set_title(f'Pearson r: {correlation:.2f}')
    ax.set_xlim(total_score_min,total_score_max)
    ax.set_ylim(interface_score_min,interface_score_max)
    ax.set_xlabel('Total Score')
    ax.set_ylabel('Interface Score')

def plot_combined_score_v_index(ax, combined_scores, all_scores_df):
    
    combined_scores = pd.Series(combined_scores)
    moving_avg = combined_scores.rolling(window=20).mean()
    ax.scatter(all_scores_df['index'], combined_scores, c='lightgrey', s=5) 
    ax.axhline(HIGHSCORE, color='b', alpha = 0.5)
    ax.axhline(NEG_BEST, color='r', alpha = 0.5)
    PARENTS = [i for i in os.listdir(f'{FOLDER_HOME}/{FOLDER_PARENT}') if i[-4:] == ".pdb"]
    ax.axvline(N_PARENT_JOBS*len(PARENTS), color='k')
    ax.plot(range(len(moving_avg)),moving_avg,c="k")
    ax.set_ylim(0,1)
    ax.set_xlim(0,MAX_DESIGNS)
    ax.set_title('Score vs Index')
    ax.set_xlabel('Index')
    ax.set_ylabel('Combined Score')    
    
def plot_combined_score_v_generation(ax, combined_scores, all_scores_df):
    
    all_scores_df['tmp'] = combined_scores
    all_scores_df = all_scores_df.dropna(subset=['tmp'])
    
    max_gen = int(all_scores_df['generation'].max())
    boxplot_data = [all_scores_df[all_scores_df['generation'] == gen]['tmp'] for gen in range(0,max_gen+1,1)]
    # Define properties for the outliers
    flierprops = dict(marker='o', markerfacecolor='green', markersize=3, linestyle='none')
    
    ax.boxplot(boxplot_data, positions=range(len(boxplot_data)), flierprops=flierprops)
    ax.axhline(HIGHSCORE, color='b')
    ax.axhline(NEG_BEST, color='r')
    ax.set_xticks(range(len(boxplot_data)))
    ax.set_xticklabels(range(0,len(boxplot_data),1))
    ax.set_ylim(0,1)
    ax.set_title('Combined Score vs Generations')
    ax.set_xlabel('Generation')
    ax.set_ylabel('Combined Score')
    ax.yaxis.grid(True, linestyle='--', which='major', color='grey', alpha=0.7)

def exponential_func(x, A, k, c):
    return c-A*np.exp(-k * x)

def plot_combined_score_v_generation_violin(ax, combined_scores, all_scores_df):
    
    all_scores_df['tmp'] = combined_scores
    all_scores_df = all_scores_df.dropna(subset=['tmp'])
    
    max_gen = int(all_scores_df['generation'].max())
    generations = np.arange(0, max_gen + 1)
    violin_data = [all_scores_df[all_scores_df['generation'] == gen]['tmp'] for gen in generations]
    
    # Create violin plots
    parts = ax.violinplot(violin_data, positions=generations, showmeans=False, showmedians=True)
    
    # Customizing the color of violin plots
    for pc in parts['bodies']:
        pc.set_facecolor('green')
        pc.set_edgecolor('black')
        pc.set_alpha(0.7)
    
    # Customizing the color of the median lines
    for partname in ('cbars', 'cmins', 'cmaxes'):
        vp = parts.get(partname)
        if vp:
            vp.set_edgecolor('tomato')
            vp.set_linewidth(0.5)
    
    vp = parts.get('cmedians')
    if vp:
        vp.set_edgecolor('tomato')
        vp.set_linewidth(2.0)
    
    # Fit the data to the exponential function
    #weights = np.linspace(1, 0.1, len(generations))
    weights = np.ones(len(generations))
    #weights[:1] = 0.3
    mean_scores = [np.mean(data) for data in violin_data]
    popt, pcov = curve_fit(exponential_func, generations, mean_scores, p0=(1, 0.1, 0.7), sigma=weights, maxfev=2000)
    
    # Plot the fitted curve
    fitted_curve = exponential_func(generations, *popt)
    ax.plot(generations, fitted_curve, 'r--', label=f'Fit: A*exp(-kt) - c\nA={popt[0]:.2f}, k={popt[1]:.2f}, c={popt[2]:.2f}')
    
    ax.axhline(HIGHSCORE, color='b', label='High Score', alpha = 0.5)
    ax.axhline(NEG_BEST, color='r', label='Negative Best', alpha = 0.5)

    # Select every second generation for ticks
    every_second_generation = generations[::2]
    ax.set_xticks(every_second_generation)
    ax.set_xticklabels(every_second_generation)

    ax.set_ylim(0, 1)
    ax.set_title('Combined Score vs Generations')
    ax.set_xlabel('Generation')
    ax.set_ylabel('Combined Score')
    ax.yaxis.grid(True, linestyle='--', which='major', color='grey', alpha=0.7)
    ax.legend(loc='lower right')

def plot_score_v_generation_violin(ax, score_type, all_scores_df):
    all_scores_df = all_scores_df.dropna(subset=[score_type])
    
    max_gen = int(all_scores_df['generation'].max())
    generations = np.arange(0, max_gen + 1)
    violin_data = [all_scores_df[all_scores_df['generation'] == gen][score_type] for gen in generations]
    
    # Create violin plots
    parts = ax.violinplot(violin_data, positions=generations, showmeans=False, showmedians=True)
    
    # Customizing the color of violin plots
    for pc in parts['bodies']:
        pc.set_facecolor('green')
        pc.set_edgecolor('black')
        pc.set_alpha(0.7)
    
    # Customizing the color of the median lines
    for partname in ('cbars', 'cmins', 'cmaxes'):
        vp = parts.get(partname)
        if vp:
            vp.set_edgecolor('tomato')
            vp.set_linewidth(0.5)
    
    vp = parts.get('cmedians')
    if vp:
        vp.set_edgecolor('tomato')
        vp.set_linewidth(2.0)
    
    # Fit the data to the exponential function
    # weights = np.ones(len(generations))
    # mean_scores = [np.mean(data) for data in violin_data]
    # popt, pcov = curve_fit(exponential_func, generations, mean_scores, p0=(1, 0.1, 0.7), sigma=weights, maxfev=10000)
    
    # # Plot the fitted curve
    # fitted_curve = exponential_func(generations, *popt)
    # ax.plot(generations, fitted_curve, 'r--', label=f'Fit: A*exp(-kt) - c\nA={popt[0]:.2f}, k={popt[1]:.2f}, c={popt[2]:.2f}')
    
    # ax.axhline(HIGHSCORE, color='b', label='High Score', alpha=0.5)
    # ax.axhline(NEG_BEST, color='r', label='Negative Best', alpha=0.5)

    # Select every fourth generation for ticks
    every_fourth_generation = generations[::4]
    ax.set_xticks(every_fourth_generation)
    ax.set_xticklabels(every_fourth_generation)

    ax.set_title(f'{score_type.replace("_", " ").title()} vs Generations')
    ax.set_xlabel('Generation')
    ax.set_ylabel(f'{score_type.replace("_", " ").title()}')
    ax.yaxis.grid(True, linestyle='--', which='major', color='grey', alpha=0.7)

    
def plot_mutations_v_generation(ax, all_scores_df,  mut_min, mut_max):
    
    all_scores_df = all_scores_df.dropna(subset=['mutations'])
    
    max_gen = int(all_scores_df['generation'].max())
    boxplot_data = [all_scores_df[all_scores_df['generation'] == gen]['mutations'] for gen in range(0,max_gen+1,1)]
    # Define properties for the outliers
    flierprops = dict(marker='o', markerfacecolor='red', markersize=3, linestyle='none')
    
    ax.boxplot(boxplot_data, positions=range(len(boxplot_data)), flierprops=flierprops)
    ax.axhline(len(DESIGN.split(",")), color='r')
    ax.set_xticks(range(len(boxplot_data)))
    ax.set_xticklabels(range(0,len(boxplot_data),1))
    ax.set_ylim(mut_min,mut_max)
    ax.set_title('Mutations vs Generations')
    ax.set_xlabel('Generation')
    ax.set_ylabel('Number of Mutations')
    ax.yaxis.grid(True, linestyle='--', which='major', color='grey', alpha=0.7)

def plot_mutations_v_generation_violin(ax, all_scores_df, mut_min, mut_max):
    
    all_scores_df = all_scores_df.dropna(subset=['mutations'])
    
    max_gen = int(all_scores_df['generation'].max())
    generations = np.arange(0, max_gen + 1)
    violin_data = [all_scores_df[all_scores_df['generation'] == gen]['mutations'] for gen in generations]
    
    # Create violin plots
    parts = ax.violinplot(violin_data, positions=generations, showmeans=False, showmedians=True)
    
    # Customizing the color of violin plots
    for pc in parts['bodies']:
        pc.set_facecolor('green')
        pc.set_edgecolor('black')
        pc.set_alpha(0.7)
    
    # Customizing the color of the median lines
    for partname in ('cbars', 'cmins', 'cmaxes'):
        vp = parts.get(partname)
        if vp:
            vp.set_edgecolor('tomato')
            vp.set_linewidth(0.5)
    
    vp = parts.get('cmedians')
    if vp:
        vp.set_edgecolor('tomato')
        vp.set_linewidth(2.0)

    # Fit the data to the exponential function
    # mean_mutations = [np.mean(data) for data in violin_data]
    # weights = np.ones(len(generations))  # Uniform weights, adjust as needed
    # popt, pcov = curve_fit(exponential_func, generations, mean_mutations, p0=(1, 0.1, 0.7), sigma=weights, maxfev=2000)
    
    # # Plot the fitted curve
    # fitted_curve = exponential_func(generations, *popt)
    # ax.plot(generations, fitted_curve, 'r--', label=f'Fit: A*exp(-kt) + c\nA={popt[0]:.2f}, k={popt[1]:.2f}, c={popt[2]:.2f}')
    
    
    ax.axhline(len(DESIGN.split(",")), color='r', label='Design Length')

    # Select every second generation for ticks
    every_second_generation = generations[::2]
    ax.set_xticks(every_second_generation)
    ax.set_xticklabels(every_second_generation)
    
    ax.set_ylim(mut_min, mut_max)
    ax.set_title('Mutations vs Generations')
    ax.set_xlabel('Generation')
    ax.set_ylabel('Number of Mutations')
    ax.yaxis.grid(True, linestyle='--', which='major', color='grey', alpha=0.7)
    ax.legend(loc='lower right')

def plot_delta_scores():
    all_scores_df = pd.read_csv(ALL_SCORES_CSV)
    all_scores_df = all_scores_df.dropna(subset=['total_score'])
    
    # Calculate combined scores using normalized scores
    _, _, _, _, combined_scores = normalize_scores(all_scores_df, print_norm=True, norm_all=True)
    
    # Add combined scores to the DataFrame
    all_scores_df['combined_score'] = combined_scores

    # Calculate delta scores
    all_scores_df['delta_combined'] = all_scores_df.apply(lambda row: row['combined_score'] - all_scores_df.loc[all_scores_df['index'] == int(float(row['parent_index'])), 'combined_score'].values[0] if row['parent_index'] != "Parent" else 0, axis=1)
    all_scores_df['delta_total'] = all_scores_df.apply(lambda row: row['total_score'] - all_scores_df.loc[all_scores_df['index'] == int(float(row['parent_index'])), 'total_score'].values[0] if row['parent_index'] != "Parent" else 0, axis=1)
    all_scores_df['delta_interface'] = all_scores_df.apply(lambda row: row['interface_score'] - all_scores_df.loc[all_scores_df['index'] == int(float(row['parent_index'])), 'interface_score'].values[0] if row['parent_index'] != "Parent" else 0, axis=1)
    all_scores_df['delta_efield'] = all_scores_df.apply(lambda row: row['efield_score'] - all_scores_df.loc[all_scores_df['index'] == int(float(row['parent_index'])), 'efield_score'].values[0] if row['parent_index'] != "Parent" else 0, axis=1)

    fig, axs = plt.subplots(2, 2, figsize=(15, 10))

    def plot_violin(ax, delta_scores, title, all_scores_df):
        all_scores_df['tmp'] = delta_scores
        all_scores_df = all_scores_df.dropna(subset=['tmp'])

        max_gen = int(all_scores_df['generation'].max())
        generations = np.arange(0, max_gen + 1)
        violin_data = [all_scores_df[all_scores_df['generation'] == gen]['tmp'] for gen in generations]

        # Create violin plots
        parts = ax.violinplot(violin_data, positions=generations, showmeans=False, showmedians=True)

        # Customizing the color of violin plots
        for pc in parts['bodies']:
            pc.set_facecolor('green')
            pc.set_edgecolor('black')
            pc.set_alpha(0.7)

        # Customizing the color of the median lines
        for partname in ('cbars', 'cmins', 'cmaxes'):
            vp = parts.get(partname)
            if vp:
                vp.set_edgecolor('tomato')
                vp.set_linewidth(0.5)

        vp = parts.get('cmedians')
        if vp:
            vp.set_edgecolor('tomato')
            vp.set_linewidth(2.0)

        ax.set_title(title)
        ax.set_xlabel('Generation')
        ax.set_ylabel('Delta Score')
        ax.yaxis.grid(True, linestyle='--', which='major', color='grey', alpha=0.7)

    plot_violin(axs[0, 0], all_scores_df['delta_combined'], 'Delta Combined Score vs Generations', all_scores_df)
    plot_violin(axs[0, 1], all_scores_df['delta_total'], 'Delta Total Score vs Generations', all_scores_df)
    plot_violin(axs[1, 0], all_scores_df['delta_interface'], 'Delta Interface Score vs Generations', all_scores_df)
    plot_violin(axs[1, 1], all_scores_df['delta_efield'], 'Delta Efield Score vs Generations', all_scores_df)

    plt.tight_layout()
    plt.show()


def plot_tree_lin(leaf_nodes=None):
    all_scores_df = pd.read_csv(ALL_SCORES_CSV)
    _, _, _, _, combined_potentials = normalize_scores(all_scores_df, print_norm=False, norm_all=False, extension="potential")

    max_gen = int(all_scores_df['generation'].max())

    G = nx.DiGraph()

    for idx, row in all_scores_df.iterrows():
        G.add_node(idx, sequence=row['sequence'], interface_potential=row['interface_potential'], gen=int(row['generation']))
        if row['parent_index'] != "Parent":
            parent_idx = int(float(row['parent_index']))
            parent_sequence = all_scores_df.loc[all_scores_df.index == parent_idx, 'sequence'].values[0]
            current_sequence = row['sequence']
            # Calculate Hamming distance
            distance = hamming_distance(parent_sequence, current_sequence)
            # Add edge with Hamming distance as an attribute
            G.add_edge(parent_idx, idx, hamming_distance=distance)

    if leaf_nodes is not None:
        subgraph_nodes = set()
        for leaf in leaf_nodes:
            subgraph_nodes.update(nx.ancestors(G, leaf))
            subgraph_nodes.add(leaf)
        G = G.subgraph(subgraph_nodes)

    G_undirected = G.to_undirected()

    # Find connected components
    connected_components = list(nx.connected_components(G_undirected))

    largest_component = max(connected_components, key=len)
    # Create a subgraph of G using only the nodes in the largest component
    G_largest = G.subgraph(largest_component)

    def set_node_positions(G, node, pos, x, y, counts):
        pos[node] = (x, y)
        neighbors = list(G.successors(node))
        next_y = y - counts[node] / 2
        for neighbor in neighbors:
            set_node_positions(G, neighbor, pos, x + 1, next_y + counts[neighbor] / 2, counts)
            next_y += counts[neighbor]

    def count_descendants(G, node, counts):
        neighbors = list(G.successors(node))
        count = 1
        for neighbor in neighbors:
            count += count_descendants(G, neighbor, counts)
        counts[node] = count
        return count

    counts = {}
    root_node = list(largest_component)[0]
    count_descendants(G_largest, root_node, counts)

    pos = {}
    set_node_positions(G_largest, root_node, pos, 0, 0, counts)
    y_values = [y for x, y in pos.values()]
    y_span = max(y_values) - min(y_values)
    print(y_span)

    colors = combined_potentials
    colors[0] = np.nan
    normed_colors = [(x - np.nanmin(colors[1:])) / (np.nanmax(colors[1:]) - np.nanmin(colors[1:])) for x in colors]
    normed_colors = np.nan_to_num(normed_colors, nan=0)
    normed_colors = normed_colors**2

    # Convert positions to polar coordinates
    polar_pos = {node: ((x / (max(pos.values(), key=lambda p: p[0])[0] - min(pos.values(), key=lambda p: p[0])[0])) * 2 * np.pi, y) for node, (x, y) in pos.items()}

    # Convert polar coordinates to Cartesian coordinates for plotting
    cartesian_pos = {node: (radius * np.cos(angle), radius * np.sin(angle)) for node, (radius, angle) in polar_pos.items()}

    fig, ax = plt.subplots(figsize=(10, 10), dpi=300)

    # Draw the graph with the positions set
    for start, end in G_largest.edges():
        color = plt.cm.coolwarm_r(normed_colors[end])
        if float(normed_colors[end]) == 0.0:
            color = [0., 0., 0., 1.]
        linewidth = 0.1 + 2 * normed_colors[end] * 0.01

        x0, y0 = cartesian_pos[start]
        x1, y1 = cartesian_pos[end]
        ax.plot([x0, x1], [y0, y1], color=color, linewidth=linewidth)

    # Adjust axis labels and ticks for the swapped axes
    ax.axis('on')
    ax.set_title("Colored by Potential")
    ax.set_xlabel("X Coordinate")
    ax.set_ylabel("Y Coordinate")
    ax.set_yticks([])
    ax.set_xticks([])
    ax.axis('equal')
    ax.grid(False)
    plt.show()

def plot_tree_nx_all():
    PARENT = '/net/bs-gridfs/export/grid/scratch/lmerlicek/design/Input/1ohp.pdb'
    from networkx.drawing.nx_agraph import graphviz_layout

    all_scores_df = pd.read_csv(ALL_SCORES_CSV)
    _, _, interface_potentials, _, combined_potentials = normalize_scores(all_scores_df, print_norm=False, norm_all=False, extension="potential")
    all_scores_df["interface_potential"] = interface_potentials
    G = nx.DiGraph()

    for _, row in all_scores_df.iterrows():
        index = int(float(row['index'])) + 1
        if not isinstance(row['sequence'], str):
            continue
        G.add_node(index, sequence=row['sequence'], interface_potential=row['interface_potential'], gen=int(row['generation']) + 1)
        if row['parent_index'] != "Parent":
            parent_idx = int(float(row['parent_index'])) + 1
            parent_sequence = all_scores_df.loc[all_scores_df.index == parent_idx - 1, 'sequence'].values[0]
            current_sequence = row['sequence']
            # Calculate Hamming distance
            distance = hamming_distance(parent_sequence, current_sequence)
            # Add edge with Hamming distance as an attribute
            G.add_edge(parent_idx, index, hamming_distance=distance)

    G_undirected = G.to_undirected()
    
    # Create a new root node
    G.add_node(0, sequence='root', interface_potential=0, gen=0)
    
    # Connect the new root node to all nodes of generation 1
    for node in G.nodes:
        if G.nodes[node]['gen'] == 1:
            parent_sequence = extract_sequence_from_pdb(PARENT)
            current_sequence = G.nodes[node]['sequence']
            distance = hamming_distance(parent_sequence, current_sequence)
            G.add_edge(0, node, hamming_distance=distance)

    # Use graphviz_layout to get the positions for a circular layout
    pos = graphviz_layout(G, prog="twopi", args="")

    # Normalize scores from 0 to 1
    scores = {node: all_scores_df.loc[all_scores_df['index'] == int(node)-1, 'interface_score'].values[0] for node in G.nodes if node != 0}
    min_score = min(scores.values())
    max_score = max(scores.values())
    normalized_scores = {node: (score - min_score) / (max_score - min_score) for node, score in scores.items()}

    # Get node colors based on index
    node_colors = [plt.cm.viridis(int(node) / len(G.nodes)) for node in G.nodes]

    # Mark generation 0 nodes with red
    gen_0_nodes = [node for node in G.nodes if G.nodes[node]['gen'] == 0]
    for node in gen_0_nodes:
        node_colors[list(G.nodes).index(node)] = 'red'

    # Normalize Hamming distances for edge colors
    hamming_distances = [G.edges[edge]['hamming_distance'] for edge in G.edges]
    # Normalize the Hamming distances
    min_hamming = min(hamming_distances)
    max_hamming = max(hamming_distances)
    normalized_hamming = [(dist - min_hamming) / (max_hamming - min_hamming) for dist in hamming_distances]

    # Plot the graph
    fig, ax = plt.subplots(figsize=(10, 10), dpi=300)
    nx.draw_networkx_nodes(G, pos, ax=ax, node_size=5, node_color=node_colors, linewidths=0.01)
    
    # Draw edges with custom color based on normalized Hamming distance
    edge_colors = ['blue' if norm_dist == 0 else plt.cm.RdYlGn(norm_dist) for norm_dist in normalized_hamming]
    nx.draw_networkx_edges(G, pos, ax=ax, width=0.2, edge_color=edge_colors, style='-', arrows=False)

    # Create a colorbar as a legend for Hamming distances
    sm = plt.cm.ScalarMappable(cmap=plt.cm.RdYlGn, norm=plt.Normalize(vmin=min_hamming, vmax=max_hamming))
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax)
    cbar.set_label('Hamming Distance')

    ax.set_title("Colored by Index, Gen 0 in Red, Edges by Hamming Distance")
    ax.axis("equal")
    plt.show()

def calculate_rank_order(matrix):
    # Calculate the occurrence frequency of each amino acid in each column
    unique, counts = np.unique(matrix, return_counts=True)
    frequencies = dict(zip(unique, counts))
    
    # Sort amino acids in each column by their frequency, then alphabetically
    sorted_amino_acids = sorted(frequencies.items(), key=lambda x: (x[1], -ord(x[0])), reverse=True)
    
    # Assign rank order based on sorted position
    rank_order = {amino_acid: rank for rank, (amino_acid, _) in enumerate(sorted_amino_acids, start=1)}
    
    # Replace amino acids with their rank order
    rank_matrix = np.vectorize(rank_order.get)(matrix)
    
    return rank_matrix

def seq_to_rank_order_matrix(sequences):
    # Convert sequences to a 2D numpy array (matrix) of characters
    matrix = np.array([list(seq) for seq in sequences])
    
    # Initialize an empty matrix to store the rank order numbers
    rank_order_matrix = np.zeros(matrix.shape, dtype=int)
    
    # Calculate rank order for each column
    for i in range(matrix.shape[1]):  # Iterate over columns
        column = matrix[:, i]
        rank_order_matrix[:, i] = calculate_rank_order(column)
    
    return rank_order_matrix

def seq_to_numeric(seq):
    # Define a mapping for all 20 standard amino acids plus 'X' for unknown
    mapping = {
        'A': 1,  'C': 2,  'D': 3,  'E': 4,
        'F': 5,  'G': 6,  'H': 7,  'I': 8,
        'K': 9,  'L': 10, 'M': 11, 'N': 12,
        'P': 13, 'Q': 14, 'R': 15, 'S': 16,
        'T': 17, 'V': 18, 'W': 19, 'Y': 20,
        'X': 0   # 'X' for any unknown or non-standard amino acid
    }
    numeric_seq = [mapping[char] for char in seq]
    return numeric_seq

def plot_pca_umap():
    
    all_scores_df = pd.read_csv(ALL_SCORES_CSV)

    all_scores_df = all_scores_df.dropna(subset=['total_score'])
    all_scores_df = all_scores_df.dropna(subset=['catalytic_score'])
    all_scores_df = all_scores_df.dropna(subset=['interface_score'])
    
    numeric_seqs = seq_to_rank_order_matrix(all_scores_df['sequence'].tolist())
    
    # Perform PCA
    pca = PCA(n_components=2)
    pca_result = pca.fit_transform(numeric_seqs)

    pca3 = PCA(n_components=3)
    pca_result3 = pca3.fit_transform(numeric_seqs)

    # Analyze PCA loadings for PC1
    # loadings = pca.components_.T[:, 0]  # Loadings for PC1
    # plt.figure(figsize=(10, 4))
    # plt.bar(range(len(loadings)), loadings)
    # plt.title('PCA Loadings for PC1')
    # plt.xlabel('Sequence Position')
    # plt.ylabel('Loading Value')
    # plt.show()

    # Perform UMAP
    reducer = umap.UMAP()
    umap_result = reducer.fit_transform(numeric_seqs)

    # Create a figure and a 2x2 grid of subplots
    fig, axs = plt.subplots(3, 1, figsize=(12, 24))  # Adjust the figure size as needed

    # Define a base font size
    base_font_size = 10  # Adjust here

    # Plot UMAP Interface score
    axs[0].scatter(umap_result[:, 0], umap_result[:, 1], c=all_scores_df['interface_score'], cmap='viridis', alpha=0.6, s=1)
    cbar = fig.colorbar(axs[0].collections[0], ax=axs[0], label='Interface Score')
    axs[0].set_title('UMAP of Sequences - Interface score', fontsize=base_font_size * 2)
    axs[0].set_xlabel('UMAP1', fontsize=base_font_size * 2)
    axs[0].set_ylabel('UMAP2', fontsize=base_font_size * 2)
    cbar.set_label('Interface Score', size=base_font_size * 2)

    # Filter the DataFrame to include only rows where 'total_score' is <= -340
    filtered_df = all_scores_df[all_scores_df['total_score'] <= -340]
    filtered_umap_result = umap_result[all_scores_df['total_score'] <= -340]

    # Now plot using the filtered data
    axs[1].scatter(filtered_umap_result[:, 0], filtered_umap_result[:, 1], c=filtered_df['total_score'], cmap='viridis', alpha=0.6, s=1)
    cbar = fig.colorbar(axs[1].collections[0], ax=axs[1], label='Total Score')
    axs[1].set_title('UMAP of Sequences - Total score', fontsize=base_font_size * 2)
    axs[1].set_xlabel('UMAP1', fontsize=base_font_size * 2)
    axs[1].set_ylabel('UMAP2', fontsize=base_font_size * 2)
    cbar.set_label('Total Score', size=base_font_size * 2)

    # Plot UMAP with 'index' as the color
    axs[2].scatter(umap_result[:, 0], umap_result[:, 1], c=all_scores_df['index'], cmap='viridis', alpha=0.6, s=1)
    cbar = fig.colorbar(axs[2].collections[0], ax=axs[2], label='Generation')
    axs[2].set_title('UMAP of Sequences - Generation', fontsize=base_font_size * 2)
    axs[2].set_xlabel('UMAP1', fontsize=base_font_size * 2)
    axs[2].set_ylabel('UMAP2', fontsize=base_font_size * 2)
    cbar.set_label('Generation', size=base_font_size * 2)

    plt.tight_layout()
    plt.show()



def plot_esm_umap():

    #ESM embeddings and UMAP
    def prepare_data(sequences):
        """ Convert a list of protein sequences to the model's input format. """
        batch_tokens = []
        for seq in sequences:
            tokens = torch.tensor([alphabet.encode(seq)], dtype=torch.long)
            batch_tokens.append(tokens)
        return torch.cat(batch_tokens)

    # 1. Load ESM model
    model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    model.eval()

    # Load and preprocess data
    all_scores_df = pd.read_csv(ALL_SCORES_CSV)
    all_scores_df.dropna(subset=['total_score', 'catalytic_score', 'interface_score', 'sequence'], inplace=True)

    # Extract sequences
    sequences = all_scores_df['sequence'].tolist()

    with torch.no_grad():
        tokens = prepare_data(sequences)
        results = model(tokens, repr_layers=[33])  # Specify the layer you want
        token_embeddings = results["representations"][33]

        # Mean pooling over positions
        sequence_embeddings = token_embeddings.mean(dim=1)
        
    embeddings_array = sequence_embeddings.cpu().numpy()

    # Perform UMAP
    reducer = umap.UMAP()
    umap_result = reducer.fit_transform(embeddings_array)

    # Create a figure and a 2x2 grid of subplots
    fig, axs = plt.subplots(3, 1, figsize=(12, 24))  # Adjust the figure size as needed

    # Define a base font size
    base_font_size = 10  # Adjust here

    # Plot UMAP Interface score
    axs[0].scatter(umap_result[:, 0], umap_result[:, 1], c=all_scores_df['interface_score'], cmap='viridis', alpha=0.6, s=1)
    cbar = fig.colorbar(axs[0].collections[0], ax=axs[0], label='Interface Score')
    axs[0].set_title('UMAP of Sequences - Interface score', fontsize=base_font_size * 2)
    axs[0].set_xlabel('UMAP1', fontsize=base_font_size * 2)
    axs[0].set_ylabel('UMAP2', fontsize=base_font_size * 2)
    cbar.set_label('Interface Score', size=base_font_size * 2)

    # Filter the DataFrame to include only rows where 'total_score' is <= -340
    filtered_df = all_scores_df[all_scores_df['total_score'] <= -340]
    filtered_umap_result = umap_result[all_scores_df['total_score'] <= -340]

    # Now plot using the filtered data
    axs[1].scatter(filtered_umap_result[:, 0], filtered_umap_result[:, 1], c=filtered_df['total_score'], cmap='viridis', alpha=0.6, s=1)
    cbar = fig.colorbar(axs[1].collections[0], ax=axs[1], label='Total Score')
    axs[1].set_title('UMAP of Sequences - Total score', fontsize=base_font_size * 2)
    axs[1].set_xlabel('UMAP1', fontsize=base_font_size * 2)
    axs[1].set_ylabel('UMAP2', fontsize=base_font_size * 2)
    cbar.set_label('Total Score', size=base_font_size * 2)

    # Plot UMAP with 'index' as the color
    axs[2].scatter(umap_result[:, 0], umap_result[:, 1], c=all_scores_df['index'], cmap='viridis', alpha=0.6, s=1)
    cbar = fig.colorbar(axs[2].collections[0], ax=axs[2], label='Generation')
    axs[2].set_title('UMAP of Sequences - Generation', fontsize=base_font_size * 2)
    axs[2].set_xlabel('UMAP1', fontsize=base_font_size * 2)
    axs[2].set_ylabel('UMAP2', fontsize=base_font_size * 2)
    cbar.set_label('Generation', size=base_font_size * 2)

    plt.tight_layout()
    plt.show()

def find_mutations(seq1, seq2):
    # Function to compare sequences and find mutation positions
    return [i for i, (a, b) in enumerate(zip(seq1, seq2)) if a != b]

def normalize_columnwise(matrix):
    min_vals = matrix.min(axis=0)
    max_vals = matrix.max(axis=0)
    # Avoid division by zero
    denom = np.where((max_vals - min_vals) == 0, 1, (max_vals - min_vals))
    normalized_matrix = (matrix - min_vals) / denom
    return normalized_matrix

def plot_mut_location():
    # Load the data
    all_scores_df = pd.read_csv(ALL_SCORES_CSV)

    all_scores_df = all_scores_df.dropna(subset=['sequence'])

    # Assuming the maximum length of sequences is 125
    max_length = 125
    max_generation = int(all_scores_df['generation'].max())

    # Initialize a matrix to hold mutation frequencies
    mutation_matrix = np.zeros((max_length, max_generation + 1))

    # Populate the mutation matrix
    for _, row in all_scores_df.iterrows():
        if pd.notnull(row['parent_index']) and row['parent_index'] != "Parent":  # Check if there's a valid parent
            parent_seq = all_scores_df.loc[all_scores_df['index'] == float(row['parent_index']), 'sequence'].values[0]
            mutations = find_mutations(row['sequence'], parent_seq)
            for pos in mutations:
                mutation_matrix[pos, int(row['generation'])] += 1

    # Normalize the mutation_matrix column-wise (i.e., each generation separately)
    normalized_mutation_matrix = normalize_columnwise(mutation_matrix)

   # Plotting
    fig, ax = plt.subplots(figsize=(10, 6))
    c = ax.imshow(normalized_mutation_matrix, aspect='auto', origin='lower', cmap='viridis', extent=[0, max_generation, 0, max_length])
    ax.set_xlabel('Generation')
    ax.set_ylabel('Position along AA chain')
    ax.set_title('Frequency of Mutation Over Generations')
    fig.colorbar(c, ax=ax, label='Normalized Frequency of Mutation')
    plt.show()


## Functions needed to run AIzyme algorithm

In [None]:
print("AIzyme Functions loaded!")
time.sleep(0.5)