In [None]:
# Import libraries
import pandas as pd
import numpy as np
import os
import sys
import shutil
import re
import subprocess
import pickle as pkl
import matplotlib.pyplot as plt
from matplotlib.axes._axes import _log as matplotlib_axes_logger
from Bio import SeqIO
import warnings
import time
import logging
from IPython.display import display, clear_output
from Bio.PDB import PDBParser
import statistics
import math
from datetime import datetime
import networkx as nx
from scipy.stats import pearsonr
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.gridspec import GridSpec
import json
import random 

# Setting initial options
warnings.filterwarnings('ignore')
pd.options.mode.chained_assignment = None 
matplotlib_axes_logger.setLevel('INFO')

DESIGN_COUNT = {}
FOLDER_HOME = f'{os.getcwd()}/{DESIGN_FOLDER}'
os.makedirs(FOLDER_HOME, exist_ok=True)
FOLDER_INPUT = f'{os.getcwd()}/Input'
if not os.path.isdir(FOLDER_INPUT): print("ERROR! Input folder missing!")
LOG_FILE = f'{FOLDER_HOME}.log'
ALL_SCORES_CSV = f'{FOLDER_HOME}/all_scores.csv'
BLOCKED_DAT    = f'{FOLDER_HOME}/blocked.dat'
VARIABLES_JSON  = f'{FOLDER_HOME}/variables.json'

# Configure logging file
log_format = '%(asctime)s - %(levelname)s - %(message)s'
date_format = '%Y-%m-%d %H:%M:%S'

# Remove all handlers associated with the root logger
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

# Basic configuration for logging to a file
logging.basicConfig(filename=LOG_FILE, level=logging.DEBUG, format=log_format, datefmt=date_format)

# Create a StreamHandler for console output
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(logging.Formatter(log_format, datefmt=date_format))

# Add the console handler to the root logger
logging.getLogger().addHandler(console_handler)

# main functions - running

In [None]:
def controller(RESET=False, EXPLORE=EXPLORE, UNBLOCK_ALL=False, PRINT_VAR=True, PLOT_DATA=True, BLUEPEBBLE=False, GRID=True):
    # Main AI.zymes functions. Controls the whole design process

    # Startup, will only be executed once in the beginning
    setup_aizymes(RESET) 
    # Check if Startup is done, if done, read in all_scores_df
    all_scores_df, blocked_df = startup_controller(UNBLOCK_ALL, PRINT_VAR=PRINT_VAR, PLOT_DATA=PLOT_DATA)
    
    while not os.path.exists(os.path.join(FOLDER_HOME, str(MAX_DESIGNS))):

        # Check how many jobs are currently running
        num_running_jobs = check_running_jobs()
        
        if num_running_jobs >= MAX_JOBS: 
            
            time.sleep(60)
            
        else:
            
            # Update scores
            all_scores_df, blocked_df = update_scores(all_scores_df, blocked_df)
            
            # Boltzmann Selection
            selected_index = boltzmann_selection(all_scores_df, blocked_df)
            
            # Decide Fate of selected index
            all_scores_df, blocked_df = start_calculation(all_scores_df, blocked_df, selected_index)
    
    print(f"Stopped because {os.path.join(FOLDER_HOME, str(MAX_DESIGNS))} exists.")

def check_running_jobs():
    
    if GRID:
        jobs = subprocess.check_output(["qstat", "-u", USERNAME]).decode("utf-8").split("\n")
    if BLUEPEBBLE:
        jobs = subprocess.check_output(["sacct"]).decode("utf-8").split("\n")
        jobs = [job for job in jobs if "RUNNING" in job]
    jobs = [job for job in jobs if SUBMIT_PREFIX in job]
    return len(jobs)

def update_scores(all_scores_df, blocked_df):
    # Update total_score, interface_score, and score
    
    for index, row in all_scores_df.iterrows():

        index = row['index']
        parent_index = row['parent_index']
        
        # do NOT update score if score was taken from a relax file
        score_file_path = f"{FOLDER_HOME}/{int(index)}/score_rosetta_relax.sc"
        if row['score_taken_from'] == 'Relax': continue

        # change scorefile path if run is a RosettaDesign and if score_rosetta_relax.sc does not exist
        if row['design_method'] == "RosettaDesign":
            if not os.path.exists(score_file_path):
                score_file_path = f"{FOLDER_HOME}/{int(index)}/score_rosetta_design.sc"
                # do NOT update score if score was taken from a design file
                if row['score_taken_from'] == 'Design': continue

        if os.path.exists(score_file_path):

            with open(score_file_path, "r") as f: scores = f.readlines()

            if len(scores) > 2: # if the timing is bad, the score file is not fully written. Check if len(scores) > 2!

                headers = scores[1].split()
                scores  = scores[2].split()

                catalytic_score = 0.0
                for idx_headers, header in enumerate(headers):
                    if header == 'total_score':
                        all_scores_df.at[index, 'total_score'] = float(scores[idx_headers])
                    if header == 'interface_delta_X':
                        all_scores_df.at[index, 'interface_score'] = float(scores[idx_headers])
                    if header in ['angle_constraint', 'atom_pair_constraint', 'dihedral_constraint']:
                        catalytic_score += float(scores[idx_headers])
                all_scores_df.at[index, 'catalytic_score'] = catalytic_score

                if score_file_path == f"{FOLDER_HOME}/{int(index)}/score_rosetta_relax.sc":
                    all_scores_df.at[index, 'score_taken_from'] = 'Relax'
                if score_file_path == f"{FOLDER_HOME}/{int(index)}/score_rosetta_design.sc":
                    all_scores_df.at[index, 'score_taken_from'] = 'Design'

                logging.info(f"Updated total_score, interface_delta_X, and catalytic_score of index {int(index)}.")
        
        #unblock index if relaxed file exists
        if index in blocked_df:
            if f"Rosetta_Relax_{int(index)}.pdb" in os.listdir(os.path.join(FOLDER_HOME, str(int(index)))):
                logging.debug(f"Unblocked index {int(index)}.")
                blocked_df = [i for i in blocked_df if i != index]
                if blocked_df == []:
                    np.savetxt(BLOCKED_DAT, np.array([], dtype=int), fmt='%d')
                else:
                    np.savetxt(BLOCKED_DAT, blocked_df)
                    
    all_scores_df.to_csv(ALL_SCORES_CSV, index=False)
    return all_scores_df, blocked_df 

def normalize_scores(all_scores_df):
    
    def neg_norm_array(array, array_np):
        if len(array) > 1:
            array    = -array
            array_np = -array_np
            array = (array-np.nanmin(array_np))/(np.nanmax(array_np)-np.nanmin(array_np))
            array[array < 0] = 0.0
            return array
        else:
            return array
      
    all_scores_df_np    = all_scores_df[all_scores_df['parent_index'] != "None"]
    
    catalytic_scores    = np.array(   all_scores_df["catalytic_score"].tolist())
    catalytic_scores_np = np.array(all_scores_df_np["catalytic_score"].tolist())
    catalytic_scores    = neg_norm_array(catalytic_scores, catalytic_scores_np)   
    
    total_scores        = np.array(   all_scores_df["total_score"].tolist()) 
    total_scores_np     = np.array(all_scores_df_np["total_score"].tolist()) 
    total_scores        = neg_norm_array(total_scores, total_scores_np)   
    
    interface_scores    = np.array(   all_scores_df["interface_score"].tolist())
    interface_scores_np = np.array(all_scores_df_np["interface_score"].tolist())
    interface_scores    = neg_norm_array(interface_scores, interface_scores_np)  
    
    combined_scores   = ( catalytic_scores + total_scores + interface_scores ) / 3  
    
    return catalytic_scores, total_scores, interface_scores, combined_scores
        
