## 1. Setup and Imports

In [None]:
import sys
import os
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import logging
from collections import Counter

# Add src to path
sys.path.insert(0, './src')

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader

from models.multimodal import VibroStructuralModel
from models.losses import FocalLoss, WeightedBCELoss
from datasets import CAFA5Dataset, create_dataloaders
from training import Trainer, MetricComputer, create_training_config
from data_acquisition import KaggleDataAcquisition
from utils import Logger, set_seed, get_device

# Setup
logger = Logger.setup('QDD-CAFA5', level=logging.INFO)
set_seed(42)
device = get_device()

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 6)

# Create directories
Path('./data/cafa5').mkdir(parents=True, exist_ok=True)
Path('./data/pdb_structures').mkdir(parents=True, exist_ok=True)
Path('./checkpoints').mkdir(parents=True, exist_ok=True)

logger.info("Setup complete!")

## 2. Download and Explore CAFA 5 Data

In [None]:
# Download CAFA 5 data
logger.info("Downloading CAFA 5 competition data...")
kaggle_acq = KaggleDataAcquisition(output_dir="./data/cafa5")

# Note: Uncomment to download
# train_terms, test_seqs, go_annot = kaggle_acq.download_cafa5()

# For this demo, check if files exist
train_terms_file = Path('./data/cafa5/train_terms.csv')
test_seqs_file = Path('./data/cafa5/test_sequences.fasta')
go_vocab_file = Path('./data/cafa5/go_vocabulary.csv')

if not train_terms_file.exists():
    logger.warning("CAFA 5 data not found. Please run: kaggle competitions download -c cafa-5-protein-function-prediction -p ./data/cafa5")
else:
    logger.info("CAFA 5 data found!")

## 3. Load and Explore GO Annotations

In [None]:
# Load training terms
try:
    df_terms = pd.read_csv('./data/cafa5/train_terms.csv')
    logger.info(f"Training terms shape: {df_terms.shape}")
    logger.info(f"\nFirst few rows:")
    print(df_terms.head(10))
    
    logger.info(f"\nData info:")
    print(df_terms.info())
except FileNotFoundError:
    logger.warning("train_terms.csv not found. Using demo data structure.")
    df_terms = pd.DataFrame({
        'protein_id': [f'protein_{i}' for i in range(1000)],
        'go_term': [f'GO:{np.random.randint(1000000, 9999999):07d}' for _ in range(1000)]
    })
    logger.info(f"Demo dataset created: {df_terms.shape[0]} annotations")

In [None]:
# Analyze GO term distribution
logger.info("Analyzing GO term distribution...")

# Count unique proteins and terms
n_proteins = df_terms['protein_id'].nunique()
n_go_terms = df_terms['go_term'].nunique()
n_annotations = len(df_terms)

logger.info(f"Unique proteins: {n_proteins:,}")
logger.info(f"Unique GO terms: {n_go_terms:,}")
logger.info(f"Total annotations: {n_annotations:,}")
logger.info(f"Avg terms per protein: {n_annotations / n_proteins:.2f}")

# Analyze term frequency
term_counts = df_terms['go_term'].value_counts()
logger.info(f"\nGO term frequency:")
logger.info(f"  Max: {term_counts.max()}")
logger.info(f"  Min: {term_counts.min()}")
logger.info(f"  Median: {term_counts.median()}")

In [None]:
# Visualize GO distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Top GO terms
top_go = df_terms['go_term'].value_counts().head(15)
axes[0].barh(range(len(top_go)), top_go.values, color='steelblue')
axes[0].set_yticks(range(len(top_go)))
axes[0].set_yticklabels(top_go.index, fontsize=9)
axes[0].set_xlabel('Frequency', fontsize=11)
axes[0].set_title('Top 15 Most Frequent GO Terms', fontsize=12, fontweight='bold')
axes[0].grid(axis='x', alpha=0.3)

# Terms per protein distribution
terms_per_protein = df_terms.groupby('protein_id').size()
axes[1].hist(terms_per_protein, bins=30, color='coral', edgecolor='black', alpha=0.7)
axes[1].set_xlabel('Number of GO Terms', fontsize=11)
axes[1].set_ylabel('Number of Proteins', fontsize=11)
axes[1].set_title('GO Terms per Protein Distribution', fontsize=12, fontweight='bold')
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

logger.info("GO analysis complete!")

## 4. Create GO Term Vocabulary

In [None]:
# Build GO vocabulary
logger.info("Building GO term vocabulary...")

