In [None]:
import os
import subprocess
import shutil
import tempfile
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import torch
import pandas as pd
from Bio import PDB, SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
import pydssp
import math

# Pick PDB structures
pdb_codes = [
    
]

# === 1️⃣ Parallel PDB Downloading ===
def download_pdb(pdb_code, save_dir):
    """Downloads a PDB file in parallel."""
    pdb_file = PDB.PDBList()
    pdb_file.retrieve_pdb_file(pdb_code, file_format="pdb", pdir=save_dir, overwrite=False)

def download_all_pdbs(pdb_codes, save_dir):
    """Downloads multiple PDB files in parallel."""
    os.makedirs(save_dir, exist_ok=True)
    with ThreadPoolExecutor() as executor:
        executor.map(download_pdb, pdb_codes, [save_dir] * len(pdb_codes))


# === 2️⃣ Amino Acid Conversion ===
AA_DICT = {
    "ALA": "A", "ARG": "R", "ASN": "N", "ASP": "D", "CYS": "C",
    "GLN": "Q", "GLU": "E", "GLY": "G", "HIS": "H", "ILE": "I",
    "LEU": "L", "LYS": "K", "MET": "M", "PHE": "F", "PRO": "P",
    "SER": "S", "THR": "T", "TRP": "W", "TYR": "Y", "VAL": "V"
}

def three_to_one(resname):
    """Converts 3-letter residue name to 1-letter code."""
    return AA_DICT.get(resname, "X")  # 'X' for unknown residues


# === 3️⃣ Extract FASTA from PDB ===
def extract_pdb_fasta(pdb_code, pdb_dir, chain_id, fasta_dir):
    """Extracts the sequence of a specific chain from a PDB file and writes it as FASTA."""
    pdb_filepath = os.path.join(pdb_dir, f"pdb{pdb_code}.ent")
    fasta_filepath = os.path.join(fasta_dir, f"{pdb_code}.fasta")

    parser = PDB.PDBParser(QUIET=True)
    structure = parser.get_structure(pdb_code, pdb_filepath)

    # Only process first model (speed optimization)
    model = structure[0]
    sequence, observed_residues = [], []

    chain = model[chain_id] if chain_id in model else None
    if chain:
        for residue in chain.get_residues():
            if PDB.is_aa(residue):
                sequence.append(three_to_one(residue.get_resname()))
                observed_residues.append(residue.id[1])

    # Write FASTA
    fasta_seq = SeqRecord(Seq("".join(sequence)), id=f"{pdb_code}_{chain_id}", description="")
    SeqIO.write(fasta_seq, fasta_filepath, "fasta")

    return observed_residues

def run_deeptmhmm(pdb_code, fasta_filepath_wsl):
    pdb_results_dir = f"/Users/Student/OneDrive - Aston University/Documents/Biology/Project/Project_automation/Python/DeepTMHMM_results/{pdb_code}"
    os.makedirs(pdb_results_dir, exist_ok=True)

    process = subprocess.run(
        ["wsl", "/home/dan/.local/bin/biolib", "run", "--local", "DTU/DeepTMHMM:1.0.24", "--fasta", f"{fasta_filepath_wsl}"],
        text=True,
        capture_output=True,
        cwd=pdb_results_dir
    )
    
    # print(process.stdout)
    # print(process.stderr)
    
    return pdb_code

def keep_only_tmr(results_dir, pdb_codes):
    """Removes all files except TMRs.gff3 in each PDB results folder."""
    for pdb_code in pdb_codes:
        pdb_results_dir = os.path.join(results_dir, pdb_code, "biolib_results")  # Ensure correct subfolder

        if os.path.exists(pdb_results_dir):
            for filename in os.listdir(pdb_results_dir):
                file_path = os.path.join(pdb_results_dir, filename)

                if filename != "TMRs.gff3":
                    try:
                        os.remove(file_path)  # Remove only files, leave the folder
                    except Exception as e:
                        print(f"Warning: Could not remove {file_path}: {e}")


