## 1. Setup and Imports

In [None]:
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import logging
from collections import Counter

# Ensure repo root is on sys.path (works from repo root or notebooks/)
repo_root = Path.cwd()
if not (repo_root / 'src').exists():
    repo_root = repo_root.parent
sys.path.insert(0, str(repo_root))

try:
    import seaborn as sns
    sns.set_style('whitegrid')
except ImportError:
    sns = None

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader

from src.data_acquisition import KaggleDataAcquisition
from src.datasets import CAFA5Dataset
from src.models.multimodal import VibroStructuralModel
from src.models.losses import FocalLoss, WeightedBCELoss
from src.training import Trainer, MetricComputer
from src.utils import Logger, set_seed, get_device, batch_collate_function

logger = Logger.setup('QDD-CAFA5', level=logging.INFO)
set_seed(42)
device = get_device()

plt.rcParams['figure.figsize'] = (14, 6)

data_cafa_dir = repo_root / 'data' / 'cafa5'
spectral_dir = data_cafa_dir / 'spectral'
structures_dir = data_cafa_dir / 'structures'
checkpoints_dir = repo_root / 'checkpoints'
for p in [data_cafa_dir, spectral_dir, structures_dir, checkpoints_dir]:
    p.mkdir(parents=True, exist_ok=True)

logger.info('Setup complete!')


## 2. Download and Explore CAFA 5 Data

In [None]:
# Download CAFA 5 data
logger.info("Downloading CAFA 5 competition data...")
kaggle_acq = KaggleDataAcquisition(output_dir=str(data_cafa_dir))

# Note: Uncomment to download
# train_terms, test_seqs, go_annot = kaggle_acq.download_cafa5()

# For this demo, check if files exist
train_terms_file = data_cafa_dir / 'train_terms.csv'
train_seqs_file = data_cafa_dir / 'train_sequences.fasta'
test_seqs_file = data_cafa_dir / 'test_sequences.fasta'
go_vocab_file = data_cafa_dir / 'go_vocabulary.csv'

if not train_terms_file.exists():
    logger.warning("CAFA 5 data not found. Please run: kaggle competitions download -c cafa-5-protein-function-prediction -p ./data/cafa5")
else:
    logger.info("CAFA 5 data found!")

## 3. Load and Explore GO Annotations

In [None]:
# Load training terms (normalize columns to protein_id/go_term for this notebook)
try:
    df_terms = pd.read_csv(train_terms_file)
    # Normalize common column names
    if {'target_id', 'go_id'}.issubset(df_terms.columns):
        df_terms = df_terms.rename(columns={'target_id': 'protein_id', 'go_id': 'go_term'})
    elif {'EntryID', 'term'}.issubset(df_terms.columns):
        df_terms = df_terms.rename(columns={'EntryID': 'protein_id', 'term': 'go_term'})
    elif {'entry_id', 'term'}.issubset(df_terms.columns):
        df_terms = df_terms.rename(columns={'entry_id': 'protein_id', 'term': 'go_term'})

    logger.info(f"Training terms shape: {df_terms.shape}")
    logger.info(f"\nFirst few rows:")
    print(df_terms.head(10))

    logger.info(f"\nData info:")
    print(df_terms.info())
except FileNotFoundError:
    logger.warning("train_terms.csv not found. Creating a demo dataset so the notebook can run end-to-end.")
    df_terms = pd.DataFrame({
        'protein_id': [f'protein_{i}' for i in range(500)],
        'go_term': [f'GO:{np.random.randint(1000000, 9999999):07d}' for _ in range(500)]
    })
    df_terms.to_csv(train_terms_file, index=False)
    logger.info(f"Demo dataset created: {df_terms.shape[0]} annotations")

# Ensure we have sequence FASTA files for training/inference
if not train_seqs_file.exists():
    proteins = sorted(df_terms['protein_id'].unique())[:200]
    with open(train_seqs_file, 'w') as f:
        for pid in proteins:
            f.write(f">{pid}\nACDEFGHIKLMNPQRSTVWY\n")
    logger.info(f"Wrote demo train_sequences.fasta with {len(proteins)} proteins")

