In [None]:
# Import libraries
import json
import logging
import math
import networkx as nx
import numpy as np
import os
import pickle as pkl
import random
import re
import shutil
import statistics
import subprocess
import sys
import tempfile
import time
import warnings
import pandas as pd
#import mdtraj as md
from Bio import SeqIO
from Bio.PDB import PDBParser
from datetime import datetime
from IPython.display import display, clear_output
from matplotlib import pyplot as plt
from matplotlib.axes._axes import _log as matplotlib_axes_logger
from matplotlib.gridspec import GridSpec
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from scipy.stats import gmean, pearsonr
import glob

# Setting initial options
warnings.filterwarnings('ignore')
pd.options.mode.chained_assignment = None 
matplotlib_axes_logger.setLevel('INFO')

DESIGN_COUNT = {}

if GRID:
    USERNAME = os.getlogin()
#     FOLDER_HOME = f'{os.getcwd()}/{DESIGN_FOLDER}'
# if BLUEPEBBLE:
#     FOLDER_HOME = f'{os.getcwd()}/{DESIGN_FOLDER}'
# if BACKGROUND_JOB:
#     FOLDER_HOME = f'{os.getcwd()}/{DESIGN_FOLDER}'

FOLDER_HOME = f'{os.getcwd()}/{DESIGN_FOLDER}'

os.makedirs(FOLDER_HOME, exist_ok=True)
FOLDER_INPUT = f'{os.getcwd()}/Input'
if not os.path.isdir(FOLDER_INPUT): print("ERROR! Input folder missing!")
LOG_FILE = f'{FOLDER_HOME}.log'
ALL_SCORES_CSV = f'{FOLDER_HOME}/all_scores.csv'
VARIABLES_JSON  = f'{FOLDER_HOME}/variables.json'

# Configure logging file
log_format = '%(asctime)s - %(levelname)s - %(message)s'
date_format = '%Y-%m-%d %H:%M:%S'

# Remove all handlers associated with the root logger
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

# Basic configuration for logging to a file
logging.basicConfig(filename=LOG_FILE, level=logging.DEBUG, format=log_format, datefmt=date_format)
#logging.basicConfig(filename=LOG_FILE, level=logging.INFO, format=log_format, datefmt=date_format)

# Create a StreamHandler for console output
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(logging.Formatter(log_format, datefmt=date_format))

# Add the console handler to the root logger
logging.getLogger().addHandler(console_handler)

# main functions - running

In [2]:
def controller(RESET=False, EXPLORE=False, UNBLOCK_ALL=False, 
               PRINT_VAR=True, PLOT_DATA=True, 
               BLUEPEBBLE=False, GRID=True):
    
    # Main AI.zymes functions. Controls the whole design process

    # Startup, will only be executed once in the beginning
    setup_aizymes(RESET, EXPLORE) 
    
    # Check if Startup is done, if done, read in all_scores_df
    all_scores_df = startup_controller(UNBLOCK_ALL, 
                                       RESET,
                                       PRINT_VAR=PRINT_VAR, 
                                       PLOT_DATA=PLOT_DATA)
    
    while not os.path.exists(os.path.join(FOLDER_HOME, str(MAX_DESIGNS))):

        # Check how many jobs are currently running
        # **
        num_running_jobs = check_running_jobs()
        
        if num_running_jobs >= MAX_JOBS: 
            
            time.sleep(60)
            
        else:
                        
            # Update scores
            all_scores_df = update_scores(all_scores_df)
            
            # Check if parent designs are done, if not, start design
            parent_done, all_scores_df = start_parent_design(all_scores_df)

            if parent_done:
                
                # Boltzmann Selection
                selected_index = boltzmann_selection(all_scores_df)

                # Decide Fate of selected index
                if selected_index is not None:
                    all_scores_df = start_calculation(all_scores_df, selected_index)
    
        time.sleep(1)
        
    print(f"Stopped because {os.path.join(FOLDER_HOME, str(MAX_DESIGNS))} exists.")

def check_running_jobs():
    
    if GRID:
        jobs = subprocess.check_output(["qstat", "-u", USERNAME]).decode("utf-8").split("\n")
        jobs = [job for job in jobs if SUBMIT_PREFIX in job]
        return len(jobs)
        
    if BLUEPEBBLE:
        jobs = subprocess.check_output(["squeue","--me"]).decode("utf-8").split("\n")
        jobs = [job for job in jobs if SUBMIT_PREFIX in job]
        return len(jobs)
        
    if BACKGROUND_JOB:
        with open(f'{FOLDER_HOME}/n_running_jobs.dat', 'r'): jobs = int(f.read())
        return jobs
    
    if ABBIE_LOCAL:
        return 0

def update_potential(score_type:str, index:int, all_scores_df:pd.DataFrame):
    '''Creates a <score_type>_potential.dat file in FOLDER_HOME/<index> 
    If latest score comes from Rosetta Relax - then the score for variant <index>
    will be added to the <score_type>_potential.dat of the parent index 
    and the <score_type>_potential value of the dataframe for the parent 
    will be updated with the average of the parent and child scores. Returns 
    updated dataframe'''
    
    score = all_scores_df.at[index, f'{score_type}_score']

    score_taken_from = all_scores_df.at[index, 'score_taken_from']

    parent_index = int(all_scores_df.at[index, "parent_index"]) #parent index is stored as a string because can be "Parent"

    filename = f"{FOLDER_HOME}/{index}/{score_type}_potential.dat"

    parent_filename = f"{FOLDER_HOME}/{parent_index}/{score_type}_potential.dat"

    ## overwrites contents of filename - always overwrite the current index - only append for parent
    with open(filename, "w") as f: 
            f.write(str(score))

    all_scores_df.at[index, f'{score_type}_potential'] = score

    if score_taken_from == "Relax" and parent_index != "Parent":
        #Add new index scores to parent potential file (unless parent index is "Parent")

        ## appends to parent_filename
        with open(parent_filename, "a") as f: 
            f.write(f"\n{str(score)}")
        with open(parent_filename, "r") as f:
            potentials = f.readlines()
        
        all_scores_df.at[parent_index, f'{score_type}_potential'] = np.average([float(i) for i in potentials])

    all_scores_df = all_scores_df.dropna(subset=['index'])
    
    return all_scores_df
        
def update_scores(all_scores_df):
    
    for _, row in all_scores_df.iterrows():

        index = int(row['index'])
        parent_index = row['parent_index']
          
        # do NOT update score if score was taken from a relax file. Prevents repeated scoring!
        if row['score_taken_from'] == 'Relax': continue
            
            
        # default scorefile path
        score_file_path = f"{FOLDER_HOME}/{int(index)}/score_rosetta_relax.sc"
        
        # default pdb file path
        if row["design_method"] == "ProteinMPNN":
            pdb_path = f"{FOLDER_HOME}/{int(index)}/{WT}_Rosetta_Relax_{int(index)}.pdb"
        else:
            pdb_path = f"{FOLDER_HOME}/{int(index)}/{WT}_Rosetta_Design_{int(index)}.pdb"
        
        if not os.path.exists(score_file_path):

            # change scorefile path if run is a RosettaDesign and if score_rosetta_relax.sc does not exist
            if row['design_method'] == "RosettaDesign":
                score_file_path = f"{FOLDER_HOME}/{int(index)}/score_rosetta_design.sc" 

                # do NOT update score if score was taken from a design file. Prevents repeated scoring!
                if row['score_taken_from'] == 'Design': continue

        # do NOT update score if score_file_path does not exist. Usually means job is not done.
        if not os.access(score_file_path, os.R_OK) or not os.access(pdb_path, os.R_OK): continue

        #all_scores_df['score_taken_from'] = all_scores_df['score_taken_from'].astype('object')    
        if score_file_path == f"{FOLDER_HOME}/{int(index)}/score_rosetta_relax.sc":
            all_scores_df.at[index, 'score_taken_from'] = 'Relax'
        if score_file_path == f"{FOLDER_HOME}/{int(index)}/score_rosetta_design.sc":
            all_scores_df.at[index, 'score_taken_from'] = 'Design'
           
        with open(score_file_path, "r") as f: scores = f.readlines()
        
        if len(scores) < 3: continue # if the timing is bad, the score file is not fully written. Check if len(scores) > 2!
        
        headers = scores[1].split()
        scores  = scores[2].split()

        catalytic_score = 0.0
        interface_score = 0.0
        for idx_headers, header in enumerate(headers):
            if header == 'total_score':                total_score      = float(scores[idx_headers])
            if header == 'interface_delta_X':          interface_score += float(scores[idx_headers])
            if header in ['if_X_angle_constraint', 
                          'if_X_atom_pair_constraint', 
                          'if_X_dihedral_constraint']: interface_score -= float(scores[idx_headers])   
            #Use 6-3-2 weighting when calculating the catalytic score
            if header in ['atom_pair_constraint']:     catalytic_score += float(scores[idx_headers])       
            if header in ['angle_constraint']:         catalytic_score += float(scores[idx_headers])       
            if header in ['dihedral_constraint']:      catalytic_score += float(scores[idx_headers])  

        efield_score, efields_df = calc_efields_score(pdb_path)    #efields_df 

        #update_efieldsdf(index, efields_df)              

        # Update scores
        all_scores_df.at[index, 'total_score']     = total_score
        all_scores_df.at[index, 'interface_score'] = interface_score                
        all_scores_df.at[index, 'catalytic_score'] = catalytic_score
        all_scores_df.at[index, 'efield_score'] = efield_score
        
        # This is just for book keeping. AIzymes will always use the most up_to_date scores saved above
        if score_file_path == f"{FOLDER_HOME}/{int(index)}/score_rosetta_relax.sc":
            all_scores_df.at[index, 'relax_total_score']     = total_score
            all_scores_df.at[index, 'relax_interface_score'] = interface_score                
            all_scores_df.at[index, 'relax_catalytic_score'] = catalytic_score
            all_scores_df.at[index, 'relax_efield_score'] = efield_score
        if score_file_path == f"{FOLDER_HOME}/{int(index)}/score_rosetta_design.sc":
            all_scores_df.at[index, 'design_total_score']     = total_score
            all_scores_df.at[index, 'design_interface_score'] = interface_score                
            all_scores_df.at[index, 'design_catalytic_score'] = catalytic_score
            all_scores_df.at[index, 'design_efield_score'] = efield_score
        
        for score_type in ['total', 'interface', 'catalytic', 'efield']:     

            all_scores_df = update_potential(score_type = score_type,
                                             index= index, 
                                             all_scores_df = all_scores_df,)   

        logging.info(f"Updated scores and potentials of index {int(index)}.")
        if all_scores_df.at[index, 'score_taken_from'] == 'Relax':
            logging.info(f"Adjusted potentials of parent of index {int(index)}.")
        
        #unblock index if relaxed file exists
        if all_scores_df.at[int(index), "blocked"] == True:
            if f"{WT}_Rosetta_Relax_{int(index)}.pdb" in os.listdir(os.path.join(FOLDER_HOME, str(int(index)))):
                all_scores_df.at[index, "blocked"] = False
                logging.debug(f"Unblocked index {int(index)}.")

     
        # Update catalytic residues
        if not row["design_method"] == "ProteinMPNN":
            all_scores_df = save_cat_res_into_all_scores_df(all_scores_df, index, pdb_path)

        # Update sequence and mutations
        reference_sequence = extract_sequence_from_pdb(f"{FOLDER_INPUT}/{WT}.pdb")
        current_sequence = extract_sequence_from_pdb(pdb_path)
        mutations = sum(1 for a, b in zip(current_sequence, reference_sequence) if a != b)
        all_scores_df['sequence'] = all_scores_df['sequence'].astype('object')
        all_scores_df.at[index, 'sequence']  = current_sequence
        all_scores_df.at[index, 'mutations'] = int(mutations)
        
    save_all_scores_df(all_scores_df)

    return all_scores_df

