## 1. Setup and Imports

In [None]:
import sys
import os
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import logging

# Add src to path
sys.path.insert(0, './src')

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, random_split

from models.multimodal import VibroStructuralModel
from models.losses import MarginRankingLossCustom
from datasets import NovozymesDataset, create_dataloaders
from training import Trainer, MetricComputer, create_training_config
from nma_analysis import ANMAnalyzer
from data_acquisition import KaggleDataAcquisition
from utils import Logger, set_seed, get_device

# Setup
logger = Logger.setup('QDD-Novozymes', level=logging.INFO)
set_seed(42)
device = get_device()

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 6)

# Create directories
Path('./data/kaggle').mkdir(parents=True, exist_ok=True)
Path('./data/spectral').mkdir(parents=True, exist_ok=True)
Path('./checkpoints').mkdir(parents=True, exist_ok=True)

logger.info("Setup complete!")

## 2. Download and Explore Novozymes Data

In [None]:
# Download competition data (requires kaggle CLI credentials)
logger.info("Downloading Novozymes competition data...")
kaggle_acq = KaggleDataAcquisition(output_dir="./data/kaggle")

# Note: Uncomment to download
# train_csv, test_csv, struct_pdb = kaggle_acq.download_novozymes()

# For this demo, check if files exist
train_csv = Path('./data/kaggle/train.csv')
test_csv = Path('./data/kaggle/test.csv')
struct_pdb = Path('./data/kaggle/wildtype_structure_prediction_af2.pdb')

if not train_csv.exists():
    logger.warning("Train data not found. Please run: kaggle competitions download -c novozymes-enzyme-stability-prediction -p ./data/kaggle")
else:
    logger.info("Novozymes data found!")

## 3. Load and Explore Data

In [None]:
# Load training data
df_train = pd.read_csv('./data/kaggle/train.csv')
logger.info(f"Training data shape: {df_train.shape}")
logger.info(f"\nFirst few rows:")
print(df_train.head())

logger.info(f"\nData info:")
print(df_train.info())

logger.info(f"\nTm statistics:")
print(df_train['tm'].describe())

In [None]:
# Visualize Tm distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Histogram
axes[0].hist(df_train['tm'], bins=30, color='steelblue', edgecolor='black', alpha=0.7)
axes[0].set_xlabel('Melting Temperature (°C)', fontsize=12)
axes[0].set_ylabel('Frequency', fontsize=12)
axes[0].set_title('Distribution of Tm Values', fontsize=13, fontweight='bold')
axes[0].grid(alpha=0.3)

# By pH
for ph in df_train['pH'].unique():
    mask = df_train['pH'] == ph
    axes[1].scatter(df_train[mask]['pH'], df_train[mask]['tm'], 
                   alpha=0.5, s=30, label=f'pH={ph}')
axes[1].set_xlabel('pH', fontsize=12)
axes[1].set_ylabel('Tm (°C)', fontsize=12)
axes[1].set_title('Tm vs pH', fontsize=13, fontweight='bold')
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

logger.info("Data exploration complete!")

## 4. Precompute Spectral Data

In [None]:
# For efficiency, precompute spectra for all mutations
# In full implementation, would use mass-perturbation NMA

logger.info("Precomputing spectral data...")

if struct_pdb.exists():
    # Load wildtype structure
    import prody as pr
    wt_structure = pr.parsePDB(str(struct_pdb))
    logger.info(f"Wildtype structure: {wt_structure.numResidues()} residues")
    
    # Compute WT spectrum (cached)
    from nma_analysis import ANMAnalyzer
    from spectral_generation import SpectralGenerator
    
    try:
        anm = ANMAnalyzer(wt_structure, cutoff=15.0)
        freqs, _ = anm.compute_modes(k=50)
        s_vib = anm.compute_vibrational_entropy(k=50)
        
        sg = SpectralGenerator(freq_min=0, freq_max=500, n_points=1000)
        wt_spectrum = sg.generate_dos(freqs, broadening=5.0)
        
        # Save for dataset loading
        np.save('./data/spectral/wt_spectrum.npy', wt_spectrum)
        logger.info(f"WT spectrum saved. S_vib = {s_vib:.2f} J/(mol·K)")
    except Exception as e:
        logger.error(f"NMA computation failed: {e}")
        logger.info("Using synthetic spectrum")
        np.save('./data/spectral/wt_spectrum.npy', np.random.randn(1000))
else:
    logger.warning("Wildtype structure not found. Using synthetic spectra.")
    np.save('./data/spectral/wt_spectrum.npy', np.random.randn(1000))

## 5. Create Dataset and DataLoaders

