In [None]:
# NOTE: ADD FULL (not relative) paths in the ###### areas, e.g. 
# CORRECT: ROSETTA_BIN = "/home/user/.../rosetta/main/source/bin"
# INCORRECT:  ROSETTA_BIN = "rosetta/main/source/bin"

# NOTE: if you followed README, this notebook should run as is, no modifications required.

import os
import sys
import string
import numpy as np
import pandas as pd
import glob
from tqdm.notebook import tqdm
import requests
import shutil
import Bio
from Bio import SeqUtils
from Bio.PDB import PDBParser, PDBIO
from Bio.SeqUtils import seq1
import json
import re
import matplotlib.pyplot as plt
import warnings
import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor
import subprocess
import logging

# Get the absolute path of the current working directory
notebook_dir = os.getcwd()

# Constants
###########################################################################################################################################
N_JOBS = 8
RUN_FAST_RELAX = True
######## FULL PATHS ARE OBLIGATORY ##############
# ROSETTA_BIN = "absolute/path/rosetta/main/source/bin" 
ROSETTA_BIN = os.path.join(notebook_dir, "rosetta/main/source/bin")
# ROSETTA_PATH = "absolute/path/rosetta/main/source/bin/relax.static.linuxgccrelease"
ROSETTA_PATH = os.path.join(notebook_dir, "rosetta/main/source/bin/relax.static.linuxgccrelease")
flag_parallel = True
###########################################################################################################################################
for path in [ROSETTA_BIN, ROSETTA_PATH]:
    assert os.path.exists(path), f"Path does not exist: {path}"
    assert os.path.isabs(path), f"Path is not absolute: {path}"


# Paths
###########################################################################################################################################
# path_to_save_ds = "absolute/path/to/output/directory"
path_to_save_ds = notebook_dir
dataset_name = "Ssym"
###########################################################################################################################################
path_to_ds = os.path.join(path_to_save_ds, dataset_name)
path_to_dw_add_pdbs = os.path.join(path_to_ds, "pdbs")
path_to_chains = os.path.join(path_to_ds, "chains")
path_to_pdb_chains_filtered = os.path.join(path_to_ds, "pdbs_chains_filtered")
path_to_save_rem_hetatm = os.path.join(path_to_ds, "pdb_chains_hetatm_rem")
path_to_relaxed_chains = os.path.join(path_to_ds, "relaxed_chains")
relaxed_chains_total = os.path.join(path_to_ds, "relaxed_chains_total")
relaxed_chains_total_ori = os.path.join(path_to_ds, "relaxed_chains_total_ori")
rosetta_out = os.path.join(path_to_ds, "rosetta_out")
features = os.path.join(path_to_ds, "features")
features_ds = os.path.join(features, f"{dataset_name}_ori")
features_ds_nonori = os.path.join(features, f"{dataset_name}_nonori")
ds = os.path.join(path_to_ds, "ds") # NOTE: this is the final datasets directory

###########################################################################################################################################
path_to_ori_script = os.path.join(notebook_dir, "orientation_standardization/orient_protein.py")
# path_to_ft_calc_script = os.path.join(notebook_dir, "calculate_features_for_thermonet.py") #no GLY correction (Thermonet)
path_to_ft_calc_script = os.path.join(notebook_dir, "calculate_features_for_orgnet.py") #with GLY correction (Orgnet)
###########################################################################################################################################
for path in [path_to_ori_script, path_to_ft_calc_script]:
    assert os.path.exists(path), f"Path does not exist: {path}"
    assert os.path.isabs(path), f"Path is not absolute: {path}"

log = os.path.join(path_to_ds, "ori_log.txt")
os.environ['HTMD_NONINTERACTIVE'] = '1'

# Configure logging
log_file = os.path.join(path_to_save_ds, dataset_name, "process.log")
os.makedirs(os.path.dirname(log_file), exist_ok=True)
logging.basicConfig(filename=log_file, level=logging.INFO, format='%(asctime)s %(levelname)s:%(message)s')