def normalize_scores(unblocked_all_scores_df, print_norm=False, norm_all=False, extension="score"):
    
    def neg_norm_array(array, score_type):

        if len(array) > 1:  ##check that it's not only one value
            
            array    = -array
            
            if norm_all:
                if print_norm:
                    print(score_type,NORM[score_type],end=" ")
                array = (array-NORM[score_type][0])/(NORM[score_type][1]-NORM[score_type][0])
                array[array < 0] = 0.0
                if np.any(array > 1.0): print("\nNORMALIZATION ERROR!",score_type,"has a value >1!") 
            else:
                if print_norm:
                    print(score_type,[np.nanmin(array),np.nanmax(array)],end=" ")
                array = (array-np.nanmin(array))/(np.nanmax(array)-np.nanmin(array))
                array[array < 0] = 0.0
            return array
        
        else:
            # do not normalize if array only contains 1 value
            return [1]
         
    catalytic_scores    = unblocked_all_scores_df[f"catalytic_{extension}"]
    catalytic_scores    = neg_norm_array(catalytic_scores, f"catalytic_{extension}")   
    
    total_scores        = unblocked_all_scores_df[f"total_{extension}"]
    total_scores        = neg_norm_array(total_scores, f"total_{extension}")   
    
    interface_scores    = unblocked_all_scores_df[f"interface_{extension}"]
    interface_scores    = neg_norm_array(interface_scores, f"interface_{extension}")  
    
    efield_scores    = unblocked_all_scores_df[f"efield_{extension}"]   ### to be worked on
    efield_scores    = neg_norm_array(-1*efield_scores, f"efield_{extension}")   ### to be worked on, with MINUS here
    
    if len(total_scores) == 0:
        combined_scores = []
    else:
        combined_scores     = np.stack((catalytic_scores, total_scores, interface_scores, efield_scores))  ### add electrostatic score here  ###Adjust nromalization
        combined_scores     = gmean(combined_scores, axis=0)
          
    if print_norm:
        if combined_scores != []:
            print("HIGHSCORE:","{:.2f}".format(np.amax(combined_scores)),end=" ")
            print("Designs:",len(combined_scores),end=" ")
            PARENTS = [i for i in os.listdir(f'{FOLDER_HOME}/{FOLDER_PARENT}') if i[-4:] == ".pdb"]
            print("Parents:",len(PARENTS))
        
    return catalytic_scores, total_scores, interface_scores, combined_scores
        
def boltzmann_selection(all_scores_df):

    all_scores_df = pd.read_csv(ALL_SCORES_CSV)
    all_scores_df = all_scores_df[all_scores_df["blocked"] == False] # Remove blocked indices
       
    all_scores_df = all_scores_df.dropna(subset=['total_score'])     # Remove indices without score (design running)
    
    # AB FORGOT WHAT THIS IS DOING BUT ITS PROBABLY IMPORTANT! CHECK WHERE filtered_indices is used in original code
    # relaxed_indices = all_scores_df[all_scores_df['score_taken_from'] == 'Relax']
    # relaxed_indices = [str(i) for i in relaxed_indices.index]
    # parent_indices  = set(all_scores_df['parent_index'].values)
    # filtered_indices = [index for index in relaxed_indices  if index not in parent_indices]
    # filtered_indices = [index for index in filtered_indices if index not in blocked_df]

    # If there are structures that ran through RosettaRelax but have never been used for design, run 1 design
    # NOT NEEDED! THIS IS BEEING TAKEN CARE OF IN start_parent_design
    # relaxed_all_scores_df = unblocked_all_scores_df[unblocked_all_scores_df['score_taken_from'] == 'Relax']
    #if len(relaxed_all_scores_df["index"]) >= 1:
    #    selected_index = int(relaxed_all_scores_df.index[0])
    #    logging.info(f"{selected_index} selected because its relaxed but nothing was designed from it.")
    #    return selected_index
                
    # Do Boltzmann Selection if some scores exist
    _, _, _, combined_potentials = normalize_scores(all_scores_df, norm_all=False, \
                                                    extension="potential", print_norm = False) 
        
    if len(combined_potentials) > 0:
        
        if isinstance(KBT_BOLTZMANN, (float, int)):
            kbt_boltzmann = KBT_BOLTZMANN
        elif len(KBT_BOLTZMANN) > 2:
            logging.error(f"KBT_BOLTZMANN must either be a single value or list of two values.")
            logging.error(f"KBT_BOLTZMANN is {KBT_BOLTZMANN}")
        else:
            # Ramp down kbT_boltzmann over time (i.e., with increaseing indices)
            kbt_boltzmann = KBT_BOLTZMANN[0] * 10**(-KBT_BOLTZMANN[1]*all_scores_df['index'].max())

        boltzmann_factors = np.exp(combined_potentials / (kbt_boltzmann))  
        probabilities = boltzmann_factors / sum(boltzmann_factors)        
        
        #selected_index = int(np.random.choice(np.array(all_scores_df["index"].tolist()), p=probabilities))
        if len(all_scores_df["index"] > 0):
            selected_index = int(np.random.choice(all_scores_df["index"].to_numpy(), p=probabilities))
        else:
            return None
        
    else:
        
        selected_index = 0

    return selected_index

def start_parent_design(all_scores_df):

    number_of_indices = len(all_scores_df)
    PARENTS = [i for i in os.listdir(f'{FOLDER_HOME}/{FOLDER_PARENT}') if i[-4:] == ".pdb"]
    
    if number_of_indices < N_PARENT_JOBS * len(PARENTS):
        
        parent_done = False
        selected_index = int(number_of_indices / N_PARENT_JOBS)
        parent = PARENTS[selected_index][:-4]
        
        new_index, all_scores_df = create_new_index(parent_index="Parent", all_scores_df=all_scores_df)
        all_scores_df['design_method'] = all_scores_df['design_method'].astype('object') #?
        all_scores_df.at[new_index, 'design_method'] = "RosettaDesign"
        all_scores_df['luca'] = all_scores_df['luca'].astype('object') #?
        all_scores_df.at[new_index, 'luca'] = parent
        
        #Add cat res to new entry
        all_scores_df = save_cat_res_into_all_scores_df(all_scores_df, new_index, 
                                                        f'{FOLDER_HOME}/{FOLDER_PARENT}/{PARENTS[selected_index]}',
                                                        from_parent_struct=True)
        
        run_RosettaDesign(parent_index=parent, new_index=new_index, all_scores_df=all_scores_df, parent_done=parent_done)                      
        
    else:
        
        parent_done = True
    
    save_all_scores_df(all_scores_df)
    return parent_done, all_scores_df

# Decides what to do with selected index
def start_calculation(all_scores_df, selected_index):
    
    logging.debug(f"Starting new calculation for index {selected_index}.")

    blocked = False
    if all_scores_df.at[selected_index, "blocked"] == True:
        blocked = True
        
    relaxed = False
    if f"{WT}_Rosetta_Relax_{selected_index}.pdb" in os.listdir(os.path.join(FOLDER_HOME, str(selected_index))):
        relaxed = True

    # Check if ESMfold_Rosetta_Relax is done
    if relaxed:

        # ESMfold_Rosetta_Relax is done, create a new index
        new_index, all_scores_df = create_new_index(parent_index=selected_index, all_scores_df=all_scores_df)

        #####
        # Here, we can add an AI to decide on the next steps
        #####

        # Run Rosetta_Design with new_index
        if random.random() < ProteinMPNN_PROB:  
            all_scores_df.at[new_index, 'design_method'] = "ProteinMPNN"
            run_ProteinMPNN(parent_index=selected_index, new_index=new_index, all_scores_df=all_scores_df) 
        else:                    
            all_scores_df.at[new_index, 'design_method'] = "RosettaDesign"
            run_RosettaDesign(parent_index=selected_index, new_index=new_index, all_scores_df=all_scores_df) 
        save_all_scores_df(all_scores_df)
        
    else:
        # ESMfold_Rosetta_Relax is not done, check if Index is blocked
        if blocked:
            # Blocked --> ESMfold_Rosetta_Relax is still running, do not do antyting. This shouldn't happen!
            logging.error(f"Index {selected_index} is being worked on. Skipping index.")
            logging.error(f"Note: This should not happen! Check blocking and Boltzman selection.")
        else:
            # Not blocked --> submit ESMfold_Rosetta_Relax and block index
            logging.info(f"Index {selected_index} has no relaxed structure, starting ESMfold_Rosetta_Relax.")
            submitted_status = run_ESMfold_RosettaRelax(index=selected_index, all_scores_df=all_scores_df, \
                                                        OnlyRelax=True, EXPLORE=EXPLORE)
            
            #Check if submission executed correctly before blocking
            if submitted_status:
                all_scores_df.at[selected_index, "blocked"] = True

        save_all_scores_df(all_scores_df)
        
    return all_scores_df

def create_new_index(parent_index, all_scores_df):
    
    # Create a new line with the next index and parent_index
    new_index = len(all_scores_df)
    
    # Append the new line to the DataFrame and save to  all_scores_df.csv
    if isinstance(KBT_BOLTZMANN, (float, int)):
        kbt_boltzmann = KBT_BOLTZMANN
    else:
        if len(KBT_BOLTZMANN) == 2:
            kbt_boltzmann = KBT_BOLTZMANN[0] * 10 ** (- KBT_BOLTZMANN[1] * new_index)
    if parent_index == 'Parent':
        generation = 0
        luca = "x"
    else:
        generation = all_scores_df['generation'][int(parent_index)]+1
        luca       = all_scores_df['luca'][int(parent_index)]
        
    # all_scores_df = all_scores_df.append({'index': new_index, 
    #                                       'parent_index': parent_index,
    #                                       'kbt_boltzmann': kbt_boltzmann,
    #                                       'generation': generation,
    #                                       'luca': luca,
    #                                       'blocked': False,
    #                                       }, ignore_index=True)
    
    # why do this in two steps?
    new_index_df = pd.DataFrame({'index': int(new_index), 
                                'parent_index': parent_index,
                                'kbt_boltzmann': kbt_boltzmann,
                                'generation': generation,
                                'luca': luca,
                                'blocked': False,
                                }, index = [0])
    all_scores_df = pd.concat([all_scores_df, new_index_df], ignore_index=True)

    save_all_scores_df(all_scores_df)

    # Create the folders for the new index
    os.makedirs(f"{FOLDER_HOME}/{new_index}/scripts", exist_ok=True)
           
    logging.debug(f"Child index {new_index} created for {parent_index}.")
    
    return new_index, all_scores_df

NameError: name 'pd' is not defined

# main functions - design