def extract_coordinates(pdb_code, pdb_dir, chain_id='A'):
    """Extracts coordinates from a specific chain in the PDB file."""
    pdb_filepath = os.path.join(pdb_dir, f"pdb{pdb_code}.ent")
    parser = PDB.PDBParser(QUIET=True)
    structure = parser.get_structure(pdb_code, pdb_filepath)
    
    # Select the first model (as a default)
    model = structure[0]
    chain = model[chain_id]  # Get the chain

    coordinates = []
    for residue in chain:
        if PDB.is_aa(residue):  # Ensure it's an amino acid
            for atom in residue:
                if atom.get_name() in ['N', 'CA', 'C', 'O']:  # Get backbone atoms only
                    coordinates.append(atom.coord)
                    
    # Convert coordinates to numpy array and then PyTorch tensor
    coord_array = np.array(coordinates)
    L = sum(1 for residue in chain if PDB.is_aa(residue))  # Number of residues in chain
    atoms = 4  # N, CA, C, O
    xyz = 3  # x, y, z coordinates
    
    coord_tensor = torch.tensor(coord_array, dtype=torch.float32).reshape([L, atoms, xyz])
    
    return coord_tensor


# === 5️⃣ TMH Extension Processing ===
def calculate_desired_extensions(tmh_ranges, ss_data, max_extend=9):
    """Determine how much each TMH would like to extend based on consecutive 'H' residues."""
    desired_extensions = []

    for start, end in tmh_ranges:
        if np.isnan(start) or np.isnan(end):
            # Skip if start or end are NaN (missing)
            desired_extensions.append((0, 0))
            continue
        
        # Backward extension
        backward_extension = 0
        for i in range(1, max_extend + 1):
            ss_index = start - i  # DSSP index is offset by 1 (residue indices are 1-based)
            if ss_index >= 0 and ss_data[ss_index] == "H":
                backward_extension += 1
            else:
                break  # Stop at first non-'H' or out-of-bounds

        # Forward extension
        forward_extension = 0
        for i in range(1, max_extend + 1):
            ss_index = end + i  # DSSP index is offset by 1 (residue indices are 1-based)
            if ss_index < len(ss_data) and ss_data[ss_index] == "H":
                forward_extension += 1
            else:
                break  # Stop at first non-'H' or out-of-bounds

        desired_extensions.append((backward_extension, forward_extension))

    return desired_extensions
    
def calculate_available_spaces(tmh_ranges):
    """Calculate the number of residues available between consecutive TMHs."""
    available_spaces = []

    for i in range(len(tmh_ranges) - 1):
        prev_end = tmh_ranges[i][1]  # End of the current TMH
        next_start = tmh_ranges[i + 1][0]  # Start of the next TMH
        available_space = next_start - prev_end - 1  # Residues in between
        available_spaces.append(available_space)

    return available_spaces