def make_dirs():
    """
    Create the required directory structure for dataset processing.
    
    - Creates main directories (dataset, pdbs, chains, features, etc.).
    - Also creates subdirectories under the features directory for different feature types.
    """
    directories = [
        path_to_ds,
        path_to_dw_add_pdbs,
        path_to_chains,
        path_to_pdb_chains_filtered,
        path_to_save_rem_hetatm,
        path_to_relaxed_chains,
        relaxed_chains_total,
        relaxed_chains_total_ori,
        rosetta_out,
        features,
        features_ds,
        features_ds_nonori,
        ds
    ]
    for directory in directories:
        os.makedirs(directory, exist_ok=True)
        
    sub_dirs = ["defdif_direct", "def_direct", "dif_direct",
                "defdif_reverse", "def_reverse", "dif_reverse"]
    for item in sub_dirs:
        os.makedirs(os.path.join(features_ds, f"{dataset_name}_{item}"), exist_ok=True)
        os.makedirs(os.path.join(features_ds_nonori, f"{dataset_name}_{item}"), exist_ok=True)

def make_chain_dirs(path, chains_list):
    """
    Create directories for individual protein chains.
    
    Parameters:
    - path (str): Base directory where chain directories will be created.
    - chains_list (list): List of chain identifiers.
    """
    for chain in chains_list:
        os.makedirs(os.path.join(path, chain), exist_ok=True)
    

def download_pdb(pdb, path_to_dwl):
    """
    Download a single PDB file from the RCSB repository.
    
    Parameters:
    - pdb (str): The PDB id to download.
    - path_to_dwl (str): Directory where the downloaded PDB file will be saved.
    
    Operation:
    - Constructs the file name and URL.
    - Downloads the file using requests.
    - Writes the file if the download is successful.
    """
    fname = os.path.join(path_to_dwl, f'{pdb}.pdb')
    url = f'https://files.rcsb.org/download/{pdb}.pdb'
    v = requests.get(url)
    if v.status_code != 200:
        logging.warning(f"{url} status code {v.status_code}")
        return
    with open(fname, 'w+') as f:
        f.write(v.content.decode('utf-8'))


def download_pdbs_list(lst, path_to_dwl):
    """
    Download multiple PDB files concurrently.
    
    Parameters:
    - lst (list): List of PDB ids to download.
    - path_to_dwl (str): Directory where the downloaded PDB files will be stored.
    
    Operation:
    - Determines which PDBs are not already downloaded.
    - Uses ThreadPoolExecutor to download the missing PDB files in parallel.
    - Displays progress via tqdm.
    """
    existing_pdbs = set([os.path.basename(f)[:4] for f in glob.glob(os.path.join(path_to_dwl, '*.pdb'))])
    to_download = set(lst) - existing_pdbs

    with ThreadPoolExecutor(max_workers=N_JOBS) as executor:
        list(tqdm(executor.map(lambda pdb: download_pdb(pdb, path_to_dwl), to_download), total=len(to_download), desc='Downloading PDB files'))