if not test_seqs_file.exists():
    with open(test_seqs_file, 'w') as f:
        for i in range(50):
            f.write(f">test_{i}\nACDEFGHIKLMNPQRSTVWY\n")
    logger.info("Wrote demo test_sequences.fasta with 50 proteins")


In [None]:
# Analyze GO term distribution
logger.info("Analyzing GO term distribution...")

# Count unique proteins and terms
n_proteins = df_terms['protein_id'].nunique()
n_go_terms = df_terms['go_term'].nunique()
n_annotations = len(df_terms)

logger.info(f"Unique proteins: {n_proteins:,}")
logger.info(f"Unique GO terms: {n_go_terms:,}")
logger.info(f"Total annotations: {n_annotations:,}")
logger.info(f"Avg terms per protein: {n_annotations / n_proteins:.2f}")

# Analyze term frequency
term_counts = df_terms['go_term'].value_counts()
logger.info(f"\nGO term frequency:")
logger.info(f"  Max: {term_counts.max()}")
logger.info(f"  Min: {term_counts.min()}")
logger.info(f"  Median: {term_counts.median()}")

In [None]:
# Visualize GO distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Top GO terms
top_go = df_terms['go_term'].value_counts().head(15)
axes[0].barh(range(len(top_go)), top_go.values, color='steelblue')
axes[0].set_yticks(range(len(top_go)))
axes[0].set_yticklabels(top_go.index, fontsize=9)
axes[0].set_xlabel('Frequency', fontsize=11)
axes[0].set_title('Top 15 Most Frequent GO Terms', fontsize=12, fontweight='bold')
axes[0].grid(axis='x', alpha=0.3)

# Terms per protein distribution
terms_per_protein = df_terms.groupby('protein_id').size()
axes[1].hist(terms_per_protein, bins=30, color='coral', edgecolor='black', alpha=0.7)
axes[1].set_xlabel('Number of GO Terms', fontsize=11)
axes[1].set_ylabel('Number of Proteins', fontsize=11)
axes[1].set_title('GO Terms per Protein Distribution', fontsize=12, fontweight='bold')
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

logger.info("GO analysis complete!")

## 4. Create GO Term Vocabulary

In [None]:
# Build GO vocabulary
logger.info("Building GO term vocabulary...")

unique_go_terms = sorted(df_terms['go_term'].unique())
go_to_idx = {go: idx for idx, go in enumerate(unique_go_terms)}
idx_to_go = {idx: go for go, idx in go_to_idx.items()}

logger.info(f"Vocabulary size: {len(go_to_idx)} GO terms")
logger.info("\nSample terms:")
for go_term in unique_go_terms[:10]:
    idx = go_to_idx[go_term]
    count = (df_terms['go_term'] == go_term).sum()
    logger.info(f"  {idx}: {go_term} (n={count})")

# Save vocabulary (use repo_root-relative paths; safe when executed from notebooks/)
vocab_df = pd.DataFrame({
    'go_term': unique_go_terms,
    'index': [go_to_idx[go] for go in unique_go_terms]
})
vocab_df.to_csv(go_vocab_file, index=False)
logger.info(f"\nVocabulary saved to {go_vocab_file}")


## 5. Create Dataset and DataLoaders

In [None]:
# Create CAFA 5 dataset
logger.info("Creating CAFA 5 dataset...")

dataset = CAFA5Dataset(
    sequences_fasta=str(train_seqs_file),
    terms_csv=str(train_terms_file),
    spectra_dir=str(spectral_dir),
    structure_dir=str(structures_dir),
    go_terms_list=unique_go_terms,
)
logger.info(f"Dataset created: {len(dataset)} proteins")

# Split into train/val and reuse val as test (robust for tiny demo datasets)
n_total = len(dataset)
if n_total <= 0:
    raise ValueError(f"No sequences found in {train_seqs_file}")

if n_total == 1:
    train_dataset = dataset
    val_dataset = dataset
    test_dataset = dataset
    train_size = val_size = test_size = 1