def boltzmann_selection(all_scores_df, blocked_df):

    number_of_indices = len(all_scores_df)
    
    # Drop rows where 'total_score' is NaN
    all_scores_df  = all_scores_df.dropna(subset=['total_score'])
    all_scores_df  = all_scores_df[all_scores_df['parent_index'] != "None"]
     
    all_scores_df  = all_scores_df[~all_scores_df['index'].isin(blocked_df)]
        
    catalytic_scores, total_scores, interface_scores, combined_scores = normalize_scores(all_scores_df)
        
    original_index = np.array(all_scores_df["index"].tolist())
  
    if isinstance(KBT_BOLTZMANN, (float, int)):
        kbt_boltzmann = KBT_BOLTZMANN
    else:
        if len(KBT_BOLTZMANN) == 2:
            kbt_boltzmann = KBT_BOLTZMANN[0] * 10**(-KBT_BOLTZMANN[1]*all_scores_df['index'].max())
    
    boltzmann_factors = np.exp(combined_scores / (kbt_boltzmann)) 
    probabilities = boltzmann_factors / sum(boltzmann_factors)
    
    if original_index != []:
        selected_index = int(np.random.choice(original_index, p=probabilities))
    else:
        selected_index = 1
        
    if number_of_indices < N_PARENT_JOBS: selected_index = 0
        
    return selected_index
         
# Decides what to do with selected index
def start_calculation(all_scores_df, blocked_df, selected_index):
    logging.debug(f"Starting new calculation for index {selected_index}.")

    blocked = False
    if selected_index in blocked_df: blocked = True
        
    releaxed = False
    if f"Rosetta_Relax_{selected_index}.pdb" in os.listdir(os.path.join(FOLDER_HOME, str(selected_index))): releaxed = True
    
    # Check if ESMfold_Rosetta_Relax is done
    if releaxed:
        
        # ESMfold_Rosetta_Relax is done, create a new index
        new_index, all_scores_df = create_new_index(parent_index=selected_index, all_scores_df=all_scores_df)
        
        #####
        # Here, we can add an AI to decide on the next steps
        #####
        
        # Run Rosetta_Design with new_index
        if random.random() < ProteinMPNN_PROB:  
            all_scores_df.at[new_index, 'design_method'] = "ProteinMPNN"
            run_ProteinMPNN(parent_index=selected_index, new_index=new_index) 
        else:                    
            all_scores_df.at[new_index, 'design_method'] = "RosettaDesign"
            run_RosettaDesign(parent_index=selected_index, new_index=new_index) 
            
        all_scores_df.to_csv(ALL_SCORES_CSV, index=False)   
        
    else:
        # ESMfold_Rosetta_Relax is not done, check if Index is blocked
        if blocked:
            # Blocked --> ESMfold_Rosetta_Relax is still running, do not do antyting. This shouldn't happen!
            logging.error(f"Index {selected_index} is being worked on. Skipping index.")
            logging.error(f"Note: This should not happen! Check blocking and Boltzman selection.")
        else:
            # Not blocked --> submit ESMfold_Rosetta_Relax and block index
            logging.info(f"Index {selected_index} has no relaxed structure, starting ESMfold_Rosetta_Relax.")
            parent_index = int(float(all_scores_df['parent_index'][selected_index]))
            run_ESMfold_RosettaRelax(index=selected_index, RosettaDesign=True)
            # Block index
            blocked_df += [selected_index]
            if blocked_df == []:
                np.savetxt(BLOCKED_DAT, np.array([], dtype=int), fmt='%d')
            else:
                np.savetxt(BLOCKED_DAT, blocked_df)
              
    return all_scores_df, blocked_df  

def create_new_index(parent_index, all_scores_df):
    
    # Get the latest index
    latest_index = all_scores_df['index'].max()

    # Create a new line with the next index and parent_index
    new_index = int(latest_index + 1)
    
    # Append the new line to the DataFrame and save to  all_scores_df.csv
    if isinstance(KBT_BOLTZMANN, (float, int)):
        kbt_boltzmann = KBT_BOLTZMANN
    else:
        if len(KBT_BOLTZMANN) == 2:
            kbt_boltzmann = KBT_BOLTZMANN[0] * 10 ** (- KBT_BOLTZMANN[1] * new_index)
    all_scores_df = all_scores_df.append({'index': new_index, 
                                          'parent_index': parent_index,
                                          'kbt_boltzmann': kbt_boltzmann}, ignore_index=True)
    all_scores_df.to_csv(ALL_SCORES_CSV, index=False)

    # Create the folders for the new index
    os.makedirs(f"{FOLDER_HOME}/{new_index}/scripts", exist_ok=True)
           
    logging.debug(f"Child index {new_index} created for {parent_index}.")
    return new_index, all_scores_df

# main functions - design