def pdbs_to_chains(path_to_pdbs, path_to_chains):
    """
    Process PDB files to extract individual chains and associated sequence information.
    
    Parameters:
    - path_to_pdbs (str): Directory containing the PDB files.
    - path_to_chains (str): Directory where individual chain PDB files will be saved.
    
    Operation:
    - Parses each PDB file using BioPython's PDBParser.
    - Validates that all models in a structure have the same sequence.
    - For each chain, extracts the sequence and residue IDs.
    - Saves each chain to a separate PDB file.
    - Returns a Pandas DataFrame summarizing the chain metadata.
    """
    chain_seq = []
    parser = PDBParser()
    io = PDBIO()
    pdb_files = glob.glob(os.path.join(path_to_pdbs, '*.pdb'))
    
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        for pdb in tqdm(pdb_files, desc='Processing PDBs'):
            pdb_id = os.path.basename(pdb)[:4]
            structure = parser.get_structure(pdb_id, pdb)

            # Validate models
            models = list(structure.get_models())
            for i in range(len(models) - 1):
                s1 = seq1(''.join(residue.resname for residue in models[i].get_residues()))
                s2 = seq1(''.join(residue.resname for residue in models[i + 1].get_residues()))
                assert s1 == s2, "Invalid models"

            pdb_chains = structure.get_chains()

            for chain in pdb_chains:
                chain_name = f'{structure.get_id()}{chain.get_id()}'
                chain_seq.append({
                    'PDB_chain': chain_name,
                    'sequence': seq1(''.join(residue.resname for residue in chain)),
                    'pdb_ids': tuple([r.get_id()[1] for r in chain]),
                })
                chain_file = os.path.join(path_to_chains, f'{chain_name}.pdb')
                if not os.path.exists(chain_file):
                    io.set_structure(chain)
                    io.save(chain_file)

    return pd.DataFrame(chain_seq).drop_duplicates()
    

def filter_chains(path_to_chains, path_to_pdb_chains_filtered):
    """
    Filter and copy chains based on an unwanted chain set.
    
    Parameters:
    - path_to_chains (str): Directory containing individual chain PDB files.
    - path_to_pdb_chains_filtered (str): Destination directory for filtered chains.
    
    Operation:
    - Uses a set of unwanted chain identifiers (`un_pdb_chain_dat` must be defined externally).
    - Copies chain files whose id is in the unwanted set to the filtered directory.
    """
    un_pdb_chain_dat_set = set(un_pdb_chain_dat)  # Convert to set for faster lookup
    for filename in tqdm(os.listdir(path_to_chains), desc='Filtering Chains'):
        chain_id = filename[:5]
        if chain_id in un_pdb_chain_dat_set:
            pdb_chain_path = os.path.join(path_to_chains, filename)
            shutil.copy(pdb_chain_path, path_to_pdb_chains_filtered)
    

def rem_hetatm(path_to_flt_pdbs, path_to_save_rem_hetatm):
    """
    Remove heteroatom records from PDB files.
    
    Parameters:
    - path_to_flt_pdbs (str): Directory containing filtered PDB files.
    - path_to_save_rem_hetatm (str): Destination directory to save PDB files without HETATM lines.
    
    Operation:
    - Reads each file line-by-line and writes out lines that do not start with 'HETATM'.
    """
    for file in os.listdir(path_to_flt_pdbs):
        input_file = os.path.join(path_to_flt_pdbs, file)
        output_file = os.path.join(path_to_save_rem_hetatm, file)
        with open(input_file, 'r') as infile, open(output_file, 'w') as outfile:
            for line in infile:
                if not line.startswith('HETATM'):
                    outfile.write(line)
                    

def run_relax():
    """
    Run the Rosetta relax protocol on non-relaxed PDB chains.
    
    Operation:
    - If the RUN_FAST_RELAX flag is set, constructs a command to relax the PDB chains using Rosetta.
    - Uses GNU parallel to execute the relax command on multiple PDB files.
    - Logs the command and output to a log file.
    """
    print("RUNNING WT ROSETTA")
    if RUN_FAST_RELAX:
        nrc = ' '.join([f'{f}.pdb' for f in non_relaxed_chains])
        cmd = f"cd {path_to_save_rem_hetatm} && ls {nrc} | parallel -j {N_JOBS} {ROSETTA_BIN}/relax.static.linuxgccrelease -in:file:s {{}}  -relax:constrain_relax_to_start_coords -out:suffix _relaxed -out:no_nstruct_label -relax:ramp_constraints false"
        logging.info(f"Executing: {cmd}")
        log_file_path = os.path.join(path_to_ds, 'rosetta_relax.log')
        with open(log_file_path, 'w') as logfile:
            subprocess.run(cmd, shell=True, check=True, stdout=logfile, stderr=subprocess.STDOUT)
    