In [None]:
def run_ESMfold_RosettaRelax(index, all_scores_df, OnlyRelax=False, ProteinMPNN=False, PreMatchRelax=False,
                             ProteinMPNN_parent_index=0, cmd="", bash=False, EXPLORE=False):
    
    # Giving the ESMfold algorihm the needed inputs
    output_file = f'{FOLDER_HOME}/{index}/ESMfold/{WT}_ESMfold_output_{index}.pdb'
    if ProteinMPNN:
        sequence_file = f'{FOLDER_HOME}/{index}/ProteinMPNN/{WT}_{index}.seq'
    else:
        sequence_file = f'{FOLDER_HOME}/{index}/ESMfold/{WT}_{index}.seq'
        
    # Make directories
    os.makedirs(f"{FOLDER_HOME}/{index}/ESMfold", exist_ok=True)
    os.makedirs(f"{FOLDER_HOME}/{index}/scripts", exist_ok=True)
        
    # Options for EXPLORE, accelerated script for testing
    ex = "-ex1 -ex2"
    if EXPLORE: ex = ""
        
    # Get Name of parent PDB
    if OnlyRelax: 
        PDBFile = f"{FOLDER_HOME}/{index}/{WT}_Rosetta_Design_{index}.pdb"
    elif PreMatchRelax: 
        PDBFile = f"{FOLDER_INPUT}/{WT}.pdb"
    elif ProteinMPNN:
        PDBFile = f"{FOLDER_HOME}/{ProteinMPNN_parent_index}/{WT}_Rosetta_Relax_{ProteinMPNN_parent_index}.pdb"
    else:
        print("I don't know what you want me to do")
        return
    if not os.path.isfile(PDBFile):
        logging.error(f"{PDBFile} not present!")
        return False

    # Make sequence file
    if OnlyRelax or PreMatchRelax: 
        seq = extract_sequence_from_pdb(PDBFile)
        with open(f"{FOLDER_HOME}/{index}/ESMfold/{WT}_{index}.seq","w") as f: f.write(seq)
                    
    # Get the pdb file from the last step and strip away ligand and hydrogens 
    cpptraj = f'''parm    {PDBFile}
trajin  {PDBFile}
strip   :{LIGAND}
strip   !@C,N,O,CA
trajout {FOLDER_HOME}/{index}/ESMfold/{WT}_CPPTraj_Apo_{index}.pdb
'''
    with open(f'{FOLDER_HOME}/{index}/ESMfold/CPPTraj_Apo_{index}.in','w') as f: f.write(cpptraj)

    # Get the pdb file from the last step and strip away everything except the ligand
    cpptraj = f'''parm    {PDBFile}
trajin  {PDBFile}
strip   !:{LIGAND}
trajout {FOLDER_HOME}/{index}/ESMfold/{WT}_CPPTraj_Lig_{index}.pdb
'''
    with open(f'{FOLDER_HOME}/{index}/ESMfold/CPPTraj_Lig_{index}.in','w') as f: f.write(cpptraj)

    # Get the ESMfold pdb file and strip away all hydrogens
    cpptraj = f'''parm    {FOLDER_HOME}/{index}/ESMfold/{WT}_ESMfold_output_{index}.pdb
trajin  {FOLDER_HOME}/{index}/ESMfold/{WT}_ESMfold_output_{index}.pdb
strip   !@C,N,O,CA
trajout {FOLDER_HOME}/{index}/ESMfold/{WT}_ESMfold_no_hydrogens_{index}.pdb
'''
    with open(f'{FOLDER_HOME}/{index}/ESMfold/CPPTraj_no_hydrogens_{index}.in','w') as f: f.write(cpptraj)

    # Align substrate and ESM prediction of scaffold without hydrogens
    cpptraj = f'''parm    {FOLDER_HOME}/{index}/ESMfold/{WT}_ESMfold_no_hydrogens_{index}.pdb
reference {FOLDER_HOME}/{index}/ESMfold/{WT}_CPPTraj_Apo_{index}.pdb [apo]
trajin    {FOLDER_HOME}/{index}/ESMfold/{WT}_ESMfold_no_hydrogens_{index}.pdb
rmsd      @CA ref [apo]
trajout   {FOLDER_HOME}/{index}/ESMfold/{WT}_ESMfold_aligned_{index}.pdb noter
'''
    with open(f'{FOLDER_HOME}/{index}/ESMfold/CPPTraj_aligned_{index}.in','w') as f: f.write(cpptraj) 
              
    if GRID:           extension = "linuxgccrelease"
    if BLUEPEBBLE:     extension = "serialization.linuxgccrelease"
    if BACKGROUND_JOB: extension = "serialization.linuxgccrelease"
    if ABBIE_LOCAL:    
        extension = "linuxgccrelease"
        bash_args = "OMP_NUM_THREADS=1"
    else:
        bash_args = ""
 
    cmd += f"""
    
{bash_args} python {FOLDER_HOME}/ESMfold.py {output_file} {sequence_file}

sed -i '/PARENT N\/A/d' {FOLDER_HOME}/{index}/ESMfold/{WT}_ESMfold_output_{index}.pdb
cpptraj -i {FOLDER_HOME}/{index}/ESMfold/CPPTraj_Apo_{index}.in           &> \
           {FOLDER_HOME}/{index}/ESMfold/CPPTraj_Apo_{index}.out
cpptraj -i {FOLDER_HOME}/{index}/ESMfold/CPPTraj_Lig_{index}.in           &> \
           {FOLDER_HOME}/{index}/ESMfold/CPPTraj_Lig_{index}.out
cpptraj -i {FOLDER_HOME}/{index}/ESMfold/CPPTraj_no_hydrogens_{index}.in  &> \
           {FOLDER_HOME}/{index}/ESMfold/CPPTraj_no_hydrogens_{index}.out
cpptraj -i {FOLDER_HOME}/{index}/ESMfold/CPPTraj_aligned_{index}.in       &> \
           {FOLDER_HOME}/{index}/ESMfold/CPPTraj_aligned_{index}.out

# Assemble the final protein
sed -i '/END/d' {FOLDER_HOME}/{index}/ESMfold/{WT}_ESMfold_aligned_{index}.pdb
# Return HETATM to ligand output and remove TER
sed -i -e 's/^ATOM  /HETATM/' -e '/^TER/d' {FOLDER_HOME}/{index}/ESMfold/{WT}_CPPTraj_Lig_{index}.pdb
"""
    
    input_extension_relax = ""
    if PreMatchRelax:
        extension_relax = "_APO"
        ## No ligand necessary so just use the aligned pdb from ESMfold
        cmd += f"""
cp {FOLDER_HOME}/{index}/ESMfold/{WT}_ESMfold_aligned_{index}.pdb \
   {FOLDER_HOME}/{index}/{WT}_ESMfold_{index}{extension_relax}.pdb
"""  

    else:
        extension_relax = ""
        remark = generate_remark_from_all_scores_df(all_scores_df, index)
        with open(f'{FOLDER_HOME}/{index}/{WT}_ESMfold_{index}.pdb', 'w') as f: f.write(remark+"\n")
        cmd += f"""
cat {FOLDER_HOME}/{index}/ESMfold/{WT}_ESMfold_aligned_{index}.pdb >> {FOLDER_HOME}/{index}/{WT}_ESMfold_{index}.pdb
cat {FOLDER_HOME}/{index}/ESMfold/{WT}_CPPTraj_Lig_{index}.pdb     >> {FOLDER_HOME}/{index}/{WT}_ESMfold_{index}.pdb
sed -i '/TER/d' {FOLDER_HOME}/{index}/{WT}_ESMfold_{index}.pdb
"""
        
    cmd += f"""
# Run Rosetta Relax
{ROSETTA_PATH}/bin/rosetta_scripts.{extension} \
                -s                                        {FOLDER_HOME}/{index}/{WT}_ESMfold_{index}{extension_relax}.pdb \
                -extra_res_fa                             {FOLDER_INPUT}/{LIGAND}.params \
                -parser:protocol                          {FOLDER_HOME}/{index}/scripts/Rosetta_Relax_{index}.xml \
                -out:file:scorefile                       {FOLDER_HOME}/{index}/score_rosetta_relax.sc \
                -nstruct                                  1 \
                -ignore_zero_occupancy                    false \
                -corrections::beta_nov16                  true \
                -run:preserve_header                      true \
                -overwrite {ex}

# Rename the output file
mv {WT}_ESMfold_{index}{extension_relax}_0001.pdb {WT}_Rosetta_Relax_{index}{extension_relax}.pdb
sed -i '/        H  /d' {WT}_Rosetta_Relax_{index}{extension_relax}.pdb
"""
    
    if PreMatchRelax:
        extension_relax = "_APO"
        
        cmd += f"""
# Align relaxed ESM prediction of scaffold without hydrogens
cpptraj -i {FOLDER_HOME}/{index}/ESMfold/{WT}_Rosetta_Relax_aligned_{index}{extension_relax}.in           &> \
           {FOLDER_HOME}/{index}/ESMfold/{WT}_Rosetta_Relax_aligned_{index}{extension_relax}.out
sed -i '/END/d' {FOLDER_HOME}/{index}/{WT}_Rosetta_Relax_aligned_{index}{extension_relax}.pdb
"""  
        
        cpptraj = f'''parm    {FOLDER_HOME}/{index}/{WT}_Rosetta_Relax_{index}{extension_relax}.pdb [protein]
parm      {FOLDER_HOME}/{index}/ESMfold/{WT}_CPPTraj_Apo_{index}.pdb [reference]
reference {FOLDER_HOME}/{index}/ESMfold/{WT}_CPPTraj_Apo_{index}.pdb parm [reference] [apo]
trajin    {FOLDER_HOME}/{index}/{WT}_Rosetta_Relax_{index}{extension_relax}.pdb parm [protein]
rmsd      @CA ref [apo]
trajout   {FOLDER_HOME}/{index}/{WT}_Rosetta_Relax_aligned_{index}{extension_relax}.pdb noter
'''
        with open(f'{FOLDER_HOME}/{index}/ESMfold/{WT}_Rosetta_Relax_aligned_{index}{extension_relax}.in','w') as f: 
            f.write(cpptraj) 
    
        # Create the Rosetta_Relax.xml file
    repeats = "3"
    if EXPLORE: repeats = "1"
    Rosetta_Relax_xml = f"""
<ROSETTASCRIPTS>

    <SCOREFXNS>
    
        <ScoreFunction name      = "score"                   weights = "beta_nov16" >
            <Reweight scoretype  = "atom_pair_constraint"    weight  = "4" />
            <Reweight scoretype  = "angle_constraint"        weight  = "2" />
            <Reweight scoretype  = "dihedral_constraint"     weight  = "1" />
        </ScoreFunction> 
        
        <ScoreFunction name      = "score_final"             weights = "beta_nov16" >
            <Reweight scoretype  = "atom_pair_constraint"    weight  = "4" />
            <Reweight scoretype  = "angle_constraint"        weight  = "2" />
            <Reweight scoretype  = "dihedral_constraint"     weight  = "1" />
        </ScoreFunction>
        
    </SCOREFXNS>
       
    <MOVERS>
                                  
        <FastRelax  name="mv_relax" disable_design="false" repeats="{repeats}" /> 
"""
    if not PreMatchRelax: Rosetta_Relax_xml += f"""
        <AddOrRemoveMatchCsts     name="mv_add_cst" 
                                  cst_instruction="add_new" 
                                  cstfile="{FOLDER_INPUT}/{LIGAND}_enzdes.cst" />

"""
    Rosetta_Relax_xml += f"""

        <InterfaceScoreCalculator   name                   = "mv_inter" 
                                    chains                 = "X" 
                                    scorefxn               = "score_final" />
    </MOVERS>
    
    <PROTOCOLS>  

        <Add mover_name="mv_relax" />
"""
    if not PreMatchRelax: Rosetta_Relax_xml += f"""                                  
        <Add mover_name="mv_add_cst" />       
        <Add mover_name="mv_inter" />
"""
    Rosetta_Relax_xml += f"""
    </PROTOCOLS>
    
</ROSETTASCRIPTS>
"""
    # Write the Rosetta_Relax.xml to a file
    with open(f'{FOLDER_HOME}/{index}/scripts/Rosetta_Relax_{index}.xml', 'w') as f:
        f.writelines(Rosetta_Relax_xml)      
        
    if OnlyRelax or PreMatchRelax: 
        with open(f'{FOLDER_HOME}/{index}/scripts/ESMfold_Rosetta_Relax_{index}.sh', 'w') as file:
            file.write(cmd)
        logging.info(f"Run ESMfold & Rosetta_Relax for index {index}.")
        submit_job(index=index, job="ESMfold_Rosetta_Relax", bash=bash)
        
    if ProteinMPNN:
        with open(f'{FOLDER_HOME}/{index}/scripts/ProteinMPNN_ESMfold_Rosetta_Relax_{index}.sh', 'w') as file:
            file.write(cmd)
        logging.info(f"Run ProteinMPNN for index {index} based on index {ProteinMPNN_parent_index}.")
        submit_job(index=index, job="ProteinMPNN_ESMfold_Rosetta_Relax", bash=bash)

    return True
        