In [None]:
def run_ESMfold_RosettaRelax(index, RosettaDesign=False, ProteinMPNN=False, StartupRelax=False,
                             ProteinMPNN_parent_index=0, cmd="", bash=False):
    
    os.makedirs(f"{FOLDER_HOME}/{index}/ESMfold", exist_ok=True)
    os.makedirs(f"{FOLDER_HOME}/{index}/scripts", exist_ok=True)
        
    ex = "-ex1 -ex2"
    if EXPLORE: ex = ""
        
    # Get Name of parent PDB
    if StartupRelax:
        RosettaDesign = True ### Run these options!
    if RosettaDesign: 
        PDBFile = f"{FOLDER_HOME}/{index}/Rosetta_Design_{index}.pdb"
    if ProteinMPNN:
        PDBFile = f"{FOLDER_HOME}/{ProteinMPNN_parent_index}/Rosetta_Relax_{ProteinMPNN_parent_index}.pdb"
    if not os.path.isfile(PDBFile):
        logging.error(f"{PDBFile} not present!")
        return
    
    #Get updated RemarkLine
    with open(PDBFile, 'r') as f: PDB = f.readlines()
    for line in PDB:
        if line.startswith("ATOM") or line.startswith("HETATM"):
            if int(line[22:26].strip()) == 99:
                residue_99_code = line[17:20].strip()
                break
    amino_acids = ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLU', 'GLN', 'GLY', 'HIS', 'ILE', 
                   'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL']
    for RES in amino_acids:
        if RES in REMARK: header = REMARK.replace(RES, residue_99_code)
    
    # Make sequence file
    if RosettaDesign: 
        seq = extract_sequence_from_pdb(PDBFile)
        with open(f"{FOLDER_HOME}/{index}/ESMfold/{PARENT}_{index}.seq","w") as f: f.write(seq)
                    
    # Get the pdb file from the last step and strip away ligand and hydrogens 
    cpptraj = f'''parm    {PDBFile}
trajin  {PDBFile}
strip   :{LIGAND}
strip   @H=
trajout {FOLDER_HOME}/{index}/ESMfold/CPPTraj_Apo_{index}.pdb
'''
    with open(f'{FOLDER_HOME}/{index}/ESMfold/CPPTraj_Apo_{index}.in','w') as f: f.write(cpptraj)

    # Get the pdb file from the last step and strip away everything except the ligand
    cpptraj = f'''parm    {PDBFile}
trajin  {PDBFile}
strip   !:{LIGAND}
trajout {FOLDER_HOME}/{index}/ESMfold/CPPTraj_Lig_{index}.pdb
'''
    with open(f'{FOLDER_HOME}/{index}/ESMfold/CPPTraj_Lig_{index}.in','w') as f: f.write(cpptraj)

    # Get the ESMfold pdb file and strip away all hydrogens
    cpptraj = f'''parm    {FOLDER_HOME}/{index}/ESMfold/ESMfold_output_{index}.pdb
trajin  {FOLDER_HOME}/{index}/ESMfold/ESMfold_output_{index}.pdb
strip   @H=
trajout {FOLDER_HOME}/{index}/ESMfold/ESMfold_no_hydrogens_{index}.pdb
'''
    with open(f'{FOLDER_HOME}/{index}/ESMfold/CPPTraj_no_hydrogens_{index}.in','w') as f: f.write(cpptraj)

    # Align substrate and ESM prediction of scaffold without hydrogens
    cpptraj = f'''parm    {FOLDER_HOME}/{index}/ESMfold/ESMfold_no_hydrogens_{index}.pdb
reference {FOLDER_HOME}/{index}/ESMfold/CPPTraj_Apo_{index}.pdb [apo]
trajin    {FOLDER_HOME}/{index}/ESMfold/ESMfold_no_hydrogens_{index}.pdb
rmsd      @CA ref [apo]
trajout   {FOLDER_HOME}/{index}/ESMfold/ESMfold_aligned_{index}.pdb noter
'''
    with open(f'{FOLDER_HOME}/{index}/ESMfold/CPPTraj_aligned_{index}.in','w') as f: f.write(cpptraj) 
    output_file = f'{FOLDER_HOME}/{index}/ESMfold/ESMfold_output_{index}.pdb'      
    
    # Giving the ESMfold algorihm the needed inputs
    sequence_file = f'{FOLDER_HOME}/{index}/ESMfold/{PARENT}_{index}.seq'
    cmd += f"""
    
python {FOLDER_HOME}/ESMfold.py {output_file} {sequence_file}

sed -i '/PARENT N\/A/d' {FOLDER_HOME}/{index}/ESMfold/ESMfold_output_{index}.pdb
cpptraj -i {FOLDER_HOME}/{index}/ESMfold/CPPTraj_Apo_{index}.in           &> \
           {FOLDER_HOME}/{index}/ESMfold/CPPTraj_Apo_{index}.out
cpptraj -i {FOLDER_HOME}/{index}/ESMfold/CPPTraj_Lig_{index}.in           &> \
           {FOLDER_HOME}/{index}/ESMfold/CPPTraj_Lig_{index}.out
cpptraj -i {FOLDER_HOME}/{index}/ESMfold/CPPTraj_no_hydrogens_{index}.in  &> \
           {FOLDER_HOME}/{index}/ESMfold/CPPTraj_no_hydrogens_{index}.out
cpptraj -i {FOLDER_HOME}/{index}/ESMfold/CPPTraj_aligned_{index}.in       &> \
           {FOLDER_HOME}/{index}/ESMfold/CPPTraj_aligned_{index}.out

sed -i '/END/d' {FOLDER_HOME}/{index}/ESMfold/ESMfold_aligned_{index}.pdb
grep '^REMARK' {PDBFile} > {FOLDER_HOME}/remark.txt

cat {FOLDER_HOME}/remark.txt \
    {FOLDER_HOME}/{index}/ESMfold/ESMfold_aligned_{index}.pdb \
    {FOLDER_HOME}/{index}/ESMfold/CPPTraj_Lig_{index}.pdb > {FOLDER_HOME}/{index}/ESMfold_{index}.pdb
sed -i '/TER/d' {FOLDER_HOME}/{index}/ESMfold_{index}.pdb

# Run Rosetta Relax
{ROSETTA_PATH}/bin/rosetta_scripts.linuxgccrelease \
                -s                                        {FOLDER_HOME}/{index}/ESMfold_{index}.pdb \
                -extra_res_fa                             {FOLDER_INPUT}/{LIGAND}.params \
                -parser:protocol                          {FOLDER_HOME}/Rosetta_Relax.xml \
                -out:file:scorefile                       {FOLDER_HOME}/{index}/score_rosetta_relax.sc \
                -nstruct                                  1 \
                -ignore_zero_occupancy                    false \
                -corrections::beta_nov16                  true \
                -run:preserve_header  \
                -overwrite {ex}

# Rename the output file
mv ESMfold_{index}_0001.pdb Rosetta_Relax_{index}.pdb

# Clean output file
sed -i '/^\(HET\|ATO\|TER\|REM\)/!d' Rosetta_Relax_{index}.pdb
"""
    
    if RosettaDesign: 
        with open(f'{FOLDER_HOME}/{index}/scripts/ESMfold_Rosetta_Relax_{index}.sh', 'w') as file:
            file.write(cmd)
        logging.info(f"Run ESMfold & Rosetta_Relax for index {index}.")
        submit_job(index=index, job="ESMfold_Rosetta_Relax", bash=bash)
        
    if ProteinMPNN:
        with open(f'{FOLDER_HOME}/{index}/scripts/ProteinMPNN_ESMfold_Rosetta_Relax_{index}.sh', 'w') as file:
            file.write(cmd)
        logging.info(f"Run ProteinMPNN for index {index} based on index {ProteinMPNN_parent_index}.")
        submit_job(index=index, job="ProteinMPNN_ESMfold_Rosetta_Relax", bash=bash)
        
def run_RosettaDesign(parent_index, new_index):
        
    ex = "-ex1 -ex2"
    if EXPLORE: ex = ""

    cmd = f"""{ROSETTA_PATH}/bin/rosetta_scripts.linuxgccrelease \
    -s                                        {FOLDER_HOME}/{parent_index}/Rosetta_Relax_{parent_index}.pdb \
    -in:file:native                           {FOLDER_HOME}/{parent_index}/Rosetta_Relax_{parent_index}.pdb \
    -run:preserve_header                      true \
    -extra_res_fa                             {FOLDER_INPUT}/{LIGAND}.params \
    -parser:protocol                          {FOLDER_HOME}/Rosetta_Design.xml \
    -out:file:scorefile                       {FOLDER_HOME}/{new_index}/score_rosetta_design.sc \
    -nstruct                                  1  \
    -ignore_zero_occupancy                    false  \
    -corrections::beta_nov16                  true \
    -run:preserve_header  \
    -overwrite {ex}
            
# Rename the output file
mv Rosetta_Relax_{parent_index}_0001.pdb Rosetta_Design_{new_index}.pdb
"""
    # Write the shell command to a file
    with open(f'{FOLDER_HOME}/{new_index}/scripts/Rosetta_Design_{new_index}.sh','w') as file: file.write(cmd)
                
    # Submit the job using the submit_job function
    logging.info(f"Run RosettaDesign for index {new_index} based on index {parent_index}.")
    submit_job(index=new_index, job="Rosetta_Design")