def copy_relaxed(path_to_dat):
    """
    Copy relaxed PDB chains to a designated folder.
    
    Parameters:
    - path_to_dat (str): Base directory containing the relaxed chains.
    
    Operation:
    - Removes any pre-existing score file if present.
    - Iterates over relaxed chain files and copies those containing "relaxed" in their filename into specific subdirectories.
    """
    score_file = os.path.join(path_to_dat, "pdb_chains_hetatm_rem", "score_relaxed.sc")
    if os.path.exists(score_file):
        os.remove(score_file)
        logging.info(f"Removed {score_file}")
        
    relaxed_dir = os.path.join(path_to_dat, "pdb_chains_hetatm_rem")
    for filename in tqdm(os.listdir(relaxed_dir), desc='Copying Relaxed Chains'):
        if "relaxed" in filename:
            rel_ch = os.path.join(relaxed_dir, filename)
            ch_id = filename[:5]
            path_to_copy = os.path.join(path_to_dat, "relaxed_chains", ch_id)
            os.makedirs(path_to_copy, exist_ok=True)
            shutil.copy(rel_ch, path_to_copy)
            

def create_mut_df():
    """
    Create a DataFrame of mutation data.
    
    Operation:
    - Extracts 'pdb', 'pos', 'wt', and 'mut' columns from the global DataFrame `init_df`.
    - Renames the columns to uppercase for consistency.
    - Returns the new mutation DataFrame.
    """
    ds_mut = init_df[['pdb', 'pos', 'wt', 'mut']].copy()
    ds_mut.columns = ['PDB', 'POS', 'WT', 'MUT']
    return ds_mut
    

def run_rosetta_relax(pdb_id, wt, mut, pos, path_to_relaxed_chains, PDBDIR, OUTDIR, ROSETTA_PATH):
    """
    Run Rosetta relax protocol for a specific mutation.
    
    Parameters:
    - pdb_id (str): PDB identifier.
    - wt (str): Wild-type amino acid.
    - mut (str): Mutant amino acid.
    - pos (str or int): Position of the mutation.
    - path_to_relaxed_chains (str): Directory with relaxed chain files.
    - PDBDIR (str): Not used directly here, but expected for input structure directory.
    - OUTDIR (str): Output directory for Rosetta relax results.
    - ROSETTA_PATH (str): Path to the Rosetta relax binary.
    
    Operation:
    - Creates an output folder and a Rosetta resfile specifying the mutation.
    - Constructs and runs the Rosetta command.
    - Logs the command and results.
    - Raises an error if mutation data is invalid.
    """
    out_folder = os.path.join(OUTDIR, pdb_id)
    os.makedirs(out_folder, exist_ok=True)
        
    if wt != '-' and mut != '-' and pos != '-':
        variant = wt + str(pos) + mut  # e.g., D234K

        # Create a resfile
        variant_resfile = os.path.join(out_folder, f'{pdb_id}_{variant}.resfile')
        with open(variant_resfile, 'wt') as opf:
            opf.write('NATAA\n')
            opf.write('start\n')
            opf.write(f"{variant[1:-1]} {pdb_id[-1]} PIKAA {variant[-1]}\n")

        # Rosetta command
        start_struct = os.path.join(path_to_relaxed_chains, f'{pdb_id}/{pdb_id}_relaxed.pdb')
        if not os.path.exists(start_struct):
            logging.error(f"No such file {start_struct}")
            return 
        
        cmd = [
            ROSETTA_PATH, 
            '-in:file:s', start_struct, '-in:file:fullatom',
            '-relax:constrain_relax_to_start_coords',
            '-out:no_nstruct_label', '-relax:ramp_constraints', 'false',
            '-relax:respect_resfile',
            '-packing:resfile', variant_resfile,
            '-default_max_cycles', '200',
            '-out:file:scorefile', os.path.join(out_folder, f'{pdb_id}_{variant}_relaxed.sc'),
            '-out:suffix', f'_{variant}_relaxed'
        ]

        cmd_str = ' '.join(cmd)
        logging.info(f"Executing: {cmd_str}")
        log_file_path = os.path.join(out_folder, f'{pdb_id}_{variant}_relax.log')
        with open(log_file_path, 'w') as logfile:
            subprocess.run(cmd, cwd=out_folder, stdout=logfile, stderr=subprocess.STDOUT)

        # Move the output PDB file
        output_pdb = os.path.join(out_folder, f"{pdb_id}_relaxed_{variant}_relaxed.pdb")
        if os.path.exists(output_pdb):
            logging.info(f"Output PDB file found: {output_pdb}")
            # The file is already in the correct folder
        else:
            logging.error(f"Output PDB file not found: {output_pdb}")
    else:
        logging.error("Invalid mutation data")
        raise ValueError('Not clear if this row is for wildtype or a mutant type')
    

