In [None]:
# format_dataset.py

import os
import json
import pickle
import numpy as np
from tqdm import tqdm

# Import config to get paths to the source files
import config

def format_epitope_dataset():
    """
    Reads the structured .pkl dataset and the pre-computed splits,
    and reformats them into a directory containing separate .npy files for
    train, validation, and test sets, along with a metadata info.json file.
    """
    # --- Configuration ---
    # The name of the output directory for the newly formatted dataset
    DATASET_NAME = 'epitope_prediction'
    TARGET_DIR = os.path.join('output', DATASET_NAME)
    DATA_PERCENTAGE_TO_USE = 1.0

    print("--- Starting Dataset Formatting ---")

    # --- 1. Check for source files ---
    if not os.path.exists(config.STRUCTURED_DATA_PATH):
        print(f"Error: Source file not found at '{config.STRUCTURED_DATA_PATH}'")
        print("Please run the feature engineering pipeline first.")
        return
    if not os.path.exists(config.SPLITS_FILE_PATH):
        print(f"Error: Splits file not found at '{config.SPLITS_FILE_PATH}'")
        print("Please run the sequence clustering step in the pipeline first.")
        return

    # --- 2. Load the source data and splits ---
    print(f"Loading structured data from: {config.STRUCTURED_DATA_PATH}")
    with open(config.STRUCTURED_DATA_PATH, 'rb') as f:
        protein_data_list = pickle.load(f)

    print(f"Loading data splits from: {config.SPLITS_FILE_PATH}")
    with open(config.SPLITS_FILE_PATH, 'r') as f:
        splits = json.load(f)
        
    if DATA_PERCENTAGE_TO_USE < 1.0:
        print(f"\nSubsetting the data to use {DATA_PERCENTAGE_TO_USE * 100}% of proteins (stratified).")
        proteins_with_positives = []
        proteins_without_positives = []
        for p in protein_data_list:
            if np.any(p['df_stats']['is_epitope'].values):
                proteins_with_positives.append(p['pdb_id'])
            else:
                proteins_without_positives.append(p['pdb_id'])
        print(f"Found {len(proteins_with_positives)} proteins with epitopes and {len(proteins_without_positives)} without.")
        np.random.shuffle(proteins_with_positives)
        np.random.shuffle(proteins_without_positives)
        num_pos_to_take = int(len(proteins_with_positives) * DATA_PERCENTAGE_TO_USE)
        num_neg_only_to_take = int(len(proteins_without_positives) * DATA_PERCENTAGE_TO_USE)
        if len(proteins_with_positives) > 0 and num_pos_to_take == 0:
            num_pos_to_take = 1
        selected_pos_proteins = proteins_with_positives[:num_pos_to_take]
        selected_neg_only_proteins = proteins_without_positives[:num_neg_only_to_take]
        selected_protein_ids = set(selected_pos_proteins + selected_neg_only_proteins)
        print(f"Selected {len(selected_pos_proteins)} proteins with epitopes and {len(selected_neg_only_proteins)} without.")
        print(f"Total proteins to use: {len(selected_protein_ids)} out of {len(protein_data_list)}.")
        protein_data_list = [p for p in protein_data_list if p['pdb_id'] in selected_protein_ids]

    # --- 4. Reconstruct the full, flat dataset arrays ---
    print("\nReconstructing flat data arrays from the list of proteins...")
    features, labels, groups = [], [], []
    for protein_data in tqdm(protein_data_list, desc="Processing proteins"):
        features.append(protein_data['X_arr'])
        labels.append(protein_data['df_stats']['is_epitope'].values)
        groups.append(np.full(protein_data['length'], protein_data['pdb_id']))
        
    if not features:
        print("Error: No data left after subsetting. Check your percentage or source data.")
        return

    X = np.vstack(features)
    y = np.concatenate(labels)
    groups = np.concatenate(groups)
    print(f"Full dataset reconstructed. Total residues (samples): {len(y)}, Features: {X.shape[1]}")

    # --- 5. Partition the data based on the pre-computed splits ---
    print("Applying train/validation/test splits...")
    train_groups = splits['train']
    val_groups = splits['val']
    test_groups = splits['test']

    # Create boolean masks to select rows for each split
    # The masks will only be true for proteins in the selected subset
    train_mask = np.isin(groups, train_groups)
    val_mask = np.isin(groups, val_groups)
    test_mask = np.isin(groups, test_groups)

    # Apply masks to get the final data arrays
    X_train, y_train = X[train_mask], y[train_mask]
    X_val, y_val = X[val_mask], y[val_mask]
    X_test, y_test = X[test_mask], y[test_mask]

    print("\nSplit sizes:")
    print(f"  - Train: {len(y_train)} residues from {len(np.unique(groups[train_mask]))} proteins. Shape: {X_train.shape}")
    print(f"  - Validation: {len(y_val)} residues from {len(np.unique(groups[val_mask]))} proteins. Shape: {X_val.shape}")
    print(f"  - Test: {len(y_test)} residues from {len(np.unique(groups[test_mask]))} proteins. Shape: {X_test.shape}")

    # --- 6. Save the data into the target directory structure ---
    print(f"\nSaving formatted data to directory: '{TARGET_DIR}'")
    os.makedirs(TARGET_DIR, exist_ok=True)

    # In your case, all features are numerical, so we use the 'X_num' prefix
    np.save(os.path.join(TARGET_DIR, 'X_num_train.npy'), X_train.astype(np.float32))
    np.save(os.path.join(TARGET_DIR, 'X_num_val.npy'), X_val.astype(np.float32))
    np.save(os.path.join(TARGET_DIR, 'X_num_test.npy'), X_test.astype(np.float32))

    np.save(os.path.join(TARGET_DIR, 'y_train.npy'), y_train)
    np.save(os.path.join(TARGET_DIR, 'y_val.npy'), y_val)
    np.save(os.path.join(TARGET_DIR, 'y_test.npy'), y_test)

    # --- 7. Create and save the info.json metadata file ---
    info = {
        'name': DATASET_NAME,
        'task_type': 'binclass',  # This is a binary classification task
        'train_size': len(y_train),
        'val_size': len(y_val),
        'test_size': len(y_test),
        'n_num_features': X_train.shape[1],
        'n_cat_features': 0,  # You have no categorical features in this model
    }

    info_path = os.path.join(TARGET_DIR, 'info.json')
    with open(info_path, 'w') as f:
        json.dump(info, f, indent=4)
    print(f"Metadata saved to: {info_path}")

    print("\n--- Formatting Complete! ---")


if __name__ == '__main__':
    format_epitope_dataset()

--- Starting Dataset Formatting ---
Loading structured data from: output/structured_protein_data.pkl
Loading data splits from: output/data_splits.json

Subsetting the data to use 25.0% of proteins (stratified).
Found 5071 proteins with epitopes and 2 without.
Selected 1267 proteins with epitopes and 0 without.
Total proteins to use: 1267 out of 5073.

Reconstructing flat data arrays from the list of proteins...


Processing proteins: 100%|██████████| 1267/1267 [00:00<00:00, 93988.14it/s]

Full dataset reconstructed. Total residues (samples): 475294, Features: 24
Applying train/validation/test splits...

Split sizes:
  - Train: 198906 residues from 600 proteins. Shape: (198906, 24)
  - Validation: 26865 residues from 80 proteins. Shape: (26865, 24)
  - Test: 55533 residues from 180 proteins. Shape: (55533, 24)

Saving formatted data to directory: 'output/epitope_prediction'
Metadata saved to: output/epitope_prediction/info.json

--- Formatting Complete! ---