def run_ProteinMPNN(parent_index, new_index, bash=False):

    #Throw error if ProteinMPNN not cloned
    if not os.path.exists(f'{FOLDER_HOME}/../ProteinMPNN'):
        logging.error(f"{ProteinMPNN} not installed in {FOLDER_HOME}/../ProteinMPNN.")
        logging.error(f"Install using: git clone https://github.com/dauparas/ProteinMPNN.git")
        return
    
    # Make the fasta file for the index variant 
    PDBFile = f"{FOLDER_HOME}/{parent_index}/Rosetta_Relax_{parent_index}.pdb"

    #Throw error if design not present!
    if not os.path.isfile(PDBFile):
        logging.error(f"{PDBFile} not present!")
        return

    os.makedirs(f"{FOLDER_HOME}/{new_index}/ProteinMPNN", exist_ok=True)
    os.makedirs(f"{FOLDER_HOME}/{new_index}/ESMfold", exist_ok=True)
    shutil.copy(PDBFile, f"{FOLDER_HOME}/{new_index}/ProteinMPNN/Rosetta_Relax_{parent_index}.pdb")
    seq = extract_sequence_from_pdb(PDBFile)
    with open(f"{FOLDER_HOME}/{new_index}/ProteinMPNN/Rosetta_Relax_{parent_index}.seq","w") as f: f.write(seq)
          
    cmd = f'''
# Run ProteinMPNN

python {FOLDER_HOME}/../ProteinMPNN/helper_scripts/parse_multiple_chains.py \
    --input_path={FOLDER_HOME}/{new_index}/ProteinMPNN/ \
    --output_path={FOLDER_HOME}/{new_index}/ProteinMPNN/parsed_chains.json1

python {FOLDER_HOME}/../ProteinMPNN/helper_scripts/assign_fixed_chains.py \
    --input_path={FOLDER_HOME}/{new_index}/ProteinMPNN/parsed_chains.json1 \
    --output_path={FOLDER_HOME}/{new_index}/ProteinMPNN/assigned_chains.json1 \
    --chain_list 'A'

python {FOLDER_HOME}/../ProteinMPNN/helper_scripts/make_fixed_positions_dict.py \
    --input_path={FOLDER_HOME}/{new_index}/ProteinMPNN/parsed_chains.json1 \
    --output_path={FOLDER_HOME}/{new_index}/ProteinMPNN/fixe_positions.json1 \
    --chain_list 'A' \
    --position_list '{" ".join(DESIGN.split(","))}'

python {FOLDER_HOME}/../ProteinMPNN/protein_mpnn_run.py \
    --jsonl_path            {FOLDER_HOME}/{new_index}/ProteinMPNN/parsed_chains.json1 \
    --chain_id_jsonl        {FOLDER_HOME}/{new_index}/ProteinMPNN/assigned_chains.json1\
    --fixed_positions_jsonl {FOLDER_HOME}/{new_index}/ProteinMPNN/fixe_positions.json1 \
    --out_folder            {FOLDER_HOME}/{new_index}/ProteinMPNN/ \
    --num_seq_per_target    100 \
    --sampling_temp         "{ProteinMPNN_T}" \
    --seed                  37 \
    --batch_size            1
       
# Get highest scoreing sequence
file_path='{FOLDER_HOME}/{new_index}/ProteinMPNN/seqs/Rosetta_Relax_{parent_index}.fa' 
parent_seq_file='{FOLDER_HOME}/{new_index}/ProteinMPNN/Rosetta_Relax_{parent_index}.seq'
input_sequence='{FOLDER_HOME}/input_sequence_with_X_as_wildecard.seq'
highest_score=0
highest_scoring_sequence=''
read -r parent_sequence < "$parent_seq_file"

# loop through scores and find greatest score that is not the same as the parent sequence
while read -r line; do
    if [[ $line == ">"* ]]; then
        score=$(echo $line | grep -oP 'global_score=\K[\d.]+')
        read -r sequence
        
        # Check if score is higher than the highest score
        if (( $(echo "$score > $highest_score" | bc -l) )); then
        
            # Check if sequence is different from parent_sequence
            if [ "$sequence" != "$parent_sequence" ]; then
            
                # Check if sequence does not match input_sequence pattern
                pattern=$(echo "$input_sequence" | sed 's/X/./g')  # Replace 'X' with wildecard '.'
                if [[ ! "$sequence" =~ $pattern ]]; then
                
                    highest_score=$score
                    highest_scoring_sequence=$sequence
                    
                fi
            fi
        fi
    fi
done < "$file_path"

# Save highest scoring sequence
echo $highest_scoring_sequence > {FOLDER_HOME}/{new_index}/ProteinMPNN/{PARENT}_{new_index}.seq
echo $highest_scoring_sequence > {FOLDER_HOME}/{new_index}/ESMfold/{PARENT}_{new_index}.seq
''' 
    run_ESMfold_RosettaRelax(index=new_index, RosettaDesign=False, \
                             ProteinMPNN=True, ProteinMPNN_parent_index=parent_index, cmd=cmd, bash=bash)  

# main functions - startup

In [None]:
def startup_controller(UNBLOCK_ALL, PRINT_VAR=True, PLOT_DATA=True):

    prepare_input_files()
    
    if PRINT_VAR:
        if os.path.isfile(VARIABLES_JSON):
            with open(VARIABLES_JSON, 'r') as f: globals_dict = json.load(f)
            if globals_dict['DESIGN_FOLDER'] == DESIGN_FOLDER:
                for k, v in globals_dict.items():
                    globals()[k] = v
                    if k == "REMARK": v = v[:-1]
                    print(k.ljust(16), ':', v)
            else:
                print("WRONG DESIGN FOLDER!")
                sys.exit()
            
    if PLOT_DATA:
        plot_scores()
        
    #Unblock all (use if all ESMfold_Relax jobs had to be killed.)
    if UNBLOCK_ALL: np.savetxt(BLOCKED_DAT, np.array([], dtype=int), fmt='%d')
    
    #Check if Rosetta_Relax_0 is ready
    while not f"Rosetta_Relax_0.pdb" in os.listdir(os.path.join(FOLDER_HOME, str(0))):
        print(f'Initial relax running. Rosetta_Relax_0.pdb not in ./{FOLDER_HOME.split("/")[-1]}/{0}')
        time.sleep(60) 

    all_scores_df = pd.read_csv(ALL_SCORES_CSV)
    all_scores_df['sequence'] = all_scores_df['sequence'].astype(str)
    all_scores_df['design_method'] = all_scores_df['design_method'].astype(str)
    all_scores_df['score_taken_from'] = all_scores_df['score_taken_from'].astype(str)

    blocked_df = np.loadtxt(BLOCKED_DAT)
    
    all_scores_df, blocked_df = update_scores(all_scores_df, blocked_df)
           
    return all_scores_df, blocked_df