def run_rosetta_for_mutants():
    """
    Process multiple mutations in parallel using Rosetta relax protocol.
    
    Operation:
    - Filters mutation entries in the global mutation DataFrame (`ds_mut`) for valid mutants.
    - Determines CPU availability and creates a pool for parallel processing.
    - Constructs an argument list and executes `run_rosetta_relax` in parallel using multiprocessing.
    """
    print("RUNNING ROSETTA")
    if flag_parallel:
        df = ds_mut.copy()
        df = df[(df['WT'] != '-') & (df['MUT'] != '-')]
        df.dropna(inplace=True)
        df['POS'] = df['POS'].astype(int)

        n_cpu = mp.cpu_count()
        pool_size = max(1, n_cpu - 1)
        logging.info(f'Using {pool_size} CPUs')

        args_list = [(row['PDB'], row['WT'], row['MUT'], row['POS'],
                      path_to_relaxed_chains, path_to_relaxed_chains,
                      rosetta_out, ROSETTA_PATH) for _, row in df.iterrows()]

        with mp.Pool(pool_size) as pool:
            list(tqdm(pool.starmap(run_rosetta_relax, args_list), total=len(args_list)))


def copy_pdb_files_and_directories(src_folders, dest_folder):
    """
    Recursively copy PDB files from source folders to a destination folder while preserving directory structure.
    
    Parameters:
    - src_folders (list): List of source directories to search for PDB files.
    - dest_folder (str): Destination directory where PDB files will be copied.
    
    Operation:
    - Walks through each source folder and copies files ending with ".pdb" to the corresponding relative path in the destination.
    """
    os.makedirs(dest_folder, exist_ok=True)

    for src_folder in src_folders:
        for root, dirs, files in os.walk(src_folder):
            for file in files:
                if file.endswith(".pdb"):
                    relative_path = os.path.relpath(root, src_folder)
                    dest_dir = os.path.join(dest_folder, relative_path)
                    os.makedirs(dest_dir, exist_ok=True)
                    shutil.copy2(os.path.join(root, file), os.path.join(dest_dir, file))

def feature_calc_nonori(ptc, ptofeat):
    """
    Calculate nonoriented voxel-based features for protein structures.
    
    Parameters:
    - ptc (str): Path to the directory containing protein structure folders.
    - ptofeat (str): Base directory where feature subdirectories are located.
    
    Operation:
    - For each protein folder, creates necessary subdirectories.
    - Identifies mutant PDB files by a naming pattern.
    - For each mutant, extracts the mutation position and constructs paths for both mutant and wild type PDBs.
    - Executes an external feature calculation script via subprocess.
    - Logs warnings for missing files and errors for failed folders.
    """
    bad_fold = []
    log_file_path = os.path.join(path_to_ds, 'feature_calc_nonori.log')
    for folder in tqdm(os.listdir(ptc), desc='Calculating nonoriented features'):
        try:
            for ff in os.listdir(ptofeat):
                os.makedirs(os.path.join(ptofeat, ff, folder), exist_ok=True)

            folder_path = os.path.join(ptc, folder+"/")
            mut_prots = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if "_relaxed_" in f]
            
            # Process each mutant protein file
            for pdb_file in mut_prots:
                pos = pdb_file.split("/")[-1].split("_")[2][1:-1]
                path_to_wt = os.path.join(folder_path, f"{folder}_relaxed.pdb")
                path_to_mut = pdb_file
                if os.path.exists(path_to_mut) and os.path.exists(path_to_wt):
                    cmd = f"python {path_to_ft_calc_script} -iwt {path_to_wt} -imut {path_to_mut} -o {ptofeat+'/'} --boxsize 16 --voxelsize 1"
                    with open(log_file_path, 'a') as logfile:
                        subprocess.run(cmd, shell=True, stdout=logfile, stderr=subprocess.STDOUT)
                else:
                    logging.warning(f"Files not found: {path_to_mut}, {path_to_wt}")
        except Exception as e:
            logging.error(f"Error processing folder {folder}: {e}")
            bad_fold.append(folder)

