In [None]:
import tensorflow as tf
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, matthews_corrcoef, accuracy_score, balanced_accuracy_score
from sklearn.model_selection import KFold
from sklearn.preprocessing import RobustScaler # Keep for potential use in prepare_structure_data if needed elsewhere, but not directly used here
import matplotlib.pyplot as plt
import random
import gc # Garbage Collection

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

# --- Data Loading and Preparation Functions ---
def load_prot_t5_data(pos_file, neg_file):
    """Load ProtT5 embeddings and align with existing data"""
    # Read positive and negative files
    pos_data = []
    with open(pos_file, 'r') as f:
        for line in f:
            parts = line.strip().split(',')
            entry = parts[0]
            pos = int(parts[1])
            embeddings = [float(x) for x in parts[2:]]
            pos_data.append((entry, pos, embeddings))

    neg_data = []
    with open(neg_file, 'r') as f:
        for line in f:
            parts = line.strip().split(',')
            entry = parts[0]
            pos = int(parts[1])
            embeddings = [float(x) for x in parts[2:]]
            neg_data.append((entry, pos, embeddings))

    # Convert to dictionaries for easy lookup
    pos_dict = {(entry, pos): emb for entry, pos, emb in pos_data}
    neg_dict = {(entry, pos): emb for entry, pos, emb in neg_data}

    return pos_dict, neg_dict

def prepare_aligned_data(seq_struct_df, pos_dict, neg_dict):
    """Align ProtT5 embeddings with sequence+structure data"""
    embeddings = []
    aligned_indices = []
    original_indices_map = {} # Map new index to original index

    for i, (idx, row) in enumerate(seq_struct_df.iterrows()):
        key = (row['entry'], row['pos'])
        emb = pos_dict.get(key) if row['label'] == 1 else neg_dict.get(key)

        if emb is not None:
            embeddings.append(emb)
            aligned_indices.append(idx)
            original_indices_map[i] = idx # Store mapping: new index -> original df index

    # Convert to numpy array
    X_prot_t5 = np.array(embeddings)

    # Get aligned sequence+structure data
    aligned_df = seq_struct_df.loc[aligned_indices].reset_index(drop=True) # Reset index for easier mapping

    return X_prot_t5, aligned_df

def extract_entry_id(header):
    """Extract entry ID between first and second '|' characters if present, otherwise return as is"""
    if '|' in header:
        try:
            return header.split('|')[1]
        except IndexError: # Handle cases like '>|P12345|-' where split creates ['', 'P12345', '-']
             if len(header.split('|')) > 1:
                 return header.split('|')[1]
             else:
                 print(f"Warning: Could not parse header with pipes: {header}")
                 return header
        except Exception as e:
            print(f"Warning: Could not parse header '{header}': {e}")
            return header
    else:
        # If no pipes, assume it's already an ID
        return header

