In [None]:
import os
import subprocess
import shutil
import tempfile
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import torch
import pandas as pd
from Bio import PDB, SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
import pydssp
import math

# === 1️⃣ Parallel PDB Downloading ===
def download_pdb(pdb_code, save_dir):
    """Downloads a PDB file in parallel."""
    pdb_file = PDB.PDBList()
    pdb_file.retrieve_pdb_file(pdb_code, file_format="pdb", pdir=save_dir, overwrite=False)

def download_all_pdbs(pdb_codes, save_dir):
    """Downloads multiple PDB files in parallel."""
    os.makedirs(save_dir, exist_ok=True)
    with ThreadPoolExecutor() as executor:
        executor.map(download_pdb, pdb_codes, [save_dir] * len(pdb_codes))


# === 2️⃣ Amino Acid Conversion ===
AA_DICT = {
    "ALA": "A", "ARG": "R", "ASN": "N", "ASP": "D", "CYS": "C",
    "GLN": "Q", "GLU": "E", "GLY": "G", "HIS": "H", "ILE": "I",
    "LEU": "L", "LYS": "K", "MET": "M", "PHE": "F", "PRO": "P",
    "SER": "S", "THR": "T", "TRP": "W", "TYR": "Y", "VAL": "V"
}

def three_to_one(resname):
    """Converts 3-letter residue name to 1-letter code."""
    return AA_DICT.get(resname, "X")  # 'X' for unknown residues


# === 3️⃣ Extract FASTA from PDB ===
def extract_pdb_fasta(pdb_code, pdb_dir, chain_id, fasta_dir):
    """Extracts the sequence of a specific chain from a PDB file and writes it as FASTA."""
    pdb_filepath = os.path.join(pdb_dir, f"pdb{pdb_code}.ent")
    fasta_filepath = os.path.join(fasta_dir, f"{pdb_code}.fasta")

    parser = PDB.PDBParser(QUIET=True)
    structure = parser.get_structure(pdb_code, pdb_filepath)

    # Only process first model (speed optimization)
    model = structure[0]
    sequence, observed_residues = [], []

    chain = model[chain_id] if chain_id in model else None
    if chain:
        for residue in chain.get_residues():
            if PDB.is_aa(residue):
                sequence.append(three_to_one(residue.get_resname()))
                observed_residues.append(residue.id[1])

    # Write FASTA
    fasta_seq = SeqRecord(Seq("".join(sequence)), id=f"{pdb_code}_{chain_id}", description="")
    SeqIO.write(fasta_seq, fasta_filepath, "fasta")

    return observed_residues


# === 4️⃣ Run DeepTMHMM Asynchronously ===
def run_deeptmhmm(pdb_code, fasta_filepath, results_dir):
    """Runs DeepTMHMM for a single PDB code asynchronously."""
    pdb_results_dir = os.path.join(results_dir, pdb_code)
    os.makedirs(pdb_results_dir, exist_ok=True)

    process = subprocess.Popen(
        ["biolib", "run", "DTU/DeepTMHMM", "--fasta", fasta_filepath],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True,
        cwd=pdb_results_dir,
        start_new_session=True  # Allows independent execution
    )
    return process


def run_all_deeptmhmm(pdb_codes, fasta_dir, results_dir):
    """Runs DeepTMHMM for multiple PDB codes in parallel asynchronously."""
    os.makedirs(results_dir, exist_ok=True)

    with ThreadPoolExecutor() as executor:
        processes = {
            pdb_code: executor.submit(run_deeptmhmm, pdb_code, os.path.join(fasta_dir, f"{pdb_code}.fasta"), results_dir)
            for pdb_code in pdb_codes
        }

    # Wait for all processes to complete
    for pdb_code, future in processes.items():
        process = future.result()
        stdout, stderr = process.communicate()
        print(f"[{pdb_code}] STDOUT:", stdout)
        print(f"[{pdb_code}] STDERR:", stderr)