elif n_total == 2:
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [1, 1])
    test_dataset = val_dataset
    train_size = val_size = test_size = 1
else:
    train_size = max(1, int(0.8 * n_total))
    val_size = n_total - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    test_dataset = val_dataset
    test_size = val_size

logger.info(f"Split: train={train_size}, val={val_size}, test={test_size}")

# Create DataLoaders (important: collate_fn for PyG graphs)
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0,
    collate_fn=batch_collate_function,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=0,
    collate_fn=batch_collate_function,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=0,
    collate_fn=batch_collate_function,
)

logger.info("DataLoaders created successfully!")


## 6. Initialize Model and Training

In [None]:
# Initialize model
logger.info("Initializing Vibro-Structural model for multi-label classification...")

num_go_terms = len(unique_go_terms)

model = VibroStructuralModel(
    latent_dim=128,
    gnn_input_dim=24,
    fusion_type='bilinear',
    dropout=0.2,
    num_go_terms=num_go_terms,
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Model: {total_params:,} total parameters, {trainable_params:,} trainable")
logger.info(f"Output dimension: {num_go_terms} GO terms")

# Setup training
optimizer = Adam(model.parameters(), lr=5e-4, weight_decay=1e-5)
try:
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
except TypeError:
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

# Up-weight positive examples (simple baseline; tune on real data)
pos_weight = torch.full((num_go_terms,), 2.0, dtype=torch.float32, device=device)
loss_fn = WeightedBCELoss(pos_weight=pos_weight)
# loss_fn = FocalLoss(alpha=0.25, gamma=2.0)

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    device=device,
    checkpoint_dir=str(checkpoints_dir),
)

logger.info("Training setup complete!")


## 7. Train Model

In [None]:
# Train model
logger.info("Starting training...")
logger.info(f"Task: Multi-label GO term prediction ({num_go_terms} terms)")
logger.info("Metric: F-max score (mean F1 optimized over thresholds)")

best_loss = trainer.fit(
    train_loader=train_loader,
    val_loader=val_loader,
    loss_fn=loss_fn,
    epochs=5,
    metric_fn=MetricComputer.f_max,
    early_stopping_patience=3,
    task='cafa5',
)

logger.info(f"\nTraining complete! Best validation loss: {best_loss:.4f}")


## 8. Evaluate with F-max Metric

In [None]:
logger.info("Evaluating on test split and selecting a threshold...")

model.eval()
all_probs = []
all_labels = []

with torch.no_grad():
    for batch in test_loader:
        graph = batch['graph'].to(device)
        spectra = batch['spectra'].to(device)
        labels = batch['labels'].to(device)

        global_features = batch.get('global_features')
        if global_features is not None:
            global_features = global_features.to(device)

        logits = model(graph, spectra, global_features=global_features, task='cafa5')
        probs = torch.sigmoid(logits)

        all_probs.append(probs.cpu().numpy())
        all_labels.append(labels.cpu().numpy())

all_probs = np.concatenate(all_probs, axis=0)
all_labels = np.concatenate(all_labels, axis=0)

thresholds = np.arange(0.1, 0.95, 0.05)
f_scores = []

for threshold in thresholds:
    preds_binary = (all_probs >= threshold).astype(int)

    tp = np.sum((preds_binary == 1) & (all_labels == 1), axis=1)
    fp = np.sum((preds_binary == 1) & (all_labels == 0), axis=1)
    fn = np.sum((preds_binary == 0) & (all_labels == 1), axis=1)

    precision = tp / (tp + fp + 1e-8)
    recall = tp / (tp + fn + 1e-8)
    f1 = 2 * (precision * recall) / (precision + recall + 1e-8)

    f_scores.append(float(np.mean(f1)))

best_threshold_idx = int(np.argmax(f_scores))
best_threshold = float(thresholds[best_threshold_idx])
best_f_max = float(f_scores[best_threshold_idx])

logger.info(f"Best F-max (mean F1): {best_f_max:.4f} at threshold {best_threshold:.2f}")