unique_go_terms = sorted(df_terms['go_term'].unique())
go_to_idx = {go: idx for idx, go in enumerate(unique_go_terms)}
idx_to_go = {idx: go for go, idx in go_to_idx.items()}

logger.info(f"Vocabulary size: {len(go_to_idx)} GO terms")
logger.info(f"\nSample terms:")
for i, go_term in enumerate(unique_go_terms[:10]):
    idx = go_to_idx[go_term]
    count = (df_terms['go_term'] == go_term).sum()
    logger.info(f"  {idx}: {go_term} (n={count})")

# Save vocabulary
vocab_df = pd.DataFrame({
    'go_term': unique_go_terms,
    'index': [go_to_idx[go] for go in unique_go_terms]
})
vocab_df.to_csv('./data/cafa5/go_vocabulary.csv', index=False)
logger.info(f"\nVocabulary saved to go_vocabulary.csv")

## 5. Create Dataset and DataLoaders

In [None]:
# Create CAFA 5 dataset
logger.info("Creating CAFA 5 dataset...")

try:
    dataset = CAFA5Dataset(
        terms_file='./data/cafa5/train_terms.csv',
        sequences_file='./data/cafa5/test_sequences.fasta',
        structures_dir='./data/pdb_structures',
        vocab_file='./data/cafa5/go_vocabulary.csv'
    )
    logger.info(f"Dataset created: {len(dataset)} proteins")
except Exception as e:
    logger.warning(f"CAFA5Dataset creation failed: {e}")
    logger.info("Using demo dataset...")
    dataset = None

if dataset is not None:
    # Split into train/val/test
    train_size = int(0.7 * len(dataset))
    val_size = int(0.15 * len(dataset))
    test_size = len(dataset) - train_size - val_size
    
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size, test_size]
    )
    
    logger.info(f"Split: train={train_size}, val={val_size}, test={test_size}")
    
    # Create DataLoaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=0,
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=0,
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=0,
    )
    
    logger.info("DataLoaders created successfully!")
else:
    logger.info("Demo mode: skipping DataLoader creation")

## 6. Initialize Model and Training

In [None]:
# Initialize model
logger.info("Initializing Vibro-Structural model for multi-label classification...")

num_go_terms = len(go_to_idx) if dataset is not None else 10000