def prepare_input_files():
                
    # Write the REEMARK to a textfile
    with open(f"{FOLDER_HOME}/remark.txt", "w") as f: f.write(REMARK)
        
    # Create the ESMfold.py script
    ESMfold_python_script = """import sys
from transformers import AutoTokenizer, EsmForProteinFolding, EsmConfig
import torch
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37

output_file = sys.argv[1]
sequence_file = sys.argv[2]

with open(sequence_file) as f: sequence=f.read()

def convert_outputs_to_pdb(outputs):
    final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
    outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
    final_atom_positions = final_atom_positions.cpu().numpy()
    final_atom_mask = outputs["atom37_atom_exists"]
    pdbs = []
    for i in range(outputs["aatype"].shape[0]):
        aa = outputs["aatype"][i]
        pred_pos = final_atom_positions[i]
        mask = final_atom_mask[i]
        resid = outputs["residue_index"][i] + 1
        pred = OFProtein(
            aatype=aa,
            atom_positions=pred_pos,
            atom_mask=mask,
            residue_index=resid,
            b_factors=outputs["plddt"][i],
            chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
        )
        pdbs.append(to_pdb(pred))
    return pdbs

tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)
torch.backends.cuda.matmul.allow_tf32 = True
tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids']
with torch.no_grad(): output = model(tokenized_input)
pdb = convert_outputs_to_pdb(output)
with open(output_file, "w") as f: f.write("".join(pdb))
"""

    # Write the ESMfold.py to a file
    with open(f"{FOLDER_HOME}/ESMfold.py", "w") as f: 
        f.write(ESMfold_python_script)

    # Create the Rosetta_Relax.xml file
    repeats = "3"
    if EXPLORE: repeats = "1"
    Rosetta_Relax_xml = f"""
<ROSETTASCRIPTS>

    <SCOREFXNS>
    
        <ScoreFunction name      = "score"                   weights = "beta_nov16" >
            <Reweight scoretype  = "atom_pair_constraint"    weight  = "1" />
            <Reweight scoretype  = "dihedral_constraint"     weight  = "1" />
            <Reweight scoretype  = "angle_constraint"        weight  = "1" />
        </ScoreFunction> 
        
        <ScoreFunction name      = "score_final"             weights = "beta_nov16" >
            <Reweight scoretype  = "atom_pair_constraint"    weight  = "1" />
            <Reweight scoretype  = "dihedral_constraint"     weight  = "1" />
            <Reweight scoretype  = "angle_constraint"        weight  = "1" />
        </ScoreFunction>
        
    </SCOREFXNS>
       
    <MOVERS>

        <AddOrRemoveMatchCsts     name="mv_add_cst" 
                                  cst_instruction="add_new" 
                                  cstfile="{FOLDER_INPUT}/{LIGAND}_enzdes.cst" />
                                  
        <FastRelax  name="mv_relax" disable_design="false" repeats="{repeats}" 
                    ramp_down_constraints="false" scorefxn="score" >
        </FastRelax>
        
        <InterfaceScoreCalculator   name                   = "mv_inter" 
                                    chains                 = "X" 
                                    scorefxn               = "score_final" />
    </MOVERS>
    
    <PROTOCOLS>
        <Add mover_name="mv_add_cst" />
        <Add mover_name="mv_relax" />
        <Add mover_name="mv_inter" />
    </PROTOCOLS>
    
</ROSETTASCRIPTS>
"""
    # Write the Rosetta_Relax.xml to a file
    with open(f'{FOLDER_HOME}/Rosetta_Relax.xml', 'w') as f:
        f.writelines(Rosetta_Relax_xml)      

    # Create XML script for Rosetta Design  
    repeats = "2"
    if EXPLORE: repeats = "1"
    Rosetta_Design_xml = f"""
<ROSETTASCRIPTS>

    <SCOREFXNS>
    
        <ScoreFunction            name="score_hbnet"                     weights="beta_nov16" >        
            <Reweight             scoretype="hbnet"                      weight="1.0" />
            <Reweight             scoretype="buried_unsatisfied_penalty" weight="1" />
            <Reweight             scoretype="atom_pair_constraint"       weight="1" />
            <Reweight             scoretype="dihedral_constraint"        weight="1" />
            <Reweight             scoretype="angle_constraint"           weight="1" />      
            <Reweight             scoretype="res_type_constraint"        weight="1" />         
        </ScoreFunction>

        <ScoreFunction            name="score"                           weights="beta_nov16" >        
            <Reweight             scoretype="atom_pair_constraint"       weight="1" />
            <Reweight             scoretype="dihedral_constraint"        weight="1" />
            <Reweight             scoretype="angle_constraint"           weight="1" />      
            <Reweight             scoretype="res_type_constraint"        weight="1" />               
        </ScoreFunction>

        <ScoreFunction            name="score_final"                     weights="beta_nov16" >    
            <Reweight             scoretype="atom_pair_constraint"       weight="1" />
            <Reweight             scoretype="dihedral_constraint"        weight="1" />
            <Reweight             scoretype="angle_constraint"           weight="1" />                
        </ScoreFunction>
   
   </SCOREFXNS>
    
    <RESIDUE_SELECTORS>
    
        <Index                    name="sel_design" 
                                  resnums="{DESIGN}" />

        <Index                    name="sel_repack" 
                                  resnums="{REPACK}" />

        <Index                    name="sel_cat" 
                                  resnums="{RESTRICT.split(",")[0]}" />

        <Or                       name="sel_desrep" 
                                  selectors="sel_design,sel_repack" />

        <Not                      name="sel_nothing"
                                  selector="sel_desrep" />
    </RESIDUE_SELECTORS>
    
    <TASKOPERATIONS>
    
        <OperateOnResidueSubset   name="tsk_design"                      selector="sel_design" >
                                  <RestrictAbsentCanonicalAASRLT         aas="GPAVLIMFYWHKRQNEDST" />
        </OperateOnResidueSubset>
        
        <OperateOnResidueSubset   name="tsk_cat"                         selector="sel_cat" >
                                  <RestrictAbsentCanonicalAASRLT         aas="{RESTRICT.split(",")[1]}" />
        </OperateOnResidueSubset>
        
        <OperateOnResidueSubset   name="tsk_repack"                      selector="sel_repack" >
                                  <RestrictToRepackingRLT />
        </OperateOnResidueSubset>
        
        <OperateOnResidueSubset   name="tsk_nothing"                     selector="sel_nothing" >
                                  <PreventRepackingRLT />
        </OperateOnResidueSubset>
        
    </TASKOPERATIONS>

    <FILTERS>
    
        <HbondsToResidue          name="flt_hbonds" 
                                  scorefxn="score" 
                                  partners="1"
                                  residue="1X"
                                  backbone="true" 
                                  sidechain="true" 
                                  from_other_chains="true" 
                                  from_same_chain="false"
                                  confidence="0" />
    </FILTERS>
    
    <MOVERS>
        
        <FavorSequenceProfile     name="mv_native" 
                                  weight="{CST_WEIGHT}" 
                                  use_native="true" 
                                  matrix="IDENTITY" 
                                  scorefxns="score,score_hbnet" />  
                                
        <AddOrRemoveMatchCsts     name="mv_add_cst" 
                                  cst_instruction="add_new" 
                                  cstfile="{FOLDER_INPUT}/{LIGAND}_enzdes.cst" />

        <FastDesign               name="mv_design_hbnet" 
                                  disable_design="false" 
                                  task_operations="tsk_design,tsk_repack,tsk_nothing,tsk_cat" 
                                  repeats="{repeats}" 
                                  ramp_down_constraints="false" 
                                  scorefxn="score_hbnet" />
        
        <AddOrRemoveMatchCsts     name="mv_remove_cst" 
                                  cst_instruction="remove" 
                                  cstfile="{FOLDER_INPUT}/{LIGAND}_enzdes.cst" />
                              
        <FastDesign               name                   ="mv_design" 
                                  disable_design         ="false" 
                                  task_operations        ="tsk_design,tsk_repack,tsk_nothing,tsk_cat" 
                                  repeats                ="1" 
                                  ramp_down_constraints  = "false" 
                                  scorefxn               = "score" />                                 
            
        <InterfaceScoreCalculator name                   = "mv_inter" 
                                  chains                 = "X" 
                                  scorefxn               = "score_final" />
                                  
    </MOVERS>

    <PROTOCOLS>
        <Add mover_name="mv_native" />
        <Add mover_name="mv_add_cst" />
        <Add mover_name="mv_design_hbnet" />
        <Add mover_name="mv_remove_cst" />
        <Add mover_name="mv_design" />
        <Add mover_name="mv_add_cst" />
        <Add mover_name="mv_inter" />
    </PROTOCOLS>
    
</ROSETTASCRIPTS>

"""
    # Write the XML script to a file
    with open(f'{FOLDER_HOME}/Rosetta_Design.xml', 'w') as f:
        f.writelines(Rosetta_Design_xml)    
        
    # Save input sequence with X as wildcard
    seq = extract_sequence_from_pdb(f"{FOLDER_INPUT}/{PARENT}.pdb")
    design_positions = [int(x) for x in DESIGN.split(',')]
    # Replace seq with X at design positions. Note: Subtract 1 from each position to convert to Python's 0-based indexing
    seq = ''.join('X' if (i+1) in design_positions else amino_acid for i, amino_acid in enumerate(seq))
    with open(f'{FOLDER_HOME}/input_sequence_with_X_as_wildecard.seq', 'w') as f:
        f.writelines(seq)    
    