In [None]:
# Plot threshold sensitivity
if dataset is not None:
    fig, ax = plt.subplots(figsize=(10, 6))
    
    ax.plot(thresholds, f_scores, marker='o', linewidth=2, markersize=8, color='steelblue')
    ax.axvline(best_threshold, color='red', linestyle='--', linewidth=2, label=f'Best: {best_threshold:.2f}')
    ax.set_xlabel('Classification Threshold', fontsize=12)
    ax.set_ylabel('F-max Score', fontsize=12)
    ax.set_title('F-max Score vs Classification Threshold', fontsize=13, fontweight='bold')
    ax.legend(fontsize=11)
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    logger.info(f"Threshold sensitivity plot shown.")

## 9. Generate Competition Predictions

In [None]:
logger.info("Generating test-set predictions for submission...")

infer_dataset = CAFA5Dataset(
    sequences_fasta=str(test_seqs_file),
    terms_csv=None,
    spectra_dir=str(spectral_dir),
    structure_dir=str(structures_dir),
    go_terms_list=unique_go_terms,
)
protein_ids = list(infer_dataset.sequences.keys())

infer_loader = DataLoader(
    infer_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=0,
    collate_fn=batch_collate_function,
)

threshold = best_threshold if 'best_threshold' in globals() else 0.5

model.eval()
rows = []
idx_offset = 0

with torch.no_grad():
    for batch in infer_loader:
        graph = batch['graph'].to(device)
        spectra = batch['spectra'].to(device)

        global_features = batch.get('global_features')
        if global_features is not None:
            global_features = global_features.to(device)

        logits = model(graph, spectra, global_features=global_features, task='cafa5')
        probs = torch.sigmoid(logits).cpu().numpy()

        for p in probs:
            pid = protein_ids[idx_offset]
            idx_offset += 1

            pred_idx = np.where(p >= threshold)[0]
            terms = [idx_to_go[i] for i in pred_idx]
            if not terms:
                terms = ['GO:0005575']

            rows.append({'protein_id': pid, 'go_terms': ' '.join(terms)})

out_path = data_cafa_dir / 'submission.csv'
df_submission = pd.DataFrame(rows)
df_submission.to_csv(out_path, index=False)

logger.info(f"Submission file created: {out_path} ({len(df_submission)} proteins)")
print(df_submission.head())


## 10. Summary and Next Steps

In [None]:
logger.info("\n" + "="*60)
logger.info("CAFA 5 Competition Execution Summary")
logger.info("="*60)
logger.info(f"Competition: CAFA 5 - Protein Function Prediction")
logger.info(f"Task: Multi-label GO term prediction")
logger.info(f"Metric: F-max score (optimized F1 across thresholds)")
logger.info(f"Approach: Vibro-structural multimodal model")
logger.info(f"\nModel: VibroStructuralModel (Multi-label Head)")
logger.info(f"  - GNN branch: Structural graph encoding")
logger.info(f"  - CNN branch: Spectral fingerprint encoding")
logger.info(f"  - Fusion: Bilinear transformation")
logger.info(f"  - Head: Multi-label logistic regression ({num_go_terms} GO terms)")
logger.info(f"\nTraining Details:")
logger.info(f"  - Loss: Weighted BCE (focal loss alternative available)")
logger.info(f"  - Optimizer: Adam (lr=5e-4)")
logger.info(f"  - Schedule: ReduceLROnPlateau (patience=5)")
logger.info(f"  - Early stopping: patience=15")
logger.info(f"\nNext steps:")
logger.info(f"  1. Download full CAFA 5 dataset")
logger.info(f"  2. Retrieve 3D structures from AlphaFold DB")
logger.info(f"  3. Precompute spectral features for all proteins")
logger.info(f"  4. Implement hierarchical GO prediction (respecting ontology)")
logger.info(f"  5. Train on full dataset (~30,000 proteins)")
logger.info(f"  6. Optimize threshold for F-max metric")
logger.info(f"  7. Ensemble with ESM-2 sequence embeddings")
logger.info(f"  8. Generate final submission")
logger.info("="*60)