def keep_only_tmr(results_dir, pdb_codes):
    """Removes all files except TMRs.gff3 in each PDB results folder."""
    for pdb_code in pdb_codes:
        pdb_results_dir = os.path.join(results_dir, pdb_code, "biolib_results")  # Ensure correct subfolder

        if os.path.exists(pdb_results_dir):
            for filename in os.listdir(pdb_results_dir):
                file_path = os.path.join(pdb_results_dir, filename)

                if filename != "TMRs.gff3":
                    try:
                        os.remove(file_path)  # Remove only files, leave the folder
                    except Exception as e:
                        print(f"Warning: Could not remove {file_path}: {e}")


def extract_coordinates(pdb_code, pdb_dir, chain_id='A'):
    """Extracts coordinates from a specific chain in the PDB file."""
    pdb_filepath = os.path.join(pdb_dir, f"pdb{pdb_code}.ent")
    parser = PDB.PDBParser(QUIET=True)
    structure = parser.get_structure(pdb_code, pdb_filepath)
    
    # Select the first model (as a default)
    model = structure[0]
    chain = model[chain_id]  # Get the chain

    coordinates = []
    for residue in chain:
        if PDB.is_aa(residue):  # Ensure it's an amino acid
            for atom in residue:
                if atom.get_name() in ['N', 'CA', 'C', 'O']:  # Get backbone atoms only
                    coordinates.append(atom.coord)
                    
    # Convert coordinates to numpy array and then PyTorch tensor
    coord_array = np.array(coordinates)
    L = sum(1 for residue in chain if PDB.is_aa(residue))  # Number of residues in chain
    atoms = 4  # N, CA, C, O
    xyz = 3  # x, y, z coordinates
    
    coord_tensor = torch.tensor(coord_array, dtype=torch.float32).reshape([L, atoms, xyz])
    
    return coord_tensor


# === 5️⃣ TMH Extension Processing ===
def calculate_desired_extensions(tmh_ranges, ss_data, max_extend=9):
    """Determine how much each TMH would like to extend based on consecutive 'H' residues."""
    desired_extensions = []

    for start, end in tmh_ranges:
        # Backward extension
        backward_extension = 0
        for i in range(1, max_extend + 1):
            if start - i >= 0 and ss_data[start - i] == "H":  # Check if index is within bounds
                backward_extension += 1
            else:
                break  # Stop at first non-'H' or out-of-bounds

        # Forward extension
        forward_extension = 0
        for i in range(1, max_extend + 1):
            if end + i < len(ss_data) and ss_data[end + i] == "H":  # Check if index is within bounds
                forward_extension += 1
            else:
                break  # Stop at first non-'H' or out-of-bounds

        desired_extensions.append((backward_extension, forward_extension))

    return desired_extensions




def calculate_available_spaces(tmh_ranges):
    """Calculate the number of residues available between consecutive TMHs."""
    available_spaces = []

    for i in range(len(tmh_ranges) - 1):
        prev_end = tmh_ranges[i][1]  # End of the current TMH
        next_start = tmh_ranges[i + 1][0]  # Start of the next TMH
        available_space = next_start - prev_end - 1  # Residues in between
        available_spaces.append(available_space)

    return available_spaces