def setup_aizymes(RESET):
    # Check if setup needs to run
    if not os.path.isdir(FOLDER_HOME):
        if not input(f'''{FOLDER_HOME} not present.
Do you want to start AIzymes? [y/n]

''') == 'y':
            return #stops the start up. Although FOLDER_HOME is missing, user elected not to start AIzymes
    else:
        if RESET:
            if not input(f'''Do you really want to restart AIzymes from scratch? 
This will delete all existing files in {FOLDER_HOME} [y/n]

''') == 'y':
                return #stops the start up. Although folder exists and RESET set, user canceled
        else:
            return #stop startup. FOLDER_HOME exists and RESET not set by user

    with open(LOG_FILE, 'w'): pass  #resets logfile
    logging.info(f"Running AI.zymes setup.")
    logging.info(f"Content of {FOLDER_HOME} deleted.")
    logging.info(f"Happy AI.zymeing! :)")
   
    if os.path.exists(FOLDER_HOME):
        shutil.rmtree(FOLDER_HOME)
    os.makedirs(FOLDER_HOME, exist_ok=True)

    prepare_input_files()
        
    # If the file doesn't exist, we start with index 0
    all_scores_df = pd.DataFrame(columns=['index', 'sequence', 'parent_index', \
                                          'interface_score', 'total_score', 'catalytic_score', \
                                          'generation', 'mutations', 'design_method', 'score_taken_from'])
    
    # Append the new line to the DataFrame
    all_scores_df = all_scores_df.append({'index': 0, \
                                          'parent_index': "None", \
                                          'interface_score' : 0.0,\
                                          'total_score' : 0.0,\
                                          'catalytic_score' : 0.0,\
                                          'generation' : 0,\
                                          'mutations' : 0,\
                                          'sequence' : '',\
                                          'design_method' : "parent", \
                                          'score_taken_from' : ''}, ignore_index=True)
    all_scores_df['sequence'] = all_scores_df['sequence'].astype(str)
    all_scores_df['design_method'] = all_scores_df['design_method'].astype(str)
    all_scores_df['score_taken_from'] = all_scores_df['score_taken_from'].astype(str)
    all_scores_df.to_csv(ALL_SCORES_CSV, index=False)
    
    # create empty blocked.dat
    np.savetxt(BLOCKED_DAT, np.array([], dtype=int), fmt='%d')

    # Create the folders for the new index
    os.makedirs(f"{FOLDER_HOME}/0/ESMfold", exist_ok=True)
    os.makedirs(f"{FOLDER_HOME}/0/scripts", exist_ok=True)

    # Save important varliables
    variables_to_save = [
        'DESIGN_FOLDER', 'MAX_JOBS', 'N_PARENT_JOBS', 'MAX_DESIGNS', 'KBT_BOLTZMANN', 'CST_WEIGHT',
        'ProteinMPNN_PROB', 'PARENT', 'LIGAND', 'ROSETTA_PATH', 'REPACK', 'DESIGN', 
        'RESTRICT', 'SCORE_NAMES', 'REMARK', 'EXPLORE', 'ProteinMPNN_T', 'SUBMIT_PREFIX', 'BLUEPEBBLE', 'GRID'
    ]
    
    # Creating a dictionary of specific global variables
    globals_to_save = {k: globals()[k] for k in variables_to_save}
    with open(VARIABLES_JSON, 'w') as f: json.dump(globals_to_save, f, indent=4)
        
    # Copy PDB from INPUT, add REMARK header, and rename it to fit the algorithms logic
    pdb_file_path = f"{FOLDER_HOME}/0/Rosetta_Design_0.pdb" 
    shutil.copy(f"{FOLDER_INPUT}/{PARENT}.pdb", f"{FOLDER_HOME}/0/Rosetta_Design_0.pdb")
    with open(pdb_file_path, 'r') as file: original_content = file.readlines()
    modified_content = [REMARK] + original_content
    with open(pdb_file_path, 'w') as file: file.writelines(modified_content)
    run_ESMfold_RosettaRelax(0,StartupRelax=True) 

# helper functions

In [None]:
def submit_job(index, job, bash=False):        
      
    if GRID:
        submission_script = f"""#!/bin/bash
#$ -V
#$ -cwd
#$ -N {SUBMIT_PREFIX}_{job}_{index}
#$ -hard -l mf=16G
#$ -o {FOLDER_HOME}/{index}/scripts/AI_{job}_{index}.out
#$ -e {FOLDER_HOME}/{index}/scripts/AI_{job}_{index}.err
"""
    if BLUEPEBBLE:
        submission_script = f"""#!/bin/bash
#SBATCH --account=ptch000361
#SBATCH --partition=short
#SBATCH --mem=40GB
#SBATCH --ntasks-per-node=1
#SBATCH --time=2:00:00    
#SBATCH --nodes=1          
#SBATCH --job-name={SUBMIT_PREFIX}_{job}_{index}
#SBATCH --output=AI_{job}_{index}.out
#SBATCH --error=AI_{job}_{index}.err
"""
    submission_script += f"""
# Output folder
cd {FOLDER_HOME}/{index}
pwd

# Count the number of .pdb files in the current directory
file_count=$(find . -maxdepth 1 -type f -name "{job}_{index}.pdb" | wc -l)

# Check whether the number of .pdb files is equal to the input structure number and if so write step done in log
if [ "$file_count" = 0 ]; then
# Run the bash script
bash {FOLDER_HOME}/{index}/scripts/{job}_{index}.sh
fi
""" 
    
    # Create the submission_script
    with open(f'{FOLDER_HOME}/{index}/scripts/submit_{job}_{index}.sh', 'w') as file: file.write(submission_script)
    
    if bash:
        #Bash the submission_script for testing
        subprocess.run(f'bash {FOLDER_HOME}/{index}/scripts/submit_{job}_{index}.sh', shell=True, text=True)
    else:
        #Submit the submission_script
        if GRID:
            output = subprocess.check_output(f'qsub -l h="!bs-dsvr64&!bs-dsvr58" -q regular.q \
                                             {FOLDER_HOME}/{index}/scripts/submit_{job}_{index}.sh', \
                                             shell=True, text=True)
            logging.debug(output[:-1]) #remove newline at end of output
            
        if BLUEPEBBLE:
            output = subprocess.check_output(f'sbatch {FOLDER_HOME}/{index}/scripts/submit_{job}_{index}.sh', \
                                             shell=True, text=True)
            logging.debug(output[:-1]) #remove newline at end of output
        