def run_RosettaDesign(parent_index, new_index, all_scores_df, parent_done=True):

    # Options for EXPLORE, accelerated script for testing
    ex = "-ex1 -ex2"
    if EXPLORE: ex = ""

    if GRID:           extension = "linuxgccrelease"
    if BLUEPEBBLE:     extension = "serialization.linuxgccrelease"
    if BACKGROUND_JOB: extension = "serialization.linuxgccrelease"

    if ABBIE_LOCAL:    extension = "linuxgccrelease"

    if parent_done:
        PDB_input  = f'{FOLDER_HOME}/{parent_index}/{WT}_Rosetta_Relax_{parent_index}.pdb'
        PDB_output = f'{WT}_Rosetta_Relax_{parent_index}_0001.pdb'
    else:
        PDB_input  = f'{FOLDER_HOME}/{FOLDER_PARENT}/{parent_index}.pdb'
        PDB_output = f'{parent_index}_0001.pdb'
        
    all_scores_df = save_cat_res_into_all_scores_df(all_scores_df, new_index, PDB_input, from_parent_struct=True)
    
    cmd = f"""{ROSETTA_PATH}/bin/rosetta_scripts.{extension}\
    -s                                        {PDB_input} \
    -in:file:native                           {PDB_input} \
    -run:preserve_header                      true \
    -extra_res_fa                             {FOLDER_INPUT}/{LIGAND}.params \
    -parser:protocol                          {FOLDER_HOME}/{new_index}/scripts/Rosetta_Design_{new_index}.xml \
    -out:file:scorefile                       {FOLDER_HOME}/{new_index}/score_rosetta_design.sc \
    -nstruct                                  1  \
    -ignore_zero_occupancy                    false  \
    -corrections::beta_nov16                  true \
    -overwrite {ex}
    
mv {PDB_output} {WT}_Rosetta_Design_{new_index}.pdb 
"""
    # Write the shell command to a file
    with open(f'{FOLDER_HOME}/{new_index}/scripts/Rosetta_Design_{new_index}.sh','w') as file: file.write(cmd)
                
    # Create XML script for Rosetta Design  
    repeats = "3"
    if EXPLORE: repeats = "1"
        
    Rosetta_Design_xml = f"""
<ROSETTASCRIPTS>

    <SCOREFXNS>

        <ScoreFunction            name="score"                           weights="beta_nov16" >  
            <Reweight             scoretype="atom_pair_constraint"       weight="4" />
            <Reweight             scoretype="angle_constraint"           weight="2" />    
            <Reweight             scoretype="dihedral_constraint"        weight="1" />        
            <Reweight             scoretype="res_type_constraint"        weight="1" />              
        </ScoreFunction>
       
        <ScoreFunction            name="score_unconst"                   weights="beta_nov16" >        
            <Reweight             scoretype="atom_pair_constraint"       weight="0" />
            <Reweight             scoretype="dihedral_constraint"        weight="0" />
            <Reweight             scoretype="angle_constraint"           weight="0" />              
        </ScoreFunction>

        <ScoreFunction            name="score_final"                     weights="beta_nov16" >    
            <Reweight             scoretype="atom_pair_constraint"       weight="4" />
            <Reweight             scoretype="angle_constraint"           weight="2" />    
            <Reweight             scoretype="dihedral_constraint"        weight="1" />               
        </ScoreFunction>
   
   </SCOREFXNS>
   
    <RESIDUE_SELECTORS>
   
        <Index                    name="sel_design"
                                  resnums="{DESIGN}" />

        <Index                    name="sel_repack"
                                  resnums="{REPACK}" />
"""
    
    # Add residue number constraints from REMARK (via all_scores_df['cat_resi'])
    cat_resis = all_scores_df.at[new_index, 'cat_resi'].split(';')
    for idx, cat_resi in enumerate(cat_resis): 
        
        Rosetta_Design_xml += f"""
        <Index                    name="sel_cat_{idx}"
                                  resnums="{int(cat_resi)}" />
"""
        
    Rosetta_Design_xml += f"""
        <Or                       name="sel_desrep"
                                  selectors="sel_design,sel_repack" />

        <Not                      name="sel_nothing"
                                  selector="sel_desrep" />
    </RESIDUE_SELECTORS>
   
    <TASKOPERATIONS>
   
        <OperateOnResidueSubset   name="tsk_design"                      selector="sel_design" >
                                  <RestrictAbsentCanonicalAASRLT         aas="GPAVLIMFYWHKRQNEDST" />
        </OperateOnResidueSubset>
"""
    
    # Add residue identity constraints from constraint file
    with open(f'{FOLDER_HOME}/cst.dat', 'r') as f: cat_resns = f.read()    
    cat_resns = cat_resns.split(";")
    
    for idx, cat_resn in enumerate(cat_resns): 
        Rosetta_Design_xml += f"""
        <OperateOnResidueSubset   name="tsk_cat_{idx}"                   selector="sel_cat_{idx}" >
                                  <RestrictAbsentCanonicalAASRLT         aas="{cat_resn}" />
        </OperateOnResidueSubset>
"""
    
    tsk_cat = []
    for idx, cat_res in enumerate(cat_resns): 
        tsk_cat += [f"tsk_cat_{idx}"]
    tsk_cat = ",".join(tsk_cat)
        
    Rosetta_Design_xml += f"""
       
        <OperateOnResidueSubset   name="tsk_repack"                      selector="sel_repack" >
                                  <RestrictToRepackingRLT />
        </OperateOnResidueSubset>
       
        <OperateOnResidueSubset   name="tsk_nothing"                     selector="sel_nothing" >
                                  <PreventRepackingRLT />
        </OperateOnResidueSubset>
       
    </TASKOPERATIONS>

    <FILTERS>
   
        <HbondsToResidue          name="flt_hbonds"
                                  scorefxn="score"
                                  partners="1"
                                  residue="1X"
                                  backbone="true"
                                  sidechain="true"
                                  from_other_chains="true"
                                  from_same_chain="false"
                                  confidence="0" />
    </FILTERS>
   
    <MOVERS>
       
        <FavorSequenceProfile     name="mv_native"
                                  weight="{CST_WEIGHT}"
                                  use_native="true"
                                  matrix="IDENTITY"
                                  scorefxns="score" />  
                               
        <AddOrRemoveMatchCsts     name="mv_add_cst"
                                  cst_instruction="add_new"
                                  cstfile="{FOLDER_INPUT}/{LIGAND}_enzdes.cst" />

        <FastDesign               name                   = "mv_design"
                                  disable_design         = "false"
                                  task_operations        = "tsk_design,tsk_repack,tsk_nothing,{tsk_cat}"
                                  repeats                = "{repeats}"
                                  ramp_down_constraints  = "false"
                                  scorefxn               = "score" />
                                  
        <FastRelax                name                   = "mv_relax"
                                  disable_design         = "true"
                                  task_operations        = "tsk_design,tsk_repack,tsk_nothing,{tsk_cat}"
                                  repeats                = "1"
                                  ramp_down_constraints  = "false"
                                  scorefxn               = "score_unconst" />  
                                  
        <InterfaceScoreCalculator name                   = "mv_inter"
                                  chains                 = "X"
                                  scorefxn               = "score_final" />
                                 
    </MOVERS>

    <PROTOCOLS>
        <Add mover_name="mv_native" />
        <Add mover_name="mv_add_cst" />
        <Add mover_name="mv_design" />
        <Add mover_name="mv_relax" />
        <Add mover_name="mv_inter" />
    </PROTOCOLS>
   
</ROSETTASCRIPTS>

"""
    # Write the XML script to a file
    with open(f'{FOLDER_HOME}/{new_index}/scripts/Rosetta_Design_{new_index}.xml', 'w') as f:
        f.writelines(Rosetta_Design_xml)               
        
    # Submit the job using the submit_job function
    logging.info(f"Run RosettaDesign for index {new_index} based on index {parent_index}.")
    submit_job(index=new_index, job="Rosetta_Design")



def run_ProteinMPNN(parent_index, new_index, all_scores_df):
    """
    Executes the ProteinMPNN pipeline for a given protein structure and generates
    new protein sequences with potentially higher functional scores.

    Parameters:
    - parent_index (str): The index of the parent protein variant.
    - new_index (str): The index assigned to the new protein variant.
    - all_scores_df (DataFrame): A DataFrame containing information for protein variants.
    - --------------------------------GLOBAL variables used ----------------------------
    - FOLDER_HOME (str): The base directory where ProteinMPNN and related files are located.
    - WT (str): The wild type or reference protein identifier.
    - DESIGN (str): A string representing positions and types of amino acids to design with RosettaDesign.
    - ProteinMPNN_T (float): The sampling temperature for ProteinMPNN.
    - EXPLORE (bool): Flag to indicate whether exploration mode is enabled.

    Returns:
    None: The function primarily works by side effects, finally producing the highes scoring sequence 
    in the specified directories.
    
    Note:
    This function assumes the ProteinMPNN toolkit is available and properly set up in the specified location.
    It involves multiple subprocess calls to Python scripts for processing protein structures and generating new sequences.
    """
    # Ensure ProteinMPNN is available
    if not os.path.exists(f'{FOLDER_HOME}/../ProteinMPNN'):
        logging.error(f"ProteinMPNN not installed in {FOLDER_HOME}/../ProteinMPNN.")
        logging.error("Install using: git clone https://github.com/dauparas/ProteinMPNN.git")
        return

    # Prepare file paths
    pdb_file = f"{FOLDER_HOME}/{parent_index}/{WT}_Rosetta_Relax_{parent_index}.pdb"
    if not os.path.isfile(pdb_file):
        logging.error(f"{pdb_file} not present!")
        return

    protein_mpnn_folder = f"{FOLDER_HOME}/{new_index}/ProteinMPNN"
    os.makedirs(protein_mpnn_folder, exist_ok=True)
    shutil.copy(pdb_file, os.path.join(protein_mpnn_folder, f"{WT}_Rosetta_Relax_{parent_index}.pdb"))

    seq = extract_sequence_from_pdb(pdb_file)
    with open(os.path.join(protein_mpnn_folder, f"Rosetta_Relax_{parent_index}.seq"), "w") as f:
        f.write(seq)

    # Run ProteinMPNN steps using subprocess
    helper_scripts_path = f"{FOLDER_HOME}/../ProteinMPNN/helper_scripts"
    protein_mpnn_path = f"{FOLDER_HOME}/../ProteinMPNN"

    # Parse multiple chains
    run_command([
        "python", os.path.join(helper_scripts_path, "parse_multiple_chains.py"),
        "--input_path", protein_mpnn_folder,
        "--output_path", os.path.join(protein_mpnn_folder, "parsed_chains.json1")
    ])

    # Assign fixed chains
    run_command([
        "python", os.path.join(helper_scripts_path, "assign_fixed_chains.py"),
        "--input_path", os.path.join(protein_mpnn_folder, "parsed_chains.json1"),
        "--output_path", os.path.join(protein_mpnn_folder, "assigned_chains.json1"),
        "--chain_list", 'A'
    ])

    # Make fixed positions dict
    run_command([
        "python", os.path.join(helper_scripts_path, "make_fixed_positions_dict.py"),
        "--input_path", os.path.join(protein_mpnn_folder, "parsed_chains.json1"),
        "--output_path", os.path.join(protein_mpnn_folder, "fixed_positions.json1"),
        "--chain_list", 'A',
        "--position_list", " ".join(DESIGN.split(","))
    ])

    # Protein MPNN run
    run_command([
        "python", os.path.join(protein_mpnn_path, "protein_mpnn_run.py"),
        "--jsonl_path", os.path.join(protein_mpnn_folder, "parsed_chains.json1"),
        "--chain_id_jsonl", os.path.join(protein_mpnn_folder, "assigned_chains.json1"),
        "--fixed_positions_jsonl", os.path.join(protein_mpnn_folder, "fixed_positions.json1"),
        "--out_folder", protein_mpnn_folder,
        "--num_seq_per_target", "100",
        "--sampling_temp", ProteinMPNN_T,
        "--seed", "37",
        "--batch_size", "1"
    ])
    

    # Find highest scoring sequence
    highest_scoring_sequence = find_highest_scoring_sequence(protein_mpnn_folder, parent_index, input_sequence_path=
                                                             f"{FOLDER_HOME}/input_sequence_with_X_as_wildecard.seq")

    # Save highest scoring sequence and prepare for ESMfold
    with open(os.path.join(protein_mpnn_folder, f"{WT}_{new_index}.seq"), "w") as f:
        f.write(highest_scoring_sequence)
    
    if highest_scoring_sequence:
        logging.info(f"Ran ProteinMPNN for index {parent_index} and found a new sequence with index {new_index}.")
    else:
        logging.error(f"Failed to find a new sequnce for index {parent_index} with ProteinMPNN.")
    
    all_scores_df = save_cat_res_into_all_scores_df(all_scores_df, new_index, pdb_file, from_parent_struct=False)
        
    # Run ESMfold Relax with the ProteinMPNN Flag
    run_ESMfold_RosettaRelax(index=new_index, all_scores_df=all_scores_df, OnlyRelax=False, \
                             ProteinMPNN=True, ProteinMPNN_parent_index=parent_index, EXPLORE=EXPLORE)