def load_complete_data(mode='train', exclude_emb=False, base_data_path='../../data', struct_path_override=None):
    """
    Load data line by line ensuring perfect matching between FASTA and ProtT5 files
    Uses the 'has_structure' flag derived from the structure file.
    """
    # Set paths based on mode
    if mode == 'train':
        pos_fasta = f'{base_data_path}/train/fasta/positive_sites.fasta'
        neg_fasta = f'{base_data_path}/train/fasta/negative_sites.fasta'
        # ProtT5 paths are handled later by load_prot_t5_data
        struct_path = struct_path_override if struct_path_override else f"{base_data_path}/train/structure/processed_features_train.csv" # Adjust relative path
    else:  # test
        pos_fasta = f'{base_data_path}/test/fasta/test_positive_sites.fasta'
        neg_fasta = f'{base_data_path}/test/fasta/test_negative_sites.fasta'
        # ProtT5 paths are handled later by load_prot_t5_data
        struct_path = struct_path_override if struct_path_override else f"{base_data_path}/test/structure/processed_features_test.csv" # Adjust relative path

    print(f"Loading structure index from: {struct_path}...")
    try:
        struct_data = pd.read_csv(struct_path)
         # Clean structure data entry IDs only if 'entry' column exists
        if 'entry' in struct_data.columns:
            struct_data['entry'] = struct_data['entry'].apply(lambda x: extract_entry_id(x) if isinstance(x, str) else x)
        else:
            print("Warning: 'entry' column not found in structure file. Cannot clean IDs.")
            # Handle cases where entry might be in another column or format
            # Example: Assuming the first column might contain the ID if 'entry' is missing
            if len(struct_data.columns) > 0 and 'pos' in struct_data.columns:
                 id_col = struct_data.columns[0]
                 print(f"Attempting to use column '{id_col}' as entry ID.")
                 struct_data['entry'] = struct_data[id_col].apply(lambda x: extract_entry_id(x) if isinstance(x, str) else x)
            else:
                 print("Error: Cannot determine entry ID column in structure file.")
                 struct_data = pd.DataFrame(columns=['entry', 'pos']) # Create empty df

    except FileNotFoundError:
        print(f"Warning: Structure file not found at {struct_path}. Assuming no structure data available.")
        struct_data = pd.DataFrame(columns=['entry', 'pos']) # Create empty df to avoid errors

    # Create dictionary for quick lookup of structure data availability
    struct_dict_keys = set()
    if 'entry' in struct_data.columns and 'pos' in struct_data.columns:
        for _, row in struct_data.iterrows():
            # Ensure entry is not NaN or None before creating the key
            if pd.notna(row['entry']) and pd.notna(row['pos']):
                 key = (str(row['entry']), int(row['pos'])) # Ensure consistent types
                 struct_dict_keys.add(key)
            else:
                 print(f"Warning: Skipping row with missing entry/pos in structure file: {row}")
    else:
         print("Warning: 'entry' or 'pos' column missing in structure data. Cannot determine structure availability accurately.")


    print("\nProcessing positive data...")
    positive_data = []

    # Process positive data
    try:
        with open(pos_fasta) as fasta_file:
            fasta_lines = fasta_file.readlines()
    except FileNotFoundError:
        print(f"Error: Positive FASTA file not found at {pos_fasta}")
        fasta_lines = []

    for i in range(0, len(fasta_lines), 2):
        # Process FASTA header
        header = fasta_lines[i].strip()[1:]  # remove '>'
        sequence = fasta_lines[i + 1].strip()

        try:
            header_parts = header.split('|-|')
            entry = extract_entry_id(header_parts[0])
            pos = int(header_parts[1])
        except (IndexError, ValueError):
            print(f"Warning: Skipping malformed positive FASTA header: {header}")
            continue

        # Create data entry
        data_entry = {
            'entry': str(entry), # Ensure consistent type
            'pos': pos,
            'sequence': sequence,
            'label': 1,
            'has_structure': False
        }

        # Check if structure features are available
        if (data_entry['entry'], data_entry['pos']) in struct_dict_keys:
            data_entry['has_structure'] = True

        positive_data.append(data_entry)

    print("Processing negative data...")
    negative_data = []

    # Process negative data
    try:
        with open(neg_fasta) as fasta_file:
            fasta_lines = fasta_file.readlines()
    except FileNotFoundError:
        print(f"Error: Negative FASTA file not found at {neg_fasta}")
        fasta_lines = []

    for i in range(0, len(fasta_lines), 2):
        # Process FASTA header
        header = fasta_lines[i].strip()[1:]  # remove '>'
        sequence = fasta_lines[i + 1].strip()

        try:
            header_parts = header.split('|-|')
            entry = extract_entry_id(header_parts[0])
            pos = int(header_parts[1])
        except (IndexError, ValueError):
            print(f"Warning: Skipping malformed negative FASTA header: {header}")
            continue

        # Create data entry
        data_entry = {
            'entry': str(entry), # Ensure consistent type
            'pos': pos,
            'sequence': sequence,
            'label': 0,
            'has_structure': False
        }

        # Check if structure features are available
        if (data_entry['entry'], data_entry['pos']) in struct_dict_keys:
            data_entry['has_structure'] = True

        negative_data.append(data_entry)

    # Combine all data
    if not positive_data and not negative_data:
         print("Error: No data loaded from FASTA files. Returning empty DataFrame.")
         return pd.DataFrame() # Return empty if no data

    all_data = pd.DataFrame(positive_data + negative_data)

    # Check for duplicates
    duplicates = all_data.duplicated(subset=['entry', 'pos'], keep=False)
    if duplicates.any():
       print(f"\nWarning: Found {duplicates.sum()} duplicate entries based on ('entry', 'pos').")
       # print(all_data[duplicates][['entry', 'pos', 'label']].sort_values(['entry', 'pos']))
       # Decide on handling duplicates if necessary (e.g., remove them)
       # all_data = all_data.drop_duplicates(subset=['entry', 'pos'], keep='first')
       # print("Duplicates removed, keeping first occurrence.")

    # Print statistics
    print("\nDataset statistics:")
    print(f"Total entries loaded: {len(all_data)}")
    if not all_data.empty:
        print(f"Positive examples: {all_data['label'].sum()}")
        print(f"Negative examples: {len(all_data) - all_data['label'].sum()}")
        print(f"Entries marked as having structure: {all_data['has_structure'].sum()}")
        print(f"Unique proteins (entries): {all_data['entry'].nunique()}")
    else:
        print("No data to calculate statistics.")

    return all_data