def reorder_gpcr_tmh_ends(tmh_extended_pairs):
    """Reorder TMH ends for a GPCR assuming 14 TMH ends in the given pattern."""
    pattern = ["extra", "intra", "intra", "extra", "extra", "intra", "intra", 
               "extra", "extra", "intra", "intra", "extra", "extra", "intra"]

    reordered = []
    for i, label in enumerate(pattern):
        if label == "extra":
            reordered.append(tmh_extended_pairs[i // 2][0])  # Take start residue
        else:  # "intra"
            reordered.append(tmh_extended_pairs[i // 2][1])  # Take end residue

    return reordered

def main():
    pdb_dir = r"C:\Users\Student\OneDrive - Aston University\Documents\Biology\Project\Project_automation\Python\PDB_files"
    fasta_dir = r"C:\Users\Student\OneDrive - Aston University\Documents\Biology\Project\Project_automation\Python\Fasta_files"
    results_dir = r"C:\Users\Student\OneDrive - Aston University\Documents\Biology\Project\Project_automation\Python\DeepTMHMM_results"
    
    # Download PDB files
    download_all_pdbs(pdb_codes, pdb_dir)

    # Extract sequences
    chain_id = "A"
    all_observed_residues = {
        pdb_code: extract_pdb_fasta(pdb_code, pdb_dir, chain_id, fasta_dir) for pdb_code in pdb_codes
    }

    # Run DeepTMHMM in parallel asynchronously
    for pdb_code in pdb_codes:
        fasta_filepath_wsl = f"/mnt/c/Users/Student/OneDrive - Aston University/Documents/Biology/Project/Project_automation/Python/Fasta_files/{pdb_code}.fasta"  
        run_deeptmhmm(pdb_code, fasta_filepath_wsl)
    keep_only_tmr(results_dir, pdb_codes)

    pdb_data = {}

    for pdb_code in pdb_codes:
        tmh_ranges = []
        tmh_result_file = os.path.join(results_dir, pdb_code, "biolib_results", "TMRs.gff3")

        if os.path.exists(tmh_result_file):
            with open(tmh_result_file) as file:
                for line in file:
                    if "TMhelix" in line:
                        parts = line.strip().split("\t")
                        start, end = int(parts[2]), int(parts[3])
                        tmh_ranges.append((start, end))

        if len(tmh_ranges) < 7:
            print(f"⚠️ Warning: {pdb_code} has only {len(tmh_ranges)} TMHs detected.")
            tmh_ranges += [(np.nan, np.nan)] * (7 - len(tmh_ranges))

        # Extract secondary structure
        pdb_filepath = os.path.join(pdb_dir, f"pdb{pdb_code}.ent")
        coord_tensor = extract_coordinates(pdb_code, pdb_dir, chain_id)
        ss_data = pydssp.assign(coord_tensor, out_type='c3')

        # Compute desired extensions
        desired_extensions = calculate_desired_extensions(tmh_ranges, ss_data)
        available_spaces = calculate_available_spaces(tmh_ranges)
        max_extension = 9

        desired_extensions = [
            (min(start, max_extension), min(end, max_extension)) for start, end in desired_extensions
        ]

        # Align extensions with available spaces
        desired_extensions_dict = dict(enumerate(desired_extensions))
        for i in range(len(available_spaces)):
            total_desired = desired_extensions_dict[i][1] + desired_extensions_dict[i + 1][0]
            available = available_spaces[i]

            if total_desired > available:
                half_space = math.floor(available / 2)
                if half_space >= desired_extensions_dict[i][1]:
                    desired_extensions_dict[i + 1] = (available - desired_extensions_dict[i][1], desired_extensions_dict[i + 1][1])
                elif half_space >= desired_extensions_dict[i + 1][0]:
                    desired_extensions_dict[i] = (desired_extensions_dict[i][0], available - desired_extensions_dict[i + 1][0])
                else:
                    desired_extensions_dict[i] = (desired_extensions_dict[i][0], half_space)
                    desired_extensions_dict[i + 1] = (half_space, desired_extensions_dict[i + 1][1])

        extended_tmh_ranges = []
        for i, (start, end) in enumerate(tmh_ranges):
            left_extension = desired_extensions_dict[i][0]
            right_extension = desired_extensions_dict[i][1]

            if np.isnan(start) or np.isnan(end):
                new_start, new_end = np.nan, np.nan  
            else:
                new_start = start - left_extension
                new_end = end + right_extension

            extended_tmh_ranges.append((new_start, new_end))

        # Reorder TMH ends
        tmh_extended_pairs = [
            (all_observed_residues[pdb_code][int(start) - 1] if not np.isnan(start) else np.nan,
             all_observed_residues[pdb_code][int(end) - 1] if not np.isnan(end) else np.nan)
            for start, end in extended_tmh_ranges
        ]

        reordered_tmh_ends = reorder_gpcr_tmh_ends(tmh_extended_pairs)
        print(f"Reordered TMH ends for {pdb_code}: {reordered_tmh_ends}")

        # Store reordered_tmh_ends in pdb_data
        pdb_data[pdb_code] = {"A": reordered_tmh_ends}

    print(pdb_data)  # Debugging step
    return pdb_data  # Return final dictionary

# 🔹 Run main() and store result in a variable
if __name__ == "__main__":
    pdb_data = main()
    print("Final PDB Data:", pdb_data)  # Use or return it outside of main



# Create function to extract coordinates
def get_coords(pdb_id, chain_id, residues):
    pdb_id = pdb_id.lower()
    pdbl = PDB.PDBList()
    coords = [pdb_id.upper(), chain_id] 

    # Create temporary directory
    with tempfile.TemporaryDirectory() as temp_dir:
        pdb_file_path = pdbl.retrieve_pdb_file(pdb_id, pdir=temp_dir, file_format="pdb")

        # Parse the file
        parser = PDB.PDBParser(QUIET=True)
        structure = parser.get_structure(pdb_id, pdb_file_path)

        # Extract resolution
        resolution = "Unknown"
        with open(pdb_file_path, "r") as f:
            for line in f:
                if line.startswith("REMARK   2 RESOLUTION"):
                    resolution = line.split()[3]
                    break

        coords.insert(1, resolution) 
        
        # Loop through file to extract residue coordinates
        for residue_id in residues:
            found = False
            for model in structure:
                if chain_id in model:
                    chain = model[chain_id]
                    if residue_id in chain:
                        residue = chain[residue_id]
                        residue_name = residue.get_resname()
                        for atom in residue:
                            if atom.get_name() == "CA":
                                coords.extend([f"{residue_name}{residue_id}", *atom.coord])
                                found = True
                                break
            if not found:
                coords.extend([f"Unknown{residue_id}", "NA", "NA", "NA"])

    return coords

data = []

# Collect and prepare items to be processed by function
for pdb_id, chains in pdb_data.items():
    for chain_id, residues in chains.items():
        data.append(get_coords(pdb_id, chain_id, residues))


if not pdb_data or all(not chains for chains in pdb_data.values()):
    raise ValueError("pdb_data is empty or contains no residue information.")

max_residues = max(len(residues) for chains in pdb_data.values() for residues in chains.values())
# Organise the data frame for accurate conversion to Excel
max_residues = max(len(residues) for chains in pdb_data.values() for residues in chains.values())

# Define columns dynamically
columns = ["PDB ID", "Resolution", "Chain"] + sum([["Res", "X", "Y", "Z"]] * max_residues, [])

# Convert to DataFrame
df = pd.DataFrame(data, columns=columns)

# Save as Excel output
output_file = "C:/Users/Student/OneDrive - Aston University/Documents/Biology/Project/Landmarks/Automated landmarks/Protein_coordinates.xlsx"
df.to_excel(output_file, index=False)

print(pdb_data)

#Print coordinates as a test
coordinates = get_coords(pdb_id, chain_id, residues)


Structure exists: 'C:\Users\Student\OneDrive - Aston University\Documents\Biology\Project\Project_automation\Python\PDB_files\pdb2r4r.ent' 
Structure exists: 'C:\Users\Student\OneDrive - Aston University\Documents\Biology\Project\Project_automation\Python\PDB_files\pdb2r4s.ent' 
Structure exists: 'C:\Users\Student\OneDrive - Aston University\Documents\Biology\Project\Project_automation\Python\PDB_files\pdb2rh1.ent' 
Structure exists: 'C:\Users\Student\OneDrive - Aston University\Documents\Biology\Project\Project_automation\Python\PDB_files\pdb3kj6.ent' 
Structure exists: 'C:\Users\Student\OneDrive - Aston University\Documents\Biology\Project\Project_automation\Python\PDB_files\pdb3d4s.ent' 
Structure exists: 'C:\Users\Student\OneDrive - Aston University\Documents\Biology\Project\Project_automation\Python\PDB_files\pdb3ny8.ent' 
Structure exists: 'C:\Users\Student\OneDrive - Aston University\Documents\Biology\Project\Project_automation\Python\PDB_files\pdb3ny9.ent' 
Structure exists: 'C