def find_highest_scoring_sequence(folder_path, parent_index, input_sequence_path):
    """
    Identifies the highest scoring protein sequence from a set of generated PNPNN sequences,
    excluding the parent and WT sequence (except wildcard positions specified by DESIGN).

    Parameters:
    - folder_path (str): The path to the directory containing sequence files (/ProteinMPNN).
    - parent_index (str): The index of the parent protein sequence.
    - input_sequence_path (str): The path to a file containing the input sequence pattern,
      where 'X' represents wildcard positions that can match any character.
    - -------------------------------- GLOBAL variables used ----------------------------
    - WT (str): The wild type or reference protein identifier.
      

    Returns:
    - highest_scoring_sequence (str): The protein sequence with the highest score 
      that does not match the parent and WT.
    
    Note:
    This function parses .fa files to find sequences and their scores, and applies
    a regex pattern derived from the input sequence to filter sequences.
    It assumes the presence of 'global_score' within the sequence descriptor lines
    in the .fa file for scoring.
    """
    # Construct the file path for the sequence data
    file_path = f'{folder_path}/seqs/{WT}_Rosetta_Relax_{parent_index}.fa'
    parent_seq_file = f'{folder_path}/Rosetta_Relax_{parent_index}.seq'
    
    # Read the parent sequence from its file
    with open(parent_seq_file, 'r') as file:
        parent_sequence = file.readline().strip()

    # Read the input sequence pattern and prepare it for regex matching
    with open(input_sequence_path, 'r') as file:
        input_sequence = file.readline().strip()
    pattern = re.sub('X', '.', input_sequence)  # Replace 'X' with regex wildcard '.'

    highest_score = 0
    highest_scoring_sequence = ''

    # Process the sequence file to find the highest scoring sequence
    with open(file_path, 'r') as file:
        for line in file:
            if line.startswith('>'):
                score_match = re.search('global_score=(\d+\.\d+)', line)
                if score_match:
                    score = float(score_match.group(1))
                    sequence = next(file, '').strip()  # Read the next line for the sequence
                    
                    # Check if the score is higher, the sequence is different from the parent,
                    # and does not match the input sequence pattern
                    if score > highest_score and sequence != parent_sequence and not re.match(pattern, sequence):
                        highest_score = score
                        highest_scoring_sequence = sequence

    # Return the highest scoring sequence found
    return highest_scoring_sequence

def run_LigandMPNN(parent_index, new_index, all_scores_df):
    """
    Executes the LigandMPNN pipeline for a given protein-ligand structure and generates
    new protein sequences with potentially higher functional scores considering the ligand context.

    Parameters:
    - parent_index (str): The index of the parent protein variant.
    - new_index (str): The index assigned to the new protein variant.
    - all_scores_df (DataFrame): A DataFrame containing information for protein variants.
    """
    # Ensure LigandMPNN is available
    if not os.path.exists(f'{FOLDER_HOME}/../LigandMPNN'):
        logging.error(f"LigandMPNN not installed in {FOLDER_HOME}/LigandMPNN.")
        logging.error("Install using: git clone https://github.com/dauparas/LigandMPNN.git")
        return
    ligand_mpnn_path = f"{FOLDER_HOME}/../LigandMPNN"

    # Prepare file paths
    pdb_file = f"{FOLDER_HOME}/{parent_index}/{WT}_Rosetta_Relax_{parent_index}.pdb"
    if not os.path.isfile(pdb_file):
        logging.error(f"{pdb_file} not present!")
        return

    ligand_mpnn_folder = f"{FOLDER_HOME}/{new_index}/LigandMPNN"
    os.makedirs(ligand_mpnn_folder, exist_ok=True)
    shutil.copy(pdb_file, os.path.join(ligand_mpnn_folder, f"{WT}_Rosetta_Relax_{parent_index}.pdb"))

    # Extract catalytic residue information
    cat_resi = int(all_scores_df.at[parent_index, 'cat_resi'])
    fixed_residues = f"A{cat_resi}"

    # Run LigandMPNN
    run_command([
        "python", os.path.join(ligand_mpnn_path, "run.py"),
        "--model_type", "ligand_mpnn",
        "--seed", "37",
        "--pdb_path", os.path.join(ligand_mpnn_folder, f"{WT}_Rosetta_Relax_{parent_index}.pdb"),
        "--out_folder", ligand_mpnn_folder,
        #"--pack_side_chains", "1",
        #"--number_of_packs_per_design", "4",
        "--fixed_residues", fixed_residues
    ])

    # Update all_scores_df

    logging.info(f"Ran LigandMPNN for index {parent_index} and generated modifications for new index {new_index}.")

    # Save updates to all_scores_df
    #save_all_scores_df(all_scores_df)

# main functions - startup

In [1]:
def startup_controller(UNBLOCK_ALL, RESET, PRINT_VAR=True, PLOT_DATA=True):

    # Execute setup if variables file not found
    if not os.path.isfile(VARIABLES_JSON): setup_aizymes(RESET) #? why is this called again?
    
    # Creat all input files, not needed here but does not harm
    prepare_input_files()
    
    if PRINT_VAR:
        if os.path.isfile(VARIABLES_JSON):
            with open(VARIABLES_JSON, 'r') as f: 
                globals_dict = json.load(f)
            if globals_dict['DESIGN_FOLDER'] == DESIGN_FOLDER:
                for k, v in globals_dict.items():
                    globals()[k] = v
                    print(k.ljust(16), ':', v)
            else:
                print("WRONG DESIGN FOLDER!")
                sys.exit()
    
    time.sleep(1)
    
    if PLOT_DATA:
        plot_scores()
        
    # Read in current databases of AIzymes
    all_scores_df = pd.read_csv(ALL_SCORES_CSV)

    if UNBLOCK_ALL: 
        all_scores_df["blocked"] = False
   
    return all_scores_df 

def prepare_input_files():
                      
    # Create the ESMfold.py script
    ESMfold_python_script = """import sys
from transformers import AutoTokenizer, EsmForProteinFolding, EsmConfig
import torch
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37

output_file = sys.argv[1]
sequence_file = sys.argv[2]

with open(sequence_file) as f: sequence=f.read()

def convert_outputs_to_pdb(outputs):
    final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
    outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
    final_atom_positions = final_atom_positions.cpu().numpy()
    final_atom_mask = outputs["atom37_atom_exists"]
    pdbs = []
    for i in range(outputs["aatype"].shape[0]):
        aa = outputs["aatype"][i]
        pred_pos = final_atom_positions[i]
        mask = final_atom_mask[i]
        resid = outputs["residue_index"][i] + 1
        pred = OFProtein(
            aatype=aa,
            atom_positions=pred_pos,
            atom_mask=mask,
            residue_index=resid,
            b_factors=outputs["plddt"][i],
            chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
        )
        pdbs.append(to_pdb(pred))
    return pdbs

tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)
torch.backends.cuda.matmul.allow_tf32 = True
model.trunk.set_chunk_size(64)
tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids']
with torch.no_grad(): output = model(tokenized_input)
pdb = convert_outputs_to_pdb(output)
with open(output_file, "w") as f: f.write("".join(pdb))
"""

    # Write the ESMfold.py to a file
    with open(f"{FOLDER_HOME}/ESMfold.py", "w") as f: 
        f.write(ESMfold_python_script)

    # Save input sequence with X as wildcard
    if not os.path.isfile(f"{FOLDER_INPUT}/{WT}.pdb"):
        print(f"Error, scaffold protein structure {FOLDER_INPUT}/{WT}.pdb is missing!")
        sys.exit()        
    seq = extract_sequence_from_pdb(f"{FOLDER_INPUT}/{WT}.pdb")
    design_positions = [int(x) for x in DESIGN.split(',')]
    # Replace seq with X at design positions. Note: Subtract 1 from each position to convert to Python's 0-based indexing
    seq = ''.join('X' if (i+1) in design_positions else amino_acid for i, amino_acid in enumerate(seq))
    with open(f'{FOLDER_HOME}/input_sequence_with_X_as_wildecard.seq', 'w') as f:
        f.writelines(seq)    
    
    # Get the Constraint Residues from enzdes constraints file
    with open(f'{FOLDER_INPUT}/{LIGAND}_enzdes.cst', 'r') as f:
        cst = f.readlines()    
    cst = [i.split()[-1] for i in cst if "TEMPLATE::   ATOM_MAP: 2 res" in i]
    cst = ";".join(cst)
    with open(f'{FOLDER_HOME}/cst.dat', 'w') as f:
        f.write(cst)    
    