def prepare_sequence_data(df):
    """Convert sequences to integer encoding"""
    # alphabet = 'ARNDCQEGHILKMFPSTWYV-' # Include gap character if needed
    alphabet = 'ARNDCQEGHILKMFPSTWYV'  # Standard 20 amino acids
    # Add 'X' for unknown and '-' for gap/padding if present in sequences
    valid_chars = set(alphabet)
    char_to_int = {c: i for i, c in enumerate(alphabet)}
    unknown_int = len(alphabet) # Assign next integer to unknown/gap
    char_to_int['-'] = unknown_int
    char_to_int['X'] = unknown_int
    alphabet_size = len(alphabet) + 1 # +1 for the unknown/gap character

    sequences = df['sequence'].values
    encodings = []

    expected_len = 33 # Fixed length for all sequences

    for seq in sequences:
        if len(seq) != expected_len:
             print(f"Warning: Sequence length mismatch. Expected {expected_len}, got {len(seq)}. Sequence: {seq[:10]}...{seq[-10:]}. Padding/Truncating.")
             # Handle mismatch: Pad or truncate
             if len(seq) > expected_len:
                 # Truncate from center (assuming window)
                 start = (len(seq) - expected_len) // 2
                 seq = seq[start : start + expected_len]
             else:
                 # Pad with '-'
                 padding = '-' * (expected_len - len(seq))
                 # Decide padding strategy (e.g., center, end)
                 pad_before = (expected_len - len(seq)) // 2
                 pad_after = expected_len - len(seq) - pad_before
                 seq = '-' * pad_before + seq + '-' * pad_after

        try:
             integer_encoded = [char_to_int.get(char, unknown_int) for char in seq]
            #  integer_encoded = [char_to_int[char] for char in seq if char in char_to_int] # Strict: skip invalid chars
             if len(integer_encoded) != expected_len:
                  # This shouldn't happen with the padding/truncating above, but as a safeguard:
                  print(f"Error after processing: incorrect length {len(integer_encoded)} for sequence {seq}")
                  # Handle error - e.g., skip sequence, use a default encoding
                  # Add a default encoding of the correct length (e.g., all unknowns)
                  encodings.append([unknown_int] * expected_len)
                  continue
             encodings.append(integer_encoded)

        except KeyError as e:
            print(f"Error processing sequence: Invalid character {e} in sequence '{seq}'. Assigning unknown value.")
            # Replace invalid char on the fly or handle as needed
            integer_encoded = [char_to_int.get(char, unknown_int) for char in seq]
            encodings.append(integer_encoded)
        except Exception as e:
            print(f"General error processing sequence: {e} for sequence '{seq}'")
            # Skip or add default encoding
            encodings.append([unknown_int] * expected_len)
            continue

    if not encodings:
         print("Warning: No sequences were successfully encoded.")
         # Return an empty array with the correct shape if possible, or handle upstream
         return np.empty((0, expected_len), dtype=int)

    return np.array(encodings)

# --- Model Creation Functions ---

