In [None]:
# OPTIMIZED TRAINING SCRIPT WITH BLEU-4 EVALUATION
# Ready for Google Colab deployment with hard negative mining + BLEU-4 tracking

# Install dependencies (run once)
!pip install -q torch torchvision torchaudio
!pip install -q torch-geometric torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.1.0+cu118.html
!pip install -q nltk pandas scikit-learn matplotlib wandb

# Import and mount drive
from google.colab import drive
import os
drive.mount('/content/drive')

# Set up paths
PROJECT_DIR = '/content/drive/My Drive/challenge_altegrad/data_baseline'
os.chdir(PROJECT_DIR)

# Download NLTK data
import nltk
nltk.download('punkt')
nltk.download('wordnet')

# ============================================================================
# RUN TRAINING WITH HARD NEGATIVES + BLEU-4 EVALUATION
# ============================================================================

import subprocess
import sys

# Configuration: Customize these parameters
config = {
    'ENV': 'colab',
    'EPOCHS': 100,
    'BATCH_SIZE': 64,
    'LEARNING_RATE': 0.0003,
    'HIDDEN_DIM': 256,
    'NUM_LAYERS': 3,
    'NUM_HEADS': 8,
    'TEMPERATURE': 0.1,
    'HARD_NEGATIVES': True,
    'HARD_RATIO': 0.5,
    'HARDNESS_K': 100,
    'CURRICULUM_EPOCH': 5,
    'USE_AUGMENTATION': True,
    'TEMP_SCHEDULE': True,
    'EVAL_BLEU_BERT': True,  # BLEU-4 evaluation
    'RERANK_TOPK': 10,
}

# Build command
cmd = [
    sys.executable, 'train_gt_contrast.py',
    '--env', config['ENV'],
    '--epochs', str(config['EPOCHS']),
    '--batch_size', str(config['BATCH_SIZE']),
    '--lr', str(config['LEARNING_RATE']),
    '--hidden', str(config['HIDDEN_DIM']),
    '--layers', str(config['NUM_LAYERS']),
    '--heads', str(config['NUM_HEADS']),
    '--temp', str(config['TEMPERATURE']),
    '--rerank_topk', str(config['RERANK_TOPK']),
]

# Add boolean flags if enabled
if config['HARD_NEGATIVES']:
    cmd.extend(['--hard_negatives', '--hard_ratio', str(config['HARD_RATIO']), 
                '--hardness_k', str(config['HARDNESS_K']), 
                '--curriculum_epoch', str(config['CURRICULUM_EPOCH'])])

if config['USE_AUGMENTATION']:
    cmd.append('--use_augmentation')

if config['TEMP_SCHEDULE']:
    cmd.append('--temp_schedule')

if config['EVAL_BLEU_BERT']:
    cmd.append('--eval_bleu_bert')

print("üöÄ Starting optimized training with configuration:")
print("=" * 70)
for key, value in config.items():
    print(f"  {key:20s}: {value}")
print("=" * 70)
print()

# Run training
try:
    result = subprocess.run(cmd, cwd=PROJECT_DIR, check=True)
    print("\n‚úÖ Training completed successfully!")
except subprocess.CalledProcessError as e:
    print(f"\n‚ùå Training failed with error code {e.returncode}")
    sys.exit(1)

# ============================================================================
# VISUALIZATION (after training)
# ============================================================================

import json
import matplotlib.pyplot as plt
import pandas as pd

# Load training logs
logs_path = os.path.join(PROJECT_DIR, 'output', 'training_logs.json')

if os.path.exists(logs_path):
    with open(logs_path, 'r') as f:
        logs = json.load(f)
    
    epochs_data = logs['epochs']
    df = pd.DataFrame(epochs_data)
    
    print("\nüìä Training Summary:")
    print("=" * 70)
    print(f"Total epochs trained: {len(df)}")
    print(f"Best validation MRR: {df['val_mrr'].max():.4f} at epoch {df['val_mrr'].idxmax() + 1}")
    if 'val_bleu4' in df.columns and df['val_bleu4'].max() > 0:
        print(f"Best BLEU-4: {df['val_bleu4'].max():.4f} at epoch {df['val_bleu4'].idxmax() + 1}")
    print(f"Final loss: {df['train_loss'].iloc[-1]:.4f}")
    print("=" * 70)
    
    # Plot curves
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Loss
    axes[0, 0].plot(df['epoch'], df['train_loss'], label='Train Loss', marker='o', markersize=3)
    axes[0, 0].plot(df['epoch'], df['val_loss'], label='Val Loss', marker='x', markersize=3)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training and Validation Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # MRR and R@K
    axes[0, 1].plot(df['epoch'], df['val_mrr'], label='MRR', marker='o', markersize=3)
    axes[0, 1].plot(df['epoch'], df['val_r1'], label='R@1', marker='x', markersize=3)
    axes[0, 1].plot(df['epoch'], df['val_r5'], label='R@5', marker='^', markersize=3)
    axes[0, 1].plot(df['epoch'], df['val_r10'], label='R@10', marker='s', markersize=3)
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Score')
    axes[0, 1].set_title('Recall Metrics')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # BLEU-4 (if available)
    if 'val_bleu4' in df.columns and df['val_bleu4'].max() > 0:
        axes[1, 0].plot(df['epoch'], df['val_bleu4'], label='BLEU-4', marker='o', markersize=3, color='purple')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('BLEU-4 Score')
        axes[1, 0].set_title('Retrieval-based BLEU-4 Score')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
    else:
        axes[1, 0].text(0.5, 0.5, 'BLEU-4 not available\n(enable with --eval_bleu_bert)', 
                       ha='center', va='center', transform=axes[1, 0].transAxes, fontsize=12)
    
    # Learning rate
    axes[1, 1].plot(df['epoch'], df['learning_rate'], label='Learning Rate', marker='o', markersize=3, color='orange')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Learning Rate')
    axes[1, 1].set_title('Learning Rate Schedule')
    axes[1, 1].set_yscale('log')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save figure
    fig_path = os.path.join(PROJECT_DIR, 'output', 'training_curves.png')
    plt.savefig(fig_path, dpi=150, bbox_inches='tight')
    print(f"\nüìà Training curves saved to: {fig_path}")
    
    # Display in Colab
    plt.show()
else:
    print(f"‚ö†Ô∏è  Logs not found at {logs_path}")

print("\n‚úÖ Training and visualization complete!")
print("\nüìù Next steps:")
print("  1. Download the trained model from: output/model.pt")
print("  2. Download the checkpoint from: output/checkpoint.pt")
print("  3. Download the logs from: output/training_logs.json")
print("  4. Monitor BLEU-4 convergence to align with Kaggle metric")