def setup_aizymes(RESET=False, EXPLORE=False):
    
    # Check if setup needs to run
    if not os.path.isfile(VARIABLES_JSON):
        if not input(f'''Do you want to start AIzymes? [y/n]

''') == 'y':
            return #stops the start up. Although VARIABLES_JSON is missing, user elected not to set up AIzymes
    else:
        if RESET:
            if not input(f'''Do you really want to restart AIzymes from scratch? 
This will delete all existing files in {FOLDER_HOME} [y/n]

''') == 'y':
                return #stops the start up. Although VARIABLES_JSON exists and RESET set, user canceled
        else:
            return #stop startup. VARIABLES_JSON exists and RESET not set by user

    with open(LOG_FILE, 'w'): pass  #resets logfile
    logging.info(f"Running AI.zymes setup.")
    logging.info(f"Content of {FOLDER_HOME} deleted.")
    logging.info(f"Happy AI.zymeing! :)")
   
    if os.path.exists(FOLDER_HOME):
        for item in os.listdir(FOLDER_HOME):
            if item == FOLDER_MATCH: continue
            item = f'{FOLDER_HOME}/{item}'
            if os.path.isfile(item): 
                os.remove(item)
            elif os.path.isdir(item):
                shutil.rmtree(item)
    os.makedirs(FOLDER_HOME, exist_ok=True)

    prepare_input_files()
        
    #make empyt all_scores_df
    all_scores_df = make_empty_all_scores_df()

    # create empty blocked.dat
    #np.savetxt(BLOCKED_DAT, np.array([], dtype=int), fmt='%d')

    # Save global varliables
    variables_to_save = [
        'DESIGN_FOLDER', 'FOLDER_MATCH', 'MAX_JOBS', 'N_PARENT_JOBS', 'MAX_DESIGNS', 'KBT_BOLTZMANN', 'CST_WEIGHT',
        'ProteinMPNN_PROB', 'WT', 'LIGAND', 'ROSETTA_PATH', 'REPACK', 'DESIGN', 'MATCH', 'FOLDER_PARENT',
        'ProteinMPNN_T', 'SUBMIT_PREFIX', 'BLUEPEBBLE', 'GRID', 'BACKGROUND_JOB', 'ABBIE_LOCAL'
    ]
    globals_to_save = {k: globals()[k] for k in variables_to_save}
    globals_to_save['EXPLORE'] = EXPLORE
    
    if N_PARENT_JOBS < MAX_JOBS*2:
        logging.warning(f"To ensure a smooth start, N_PARENT_JOBS should be at least 2 x MAX_JOBS.")
        logging.warning(f"N_PARENT_JOBS: {N_PARENT_JOBS}, MAX_JOBS: {MAX_JOBS}.")
            
    with open(VARIABLES_JSON, 'w') as f: json.dump(globals_to_save, f, indent=4)
    
    
def make_empty_all_scores_df():
    
    all_scores_df = pd.DataFrame(columns=['index', 'sequence', 'parent_index', \
                                          'interface_score', 'total_score', 'catalytic_score', 'efield_score', \
                                          'interface_potential', 'total_potential', 'catalytic_potential', 'efield_potential', \
                                          'relax_interface_score', 'relax_total_score', 'relax_catalytic_score', 'relax_efield_score' \
                                          'design_interface_score', 'design_total_score', 'design_catalytic_score', 'design_efield_score' \
                                          'generation', 'mutations', 'design_method', 'score_taken_from', 'blocked', \
                                          'cat_resi', 'cat_resn'])
    
    save_all_scores_df(all_scores_df)

    return all_scores_df

# RosettaMatch

In [None]:
def run_RosettaMatch(EXPLORE=False, submit=False, bash=True):
        
    prepare_input_files()
    
    os.makedirs(FOLDER_MATCH, exist_ok=True)
    if not os.path.isdir(f'{FOLDER_HOME}/{FOLDER_MATCH}/scripts'):
        run_ESMfold_RosettaRelax(FOLDER_MATCH, all_scores_df=None, PreMatchRelax=True, EXPLORE=EXPLORE) 
    elif not os.path.isfile(f'{FOLDER_HOME}/{FOLDER_MATCH}/{WT}_Rosetta_Relax_{FOLDER_MATCH}_APO.pdb'):
        print(f"ESMfold and Relax of {FOLDER_MATCH} still running.")
    elif not os.path.isdir(f'{FOLDER_HOME}/{FOLDER_MATCH}/matches'):
        run_Matcher()
    else:
        print("Matching is done")
            
def run_Matcher():
        
    cmd = f"""       
  
cd {FOLDER_HOME}/{FOLDER_MATCH}

echo C9 > {LIGAND}.central
echo {" ".join(MATCH.split(","))} > {LIGAND}.pos

{ROSETTA_PATH}/bin/gen_lig_grids.linuxgccrelease \
    -s                      {WT}_Rosetta_Relax_aligned_{FOLDER_MATCH}_APO.pdb ESMfold/{WT}_CPPTraj_Lig_{FOLDER_MATCH}.pdb \
    -extra_res_fa           {FOLDER_INPUT}/{LIGAND}.params \
    -grid_delta             0.5 \
    -grid_lig_cutoff        5.0 \
    -grid_bb_cutoff         2.25 \
    -grid_active_res_cutoff 15.0 \
    -overwrite 

mv {WT}_Rosetta_Relax_aligned_{FOLDER_MATCH}_APO.pdb_0.gridlig {WT}.gridlig
rm {WT}_Rosetta_Relax_aligned_{FOLDER_MATCH}_APO.pdb_0.pos 2>1

rm -r matches
mkdir matches
cd matches

{ROSETTA_PATH}/bin/match.linuxgccrelease \
    -s                                        ../{WT}_Rosetta_Relax_aligned_{FOLDER_MATCH}_APO.pdb \
    -match:lig_name                           {LIGAND} \
    -extra_res_fa                             {FOLDER_INPUT}/{LIGAND}.params \
    -match:geometric_constraint_file          {FOLDER_INPUT}/{LIGAND}_enzdes.cst \
    -match::scaffold_active_site_residues     ../{LIGAND}.pos \
    -match:required_active_site_atom_names    ../{LIGAND}.central \
    -match:active_site_definition_by_gridlig  ../{WT}.gridlig  \
    -match:grid_boundary                      ../{WT}.gridlig  \
    -gridligpath                              ../{WT}.gridlig  \
    -overwrite  \
    -output_format PDB  \
    -output_matches_per_group 1  \
    -consolidate_matches true 
""" 
    with open(f'{FOLDER_HOME}/{FOLDER_MATCH}/scripts/RosettaMatch_{FOLDER_MATCH}.sh', 'w') as file: file.write(cmd)
    logging.info(f"Run Rosetta_Match for index {FOLDER_MATCH}.")
    submit_job(FOLDER_MATCH, job="RosettaMatch", bash=False)

# helper functions

In [None]:
def submit_job(index, job, bash=False):        
      
    if GRID:
        submission_script = f"""#!/bin/bash
#$ -V
#$ -cwd
#$ -N {SUBMIT_PREFIX}_{job}_{index}
#$ -hard -l mf=16G
#$ -o {FOLDER_HOME}/{index}/scripts/AI_{job}_{index}.out
#$ -e {FOLDER_HOME}/{index}/scripts/AI_{job}_{index}.err
"""
    if BLUEPEBBLE:
        submission_script = f"""#!/bin/bash
#SBATCH --account={BLUEPEBBLE_ACCOUNT}
#SBATCH --partition=short
#SBATCH --mem=40GB
#SBATCH --ntasks-per-node=1
#SBATCH --time=2:00:00    
#SBATCH --nodes=1          
#SBATCH --job-name={SUBMIT_PREFIX}_{job}_{index}
#SBATCH --output={FOLDER_HOME}/{index}/scripts/AI_{job}_{index}.out
#SBATCH --error={FOLDER_HOME}/{index}/scripts/AI_{job}_{index}.err
"""
        
    if BACKGROUND_JOB:
        if not os.path.isfile(f'{FOLDER_HOME}/n_running_jobs.dat'):
            with open(f'{FOLDER_HOME}/n_running_jobs.dat', 'w') as f: f.write('0')
        with open(f'{FOLDER_HOME}/n_running_jobs.dat', 'r'): jobs = int(f.read())
        with open(f'{FOLDER_HOME}/n_running_jobs.dat', 'w'): f.write(jobs+1)
        submission_script = ""

    if ABBIE_LOCAL:
        submission_script = ""
        
    submission_script += f"""
# Output folder
cd {FOLDER_HOME}/{index}
pwd
bash {FOLDER_HOME}/{index}/scripts/{job}_{index}.sh
""" 
    if BACKGROUND_JOB:
        submission_script = f"""
jobs=$(cat {FOLDER_HOME}/n_running_jobs.dat)
jobs=$((jobs - 1))
echo "$jobs" > {FOLDER_HOME}/n_running_jobs.dat
"""


    # Create the submission_script
    with open(f'{FOLDER_HOME}/{index}/scripts/submit_{job}_{index}.sh', 'w') as file: file.write(submission_script)
    
    if bash:
        #Bash the submission_script for testing
        subprocess.run(f'bash {FOLDER_HOME}/{index}/scripts/submit_{job}_{index}.sh', shell=True, text=True)
    else:
        #Submit the submission_script
        if GRID:
            output = subprocess.check_output(f'qsub -l h="!bs-dsvr64&!bs-dsvr58" -q regular.q \
                                             {FOLDER_HOME}/{index}/scripts/submit_{job}_{index}.sh', \
                                             shell=True, text=True)
            logging.debug(output[:-1]) #remove newline at end of output
            
        if BLUEPEBBLE:
            output = subprocess.check_output(f'sbatch {FOLDER_HOME}/{index}/scripts/submit_{job}_{index}.sh', \
                                             shell=True, text=True)
            logging.debug(output[:-1]) #remove newline at end of output
            
        if BACKGROUND_JOB:

            stdout_log_file_path = f'{FOLDER_HOME}/{index}/scripts/submit_{job}_{index}_stdout.log'
            stderr_log_file_path = f'{FOLDER_HOME}/{index}/scripts/submit_{job}_{index}_stderr.log'

            with open(stdout_log_file_path, 'w') as stdout_log_file, open(stderr_log_file_path, 'w') as stderr_log_file:
                process = subprocess.Popen(f'bash {FOLDER_HOME}/{index}/scripts/submit_{job}_{index}.sh &', 
                                           shell=True, stdout=stdout_log_file, stderr=stderr_log_file)
        
        if ABBIE_LOCAL:

            stdout_log_file_path = f'{FOLDER_HOME}/{index}/scripts/submit_{job}_{index}_stdout.log'
            stderr_log_file_path = f'{FOLDER_HOME}/{index}/scripts/submit_{job}_{index}_stderr.log'

            with open(stdout_log_file_path, 'w') as stdout_log_file, open(stderr_log_file_path, 'w') as stderr_log_file:
                process = subprocess.Popen(f'bash {FOLDER_HOME}/{index}/scripts/submit_{job}_{index}.sh &', 
                                           shell=True, stdout=stdout_log_file, stderr=stderr_log_file)
        
def extract_sequence_from_pdb(pdb_path):
    with open(pdb_path, "r") as pdb_file:
        for record in SeqIO.parse(pdb_file, "pdb-atom"):
            seq = str(record.seq)
    return seq
    
def generate_remark_from_all_scores_df(all_scores_df, index):

    remark = ''
    cat_resns = str(all_scores_df.at[index, 'cat_resn']).split(';')
    cat_resis = str(all_scores_df.at[index, 'cat_resi']).split(';')
    
    remarks = []

    for cat_resi, cat_resn in zip(cat_resis, cat_resns):
        remarks.append(f'REMARK 666 MATCH TEMPLATE X {LIGAND}    0 MATCH MOTIF A {cat_resn}{str(cat_resi).rjust(5)}  1  1')
    return "\n".join(remarks)