def create_sequence_model(seq_length=33, alphabet_size=21):
    """Create CNN model for sequence data"""
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(seq_length,)),
        # Ensure Embedding input_dim matches the alphabet size used in prepare_sequence_data
        tf.keras.layers.Embedding(alphabet_size, 21, input_length=seq_length), # output_dim can be adjusted
        tf.keras.layers.Reshape((seq_length, 21, 1)),
        tf.keras.layers.Conv2D(32, kernel_size=(17, 3), activation='relu', padding='valid'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(32, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(1, activation='sigmoid')
    ])
    return model

def create_prot_t5_model(embedding_dim=1024):
    """Create model for ProtT5 embedding vectors"""
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(embedding_dim,)),
        tf.keras.layers.Dense(256),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dropout(0.3),
        tf.keras.layers.Dense(32, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(1, activation='sigmoid')
    ])
    return model

def create_seq_prot_t5_model(seq_length=33, embedding_dim=1024, alphabet_size=21):
    """Create model with sequence and ProtT5 tracks"""
    regularizer = tf.keras.regularizers.l2(0.01)

    # Sequence track
    seq_input = tf.keras.layers.Input(shape=(seq_length,), name='sequence_input')
    # Ensure Embedding input_dim matches the alphabet size
    x_seq = tf.keras.layers.Embedding(alphabet_size, 21, input_length=seq_length)(seq_input)
    x_seq = tf.keras.layers.Reshape((seq_length, 21, 1))(x_seq)
    x_seq = tf.keras.layers.Conv2D(32, kernel_size=(17, 3), activation='relu', padding='valid')(x_seq)
    x_seq = tf.keras.layers.BatchNormalization()(x_seq)
    x_seq = tf.keras.layers.Dropout(0.4)(x_seq)
    x_seq = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x_seq)
    x_seq = tf.keras.layers.Flatten()(x_seq)
    x_seq = tf.keras.layers.Dense(32, activation='relu',
                                 kernel_regularizer=regularizer,
                                 name='seq_features')(x_seq)
    x_seq = tf.keras.layers.BatchNormalization()(x_seq)
    x_seq = tf.keras.layers.Dropout(0.4)(x_seq)

    # ProtT5 track
    prot_t5_input = tf.keras.layers.Input(shape=(embedding_dim,), name='prot_t5_input')
    x_prot_t5 = tf.keras.layers.Dense(256, kernel_regularizer=regularizer)(prot_t5_input)
    x_prot_t5 = tf.keras.layers.BatchNormalization()(x_prot_t5)
    x_prot_t5 = tf.keras.layers.Dropout(0.5)(x_prot_t5)
    x_prot_t5 = tf.keras.layers.Dense(128, activation='relu',
                                     kernel_regularizer=regularizer)(x_prot_t5)
    x_prot_t5 = tf.keras.layers.BatchNormalization()(x_prot_t5)
    x_prot_t5 = tf.keras.layers.Dropout(0.5)(x_prot_t5)

    # Combine features (simple concatenation, remove weighting for now as requested)
    combined = tf.keras.layers.Concatenate()([x_seq, x_prot_t5])

    # Final layers with more regularization
    x = tf.keras.layers.Dense(64, activation='relu',
                            kernel_regularizer=regularizer)(combined)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Dropout(0.5)(x)
    x = tf.keras.layers.Dense(32, activation='relu',
                            kernel_regularizer=regularizer)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Dropout(0.5)(x)
    outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)

    model = tf.keras.Model(inputs=[seq_input, prot_t5_input], outputs=outputs)

    return model


# --- Utility Functions ---
def create_callbacks(patience=5, reduce_lr_patience=3):
    """Creates standard callbacks for training"""
    return [
        tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=patience,
            restore_best_weights=True,
            verbose=1
        ),
        tf.keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=reduce_lr_patience,
            min_lr=1e-6,
            verbose=1
        )
    ]

def calculate_metrics(y_true, y_pred_proba):
    """Calculates standard binary classification metrics"""
    y_pred_binary = (y_pred_proba > 0.5).astype(int)
    metrics = {}
    metrics['acc'] = accuracy_score(y_true, y_pred_binary)
    metrics['balanced_acc'] = balanced_accuracy_score(y_true, y_pred_binary)
    metrics['mcc'] = matthews_corrcoef(y_true, y_pred_binary)

    cm = confusion_matrix(y_true, y_pred_binary)
    tn, fp, fn, tp = cm.ravel()

    metrics['sn'] = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    metrics['sp'] = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    metrics['cm'] = cm
    return metrics