In [None]:
# Create Novozymes dataset
logger.info("Creating Novozymes dataset...")

try:
    dataset = NovozymesDataset(
        csv_file='./data/kaggle/train.csv',
        structure_file='./data/kaggle/wildtype_structure_prediction_af2.pdb',
        spectra_dir='./data/spectral',
        include_updates=True
    )
    logger.info(f"Dataset created: {len(dataset)} samples")
except:
    logger.warning("Novozymes dataset creation failed. Using dummy dataset for demo.")
    # Create dummy dataset for demonstration
    dataset = None

if dataset is not None:
    # Split into train/val/test
    train_size = int(0.7 * len(dataset))
    val_size = int(0.15 * len(dataset))
    test_size = len(dataset) - train_size - val_size
    
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size, test_size]
    )
    
    logger.info(f"Split: train={train_size}, val={val_size}, test={test_size}")
    
    # Create DataLoaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=0,
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=0,
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=0,
    )
    
    logger.info("DataLoaders created successfully!")
else:
    logger.info("Demo mode: skipping DataLoader creation")

## 6. Initialize Model and Training

In [None]:
# Initialize model
logger.info("Initializing Vibro-Structural model...")

model = VibroStructuralModel(
    latent_dim=128,
    gnn_input_dim=24,
    fusion_type='bilinear',
    dropout=0.2,
    num_go_terms=10000  # Not used for Novozymes
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Model: {total_params:,} total parameters, {trainable_params:,} trainable")

# Setup training
optimizer = Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
loss_fn = nn.MSELoss()  # L2 loss for regression

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    device=device,
    checkpoint_dir='./checkpoints'
)

logger.info("Training setup complete!")

## 7. Train Model

In [None]:
# Train model (requires actual dataset)
if dataset is not None:
    logger.info("Starting training...")
    
    best_loss = trainer.fit(
        train_loader=train_loader,
        val_loader=val_loader,
        loss_fn=loss_fn,
        epochs=50,
        metric_fn=MetricComputer.spearman_correlation,
        early_stopping_patience=10,
        task='novozymes'
    )
    
    logger.info(f"\nTraining complete! Best validation loss: {best_loss:.4f}")
else:
    logger.info("Demo mode: Training skipped. Use actual Novozymes data to train.")

## 8. Evaluate and Predict

In [None]:
if dataset is not None:
    logger.info("Evaluating on test set...")
    
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in test_loader:
            graph = batch['graph'].to(device)
            spectra = batch['spectra'].to(device)
            labels = batch['labels'].to(device)
            
            global_features = None
            if 'global_features' in batch:
                global_features = batch['global_features'].to(device)
            
            outputs = model(graph, spectra, global_features, task='novozymes')
            
            all_preds.append(outputs.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
    
    all_preds = np.concatenate(all_preds, axis=0).squeeze()
    all_labels = np.concatenate(all_labels, axis=0)
    
    # Compute metrics
    spearman = MetricComputer.spearman_correlation(all_preds, all_labels)
    mae = MetricComputer.mean_absolute_error(all_preds, all_labels)
    mse = MetricComputer.mean_squared_error(all_preds, all_labels)
    
    logger.info(f"\nTest Set Metrics:")
    logger.info(f"  Spearman Correlation: {spearman:.4f}")
    logger.info(f"  MAE: {mae:.4f}")
    logger.info(f"  MSE: {mse:.4f}")
else:
    logger.info("Demo mode: Evaluation skipped.")

## 9. Summary and Next Steps

In [None]:
logger.info("\n" + "="*60)
logger.info("Novozymes Competition Execution Summary")
logger.info("="*60)
logger.info(f"Competition: Novozymes Enzyme Stability Prediction")
logger.info(f"Task: Predict melting temperature (Tm) from structure")
logger.info(f"Metric: Spearman rank correlation")
logger.info(f"Approach: Vibro-structural multimodal model")
logger.info(f"\nModel: VibroStructuralModel")
logger.info(f"  - GNN branch: Structural graph encoding")
logger.info(f"  - CNN branch: Spectral fingerprint encoding")
logger.info(f"  - Fusion: Bilinear transformation")
logger.info(f"  - Head: Regression MLP for Tm prediction")
logger.info(f"\nNext steps:")
logger.info(f"  1. Download full Novozymes dataset")
logger.info(f"  2. Precompute all mutation spectra (mass-perturbation NMA)")
logger.info(f"  3. Implement mutation-specific delta features")
logger.info(f"  4. Train on full dataset (4,000+ mutations)")
logger.info(f"  5. Evaluate with Spearman correlation")
logger.info(f"  6. Generate test set predictions")
logger.info(f"  7. Submit to Kaggle competition")
logger.info("="*60)