In [11]:
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
import warnings
import requests
import time
import logging
from pathlib import Path

# BioPython imports
from Bio.PDB import MMCIFParser, PDBParser, Superimposer
from Bio.PDB.Polypeptide import three_to_one
from Bio import pairwise2
from Bio.SubsMat import MatrixInfo as matlist

# --- Configuration ---
try:
    SCRIPT_DIR = Path(__file__).resolve().parent
except NameError:
    SCRIPT_DIR = Path.cwd()

PROJECT_ROOT = SCRIPT_DIR.parent
CIF_DIR = PROJECT_ROOT / "Data/CIF_Files/"
AF_DIR = PROJECT_ROOT / "Data/AF_PDB/"
INPUT_DIR = PROJECT_ROOT / "Input/"
OUTPUT_DIR_BASE = PROJECT_ROOT / "Output/"

REP_CHAIN_FILE = OUTPUT_DIR_BASE / "Binding_Residue/Rep_GPCR_chain.csv"
CLASSIFICATION_FILE = OUTPUT_DIR_BASE / "Dynamics/GPCR_PDB_classification.csv"
REP_APO_FILE = OUTPUT_DIR_BASE / "Final/Representative_Apo_Structures.csv"
SEQUENCE_INFO_FILE = INPUT_DIR / 'Human_GPCR_PDB_Info.csv'
CLASS_INFO_FILE = INPUT_DIR / 'ChEMBL_GPCR_Info.csv'

OUTPUT_DIR = OUTPUT_DIR_BASE / "Decision_Tree/"
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
FEATURE_MATRIX_FILE = OUTPUT_DIR / "GPCR_feature_matrix_for_tree.csv"
LOG_FILE = OUTPUT_DIR / "feature_extraction_log.txt"

# --- Setup Logging ---
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(LOG_FILE, mode='w'),
        logging.StreamHandler()
    ]
)
warnings.filterwarnings("ignore", category=UserWarning)



In [45]:
# ==============================================================================
# 1. HELPER FUNCTIONS (Unchanged and verified)
# ==============================================================================
def get_bw_map_from_gpcrdb(entry_name):
    url = f"https://gpcrdb.org/services/residues/{entry_name.lower()}/"
    try:
        response = requests.get(url, timeout=30)
        response.raise_for_status()
        residues = response.json()
        if residues:
            return { res['sequence_number']: res['display_generic_number'].split('x')[0]
                for res in residues if res.get('display_generic_number') }
    except requests.exceptions.RequestException as e:
        logging.warning(f"GPCRdb API call failed for {entry_name}: {e}")
    return None

def get_rep_chain_id(uniprot_id, pdb_id, df_rep_chain):
    subset = df_rep_chain[(df_rep_chain['UniProt_ID'] == uniprot_id) & (df_rep_chain['PDB_ID'] == pdb_id)]
    if subset.empty: return None
    return subset.sort_values(by='score', ascending=False).iloc[0]['chain_id']

def load_structure_and_map(pdb_id, chain_id, uniprot_seq, is_alphafold=False):
    if is_alphafold:
        path = AF_DIR / f"AF-{pdb_id}-F1-model_v3.pdb"
        parser = PDBParser(QUIET=True)
    else:
        path = CIF_DIR / f"{pdb_id.lower()}.cif"
        parser = MMCIFParser(QUIET=True)
    if not path.exists():
        logging.error(f"Structure file not found: {path}")
        return None, None
    try:
        structure = parser.get_structure(str(path.name), path)
        chain_obj = structure[0][chain_id]
        pdb_residues_with_id = {res.id[1]: res for res in chain_obj.get_residues() if res.id[0] == ' ' and 'CA' in res}
        if not pdb_residues_with_id: return None, None
        sorted_pdb_keys = sorted(pdb_residues_with_id.keys())
        pdb_seq = "".join([three_to_one(pdb_residues_with_id[res_id].get_resname()) for res_id in sorted_pdb_keys])
        alignments = pairwise2.align.localds(uniprot_seq, pdb_seq, matlist.blosum62, -10, -0.5)
        if not alignments: return None, None
        uniprot_to_pdb_residue_map = {}
        pdb_key_idx, uni_pos_idx = 0, 0
        for uni_char, pdb_char in zip(alignments[0].seqA, alignments[0].seqB):
            if uni_char != '-': uni_pos_idx += 1
            if pdb_char != '-':
                if uni_char != '-' and pdb_key_idx < len(sorted_pdb_keys):
                    pdb_res_id = sorted_pdb_keys[pdb_key_idx]
                    uniprot_to_pdb_residue_map[uni_pos_idx] = pdb_residues_with_id[pdb_res_id]
                pdb_key_idx += 1
        return structure, uniprot_to_pdb_residue_map
    except Exception as e:
        logging.error(f"Failed to load or parse {path}: {e}")
        return None, None