def print_results_summary(model_name, cv_metrics, test_metrics):
    """Prints a formatted summary of CV and Test results"""
    print(f"\n--- Results Summary for: {model_name} ---")

    print("\nAverage Cross-validation Results (Validation Set):")
    for metric_name in ['acc', 'balanced_acc', 'mcc', 'sn', 'sp']:
        mean_val = np.mean(cv_metrics[metric_name])
        std_val = np.std(cv_metrics[metric_name])
        print(f"{metric_name.upper()}: {mean_val:.4f} ± {std_val:.4f}")

    print("\nFinal Test Set Results (No Structure Subset):")
    for metric_name in ['acc', 'balanced_acc', 'mcc', 'sn', 'sp']:
        print(f"{metric_name.upper()}: {test_metrics[metric_name]:.4f}")
    print("Confusion Matrix (Test Set):")
    print(test_metrics['cm'])
    print("-" * (len(model_name) + 25))


# --- Main Training and Evaluation Function ---

def train_evaluate_no_structure_models(n_splits=5, epochs=50, batch_size=32):
    """
    Trains and evaluates Sequence-Only, ProtT5-Only, and Seq+ProtT5 models.
    Trains on full data, tests ONLY on samples without structure.
    """
    # Define data paths (Adjust as needed)
    BASE_DATA_PATH = '../../data'
    TRAIN_POS_PROTT5 = f'{BASE_DATA_PATH}/train/PLM/train_positive_ProtT5-XL-UniRef50.csv'
    TRAIN_NEG_PROTT5 = f'{BASE_DATA_PATH}/train/PLM/train_negative_ProtT5-XL-UniRef50.csv'
    TEST_POS_PROTT5 = f'{BASE_DATA_PATH}/test/PLM/test_positive_ProtT5-XL-UniRef50.csv'
    TEST_NEG_PROTT5 = f'{BASE_DATA_PATH}/test/PLM/test_negative_ProtT5-XL-UniRef50.csv'
    # Let load_complete_data determine structure paths based on BASE_DATA_PATH, or override here:
    STRUCT_PATH_TRAIN = f'{BASE_DATA_PATH}/train/structure/processed_features_train.csv'
    STRUCT_PATH_TEST = f'{BASE_DATA_PATH}/test/structure/processed_features_test.csv'


    # 1. Load Data (Sequence Info + Structure Flag)
    print("Loading training data identifiers...")
    train_df_info = load_complete_data(mode='train', base_data_path=BASE_DATA_PATH, struct_path_override=STRUCT_PATH_TRAIN) # , struct_path_override=STRUCT_PATH_TRAIN)
    print("\nLoading test data identifiers...")
    test_df_info = load_complete_data(mode='test', base_data_path=BASE_DATA_PATH, struct_path_override=STRUCT_PATH_TEST) # , struct_path_override=STRUCT_PATH_TEST)

    if train_df_info.empty or test_df_info.empty:
        print("Error: Could not load necessary data. Exiting.")
        return

    # 2. Load and Align ProtT5 Embeddings
    print("\nLoading ProtT5 embeddings...")
    try:
        train_pos_dict, train_neg_dict = load_prot_t5_data(TRAIN_POS_PROTT5, TRAIN_NEG_PROTT5)
        test_pos_dict, test_neg_dict = load_prot_t5_data(TEST_POS_PROTT5, TEST_NEG_PROTT5)
    except FileNotFoundError as e:
        print(f"Error loading ProtT5 data: {e}. Make sure paths are correct.")
        return

    print("Aligning training data with ProtT5...")
    X_train_prot_t5, train_data_aligned = prepare_aligned_data(train_df_info, train_pos_dict, train_neg_dict)
    print("Aligning test data with ProtT5...")
    X_test_prot_t5, test_data_aligned = prepare_aligned_data(test_df_info, test_pos_dict, test_neg_dict)

    del train_df_info, test_df_info, train_pos_dict, train_neg_dict, test_pos_dict, test_neg_dict # Free memory
    gc.collect()

    print(f"\nAligned training data size: {len(train_data_aligned)}")
    print(f"Aligned test data size: {len(test_data_aligned)}")

    if train_data_aligned.empty or test_data_aligned.empty:
        print("Error: No data after alignment. Check input files and keys ('entry', 'pos').")
        return

    # 3. Prepare Full Training Data
    print("\nPreparing full training data...")
    X_train_seq_all = prepare_sequence_data(train_data_aligned)
    y_train_all = train_data_aligned['label'].values
    # X_train_prot_t5 is already prepared

    # --- Determine Alphabet Size ---
    # Based on the char_to_int mapping in prepare_sequence_data
    # Standard 20 + 1 unknown/gap = 21
    alphabet_size = 21
    seq_length = X_train_seq_all.shape[1] if X_train_seq_all.ndim == 2 and X_train_seq_all.shape[1] > 0 else 33 # Get from data or default
    embedding_dim = X_train_prot_t5.shape[1] if X_train_prot_t5.ndim == 2 and X_train_prot_t5.shape[1] > 0 else 1024 # Get from data or default

    print(f"Using Sequence Length: {seq_length}, Alphabet Size: {alphabet_size}, ProtT5 Dim: {embedding_dim}")


    # Shuffle training data consistently
    shuffle_idx_all = np.random.RandomState(SEED).permutation(len(y_train_all))
    X_train_seq_all = X_train_seq_all[shuffle_idx_all]
    X_train_prot_t5 = X_train_prot_t5[shuffle_idx_all]
    y_train_all = y_train_all[shuffle_idx_all]

    print("Full training data shapes:")
    print(f"X_train_seq_all: {X_train_seq_all.shape}")
    print(f"X_train_prot_t5: {X_train_prot_t5.shape}")
    print(f"y_train_all: {y_train_all.shape}")

    # 4. Prepare Test Data (No Structure Subset Only)
    print("\nPreparing test data subset (no structure)...")
    no_struct_mask_test = ~test_data_aligned['has_structure'].values
    test_data_no_struct = test_data_aligned[no_struct_mask_test].reset_index(drop=True)

    if test_data_no_struct.empty:
         print("Warning: No test samples found without structure data. Cannot perform evaluation on this subset.")
         # Decide how to proceed: exit, or skip evaluation part? For now, let's skip evaluation.
         can_evaluate = False
    else:
         can_evaluate = True
         X_test_seq_no_struct = prepare_sequence_data(test_data_no_struct)
         X_test_prot_t5_no_struct = X_test_prot_t5[no_struct_mask_test]
         y_test_no_struct = test_data_no_struct['label'].values

         print("Test data subset (no structure) shapes:")
         print(f"X_test_seq_no_struct: {X_test_seq_no_struct.shape}")
         print(f"X_test_prot_t5_no_struct: {X_test_prot_t5_no_struct.shape}")
         print(f"y_test_no_struct: {y_test_no_struct.shape}")
         print(f"Number of test samples without structure: {len(y_test_no_struct)}")
         print(f"Positive samples in this subset: {np.sum(y_test_no_struct == 1)}")
         print(f"Negative samples in this subset: {np.sum(y_test_no_struct == 0)}")


    del test_data_aligned, X_test_prot_t5 # Free memory
    gc.collect()

    # 5. Calculate Class Weights (Based on Full Training Set)
    total_samples_all = len(y_train_all)
    pos_samples_all = np.sum(y_train_all == 1)
    neg_samples_all = np.sum(y_train_all == 0)

    if pos_samples_all == 0 or neg_samples_all == 0:
        print("Warning: Training data contains only one class. Class weights cannot be computed effectively.")
        class_weights_all = None # Or {0: 1.0, 1: 1.0}
    else:
        class_weights_all = {
            0: total_samples_all / (2 * neg_samples_all),
            1: total_samples_all / (2 * pos_samples_all)
        }
    print("\nClass weights (full training set):", class_weights_all)

    # 6. K-Fold Cross-Validation and Evaluation
    kfold = KFold(n_splits=n_splits, shuffle=True, random_state=SEED)

    model_types = ['Sequence Only', 'ProtT5 Only', 'Sequence + ProtT5']
    results = {model_name: {'cv_metrics': {m: [] for m in ['acc', 'balanced_acc', 'mcc', 'sn', 'sp']},
                            'test_preds': []}
               for model_name in model_types}

    for fold, (train_idx, val_idx) in enumerate(kfold.split(X_train_seq_all), 1):
        print(f"\n===== FOLD {fold}/{n_splits} =====")

        # --- Split Data for this Fold ---
        X_train_seq_fold, X_val_seq_fold = X_train_seq_all[train_idx], X_train_seq_all[val_idx]
        X_train_p5_fold, X_val_p5_fold = X_train_prot_t5[train_idx], X_train_prot_t5[val_idx]
        y_train_fold, y_val_fold = y_train_all[train_idx], y_train_all[val_idx]

        # --- Train and Evaluate Each Model ---
        for model_name in model_types:
            print(f"\n--- Training {model_name} ---")
            tf.keras.backend.clear_session()
            gc.collect()

            callbacks_fold = create_callbacks()
            model = None # Initialize model variable

            if model_name == 'Sequence Only':
                model = create_sequence_model(seq_length=seq_length, alphabet_size=alphabet_size)
                model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
                              loss='binary_crossentropy', metrics=['accuracy'])
                history = model.fit(X_train_seq_fold, y_train_fold,
                                    validation_data=(X_val_seq_fold, y_val_fold),
                                    epochs=epochs, batch_size=batch_size,
                                    callbacks=callbacks_fold, class_weight=class_weights_all, verbose=0) # Verbose 0 for less output
                # Evaluate on validation set
                y_pred_val_proba = model.predict(X_val_seq_fold)
                # Predict on the NO STRUCTURE TEST SET
                if can_evaluate:
                     test_preds_fold = model.predict(X_test_seq_no_struct)
                     results[model_name]['test_preds'].append(test_preds_fold.flatten())


            elif model_name == 'ProtT5 Only':
                model = create_prot_t5_model(embedding_dim=embedding_dim)
                model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
                              loss='binary_crossentropy', metrics=['accuracy'])
                history = model.fit(X_train_p5_fold, y_train_fold,
                                    validation_data=(X_val_p5_fold, y_val_fold),
                                    epochs=epochs, batch_size=batch_size,
                                    callbacks=callbacks_fold, class_weight=class_weights_all, verbose=0)
                # Evaluate on validation set
                y_pred_val_proba = model.predict(X_val_p5_fold)
                # Predict on the NO STRUCTURE TEST SET
                if can_evaluate:
                     test_preds_fold = model.predict(X_test_prot_t5_no_struct)
                     results[model_name]['test_preds'].append(test_preds_fold.flatten())

            elif model_name == 'Sequence + ProtT5':
                model = create_seq_prot_t5_model(seq_length=seq_length, embedding_dim=embedding_dim, alphabet_size=alphabet_size)
                model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
                              loss='binary_crossentropy', metrics=['accuracy'])
                history = model.fit([X_train_seq_fold, X_train_p5_fold], y_train_fold,
                                    validation_data=([X_val_seq_fold, X_val_p5_fold], y_val_fold),
                                    epochs=epochs, batch_size=batch_size,
                                    callbacks=callbacks_fold, class_weight=class_weights_all, verbose=0)
                # Evaluate on validation set
                y_pred_val_proba = model.predict([X_val_seq_fold, X_val_p5_fold])
                # Predict on the NO STRUCTURE TEST SET
                if can_evaluate:
                    test_preds_fold = model.predict([X_test_seq_no_struct, X_test_prot_t5_no_struct])
                    results[model_name]['test_preds'].append(test_preds_fold.flatten())

            # --- Store Validation Metrics for this fold ---
            val_metrics = calculate_metrics(y_val_fold, y_pred_val_proba)
            print(f"Validation Metrics (Fold {fold}) - {model_name}: ",
                  f"Acc: {val_metrics['acc']:.4f}, BAcc: {val_metrics['balanced_acc']:.4f}, MCC: {val_metrics['mcc']:.4f}")
            for metric_name in ['acc', 'balanced_acc', 'mcc', 'sn', 'sp']:
                results[model_name]['cv_metrics'][metric_name].append(val_metrics[metric_name])

            del model, history, y_pred_val_proba, val_metrics # Clean up memory
            if can_evaluate:
                del test_preds_fold
            gc.collect()


    # 7. Final Evaluation on Test Set (No Structure Subset)
    print("\n===== FINAL EVALUATION =====")

    if not can_evaluate:
        print("Skipping final evaluation because no test samples without structure were found.")
        return results # Return partial results

    final_results = {}
    for model_name in model_types:
        # Average predictions across folds
        if not results[model_name]['test_preds']:
             print(f"Warning: No test predictions recorded for {model_name}. Skipping final evaluation.")
             continue

        avg_test_preds = np.mean(results[model_name]['test_preds'], axis=0)

        # Calculate final metrics on the no-structure test set
        test_metrics = calculate_metrics(y_test_no_struct, avg_test_preds)
        final_results[model_name] = test_metrics

        # Print Summary
        print_results_summary(model_name, results[model_name]['cv_metrics'], test_metrics)

    return final_results # Or return the full 'results' dict if needed