def save_cat_res_into_all_scores_df(all_scores_df, index, PDB_file_path, from_parent_struct=False):
    
    '''Finds the indices and names of the catalytic residue from <PDB_file_path> 
       Saves indices and residues into <all_scores_df> in row <index> as lists.
       To make sure these are saved and loaded as list, ";".join() and .split(";") should be used
       IF information is read from an input structure for design do not save cat_resn
       Returns the updated all_scores_df'''
      
    time.sleep(0.1)
    
    with open(PDB_file_path, 'r') as f: 
        PDB = f.readlines()
    
    remarks = [i for i in PDB if i[:10] == 'REMARK 666']

    cat_resis = []
    cat_resns = []

    for remark in remarks:
        cat_resis.append(str(int(remark[55:59])))

    for cat_resi in cat_resis:
        for line in PDB[len(remarks)+2:]:
            atomtype = line[12:16]
            if atomtype != " CA ": continue
            resi = str(int(line[22:26]))
            resn = line[17:20]
            if resi == cat_resi:
                cat_resns.append(resn)
                break

    # all_scores_df['cat_resi'] = all_scores_df['cat_resi'].astype(object)
    # all_scores_df['cat_resn'] = all_scores_df['cat_resn'].astype(object)
    all_scores_df.at[index, 'cat_resi'] = ";".join(cat_resis)
    # Only save the cat_resn if this comes from the designed structure, not from the input structure for design
    if not from_parent_struct:
        all_scores_df.at[index, 'cat_resn'] = ";".join(cat_resns)
    
    return all_scores_df 

def reset_to_after_parent_design():
    
    folders = []
    
    for folder_name in os.listdir(FOLDER_HOME):
        if os.path.isdir(os.path.join(FOLDER_HOME, folder_name)) and folder_name.isdigit():
            folders.append(int(folder_name))
    
    all_scores_df = make_empty_all_scores_df()
        
    PARENTS = [i for i in os.listdir(f'{FOLDER_HOME}/{FOLDER_PARENT}') if i[-4:] == ".pdb"]
    
    for folder in sorted(folders):
        
        folder_path = os.path.join(FOLDER_HOME, str(folder))
        
        if folder >= N_PARENT_JOBS * len(PARENTS):
            
            #Remove non-parent designs
            shutil.rmtree(folder_path)
            
        else:
            
            #Remove Potentials
            for item in os.listdir(folder_path):
                if 'potential.dat' not in item: continue
                item_path = os.path.join(folder_path, item)
                os.remove(item_path)
                print(item_path)
                    
            #Update Scorefile
            new_index, all_scores_df = create_new_index(parent_index="Parent", all_scores_df=all_scores_df)
            all_scores_df['design_method'] = all_scores_df['design_method'].astype('object') 
            all_scores_df.at[new_index, 'design_method'] = "RosettaDesign"
            all_scores_df['luca'] = all_scores_df['luca'].astype('object') 
            score_file_path = f"{FOLDER_HOME}/{int(index)}/score_rosetta_design.sc"
            with open(score_file_path, 'r') as f: score = f.readlines()[2]
            all_scores_df.at[new_index, 'luca'] = score.split()[-1][:-5]
    
            if new_index % 100 == 0: print(folder, new_index) 

    save_all_scores_df(all_scores_df)
    
def save_all_scores_df(all_scores_df):
    
    temp_fd, temp_path = tempfile.mkstemp(dir=FOLDER_HOME) # Create a temporary file

    try:
        all_scores_df.to_csv(temp_path, index=False)       # Save DataFrame to the temporary file
        os.close(temp_fd)                                  # Close file descriptor
        os.rename(temp_path, ALL_SCORES_CSV)               # Rename temporary file to final filename
    except Exception as e:
        os.close(temp_fd)                                  # Ensure file descriptor is closed in case of error
        os.unlink(temp_path)                               # Remove the temporary file if an error occurs
        raise e

def get_best_structures():
    # Remove all .pdb files in the current directory
    pdb_files = glob.glob('*.pdb')
    for file in pdb_files:
        os.remove(file)

    # Read the scores DataFrame
    all_scores_df = pd.read_csv('/home/bunzelh/231126_Aizymes_RL/Design_MATCH_AI/all_scores.csv')

    # Drop rows with NaN in 'total_score'
    all_scores_df = all_scores_df.dropna(subset=['total_score'])

    # Normalize the specified columns

    ## Normalise efield score here??
    for column in ['interface_score', 'total_score', 'catalytic_score']:
        min_val = all_scores_df[column].min()
        max_val = all_scores_df[column].max()
        all_scores_df[column] = 1 - ((all_scores_df[column] - min_val) / (max_val - min_val))

    # Calculate the geometric mean of the normalized values
    all_scores_df['combined_score'] = all_scores_df[['interface_score', 'total_score', 'catalytic_score']].apply(
        lambda row: np.prod(row)**(1/3), axis=1)

    all_scores_df['replicate_sequences'] = 0  # Initialize to count duplicates
    all_scores_df['replicate_sequences_combined_score'] = 0.0  # To store the average score
    all_scores_df['replicate_sequences_combined_score_std'] = 0.0  # To store the standard deviation

    # Loop to find duplicates, calculate average score, and standard deviation
    for i, row in all_scores_df.iterrows():
        duplicates = all_scores_df[all_scores_df['sequence'] == row['sequence']]
        avg_score  = duplicates['total_score'].mean()
        std_dev    = duplicates['total_score'].std()

        all_scores_df.at[i, 'replicate_sequences']                    = len(duplicates)
        all_scores_df.at[i, 'replicate_sequences_combined_score']     = avg_score
        all_scores_df.at[i, 'replicate_sequences_combined_score_std'] = std_dev

    # Remove replicates and keep only highest combined_score
    all_scores_df.sort_values(by=['combined_score'], ascending=[False], inplace=True)
    all_scores_df.drop_duplicates(subset=['sequence'], keep='first', inplace=True)

    # Remove all structures with 'catalytic_score' < 0.95
    all_scores_df = all_scores_df[all_scores_df['catalytic_score'] >= 0.95]

    # Get the 10 rows with the highest geometric mean values
    top10 = all_scores_df.nlargest(10, 'combined_score')[['index','luca','cat_resn','cat_resi','interface_score', 
                                                          'total_score', 'catalytic_score','combined_score','mutations',
                                                          'replicate_sequences','replicate_sequences_combined_score',
                                                          'replicate_sequences_combined_score_std']]

    display(all_scores_df)
    display(top10)

    # Copy files based on the top10 'index'
    for index, row in top10.iterrows():
        geom_mean = "{:.3f}".format(row['combined_score'])
        src_file = f"{FOLDER_HOME}/{int(index)}/{WT}_Rosetta_Design_{int(index)}.pdb"
        dest_file = f"{geom_mean}_{WT}_Rosetta_Design_{int(index)}.pdb"
        shutil.copy(src_file, dest_file)

def run_command(command):
    """Wrapper to execute .py files in runntime with arguments"""
    try:
        result = subprocess.run(command, capture_output=True, text=True, check=True)
        return result.stdout
    except subprocess.CalledProcessError as e:
        logging.error(f"Command '{e.cmd}' failed with return code {e.returncode}")
        logging.error(e.output)
        raise
    except Exception as e:
        logging.error(f"An error occurred while running command: {command}")
        raise

def wait_for_file(file_path, timeout=5):
    """Wait for a file to exist and have a non-zero size."""
    start_time = time.time()
    while time.time() - start_time < timeout:
        if os.path.exists(file_path) and os.path.getsize(file_path) > 0:
            return True
        time.sleep(0.1)  # Wait for 0.1 seconds before checking again
    return False



## Electric Fields

In [None]:
def generate_AMBER_files(structure_filename : str):
    '''Uses tleap to create a .parm7 and .rst7 file from a pdb. Requires ambertools and pdb-tools 
    (pip installable from  https://www.bonvinlab.org/pdb-tools/).
    Also requires 5TS.prepi and 5TS.frcmod in the INPUT folder
    TODO: Add script to generate these if not present'''

    #delete everything after 'CONECT' - otherwise Amber tries to read it as extra residues
    clean_pdb = run_command(['sed', '-n', '/CONECT/q;p', f"{structure_filename}.pdb"])
    with open(f"{structure_filename}_clean.pdb", "w") as f:
        f.write(clean_pdb)

    # with open(f"{structure_filename}_clean.pdb", "w") as f:
    #     subprocess.call(['sed', '-n', '/CONECT/q;p', f"{structure_filename}.pdb"], stdout=f)
        
    #remove hydrogens - requires pip install of https://www.bonvinlab.org/pdb-tools/
    noHclean_pdb = run_command(['pdb_delelem', '-H', f"{structure_filename}_clean.pdb"])
    with open(f"{structure_filename}_noHclean.pdb", "w") as f:
        f.write(noHclean_pdb)

    # with open(f"{structure_filename}_noHclean.pdb", "w") as f:
    #     subprocess.call(['pdb_delelem', '-H', f"{structure_filename}_clean.pdb"], stdout=f)

    os.remove(f"{structure_filename}_clean.pdb")

    with open("tleap.in", "w") as f:
            f.write(f"""source leaprc.protein.ff19SB 
    source leaprc.gaff
    loadamberprep   Input/5TS.prepi
    loadamberparams Input/5TS.frcmod
    mol = loadpdb {structure_filename}_noHclean.pdb
    saveamberparm mol {structure_filename}.parm7 {structure_filename}.rst7
    quit
    """)
            
    run_command(["tleap", "-s", "-f", "tleap.in"])

def calc_efields_score(pdb_path):
    '''Executes the FieldTools.py script to calculate the electric field across the C-H bond of 5TS.
    Requires a field_target.dat in the Input folder. Currently hard-coded based on 5TS
    TODO: Make this function agnostic to contents of field_target'''
    #from FieldTools import python_main    

    structure_filename = os.path.splitext(pdb_path)[0]

    generate_AMBER_files(structure_filename)

    run_command(["python", "FieldTools.py", 
                "-nc", f"{os.path.relpath(structure_filename)}.rst7", 
                "-parm", f"{os.path.relpath(structure_filename)}.parm7", 
                "-out", f"{os.path.relpath(structure_filename)}_fields.pkl", 
                "-target", "Input/field_target.dat", 
                "-solvent", "WAT"])
    
    #python_main(["FieldTools.py", "-nc", f"{os.path.relpath(output_filename)}.rst7", "-parm", f"{os.path.relpath(output_filename)}.parm7", "-out", out_path, "-target", "Input/field_target.dat", "-solvent", "WAT"])

    with open(f"{structure_filename}_fields.pkl", "rb") as f:
        FIELDS = pkl.load(f)

    bond_field = FIELDS[':5TS@C9_:5TS@H04']['Total']
    all_fields = FIELDS[':5TS@C9_:5TS@H04']

    return bond_field[0], all_fields

# def update_efieldsdf(index, all_fields):
#     if index == 0:       

# plotting functions