def get_representative_apo(uniprot_id, df_rep_apo):
    subset = df_rep_apo[df_rep_apo['UniProt_ID'] == uniprot_id].copy()
    covered = subset[subset['Binding_Coverage'] == 100.0]
    if not covered.empty:
        return covered.sort_values(by='Resolution', ascending=True).iloc[0]['PDB_ID']
    return None

In [46]:
# ==============================================================================
# 2. CORE FEATURE CALCULATION WORKFLOW (CRITICAL BUG FIXED)
# ==============================================================================
def process_gpcr(gpcr_row, df_rep_chain, df_seq_info, ref_apo_map, bw_cache):
    uniprot_id = gpcr_row['GPCR']
    uniprot_info = df_seq_info[df_seq_info['Entry'] == uniprot_id].iloc[0]
    uniprot_seq, entry_name = uniprot_info['Sequence'], uniprot_info['Entry Name']
    if uniprot_id not in bw_cache: bw_cache[uniprot_id] = get_bw_map_from_gpcrdb(entry_name)
    uniprot_to_bw = bw_cache.get(uniprot_id)
    if not uniprot_to_bw: return [], bw_cache
    bw_to_uniprot = {bw: pos for pos, bw in uniprot_to_bw.items()}
    uniprot_pos_R3_50 = bw_to_uniprot.get('3.50')
    if not uniprot_pos_R3_50: return [], bw_cache
    ref_pdb_id = ref_apo_map.get(uniprot_id)
    is_ref_alphafold = ref_pdb_id is None
    ref_pdb_id_for_loading = uniprot_id if is_ref_alphafold else ref_pdb_id
    ref_chain_id = 'A' if is_ref_alphafold else get_rep_chain_id(uniprot_id, ref_pdb_id, df_rep_chain)
    if not ref_chain_id: return [], bw_cache
    _, ref_uniprot_to_pdb_map = load_structure_and_map(ref_pdb_id_for_loading, ref_chain_id, uniprot_seq, is_alphafold=is_ref_alphafold)
    if not ref_uniprot_to_pdb_map: return [], bw_cache
    ref_original_ca_coords = {pos: res['CA'].get_coord() for pos, res in ref_uniprot_to_pdb_map.items()}
    
    all_pdbs = {pdb: 1 for pdb in gpcr_row['agonist_bound_PDBs']}
    all_pdbs.update({pdb: 0 for pdb in gpcr_row['antagonist_bound_PDBs']})
    gpcr_features_list = []

    for pdb_id, label in all_pdbs.items():
        chain_id = get_rep_chain_id(uniprot_id, pdb_id, df_rep_chain)
        if not chain_id: continue
        target_structure, target_uniprot_to_pdb_map = load_structure_and_map(pdb_id, chain_id, uniprot_seq)
        if not target_uniprot_to_pdb_map: continue

        common_uniprot_positions = sorted(list(set(ref_original_ca_coords.keys()) & set(target_uniprot_to_pdb_map.keys())))
        if len(common_uniprot_positions) < 20: continue
        
        ref_atoms_for_superimposer = [ref_uniprot_to_pdb_map[pos]['CA'] for pos in common_uniprot_positions]
        target_atoms_for_superimposer = [target_uniprot_to_pdb_map[pos]['CA'] for pos in common_uniprot_positions]
        
        super_imposer = Superimposer()
        super_imposer.set_atoms(ref_atoms_for_superimposer, target_atoms_for_superimposer)
        super_imposer.apply(target_structure.get_atoms())
        
        pdb_feature_dict = {'PDB_ID': pdb_id, 'UniProt_ID': uniprot_id, 'Label': label}
        target_res_R3_50 = target_uniprot_to_pdb_map.get(uniprot_pos_R3_50)
        if not target_res_R3_50: continue

        for uni_pos, target_res in target_uniprot_to_pdb_map.items():
            bw_num = uniprot_to_bw.get(uni_pos)
            if bw_num is None: continue
            
            dist = np.linalg.norm(target_res['CA'].get_coord() - target_res_R3_50['CA'].get_coord())
            pdb_feature_dict[f"Dist_R3.50_{bw_num}"] = dist
            
            original_ref_coord = ref_original_ca_coords.get(uni_pos)
            if original_ref_coord is not None:
                displacement = np.linalg.norm(target_res['CA'].get_coord() - original_ref_coord)
                pdb_feature_dict[f"Disp_{bw_num}"] = displacement
        gpcr_features_list.append(pdb_feature_dict)
    
    return gpcr_features_list, bw_cache