def orient_dataset(wk_dir, out_dir):
    """
    Orient protein structures using standardized orientation.
    
    Parameters:
    - wk_dir (str): Working directory containing folders of protein structures after Rosetta relax protocol.
    - out_dir (str): Directory where oriented structures will be saved.
    
    Operation:
    - For each folder, identifies a reference PDB file based on a naming convention.
    - Constructs and executes two commands per mutant file:
      one for the mutant structure and one for the wild type using the reference file.
    - Executes the orientation commands via subprocess and logs the output.
    """
    os.makedirs(out_dir, exist_ok=True)
    for folder in tqdm(os.listdir(wk_dir), desc='Orienting Dataset'):
        folder_path = os.path.join(wk_dir, folder+"/")
        out_folder = os.path.join(out_dir, folder+"/")

        os.makedirs(out_folder, exist_ok=True)

        ref_pdb_files = [f for f in os.listdir(folder_path) if f.endswith(".pdb") and len(f.split("_")) == 2]
        if not ref_pdb_files:
            continue
        ref_pdb_file = ref_pdb_files[0]

        pdb_files = [f for f in os.listdir(folder_path) if f.endswith(".pdb") and f != ref_pdb_file]
        if not pdb_files:
            continue

        for pdb_file in pdb_files:
            pdb_file_path = os.path.join(folder_path, pdb_file)
            pos = pdb_file.split("_")[-2][1:-1]
            cmd1 = f'python {path_to_ori_script} -i {pdb_file_path} -o {out_folder} --mut_pos {pos} -fl 0'
            cmd2 = f'python {path_to_ori_script} -i {os.path.join(folder_path, ref_pdb_file)} -o {out_folder} --mut_pos {pos} -fl {pos}_wt'

            with open(log, 'a') as logfile:
                subprocess.run(cmd1, shell=True, stdout=logfile, stderr=subprocess.STDOUT)
                subprocess.run(cmd2, shell=True, stdout=logfile, stderr=subprocess.STDOUT)
                

def feature_calc(ptc, ptofeat):
    """
    Calculate oriented voxel-based features for protein structures.
    
    Parameters:
    - ptc (str): Path to the directory containing protein structure folders.
    - ptofeat (str): Base directory where feature subdirectories are located.
    
    Operation:
    - For each protein folder, creates necessary subdirectories.
    - Identifies mutant PDB files by a naming pattern.
    - For each mutant, extracts the mutation position and constructs paths for both mutant and wild type PDBs.
    - Executes an external feature calculation script via subprocess.
    - Logs warnings for missing files and errors for failed folders.
    """
    bad_fold = []
    log_file_path = os.path.join(path_to_ds, 'feature_calc_ori.log')
    for folder in tqdm(os.listdir(ptc), desc='Calculating features'):
        try:
            for ff in os.listdir(ptofeat):
                os.makedirs(os.path.join(ptofeat, ff, folder), exist_ok=True)

            folder_path = os.path.join(ptc, folder+"/")
            mut_prots = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if "_relaxed_0_oriented.pdb" in f]
            
            # Process each mutant protein file
            for pdb_file in mut_prots:
                pos = pdb_file.split("/")[-1].split("_")[2][1:-1]
                path_to_wt = os.path.join(folder_path, f"{folder}_relaxed_{pos}_wt_oriented.pdb")
                path_to_mut = pdb_file
                if os.path.exists(path_to_mut) and os.path.exists(path_to_wt):
                    cmd = f"python {path_to_ft_calc_script} -iwt {path_to_wt} -imut {path_to_mut} -o {ptofeat+'/'} --boxsize 16 --voxelsize 1"
                    with open(log_file_path, 'a') as logfile:
                        subprocess.run(cmd, shell=True, stdout=logfile, stderr=subprocess.STDOUT)
                else:
                    logging.warning(f"Files not found: {path_to_mut}, {path_to_wt}")
        except Exception as e:
            logging.error(f"Error processing folder {folder}: {e}")
            bad_fold.append(folder)
                

