In [6]:
import os
import subprocess
import shutil
import tempfile
from concurrent.futures import ThreadPoolExecutor

import numpy as np
import torch
import pandas as pd

from Bio import PDB, SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
import pydssp


# === 1️⃣ Parallel PDB Downloading ===
def download_pdb(pdb_code, save_dir):
    """Downloads a PDB file in parallel."""
    pdb_file = PDB.PDBList()
    pdb_file.retrieve_pdb_file(pdb_code, file_format="pdb", pdir=save_dir, overwrite=False)


def download_all_pdbs(pdb_codes, save_dir):
    """Downloads multiple PDB files in parallel."""
    os.makedirs(save_dir, exist_ok=True)
    with ThreadPoolExecutor() as executor:
        executor.map(download_pdb, pdb_codes, [save_dir] * len(pdb_codes))


# === 2️⃣ Amino Acid Conversion ===
AA_DICT = {
    "ALA": "A", "ARG": "R", "ASN": "N", "ASP": "D", "CYS": "C",
    "GLN": "Q", "GLU": "E", "GLY": "G", "HIS": "H", "ILE": "I",
    "LEU": "L", "LYS": "K", "MET": "M", "PHE": "F", "PRO": "P",
    "SER": "S", "THR": "T", "TRP": "W", "TYR": "Y", "VAL": "V"
}


def three_to_one(resname):
    """Converts 3-letter residue name to 1-letter code."""
    return AA_DICT.get(resname, "X")  # 'X' for unknown residues


# === 3️⃣ Extract FASTA from PDB ===
def extract_pdb_fasta(pdb_code, pdb_dir, chain_id, fasta_dir):
    """Extracts the sequence of a specific chain from a PDB file and writes it as FASTA."""
    pdb_filepath = os.path.join(pdb_dir, f"pdb{pdb_code}.ent")
    fasta_filepath = os.path.join(fasta_dir, f"{pdb_code}.fasta")

    parser = PDB.PDBParser(QUIET=True)
    structure = parser.get_structure(pdb_code, pdb_filepath)

    # Only process first model (speed optimization)
    model = structure[0]
    sequence, observed_residues = [], []

    chain = model[chain_id] if chain_id in model else None
    if chain:
        for residue in chain.get_residues():
            if PDB.is_aa(residue):
                sequence.append(three_to_one(residue.get_resname()))
                observed_residues.append(residue.id[1])

    # Write FASTA
    fasta_seq = SeqRecord(Seq("".join(sequence)), id=f"{pdb_code}_{chain_id}", description="")
    SeqIO.write(fasta_seq, fasta_filepath, "fasta")

    return observed_residues


# === 4️⃣ Run DeepTMHMM Asynchronously ===
def run_deeptmhmm(pdb_code, fasta_filepath, results_dir):
    """Runs DeepTMHMM for a single PDB code asynchronously."""
    pdb_results_dir = os.path.join(results_dir, pdb_code)
    os.makedirs(pdb_results_dir, exist_ok=True)

    process = subprocess.Popen(
        ["biolib", "run", "DTU/DeepTMHMM", "--fasta", fasta_filepath],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True,
        cwd=pdb_results_dir,
        start_new_session=True  # Allows independent execution
    )
    return process


def run_all_deeptmhmm(pdb_codes, fasta_dir, results_dir):
    """Runs DeepTMHMM for multiple PDB codes in parallel asynchronously."""
    os.makedirs(results_dir, exist_ok=True)

    with ThreadPoolExecutor() as executor:
        processes = {
            pdb_code: executor.submit(run_deeptmhmm, pdb_code, os.path.join(fasta_dir, f"{pdb_code}.fasta"), results_dir)
            for pdb_code in pdb_codes
        }

    # Wait for all processes to complete
    for pdb_code, future in processes.items():
        process = future.result()
        stdout, stderr = process.communicate()
        print(f"[{pdb_code}] STDOUT:", stdout)
        print(f"[{pdb_code}] STDERR:", stderr)


def keep_only_tmr(results_dir, pdb_codes):
    """Removes all files except TMRs.gff3 in each PDB results folder."""
    for pdb_code in pdb_codes:
        pdb_results_dir = os.path.join(results_dir, pdb_code, "biolib_results")  # Ensure correct subfolder

        if os.path.exists(pdb_results_dir):
            for filename in os.listdir(pdb_results_dir):
                file_path = os.path.join(pdb_results_dir, filename)

                if filename != "TMRs.gff3":
                    try:
                        os.remove(file_path)  # Remove only files, leave the folder
                    except Exception as e:
                        print(f"Warning: Could not remove {file_path}: {e}")



# === 5️⃣ Define Directories and Execute ===
pdb_codes = ["3d4s", "2rh1", "2r4r"]
pdb_dir = r"C:\Users\Student\OneDrive - Aston University\Documents\Biology\Project\Project_automation\Python\PDB_files"
fasta_dir = r"C:\Users\Student\OneDrive - Aston University\Documents\Biology\Project\Project_automation\Python\Fasta_files"
results_dir = r"C:\Users\Student\OneDrive - Aston University\Documents\Biology\Project\Project_automation\Python\DeepTMHMM_results"

# Download PDB files in parallel
download_all_pdbs(pdb_codes, pdb_dir)

# Extract sequences
chain_id = "A"
all_observed_residues = {
    pdb_code: extract_pdb_fasta(pdb_code, pdb_dir, chain_id, fasta_dir) for pdb_code in pdb_codes
}