if __name__ == "__main__":
    # Set main parameters
    N_SPLITS = 5
    EPOCHS = 50 # Adjust as needed, EarlyStopping will terminate sooner
    BATCH_SIZE = 32

    # Run the training and evaluation
    final_model_results = train_evaluate_no_structure_models(
        n_splits=N_SPLITS,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE
    )

    # You can access final results like:
    # seq_only_test_mcc = final_model_results['Sequence Only']['mcc']
    print("\nScript finished.")

2025-04-13 11:50:21.922262: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-04-13 11:50:23.080799: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/lib/x86_64-linux-gnu/:
2025-04-13 11:50:23.080913: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/lib/x86_64-linux-gnu/:


Loading training data identifiers...
Loading structure index from: ../data/processed_features_fixed_train_contactmap.csv...

Processing positive data...
Processing negative data...


Dataset statistics:
Total entries loaded: 9500
Positive examples: 4750
Negative examples: 4750
Entries marked as having structure: 8853
Unique proteins (entries): 2193

Loading test data identifiers...
Loading structure index from: ../data/processed_features_fixed_test_contactmap.csv...

Processing positive data...
Processing negative data...

Dataset statistics:
Total entries loaded: 3224
Positive examples: 253
Negative examples: 2971
Entries marked as having structure: 2737
Unique proteins (entries): 123