def extract_sequence_from_pdb(pdb_path):
    with open(pdb_path, "r") as pdb_file:
        for record in SeqIO.parse(pdb_file, "pdb-atom"):
            seq = str(record.seq)
    return seq
    
def count_descendants(G, node, counts):
    neighbors = list(G.successors(node))
    count = 1
    for neighbor in neighbors:
        count += count_descendants(G, neighbor, counts)
    counts[node] = count
    return count

# plotting functions

In [None]:
def plot_scores(combined_score_min=0, combined_score_max=1, combined_score_bin=0.01,
                interface_score_min=0, interface_score_max=1, interface_score_bin=0.01,
                total_score_min=0, total_score_max=1, total_score_bin=0.01,
                catalytic_score_min=0, catalytic_score_max=1, catalytic_score_bin=0.01,
                mut_min=0,mut_max=len(DESIGN.split(","))+1):
        
    # Break because file does not exist
    if not os.path.isfile(ALL_SCORES_CSV): return
    
    all_scores_df = pd.read_csv(ALL_SCORES_CSV)
    all_scores_df['sequence'] = all_scores_df['sequence'].astype(str)
    all_scores_df['design_method'] = all_scores_df['design_method'].astype(str)
    all_scores_df['score_taken_from'] = all_scores_df['score_taken_from'].astype(str)
            
    ### Calculate Generations and Scores    
    G = nx.DiGraph()
    generations = {} 
    for idx, row in all_scores_df.iterrows():
        
        # Get generation
        G.add_node(idx)
        if row['parent_index'] != "None":
            G.add_edge(int(float(row['parent_index'])), idx)
            parent_gen = generations.get(int(float(row['parent_index'])), 0)
            generations[idx] = parent_gen + 1
        else:
            generations[idx] = 0  # Set generation to 0 if parent_index is "None"
        all_scores_df.at[idx, 'generation'] = generations[idx]

        # Get mutations
        if pd.isna(all_scores_df.at[idx, 'mutations']):
            reference_sequence = extract_sequence_from_pdb(f"{FOLDER_HOME}/0/Rosetta_Relax_0.pdb")
            pdb_path = f"{FOLDER_HOME}/{idx}/Rosetta_Design_{idx}.pdb"
            if os.path.exists(pdb_path): 
                current_sequence = extract_sequence_from_pdb(pdb_path)
                mutations = sum(1 for a, b in zip(current_sequence, reference_sequence) if a != b)
                all_scores_df.at[idx, 'sequence'] = current_sequence
                all_scores_df.at[idx, 'mutations'] = mutations
  
    all_scores_df.to_csv(ALL_SCORES_CSV, index=False)

    # Plot data
    fig, axs = plt.subplots(2, 4, figsize=(15, 6))
    
    all_scores_df = all_scores_df.dropna(subset=['total_score'])
    all_scores_df = all_scores_df[all_scores_df['parent_index'] != "None"]
    catalytic_scores, total_scores, interface_scores, combined_scores = normalize_scores(all_scores_df)
            
    # Break because not enough data
    if len(all_scores_df) < 3: return
    
    plot_combined_score(axs[0,0], combined_scores, \
                        combined_score_min, combined_score_max, combined_score_bin)
    plot_interface_score(axs[0,1], interface_scores, \
                         interface_score_min, interface_score_max, interface_score_bin)
    plot_total_score(axs[0,2], total_scores, \
                     total_score_min, total_score_max, total_score_bin)
    plot_catalytic_score(axs[0,3], catalytic_scores, \
                         catalytic_score_min, catalytic_score_max, catalytic_score_bin)
    
    plot_boltzmann_histogram(axs[1,0], combined_scores, all_scores_df, \
                             combined_score_min, combined_score_max, combined_score_bin)
    plot_combined_score_v_generation_scatter(axs[1,1], combined_scores, all_scores_df, \
                                             combined_score_min, combined_score_max)
    plot_combined_score_v_generation(axs[1,2], combined_scores, all_scores_df, \
                                     combined_score_min, combined_score_max)
    plot_mutations_v_generation(axs[1,3], all_scores_df, mut_min, mut_max)
    plt.tight_layout()
    plt.show()
    
    fig, ax = plt.subplots(1, 1, figsize=(15, 2.5))
    plot_tree(ax, combined_scores, all_scores_df, G)
    plt.show()

def plot_combined_score(ax, combined_scores, score_min, score_max, score_bin):
    
    ax.hist(combined_scores, bins=np.arange(score_min,score_max+score_bin,score_bin))
    ax.axvline(HIGHSCORE, color='r')
    ax.set_xlim(score_min,score_max)
    ax.set_title('Histogram of Score')
    ax.set_xlabel('Combined Score')
    ax.set_ylabel('Frequency')
    
def plot_interface_score(ax, interface_scores, interface_score_min, interface_score_max, interface_score_bin):
        
    ax.hist(interface_scores, density=True,
            bins=np.arange(interface_score_min,interface_score_max+interface_score_bin,interface_score_bin))
    ax.set_xlim(interface_score_min,interface_score_max)
    ax.set_title('Histogram of Interface Score')
    ax.set_xlabel('Interface Score')
    ax.set_ylabel('Frequency')
    
def plot_total_score(ax, total_scores, total_score_min, total_score_max, total_score_bin):

    ax.hist(total_scores, density=True,
            bins=np.arange(total_score_min,total_score_max+total_score_bin,total_score_bin))
    ax.set_xlim(total_score_min,total_score_max)
    ax.set_title('Histogram of Total Score')
    ax.set_xlabel('Total Score')
    ax.set_ylabel('Frequency')
    
def plot_catalytic_score(ax, catalytic_scores, total_score_min, total_score_max, total_score_bin):

    ax.hist(catalytic_scores, density=True,
            bins=np.arange(total_score_min,total_score_max+total_score_bin,total_score_bin))
    ax.set_xlim(total_score_min,total_score_max)
    ax.set_title('Histogram of Catalytic Score')
    ax.set_xlabel('Catalytic Score')
    ax.set_ylabel('Frequency')
    