# Run DeepTMHMM in parallel asynchronously
run_all_deeptmhmm(pdb_codes, fasta_dir, results_dir)
keep_only_tmr(results_dir, pdb_codes)


Structure exists: 'C:\Users\Student\OneDrive - Aston University\Documents\Biology\Project\Project_automation\Python\PDB_files\pdb3d4s.ent' 
Structure exists: 'C:\Users\Student\OneDrive - Aston University\Documents\Biology\Project\Project_automation\Python\PDB_files\pdb2rh1.ent' 
Structure exists: 'C:\Users\Student\OneDrive - Aston University\Documents\Biology\Project\Project_automation\Python\PDB_files\pdb2r4r.ent' 
[3d4s] STDOUT: 2025-03-04 10:21:41,613 | INFO : Extracted zip file to: output/

2025-03-04 10:21:41,613 | INFO : Done in 2.74 seconds

2025-03-04 10:21:41,613 | INFO : Extracted zip file to: output/
2025-03-04 10:21:41,613 | INFO : Done in 2.74 seconds

[3d4s] STDERR: 
[2rh1] STDOUT: 2025-03-04 10:21:53,758 | INFO : Extracted zip file to: output/

2025-03-04 10:21:53,759 | INFO : Done in 2.76 seconds

2025-03-04 10:21:53,758 | INFO : Extracted zip file to: output/
2025-03-04 10:21:53,759 | INFO : Done in 2.76 seconds

[2rh1] STDERR: 
[2r4r] STDOUT: 2025-03-04 10:21:59,058 |

In [7]:
tmh_data = {}
pdb_data = {}

# Align output ranges with actual observed residues
for pdb_code in pdb_codes:
    tmh_ranges = []
    tmh_result_file = f"C:\\Users\\Student\\OneDrive - Aston University\\Documents\\Biology\\Project\\Project_automation\\Python\\DeepTMHMM_results\\{pdb_code}\\biolib_results\\TMRs.gff3"
    
    with open(tmh_result_file) as file:
        for line in file:
            if "TMhelix" in line:
                parts = line.strip().split("\t")
                start, end = int(parts[2]), int(parts[3])
                tmh_ranges.append((start, end))
    
    # Use DSSP
    parser = PDB.PDBParser(QUIET=True)
    structure = parser.get_structure(pdb_code, pdb_filepath)
    
    # Select the first structure in the file and the A chain
    chain = structure[0]['A']
    
    # Retrieve coordinates
    coordinates = []
    for residue in chain:
        if PDB.is_aa(residue):
            res_name = residue.get_resname()
            for atom in residue:
                if atom.get_name() in ['N', 'CA', 'C', 'O']:
                    coordinates.append(atom.coord)
    
    # Prepare eliments of pytorch tensor
    L = sum(1 for residue in chain if PDB.is_aa(residue))
    atoms = 4
    xyz = 3
    
    # Make coordinates into array first to increase efficiency
    coord_array = np.array(coordinates)
    
    # Create tensor
    coord_tensor = torch.tensor(coord_array, dtype=torch.float32).reshape([L, atoms, xyz])
    
    # Use pydssp to get secondary structure
    ss_data = pydssp.assign(coord_tensor, out_type='c3')

    def extend_tmh_ranges(tmh_ranges, ss_data, max_extend=9):
        extended_tmh_ranges = []
    
        for start, end in tmh_ranges:
            # Extend start position backwards if residues are consecutive 'H'
            extended_start = start
            for i in range(1, max_extend + 1):
                prev_res = start - i
                if prev_res >= 0 and ss_data[prev_res] == "H":  # Ensure valid index
                    extended_start = prev_res
                else:
                    break  # Stop if a non-'H' is encountered
    
            # Extend end position forwards if residues are consecutive 'H'
            extended_end = end
            for i in range(1, max_extend + 1):
                next_res = end + i
                if next_res < len(ss_data) and ss_data[next_res] == "H":  # Ensure valid index
                    extended_end = next_res
                else:
                    break  # Stop if a non-'H' is encountered
    
            extended_tmh_ranges.append((extended_start, extended_end))
    
    
        return extended_tmh_ranges
    
    extended_tmh_ranges = extend_tmh_ranges(tmh_ranges, ss_data)
    print(extended_tmh_ranges)

    # Convert TMH positions to actual PDB residue numbers as pairs
    tmh_extended_pairs = [
        (observed_residues[start - 1], observed_residues[end - 1]) for start, end in extended_tmh_ranges
    ]

    def reorder_gpcr_tmh_ends(tmh_extended_pairs):
        pattern = ["extra", "intra", "intra", "extra", "extra", "intra", "intra", 
                   "extra", "extra", "intra", "intra", "extra", "extra", "intra"]
    
        reordered = []
        for i, label in enumerate(pattern):
            if label == "extra":
                reordered.append(tmh_extended_pairs[i // 2][0])  # Take start residue
            else:  # "intra"
                reordered.append(tmh_extended_pairs[i // 2][1])  # Take end residue
    
        return reordered
    
    flattened_tmh = reorder_gpcr_tmh_ends(tmh_extended_pairs)
    
    pdb_data[pdb_code.upper()] = {"A": flattened_tmh}


print(pdb_data)

IndexError: index 371 is out of bounds for axis 0 with size 216