Loading ProtT5 embeddings...
Aligning training data with ProtT5...
Aligning test data with ProtT5...

Aligned training data size: 9500
Aligned test data size: 3224

Preparing full training data...
Using Sequence Length: 33, Alphabet Size: 21, ProtT5 Dim: 1024
Full training data shapes:
X_train_seq_all: 

2025-04-13 11:50:33.822551: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:981] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2025-04-13 11:50:33.863091: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudnn.so.8'; dlerror: libcudnn.so.8: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/lib/x86_64-linux-gnu/:
2025-04-13 11:50:33.863131: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1934] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
2025-04-13 11:50:33.864352: I tensorflow/core/platform/cpu_feature_guard.cc:193] This Tensor


Epoch 9: ReduceLROnPlateau reducing learning rate to 0.0005000000237487257.

Epoch 13: ReduceLROnPlateau reducing learning rate to 0.0002500000118743628.
Restoring model weights from the end of the best epoch: 10.
Epoch 15: early stopping
Validation Metrics (Fold 1) - Sequence Only:  Acc: 0.7489, BAcc: 0.7483, MCC: 0.4996

--- Training ProtT5 Only ---

Epoch 7: ReduceLROnPlateau reducing learning rate to 0.0005000000237487257.
Restoring model weights from the end of the best epoch: 4.
Epoch 9: early stopping
Validation Metrics (Fold 1) - ProtT5 Only:  Acc: 0.7411, BAcc: 0.7402, MCC: 0.4852

--- Training Sequence + ProtT5 ---

Epoch 13: ReduceLROnPlateau reducing learning rate to 0.0005000000237487257.

Epoch 23: ReduceLROnPlateau reducing learning rate to 0.0002500000118743628.

Epoch 28: ReduceLROnPlateau reducing learning rate to 0.0001250000059371814.
Restoring model weights from the end of the best epoch: 25.
Epoch 30: early stopping
Validation Metrics (Fold 1) - Sequence + ProtT5