def plot_combined_score_v_generation_scatter(ax, combined_scores, all_scores_df, combined_score_min, combined_score_max):
    
    combined_scores = pd.Series(combined_scores)
    moving_avg = combined_scores.rolling(window=20).mean()
    ax.scatter(all_scores_df['index'], combined_scores, c='lightgrey', s=5) 
    ax.axhline(HIGHSCORE, color='r')
    ax.plot(range(len(moving_avg)),moving_avg,c="k")
    ax.set_ylim(combined_score_min,combined_score_max)
    ax.set_title('Score vs Index')
    ax.set_xlabel('Index')
    ax.set_ylabel('Combined Score')    

def plot_boltzmann_histogram(ax, combined_scores, all_scores_df, score_min, score_max, score_bin):
    
    if isinstance(KBT_BOLTZMANN, (float, int)):
        kbt_boltzmann = KBT_BOLTZMANN
    else:
        if len(KBT_BOLTZMANN) == 2:
            kbt_boltzmann = KBT_BOLTZMANN[0] * 10 ** (- KBT_BOLTZMANN[1] * len(all_scores_df))
    
    boltzmann_factors = np.exp(combined_scores / (kbt_boltzmann)) 
    probabilities     = boltzmann_factors / sum(boltzmann_factors) 
    
    random_scores     = np.random.choice(combined_scores, size=1000, replace=True)
    boltzmann_scores  = np.random.choice(combined_scores, size=1000, replace=True, p=probabilities)

    # Plot the first histogram
    ax.hist(random_scores, density=True, alpha=0.7, label='Random Sampling', \
            bins=np.arange(score_min,score_max+score_bin,score_bin))
    ax.axvline(HIGHSCORE, color='r')
    ax.set_xlabel('Combined Score')
    ax.set_ylabel('Density (Normal)')
    ax.set_title(f'kbT = {kbt_boltzmann:.1e}')
    # Create a twin y-axis for the second histogram
    ax_dup = ax.twinx()
    ax_dup.hist(boltzmann_scores, density=True, alpha=0.7, color='orange', label='Boltzmann Sampling', \
                bins=np.arange(score_min,score_max+score_bin,score_bin))
    ax.set_xlim(score_min,score_max)
    ax_dup.set_ylabel('Density (Boltzmann)')
    ax_dup.tick_params(axis='y', labelcolor='orange')



def plot_interface_score_v_total_score(ax, all_scores_df, 
                                       total_score_min, total_score_max, interface_score_min, interface_score_max):

    ax.scatter(all_scores_df['total_score'], all_scores_df['interface_score'],
            c=all_scores_df['index'], cmap='coolwarm_r', s=5)
    correlation,_ = pearsonr(all_scores_df['total_score'], all_scores_df['interface_score'])
    xmin = all_scores_df['total_score'].min()
    xmax = all_scores_df['total_score'].max()
    z = np.polyfit(all_scores_df['total_score'], all_scores_df['interface_score'], 1)
    p = np.poly1d(z)
    x_trendline = np.linspace(xmin, xmax, 100) 
    ax.plot(x_trendline, p(x_trendline), "k")
    ax.set_title(f'Pearson r: {correlation:.2f}')
    ax.set_xlim(total_score_min,total_score_max)
    ax.set_ylim(interface_score_min,interface_score_max)
    ax.set_xlabel('Total Score')
    ax.set_ylabel('Interface Score')
        
def plot_mutations_v_generation(ax, all_scores_df,  mut_min, mut_max):
    
    all_scores_df = all_scores_df.dropna(subset=['mutations'])
    
    max_gen = int(all_scores_df['generation'].max())
    boxplot_data = [all_scores_df[all_scores_df['generation'] == gen]['mutations'] for gen in range(1,max_gen+1,1)]
    ax.boxplot(boxplot_data, positions=range(len(boxplot_data)))
    ax.axhline(len(DESIGN.split(",")), color='r')
    ax.set_xticks(range(len(boxplot_data)))
    ax.set_xticklabels(range(1,len(boxplot_data)+1,1))
    ax.set_ylim(mut_min,mut_max)
    ax.set_title('Mutations vs Generations')
    ax.set_xlabel('Generation')
    ax.set_ylabel('Number of Mutations')
    ax.yaxis.grid(True, linestyle='--', which='major', color='grey', alpha=0.7)

def plot_combined_score_v_generation(ax, combined_scores, all_scores_df, score_min, score_max):

    max_gen = int(all_scores_df['generation'].max())
    boxplot_data = [combined_scores[all_scores_df['generation'] == gen] for gen in range(1, max_gen+1)]    
    ax.boxplot(boxplot_data, positions=range(len(boxplot_data)))
    ax.axhline(HIGHSCORE, color='r')
    ax.set_xticks(range(len(boxplot_data)))
    ax.set_xticklabels(range(1,len(boxplot_data)+1,1))
    ax.set_ylim(score_min,score_max)
    ax.set_title('Combined Score vs Generations')
    ax.set_xlabel('Generation')
    ax.set_ylabel('Combined Score')
    ax.yaxis.grid(True, linestyle='--', which='major', color='grey', alpha=0.7)
    
def plot_tree(ax, combined_scores, all_scores_df, G):
    
    all_scores_df = pd.read_csv(ALL_SCORES_CSV)
    catalytic_scores, total_scores, interface_scores, combined_scores = normalize_scores(all_scores_df)

    max_gen = int(all_scores_df['generation'].max())
    
    def set_node_positions(G, node, pos, x, y, counts):
        pos[node] = (x, y)
        neighbors = list(G.successors(node))
        next_y = y - counts[node] / 2
        for neighbor in neighbors:
            set_node_positions(G, neighbor, pos, x + 1, next_y + counts[neighbor] / 2, counts)
            next_y += counts[neighbor]

    counts = {}
    count_descendants(G, 0, counts)

    pos = {}
    set_node_positions(G, 0, pos, 0, 0, counts)
    y_values = [y for x, y in pos.values()]
    y_span = max(y_values) - min(y_values)  
    
    colors = combined_scores
    colors[0] = np.nan
    normed_colors = [(x - np.nanmin(colors[1:])) / (np.nanmax(colors[1:]) - np.nanmin(colors[1:])) for x in colors]
    normed_colors = np.nan_to_num(normed_colors, nan=0)

    # Draw the graph with the positions set
    #nx.draw(G, pos, ax=ax, sld', arrows=False, cmap=plt.cm.coolwarm_r)
    for start, end in G.edges():
        color = plt.cm.coolwarm_r(normed_colors[end])
        if float(normed_colors[end]) == 0.0: color = [0., 0., 0., 1.]
        linewidth = 0.1+2*normed_colors[end]
        
        x0, y0 = pos[start]
        x1, y1 = pos[end]
        ax.plot([y0, y1], [x0, x1], color=color, linewidth=linewidth)

    # Adjust axis labels and ticks for the swapped axes
    ax.axis('on')
    ax.set_xlabel("Variants")
    ax.set_ylabel("Generations")
    ax.set_yticks(range(max_gen+1))
    ax.set_ylim(0,max_gen+0.25)
    ax.set_yticklabels(range(max_gen+1))

    ax.tick_params(axis='y', which='both', bottom=True, top=False, left=True, right=False,
                   labelbottom=True, labelleft=True)

## Functions needed to run AIzyme algorithm

In [None]:
print("AIzyme Functions loaded!")
time.sleep(0.5)