def load_dataset_dir(evaluation_dataset_path, evaluation_features_dir_dir):
    """
    Load and process the dataset for direct mutations.
    
    Parameters:
    - evaluation_dataset_path (str): Path to the CSV file containing mutation data.
    - evaluation_features_dir_dir (str): Directory where direct mutation feature files are stored.
    
    Operation:
    - Reads the CSV file and logs the total number of unique mutations.
    - Constructs file paths for features and filters out rows where features are missing.
    - Loads the features (NumPy arrays) and transposes them for further use.
    - Returns the processed DataFrame.
    """
    logging.info("Loading dataset direct mutations")
    df = pd.read_csv(evaluation_dataset_path)
    logging.info(f'Total unique mutations: {len(df)}')

    df['features'] = df.apply(lambda r: os.path.join(evaluation_features_dir_dir, f"{r.pdb_id}/{r.pdb_id}_{r.wild_type}{r.position}{r.mutant}.npy"), axis=1)
    df = df[df.features.apply(os.path.exists)]
    logging.info(f'Total mutations with features: {len(df)}')
    df['features'] = [np.load(f) for f in tqdm(df.features, desc="Loading features")]
    logging.info(f'Total mutations after filtering: {len(df)}')

    df.features = df.features.apply(lambda k: np.transpose(k, (1, 2, 3, 0)))
    return df
    

def load_dataset_rev(evaluation_dataset_path, evaluation_features_dir_rev):
    """
    Load and process the dataset for reverse mutations.
    
    Parameters:
    - evaluation_dataset_path (str): Path to the CSV file containing mutation data.
    - evaluation_features_dir_rev (str): Directory where reverse mutation feature files are stored.
    
    Operation:
    - Reads the CSV file and inverts ΔΔG values.
    - Constructs file paths for features and filters rows without existing feature files.
    - Loads and transposes the features.
    - Returns the processed DataFrame.
    """
    logging.info('Loading dataset reverse mutations')
    df_rev = pd.read_csv(evaluation_dataset_path)
    df_rev.ddg = -df_rev.ddg

    df_rev['features'] = df_rev.apply(lambda r: os.path.join(evaluation_features_dir_rev, f"{r.pdb_id}/{r.pdb_id}_{r.wild_type}{r.position}{r.mutant}.npy"), axis=1)
    df_rev = df_rev[df_rev.features.apply(os.path.exists)]
    logging.info(f'Total mutations with features: {len(df_rev)}')
    df_rev['features'] = [np.load(f) for f in tqdm(df_rev.features, desc="Loading features")]
    logging.info(f'Total mutations after filtering: {len(df_rev)}')

    df_rev.features = df_rev.features.apply(lambda k: np.transpose(k, (1, 2, 3, 0)))
    return df_rev
    