model = VibroStructuralModel(
    latent_dim=128,
    gnn_input_dim=24,
    fusion_type='bilinear',
    dropout=0.2,
    num_go_terms=num_go_terms
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Model: {total_params:,} total parameters, {trainable_params:,} trainable")
logger.info(f"Output dimension: {num_go_terms} GO terms")

# Setup training
optimizer = Adam(model.parameters(), lr=5e-4, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

# Use weighted BCE for class imbalance in GO term prediction
# Focal loss is alternative (uncomment to use)
loss_fn = WeightedBCELoss(weight=2.0)  # Up-weight positive examples
# loss_fn = FocalLoss(alpha=0.25, gamma=2.0)

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    device=device,
    checkpoint_dir='./checkpoints'
)

logger.info("Training setup complete!")

## 7. Train Model

In [None]:
# Train model (requires actual dataset)
if dataset is not None:
    logger.info("Starting training...")
    logger.info(f"Task: Multi-label GO term prediction ({num_go_terms} terms)")
    logger.info(f"Metric: F-max score (optimized F1)")
    
    best_loss = trainer.fit(
        train_loader=train_loader,
        val_loader=val_loader,
        loss_fn=loss_fn,
        epochs=100,
        metric_fn=MetricComputer.f_max_score,
        early_stopping_patience=15,
        task='cafa5'
    )
    
    logger.info(f"\nTraining complete! Best validation loss: {best_loss:.4f}")
else:
    logger.info("Demo mode: Training skipped. Use actual CAFA 5 data to train.")

## 8. Evaluate with F-max Metric

In [None]:
if dataset is not None:
    logger.info("Evaluating on test set with F-max metric...")
    
    model.eval()
    all_preds_prob = []
    all_labels = []
    
    with torch.no_grad():
        for batch in test_loader:
            graph = batch['graph'].to(device)
            spectra = batch['spectra'].to(device)
            labels = batch['labels'].to(device)
            
            global_features = None
            if 'global_features' in batch:
                global_features = batch['global_features'].to(device)
            
            outputs = model(graph, spectra, global_features, task='cafa5')
            probs = torch.sigmoid(outputs)  # Convert logits to probabilities
            
            all_preds_prob.append(probs.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
    
    all_preds_prob = np.concatenate(all_preds_prob, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    
    # Compute F-max across different thresholds
    logger.info(f"\nEvaluating F-max across thresholds...")
    
    thresholds = np.arange(0.1, 0.95, 0.05)
    f_scores = []
    
    for threshold in thresholds:
        preds_binary = (all_preds_prob > threshold).astype(int)
        f_score = MetricComputer.f_max_score(all_preds_prob, all_labels)
        f_scores.append(f_score)
    
    best_threshold_idx = np.argmax(f_scores)
    best_threshold = thresholds[best_threshold_idx]
    best_f_max = f_scores[best_threshold_idx]
    
    logger.info(f"Best F-max: {best_f_max:.4f} at threshold {best_threshold:.2f}")
else:
    logger.info("Demo mode: Evaluation skipped.")

In [None]:
# Plot threshold sensitivity
if dataset is not None:
    fig, ax = plt.subplots(figsize=(10, 6))
    
    ax.plot(thresholds, f_scores, marker='o', linewidth=2, markersize=8, color='steelblue')
    ax.axvline(best_threshold, color='red', linestyle='--', linewidth=2, label=f'Best: {best_threshold:.2f}')
    ax.set_xlabel('Classification Threshold', fontsize=12)
    ax.set_ylabel('F-max Score', fontsize=12)
    ax.set_title('F-max Score vs Classification Threshold', fontsize=13, fontweight='bold')
    ax.legend(fontsize=11)
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    logger.info(f"Threshold sensitivity plot shown.")

## 9. Generate Competition Predictions

In [None]:
if dataset is not None:
    logger.info("Generating test set predictions for submission...")
    
    # Load test sequences
    try:
        df_test = pd.read_csv('./data/cafa5/test_sequences.fasta')
        test_proteins = df_test['protein_id'].unique()
        logger.info(f"Test set size: {len(test_proteins)} proteins")
    except:
        logger.warning("Could not load test sequences. Using dummy test set.")
        test_proteins = [f'test_protein_{i}' for i in range(100)]
    
    # Create submission file
    submission_data = []
    
    for protein_id in test_proteins:
        # In real scenario, would run model inference on each protein
        # For demo, generate random predictions
        predicted_probs = np.random.rand(num_go_terms)
        
        # Apply threshold
        predicted_terms = np.where(predicted_probs > best_threshold if dataset is not None else 0.5)[0]
        
        # Convert indices to GO terms
        go_terms = [idx_to_go.get(idx, f'GO:{idx:07d}') for idx in predicted_terms]
        
        submission_data.append({
            'protein_id': protein_id,
            'go_terms': ' '.join(go_terms) if go_terms else 'GO:0005575'  # Default to cellular_component
        })
    
    df_submission = pd.DataFrame(submission_data)
    df_submission.to_csv('./data/cafa5/submission.csv', index=False)
    
    logger.info(f"Submission file created: {len(df_submission)} proteins")
    logger.info(f"\nSample predictions:")
    print(df_submission.head())

## 10. Summary and Next Steps

In [None]:
logger.info("\n" + "="*60)
logger.info("CAFA 5 Competition Execution Summary")
logger.info("="*60)
logger.info(f"Competition: CAFA 5 - Protein Function Prediction")
logger.info(f"Task: Multi-label GO term prediction")
logger.info(f"Metric: F-max score (optimized F1 across thresholds)")
logger.info(f"Approach: Vibro-structural multimodal model")
logger.info(f"\nModel: VibroStructuralModel (Multi-label Head)")
logger.info(f"  - GNN branch: Structural graph encoding")
logger.info(f"  - CNN branch: Spectral fingerprint encoding")
logger.info(f"  - Fusion: Bilinear transformation")
logger.info(f"  - Head: Multi-label logistic regression ({num_go_terms} GO terms)")
logger.info(f"\nTraining Details:")
logger.info(f"  - Loss: Weighted BCE (focal loss alternative available)")
logger.info(f"  - Optimizer: Adam (lr=5e-4)")
logger.info(f"  - Schedule: ReduceLROnPlateau (patience=5)")
logger.info(f"  - Early stopping: patience=15")
logger.info(f"\nNext steps:")
logger.info(f"  1. Download full CAFA 5 dataset")
logger.info(f"  2. Retrieve 3D structures from AlphaFold DB")
logger.info(f"  3. Precompute spectral features for all proteins")
logger.info(f"  4. Implement hierarchical GO prediction (respecting ontology)")
logger.info(f"  5. Train on full dataset (~30,000 proteins)")
logger.info(f"  6. Optimize threshold for F-max metric")
logger.info(f"  7. Ensemble with ESM-2 sequence embeddings")
logger.info(f"  8. Generate final submission")
logger.info("="*60)