def reorder_gpcr_tmh_ends(tmh_extended_pairs):
    """Reorder TMH ends for a GPCR assuming 14 TMH ends in the given pattern."""
    pattern = ["extra", "intra", "intra", "extra", "extra", "intra", "intra", 
               "extra", "extra", "intra", "intra", "extra", "extra", "intra"]

    reordered = []
    for i, label in enumerate(pattern):
        if label == "extra":
            reordered.append(tmh_extended_pairs[i // 2][0])  # Take start residue
        else:  # "intra"
            reordered.append(tmh_extended_pairs[i // 2][1])  # Take end residue

    return reordered


# === 6️⃣ Define Directories and Execute ===
def main():
    pdb_codes = ["6e67", "4gbr", "5d5a", "7dhi"]
    pdb_dir = r"C:\Users\Student\OneDrive - Aston University\Documents\Biology\Project\Project_automation\Python\PDB_files"
    fasta_dir = r"C:\Users\Student\OneDrive - Aston University\Documents\Biology\Project\Project_automation\Python\Fasta_files"
    results_dir = r"C:\Users\Student\OneDrive - Aston University\Documents\Biology\Project\Project_automation\Python\DeepTMHMM_results"

    # Download PDB files in parallel
    download_all_pdbs(pdb_codes, pdb_dir)

    # Extract sequences
    chain_id = "A"
    all_observed_residues = {
        pdb_code: extract_pdb_fasta(pdb_code, pdb_dir, chain_id, fasta_dir) for pdb_code in pdb_codes
    }

    # Run DeepTMHMM in parallel asynchronously
    run_all_deeptmhmm(pdb_codes, fasta_dir, results_dir)
    keep_only_tmr(results_dir, pdb_codes)

    # Example DeepTMHMM result parsing and extension calculations
    for pdb_code in pdb_codes:
        tmh_ranges = []
        tmh_result_file = os.path.join(results_dir, pdb_code, "biolib_results", "TMRs.gff3")

        with open(tmh_result_file) as file:
            for line in file:
                if "TMhelix" in line:
                    parts = line.strip().split("\t")
                    start, end = int(parts[2]), int(parts[3])
                    tmh_ranges.append((start, end))

        # Use the correct pdb_filepath for extracting coordinates
        pdb_filepath = os.path.join(pdb_dir, f"pdb{pdb_code}.ent")
        coord_tensor = extract_coordinates(pdb_code, pdb_dir, chain_id)
                    
        # Get the secondary structure from DSSP using pydssp
        ss_data = pydssp.assign(coord_tensor, out_type='c3')  # This will give you a list of "H" and "-" values

        # Now pass this ss_data directly to calculate_desired_extensions
        desired_extensions = calculate_desired_extensions(tmh_ranges, ss_data)

        # Process available spaces
        available_spaces = calculate_available_spaces(tmh_ranges)
        max_extension = 9
        desired_extensions = [(min(start, max_extension), min(end, max_extension)) for start, end in desired_extensions]

        # Align extensions with available spaces
        desired_extensions_dict = dict(enumerate(desired_extensions))
        for i in range(len(available_spaces)):
            total_desired = desired_extensions_dict[i][1] + desired_extensions_dict[i + 1][0]
            available = available_spaces[i]

            if total_desired > available:
                half_space = math.floor(available / 2)
                if half_space >= desired_extensions_dict[i][1]:
                    desired_extensions_dict[i + 1] = (available - desired_extensions_dict[i][1], desired_extensions_dict[i + 1][1])
                elif half_space >= desired_extensions_dict[i + 1][0]:
                    desired_extensions_dict[i] = (desired_extensions_dict[i][0], available - desired_extensions_dict[i + 1][0])
                else:
                    desired_extensions_dict[i] = (desired_extensions_dict[i][0], half_space)
                    desired_extensions_dict[i + 1] = (half_space, desired_extensions_dict[i + 1][1])

        extended_tmh_ranges = []
        for i, (start, end) in enumerate(tmh_ranges):
            left_extension = desired_extensions_dict[i][0]
            right_extension = desired_extensions_dict[i][1]
            new_start = start - left_extension
            new_end = end + right_extension
            extended_tmh_ranges.append((new_start, new_end))

        # Reorder TMH ends
        tmh_extended_pairs = [
            (all_observed_residues[pdb_code][start - 1], all_observed_residues[pdb_code][end - 1])
            for start, end in extended_tmh_ranges
        ]
        reordered_tmh_ends = reorder_gpcr_tmh_ends(tmh_extended_pairs)
        
        print(f"Reordered TMH ends for {pdb_code}: {reordered_tmh_ends}")

if __name__ == "__main__":
    main()