def save_dataset_npy():
    """
    Save the processed features and ΔΔG values for both direct and reverse mutation datasets as NumPy files.
    
    Operation:
    - Loads the direct and reverse datasets using the respective functions.
    - Extracts features and ΔΔG values.
    - Saves these arrays to .npy files in the specified dataset directory.
    """
    df_train_dataset_dir = load_dataset_dir(evaluation_dataset_path, evaluation_features_dir_dir)
    df_train_dataset_rev = load_dataset_rev(evaluation_dataset_path, evaluation_features_dir_rev)

    X_direct_dataset_dir = np.array(df_train_dataset_dir.features.to_list())
    y_direct_dataset_dir = df_train_dataset_dir.ddg.to_numpy()

    X_direct_dataset_rev = np.array(df_train_dataset_rev.features.to_list())
    y_direct_dataset_rev = df_train_dataset_rev.ddg.to_numpy()

    np.save(os.path.join(ds+"/", f"{dataset_name}_X_direct.npy"), X_direct_dataset_dir)
    np.save(os.path.join(ds+"/", f"{dataset_name}_y_direct.npy"), y_direct_dataset_dir)
    np.save(os.path.join(ds+"/", f"{dataset_name}_X_reverse.npy"), X_direct_dataset_rev)
    np.save(os.path.join(ds+"/", f"{dataset_name}_y_reverse.npy"), y_direct_dataset_rev)

In [None]:
# preprocess the input df
# init_df = pd.read_csv("/datasets/path_to_your_csv_dataset/")  # NOTE: columns in dataset - "pdb", "pos", "wt", "mut", "ddg"
init_df = pd.read_csv("datasets/Ssym.csv").rename(columns={
    "pdb_id": "pdb",
    "position": "pos",
    "wild_type": "wt",
    "mutant": "mut",
    "ddg": "ddg"
})
init_df

In [None]:
#run the following commands
make_dirs()

un_pdb_dat = list(set([f[0:4] for f in init_df["pdb"].unique().tolist() if "AF-" not in f]))
#print(len(un_pdb_dat), un_pdb_dat)
un_pdb_chain_dat = [f for f in init_df["pdb"].unique().tolist() if "AF-" not in f]
#print(len(un_pdb_chain_dat), un_pdb_chain_dat)

download_pdbs_list(un_pdb_dat, path_to_dw_add_pdbs)

df_chains = pdbs_to_chains(path_to_dw_add_pdbs, path_to_chains)
filter_chains(path_to_chains, path_to_pdb_chains_filtered)
rem_hetatm(path_to_pdb_chains_filtered, path_to_save_rem_hetatm)


relaxed_chains = set(os.listdir(path_to_relaxed_chains))
non_relaxed_chains = list(set(un_pdb_chain_dat) - relaxed_chains)

run_relax()


make_chain_dirs(path_to_relaxed_chains, un_pdb_chain_dat)
copy_relaxed(path_to_ds)
ds_mut = create_mut_df()

run_rosetta_for_mutants()
copy_pdb_files_and_directories([path_to_relaxed_chains, rosetta_out], relaxed_chains_total)
feature_calc_nonori(relaxed_chains_total,features_ds_nonori)
orient_dataset(relaxed_chains_total, relaxed_chains_total_ori)
feature_calc(relaxed_chains_total_ori,features_ds)

In [None]:
####################################### specify here
feature_type = "defdif"
# evaluation_dataset_path = '/path_to_your_csv_dataset/' # NOTE: columns in dataset - "pdb_id","position","wild_type","mutant","ddg"
evaluation_dataset_path = "datasets/Ssym.csv"

#NOTE: here dataset should have the following column names "pdb_id","position","wild_type","mutant","ddg"
#renaming example
#init_df2 = pd.read_csv("ThermoNet/data/datasets/p53.txt", sep = " ", names=["pdb_id","position","wild_type","mutant","ddg"])
#init_df2.to_csv("datasets/p53_dataset.csv")
############################################

evaluation_features_dir_dir = features_ds+f"/{dataset_name}_{feature_type}_direct/"
evaluation_features_dir_rev = features_ds+f"/{dataset_name}_{feature_type}_reverse/"

save_dataset_npy()