In [47]:
# ==============================================================================
# 3. MAIN EXECUTION BLOCK (with final post-processing)
# ==============================================================================
def main():
    logging.info("--- 🚀 Starting GPCR Decision Tree Feature Matrix Generation (v9) ---")
    
    try:
        df_class_assigned = pd.read_csv(CLASSIFICATION_FILE)
        df_rep_chain = pd.read_csv(REP_CHAIN_FILE)
        df_seq_info = pd.read_csv(SEQUENCE_INFO_FILE)
        df_class_info = pd.read_csv(CLASS_INFO_FILE)
        df_rep_apo = pd.read_csv(REP_APO_FILE)
        df_class_assigned['agonist_bound_PDBs'] = df_class_assigned['agonist_bound_PDBs'].apply(eval)
        df_class_assigned['antagonist_bound_PDBs'] = df_class_assigned['antagonist_bound_PDBs'].apply(eval)
    except FileNotFoundError as e:
        logging.critical(f"FATAL: Input file not found - {e}.")
        return

    ref_apo_map = { uniprot: get_representative_apo(uniprot, df_rep_apo) for uniprot in df_class_assigned['GPCR'].unique() }
    
    df_class_info_subset = df_class_info[['UniProt Accessions', 'Class']]
    df_class_assigned = pd.merge(df_class_assigned, df_class_info_subset, left_on='GPCR', right_on='UniProt Accessions', how='left')
    df_class_assigned['Class'].fillna('Unknown', inplace=True)
    class_dummies = pd.get_dummies(df_class_assigned['Class'], prefix='is_class', dtype=int)
    df_class_assigned = pd.concat([df_class_assigned, class_dummies], axis=1).drop(columns=['UniProt Accessions', 'Class'])
    logging.info("Successfully merged and one-hot encoded GPCR class information.")
    
    all_results_list = []
    bw_cache = {}
    gpcr_rows_list = df_class_assigned.to_dict('records')

    for row in tqdm(gpcr_rows_list, desc="Processing GPCRs"):
        gpcr_pdb_features, bw_cache = process_gpcr(row, df_rep_chain, df_seq_info, ref_apo_map, bw_cache)
        if gpcr_pdb_features:
            class_feature_cols = [c for c in row if c.startswith('is_class_')]
            for pdb_dict in gpcr_pdb_features:
                for col in class_feature_cols: pdb_dict[col] = row[col]
            all_results_list.extend(gpcr_pdb_features)

    if not all_results_list:
        logging.error("No features could be extracted. Please check the log file for detailed errors.")
        return
        
    df_raw_matrix = pd.DataFrame(all_results_list)
    logging.info(f"Successfully generated raw feature matrix. Shape: {df_raw_matrix.shape}")

    logging.info("\n--- Starting Data Cleaning and Finalization ---")
    
    id_cols = ['UniProt_ID', 'PDB_ID', 'Label']
    feature_cols = [col for col in df_raw_matrix.columns if col not in id_cols]
    df_features = df_raw_matrix[feature_cols]

    missing_threshold = 0.20
    n_samples = len(df_features)
    df_filtered = df_features.dropna(axis=1, thresh=int(n_samples * (1 - missing_threshold)))
    
    dropped_cols = df_features.shape[1] - df_filtered.shape[1]
    logging.info(f"Feature Filtering: Dropped {dropped_cols} columns with > {missing_threshold*100}% missing values.")
    
    if df_filtered.isnull().sum().sum() > 0:
        logging.info("Imputing remaining missing values with column medians...")
        df_imputed = df_filtered.fillna(df_filtered.median())
    else:
        logging.info("No missing values remained after filtering. Skipping imputation.")
        df_imputed = df_filtered

    df_final_matrix = pd.concat([df_raw_matrix[id_cols], df_imputed], axis=1)
    logging.info(f"Final clean matrix shape: {df_final_matrix.shape}")
    
    df_final_matrix.to_csv(FEATURE_MATRIX_FILE, index=False)
    logging.info(f"✅ Final clean feature matrix generation complete. Saved to {FEATURE_MATRIX_FILE}")
    
    if not df_final_matrix.empty:
        logging.info("\n--- Final Clean DataFrame Sample ---")
        id_cols = ['UniProt_ID', 'PDB_ID', 'Label']
        class_cols = sorted([c for c in df_final_matrix.columns if c.startswith('is_class_')])
        sample_cols = id_cols + class_cols + [c for c in df_final_matrix.columns if '6.30' in c or '7.53' in c][:4]
        sample_cols = [c for c in sample_cols if c in df_final_matrix.columns]
        logging.info(f"Final DataFrame Shape: {df_final_matrix.shape}")
        logging.info(df_final_matrix[sample_cols].head().to_string())
    else:
        logging.warning("Final feature matrix is empty.")

