# Test Set Run

This notebook loads the pre-trained TF binding prediction model, runs inference on the held-out test sequence dataset (such as `chr22_sequences.txt.gz`), loads the corresponding ground truth scores (such as `chr22_scores.txt.gz`), and calculates the Pearson's R correlation between predictions and ground truth.

## 1. Library Set-Up

In [None]:
import torch
import torch.nn as nn
import pandas as pd
import os
import numpy as np
from scipy.stats import pearsonr

## 2. Configuration: Update the directory and file names for held-out test dataset

In [None]:
DATA_DIR = '../data'
SEQ_FILE = os.path.join(DATA_DIR, 'chr22_sequences.txt.gz')
SCORE_FILE = os.path.join(DATA_DIR, 'chr22_scores.txt.gz')
MODEL_WEIGHTS_FILE = 'lee-inhyeok-model.pth'

## 3. Helper Functions

In [None]:
def one_hot_encode(sequence):
    mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
    seq_len = len(sequence)
    encoded = torch.zeros((4, seq_len), dtype=torch.float32)
    for i, base in enumerate(sequence.upper()):
        idx = mapping.get(base, -1)
        if idx != -1:
            encoded[idx, i] = 1.0
    return encoded

def calculate_pearsonr(preds, targets):
    if isinstance(preds, torch.Tensor):
        preds = preds.detach().cpu().numpy()
    if isinstance(targets, torch.Tensor):
        targets = targets.detach().cpu().numpy()
        
    preds_flat = preds.flatten()
    targets_flat = targets.flatten()

    valid_indices = np.isfinite(preds_flat) & np.isfinite(targets_flat)
    preds_flat = preds_flat[valid_indices]
    targets_flat = targets_flat[valid_indices]

    if len(preds_flat) < 2:
        print("Warning: Not enough valid data points to calculate Pearson R.")
        return 0.0

    try:
        r, p_value = pearsonr(preds_flat, targets_flat)
        print(f"Pearson R p-value: {p_value}")
        return r if np.isfinite(r) else 0.0
    except ValueError as e:
        print(f"Error calculating Pearson R: {e}")
        return 0.0

## 4. Load Model and Weights

In [None]:
print("Loading TorchScript model...")
model = None
model_loaded_successfully = False
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS device (Apple Silicon GPU).")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA device.")
else:
    device = torch.device("cpu")
    print("Using CPU device.")

try:
    model = torch.jit.load(MODEL_WEIGHTS_FILE, map_location=device)
    model.to(device)
    model.eval()
    print(f"Successfully loaded TorchScript model from {MODEL_WEIGHTS_FILE}")
    model_loaded_successfully = True
except FileNotFoundError:
    print(f"Error: Model file not found at {MODEL_WEIGHTS_FILE}. Cannot proceed.")
except Exception as e:
    print(f"Error loading TorchScript model: {e}")

if not model_loaded_successfully:
     print("Model loading failed. Cannot proceed.")

## 5. Load Sequence and Score Data

In [None]:
print("Loading sequence and score data...")
sequences_df = None
scores_df = None
data_loaded_successfully = False
if model_loaded_successfully:
    try:
        sequences_df = pd.read_csv(SEQ_FILE, sep="\t", compression='gzip')
        scores_df = pd.read_csv(SCORE_FILE, sep="\t", compression='gzip')
        print(f"Loaded {len(sequences_df)} sequences and {len(scores_df.columns)} score vectors.")
        
        if len(sequences_df) != len(scores_df.columns):
            print(f"Warning: Number of sequences ({len(sequences_df)}) does not match number of score columns ({len(scores_df.columns)}). Ensure they correspond correctly.")
            min_len = min(len(sequences_df), len(scores_df.columns))
            sequences_df = sequences_df.iloc[:min_len]
            scores_df = scores_df.iloc[:, :min_len]
            print(f"Proceeding with {min_len} sequence/score pairs.")
        elif 'sequence' not in sequences_df.columns:
             print("Warning: 'sequence' column not found in sequences file.")
             raise ValueError("'sequence' column missing from sequence file.")
        else:
             data_loaded_successfully = True
            
    except FileNotFoundError as e:
        print(f"Error: Data file not found. Ensure '{e.filename}' exists in '{DATA_DIR}'.")
    except Exception as e:
        print(f"Error loading or processing data: {e}")

if data_loaded_successfully:
    print("Data loaded successfully.")
else:
    print("Data loading failed or was skipped. Cannot proceed with analysis.")

## 6. Prepare Input and Ground Truth Data

In [None]:
print("Preparing input and ground truth data...")
inputs_tensor = None
ground_truth_tensor = None

if data_loaded_successfully:
    try:
        all_sequences = sequences_df['sequence'].tolist()
        all_inputs_list = [one_hot_encode(seq) for seq in all_sequences]

        all_score_tensors = [torch.tensor(scores_df[col].values, dtype=torch.float32) for col in scores_df.columns]

        inputs_tensor = torch.stack(all_inputs_list).to(device)
        ground_truth_tensor = torch.stack(all_score_tensors).to(device)

        print(f"Prepared {len(all_inputs_list)} inputs and corresponding ground truth scores.")
        print(f"Input tensor shape: {inputs_tensor.shape}")
        print(f"Ground truth tensor shape: {ground_truth_tensor.shape}")
        
    except Exception as e:
        print(f"Error during data preparation: {e}")
        inputs_tensor = None
        ground_truth_tensor = None
        data_loaded_successfully = False

else:
    print("Skipping data preparation due to previous loading errors.")

## 7. Predict on the Test Dataset

In [None]:
print("Performing inference on the test dataset...")
predictions = None
if model_loaded_successfully and data_loaded_successfully and inputs_tensor is not None:
    try:
        with torch.no_grad():
            predictions = model(inputs_tensor)
        print(f"Inference complete. Prediction tensor shape: {predictions.shape}")
    except Exception as e:
        print(f"Error during inference: {e}")
        predictions = None
else:
    print("Skipping inference due to missing model, data, or prepared input tensor.")

## 8. Calculate Pearson's R Correlation

In [None]:
print("Calculating Pearson's R correlation...")
if predictions is not None and ground_truth_tensor is not None:
    if predictions.shape == ground_truth_tensor.shape:
        try:
            overall_pearson_r = calculate_pearsonr(predictions, ground_truth_tensor)
            print(f"\nOverall Pearson's R score: {overall_pearson_r:.4f}")
        except Exception as e:
            print(f"Error calculating final Pearson R: {e}")
    else:
        print(f"Error: Shape mismatch between predictions ({predictions.shape}) and ground truth ({ground_truth_tensor.shape}). Cannot calculate correlation.")
        print("Please check data loading and preparation steps.")
else:
    print("Skipping correlation calculation due to missing predictions or ground truth data, or errors during inference.")