In [None]:
def plot_scores(combined_score_min=0, combined_score_max=1, combined_score_bin=0.01,
                interface_score_min=0, interface_score_max=1, interface_score_bin=0.01,
                total_score_min=0, total_score_max=1, total_score_bin=0.01,
                catalytic_score_min=0, catalytic_score_max=1, catalytic_score_bin=0.01,
                mut_min=0,mut_max=len(DESIGN.split(","))+1):
        
    # Break because file does not exist
    if not os.path.isfile(ALL_SCORES_CSV): return
    
    all_scores_df = pd.read_csv(ALL_SCORES_CSV)
    all_scores_df['sequence'] = all_scores_df['sequence'].astype(str)
    all_scores_df['design_method'] = all_scores_df['design_method'].astype(str)
    all_scores_df['score_taken_from'] = all_scores_df['score_taken_from'].astype(str)
            
    # Break because not enough data
    if len(all_scores_df.dropna(subset=['total_score'])) < 3: return
    
    ### Calculate Generations and Scores    
    G = nx.DiGraph()
    ### generations = {} not needed
    
    for idx, row in all_scores_df.iterrows():
        G.add_node(idx)
        if row['parent_index'] != "Parent":
            G.add_edge(int(float(row['parent_index'])), idx)        
        
    # Plot data
    fig, axs = plt.subplots(2, 4, figsize=(15, 6))
    
    all_scores_df = all_scores_df.dropna(subset=['total_score'])
    catalytic_scores, total_scores, interface_scores, combined_scores = normalize_scores(all_scores_df, 
                                                                                         print_norm=True,
                                                                                         norm_all=True)
            
    plot_combined_score(axs[0,0], combined_scores, \
                        combined_score_min, combined_score_max, combined_score_bin)
    plot_interface_score(axs[0,1], interface_scores, \
                         interface_score_min, interface_score_max, interface_score_bin)
    plot_total_score(axs[0,2], total_scores, \
                     total_score_min, total_score_max, total_score_bin)
    plot_catalytic_score(axs[0,3], catalytic_scores, \
                         catalytic_score_min, catalytic_score_max, catalytic_score_bin)
    
    plot_boltzmann_histogram(axs[1,0], combined_scores, all_scores_df, \
                             combined_score_min, combined_score_max, combined_score_bin)
    plot_combined_score_v_index(axs[1,1], combined_scores, all_scores_df)
    plot_combined_score_v_generation(axs[1,2], combined_scores, all_scores_df)
    plot_mutations_v_generation(axs[1,3], all_scores_df, mut_min, mut_max)
    plt.tight_layout()
    plt.show()
    
    #fig, ax = plt.subplots(1, 1, figsize=(15, 2.5))
    #plot_tree(ax, combined_scores, all_scores_df, G)
    #plt.show()

def plot_combined_score(ax, combined_scores, score_min, score_max, score_bin):
    
    ax.hist(combined_scores, bins=np.arange(score_min,score_max+score_bin,score_bin))
    ax.axvline(HIGHSCORE, color='b')
    ax.axvline(NEG_BEST, color='r')
    ax.set_xlim(score_min,score_max)
    ax.set_title('Histogram of Score')
    ax.set_xlabel('Combined Score')
    ax.set_ylabel('Frequency')
    
def plot_interface_score(ax, interface_scores, interface_score_min, interface_score_max, interface_score_bin):
        
    ax.hist(interface_scores, density=True,
            bins=np.arange(interface_score_min,interface_score_max+interface_score_bin,interface_score_bin))
    ax.set_xlim(interface_score_min,interface_score_max)
    ax.set_title('Histogram of Interface Score')
    ax.set_xlabel('Interface Score')
    ax.set_ylabel('Frequency')
    
def plot_total_score(ax, total_scores, total_score_min, total_score_max, total_score_bin):

    ax.hist(total_scores, density=True,
            bins=np.arange(total_score_min,total_score_max+total_score_bin,total_score_bin))
    ax.set_xlim(total_score_min,total_score_max)
    ax.set_title('Histogram of Total Score')
    ax.set_xlabel('Total Score')
    ax.set_ylabel('Frequency')
    
def plot_catalytic_score(ax, catalytic_scores, total_score_min, total_score_max, total_score_bin):

    ax.hist(catalytic_scores, density=True,
            bins=np.arange(total_score_min,total_score_max+total_score_bin,total_score_bin))
    ax.set_xlim(total_score_min,total_score_max)
    ax.set_title('Histogram of Catalytic Score')
    ax.set_xlabel('Catalytic Score')
    ax.set_ylabel('Frequency')
    

def plot_boltzmann_histogram(ax, combined_scores, all_scores_df, score_min, score_max, score_bin):

    _, _, _, combined_potentials = normalize_scores(all_scores_df, print_norm=False, norm_all=False, extension="potential")
            
    if isinstance(KBT_BOLTZMANN, (float, int)):
        kbt_boltzmann = KBT_BOLTZMANN
    else:
        if len(KBT_BOLTZMANN) == 2:
            kbt_boltzmann = KBT_BOLTZMANN[0] * 10 ** (- KBT_BOLTZMANN[1] * len(all_scores_df))
    
    boltzmann_factors = np.exp(combined_potentials / (kbt_boltzmann)) 
    probabilities     = boltzmann_factors / sum(boltzmann_factors) 
    
    random_scores     = np.random.choice(combined_potentials, size=1000, replace=True)
    boltzmann_scores  = np.random.choice(combined_potentials, size=1000, replace=True, p=probabilities)

    # Plot the first histogram
    ax.hist(random_scores, density=True, alpha=0.7, label='Random Sampling', \
            bins=np.arange(score_min,score_max+score_bin,score_bin))
    ax.text(0.05, 0.95, "normalized only to \n this dataset")
    ax.set_xlabel('Potential')
    ax.set_ylabel('Density (Normal)')
    ax.set_title(f'kbT = {kbt_boltzmann:.1e}')
    # Create a twin y-axis for the second histogram
    ax_dup = ax.twinx()
    ax_dup.hist(boltzmann_scores, density=True, alpha=0.7, color='orange', label='Boltzmann Sampling', \
                bins=np.arange(score_min,score_max+score_bin,score_bin))
    ax.set_xlim(score_min,score_max)
    ax_dup.set_ylabel('Density (Boltzmann)')
    ax_dup.tick_params(axis='y', labelcolor='orange')

def plot_interface_score_v_total_score(ax, all_scores_df, 
                                       total_score_min, total_score_max, interface_score_min, interface_score_max):

    ax.scatter(all_scores_df['total_score'], all_scores_df['interface_score'],
            c=all_scores_df['index'], cmap='coolwarm_r', s=5)
    correlation,_ = pearsonr(all_scores_df['total_score'], all_scores_df['interface_score'])
    xmin = all_scores_df['total_score'].min()
    xmax = all_scores_df['total_score'].max()
    z = np.polyfit(all_scores_df['total_score'], all_scores_df['interface_score'], 1)
    p = np.poly1d(z)
    x_trendline = np.linspace(xmin, xmax, 100) 
    ax.plot(x_trendline, p(x_trendline), "k")
    ax.set_title(f'Pearson r: {correlation:.2f}')
    ax.set_xlim(total_score_min,total_score_max)
    ax.set_ylim(interface_score_min,interface_score_max)
    ax.set_xlabel('Total Score')
    ax.set_ylabel('Interface Score')

def plot_combined_score_v_index(ax, combined_scores, all_scores_df):
    
    combined_scores = pd.Series(combined_scores)
    moving_avg = combined_scores.rolling(window=20).mean()
    ax.scatter(all_scores_df['index'], combined_scores, c='lightgrey', s=5) 
    ax.axhline(HIGHSCORE, color='b')
    ax.axhline(NEG_BEST, color='r')
    PARENTS = [i for i in os.listdir(f'{FOLDER_HOME}/{FOLDER_PARENT}') if i[-4:] == ".pdb"]
    ax.axvline(N_PARENT_JOBS*len(PARENTS), color='k')
    ax.plot(range(len(moving_avg)),moving_avg,c="k")
    ax.set_ylim(0,1)
    ax.set_xlim(0,MAX_DESIGNS)
    ax.set_title('Score vs Index')
    ax.set_xlabel('Index')
    ax.set_ylabel('Combined Score')    
    
def plot_combined_score_v_generation(ax, combined_scores, all_scores_df):
    
    all_scores_df['tmp'] = combined_scores
    all_scores_df = all_scores_df.dropna(subset=['tmp'])
    
    max_gen = int(all_scores_df['generation'].max())
    boxplot_data = [all_scores_df[all_scores_df['generation'] == gen]['tmp'] for gen in range(0,max_gen+1,1)]
    ax.boxplot(boxplot_data, positions=range(len(boxplot_data)))
    ax.axhline(HIGHSCORE, color='b')
    ax.axhline(NEG_BEST, color='r')
    ax.set_xticks(range(len(boxplot_data)))
    ax.set_xticklabels(range(0,len(boxplot_data),1))
    ax.set_ylim(0,1)
    ax.set_title('Combined Score vs Generations')
    ax.set_xlabel('Generation')
    ax.set_ylabel('Combined Score')
    ax.yaxis.grid(True, linestyle='--', which='major', color='grey', alpha=0.7)
    
def plot_mutations_v_generation(ax, all_scores_df,  mut_min, mut_max):
    
    all_scores_df = all_scores_df.dropna(subset=['mutations'])
    
    max_gen = int(all_scores_df['generation'].max())
    boxplot_data = [all_scores_df[all_scores_df['generation'] == gen]['mutations'] for gen in range(0,max_gen+1,1)]
    ax.boxplot(boxplot_data, positions=range(len(boxplot_data)))
    ax.axhline(len(DESIGN.split(",")), color='r')
    ax.set_xticks(range(len(boxplot_data)))
    ax.set_xticklabels(range(0,len(boxplot_data),1))
    ax.set_ylim(mut_min,mut_max)
    ax.set_title('Mutations vs Generations')
    ax.set_xlabel('Generation')
    ax.set_ylabel('Number of Mutations')
    ax.yaxis.grid(True, linestyle='--', which='major', color='grey', alpha=0.7)
    
def plot_tree(ax, combined_scores, all_scores_df, G):
    
    all_scores_df = pd.read_csv(ALL_SCORES_CSV)
    _, _, _, combined_potentials = normalize_scores(all_scores_df, print_norm=False, norm_all=False, extension="potential")

    max_gen = int(all_scores_df['generation'].max())
    
    def set_node_positions(G, node, pos, x, y, counts):
        pos[node] = (x, y)
        neighbors = list(G.successors(node))
        next_y = y - counts[node] / 2
        for neighbor in neighbors:
            set_node_positions(G, neighbor, pos, x + 1, next_y + counts[neighbor] / 2, counts)
            next_y += counts[neighbor]
            
    def count_descendants(G, node, counts):
        neighbors = list(G.successors(node))
        count = 1
        for neighbor in neighbors:
            count += count_descendants(G, neighbor, counts)
        counts[node] = count
        return count

    counts = {}
    count_descendants(G, 0, counts)

    pos = {}
    set_node_positions(G, 0, pos, 0, 0, counts)
    y_values = [y for x, y in pos.values()]
    y_span = max(y_values) - min(y_values)  
    
    colors = combined_potentials
    colors[0] = np.nan
    normed_colors = [(x - np.nanmin(colors[1:])) / (np.nanmax(colors[1:]) - np.nanmin(colors[1:])) for x in colors]
    normed_colors = np.nan_to_num(normed_colors, nan=0)
    normed_colors = normed_colors**2

    # Draw the graph with the positions set
    #nx.draw(G, pos, ax=ax, sld', arrows=False, cmap=plt.cm.coolwarm_r)
    for start, end in G.edges():
        color = plt.cm.coolwarm_r(normed_colors[end])
        if float(normed_colors[end]) == 0.0: color = [0., 0., 0., 1.]
        linewidth = 0.1+2*normed_colors[end]
        
        x0, y0 = pos[start]
        x1, y1 = pos[end]
        ax.plot([y0, y1], [x0, x1], color=color, linewidth=linewidth)

    # Adjust axis labels and ticks for the swapped axes
    ax.axis('on')
    ax.set_title("Colored by Potential")
    ax.set_xlabel("Variants")
    ax.set_ylabel("Generations")
    ax.set_yticks(range(max_gen+1))
    ax.set_ylim(0,max_gen+0.25)
    ax.set_yticklabels(range(max_gen+1))

    ax.tick_params(axis='y', which='both', bottom=True, top=False, left=True, right=False,
                   labelbottom=True, labelleft=True)

## Functions needed to run AIzyme algorithm

In [None]:
print("AIzyme Functions loaded!")
time.sleep(0.5)