In [None]:
main()

Decision Tree modeling

In [2]:
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.metrics import accuracy_score, classification_report
import graphviz

# --- 1. Load and Prepare Data ---
file_path = '/home/hyojin0912/Activity/Output/Decision_Tree/GPCR_feature_matrix_for_tree.csv'
df = pd.read_csv(file_path)

X = df.drop(columns=['UniProt_ID', 'PDB_ID', 'Label'])
y = df['Label']

print("--- Data Overview ---")
print(f"Total samples for training: {len(df)}")
print(f"Number of features: {X.shape[1]}")
print("-" * 25)

# --- 2. Train the Decision Tree Model on 100% of the Data ---
dt_full_classifier = DecisionTreeClassifier(random_state=42)

print("\n--- Model Training (on 100% of data) ---")
print("Training a full-depth Decision Tree...")
dt_full_classifier.fit(X, y)
print("Training complete.")
print("-" * 25)

# --- 3. Visualize the Tree using Graphviz ---
print("\n--- Full Decision Tree Visualization (using Graphviz) ---")

# Export the trained tree to a DOT format string
# We still set max_depth=5 to keep the visualization interpretable
dot_data = export_graphviz(
    dt_full_classifier,
    out_file=None, # Output to string instead of file
    feature_names=X.columns,
    class_names=['Antagonist-bound', 'Agonist-bound'],
    filled=True,
    rounded=True,
    special_characters=True,
    max_depth=3
)

# Create a graph from the DOT data
graph = graphviz.Source(dot_data)

# Render and save the graph to a file. This will create a high-quality PNG.
output_filename = "../Output/Decision_Tree/decision_tree_full_visualization_graphviz_v2"
graph.render(output_filename, format='png', cleanup=True)
print(f"Tree visualization saved to '{output_filename}.png'")
print("Please check the new file. It should be rendered correctly.")

# --- 4. Evaluate Model Performance on the Training Data ---
print("\n--- Model Performance Evaluation (on Full Training Set) ---")
y_pred_train = dt_full_classifier.predict(X)
accuracy = accuracy_score(y, y_pred_train)
class_report = classification_report(y, y_pred_train, target_names=['Antagonist', 'Agagonist'])

print(f"Training Accuracy: {accuracy:.4f}")
print("\nClassification Report (on training data):")
print(class_report)
print("-" * 25)

# --- 5. Check the ACTUAL Full Depth of the Tree ---
full_depth = dt_full_classifier.get_depth()
print(f"\n--------------------------------------------------")
print(f"✅ The actual full depth of the trained tree is: {full_depth} levels.")
print(f"--------------------------------------------------")

--- Data Overview ---
Total samples for training: 617
Number of features: 445
-------------------------

--- Model Training (on 100% of data) ---
Training a full-depth Decision Tree...
Training complete.
-------------------------

--- Full Decision Tree Visualization (using Graphviz) ---
Tree visualization saved to '../Output/Decision_Tree/decision_tree_full_visualization_graphviz_v2.png'
Please check the new file. It should be rendered correctly.

--- Model Performance Evaluation (on Full Training Set) ---
Training Accuracy: 1.0000

Classification Report (on training data):
              precision    recall  f1-score   support

  Antagonist       1.00      1.00      1.00       217
   Agagonist       1.00      1.00      1.00       400

    accuracy                           1.00       617
   macro avg       1.00      1.00      1.00       617
weighted avg       1.00      1.00      1.00       617

-------------------------

--------------------------------------------------
✅ The actual 