In [None]:
# Import required libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from pathlib import Path
import wandb
import yaml
import json
from tqdm.notebook import tqdm
import os
import sys

# Add project root to Python path
import sys
sys.path.append('.')

print(os.getcwd())

# Import project modules
from src.data.dataset import create_dataloader
from src.models.contrastive_reward_model import ContrastiveRewardModel
from src.training.train_reward_model import InfoNCELoss
from src.config.tokenization_config import PAD_TOKEN


In [None]:
# Load configuration
print(os.getcwd())
config_path = 'src/training/configs/contrastive_reward_base.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Initialize wandb
run_name = f"contrastive_L{config['num_layers']}_H{config['num_heads']}_D{config['embed_dim']}_seq{config['max_seq_length']}_bs{config['batch_size']}_lr{config['learning_rate']}"

wandb.init(
    project=config['wandb_project'],
    name=run_name,
    config=config,
    job_type="contrastive_training"
)


In [None]:
# Create dataloaders
train_loader, tokenizer_info = create_dataloader(
    data_dir=Path(config['data_dir']),
    split="train",
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    sequence_length=config['max_seq_length'],
    mode='contrastive',
    shuffle=True
)

val_loader, _ = create_dataloader(
    data_dir=Path(config['data_dir']),
    split="valid",
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    sequence_length=config['max_seq_length'],
    mode='contrastive',
    shuffle=False
)


In [None]:
# Initialize model, loss function, and optimizer
model = ContrastiveRewardModel(
    melody_vocab_size=tokenizer_info['melody_vocab_size'],
    chord_vocab_size=tokenizer_info['chord_vocab_size'],
    embed_dim=config['embed_dim'],
    num_heads=config['num_heads'],
    num_layers=config['num_layers'],
    dropout=config['dropout'],
    max_seq_length=config['max_seq_length'],
    pad_token_id=PAD_TOKEN
).to(device)

print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters.")

# Initialize loss function with temperature from config
loss_fn = InfoNCELoss(temperature=config['temperature'])

# Initialize optimizer
optimizer = Adam(model.parameters(), lr=config['learning_rate'])


In [None]:
def train_epoch(model, train_loader, loss_fn, optimizer, device):
    model.train()
    total_loss = 0
    
    pbar = tqdm(train_loader, desc='Training')
    for batch in pbar:
        # Move batch to device
        melody_tokens = batch['melody_tokens'].to(device)
        chord_tokens = batch['chord_tokens'].to(device)
        
        # Create padding masks
        melody_mask = (melody_tokens == model.pad_token_id)
        chord_mask = (chord_tokens == model.pad_token_id)
        
        # Forward pass
        optimizer.zero_grad()
        melody_embeds, chord_embeds = model(
            melody_tokens=melody_tokens,
            chord_tokens=chord_tokens,
            melody_padding_mask=melody_mask,
            chord_padding_mask=chord_mask
        )
        
        # Calculate loss
        loss = loss_fn(melody_embeds, chord_embeds)
        
        # Check for NaN loss
        if torch.isnan(loss):
            print(f"\nNaN loss detected! Skipping batch.")
            continue
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Update progress bar
        total_loss += loss.item()
        pbar.set_postfix({'loss': loss.item()})
        
    return total_loss / len(train_loader)

def validate(model, val_loader, loss_fn, device):
    model.eval()
    total_loss = 0
    all_ranks = []
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc='Validating'):
            # Move batch to device
            melody_tokens = batch['melody_tokens'].to(device)
            chord_tokens = batch['chord_tokens'].to(device)
            
            # Create padding masks
            melody_mask = (melody_tokens == model.pad_token_id)
            chord_mask = (chord_tokens == model.pad_token_id)
            
            # Forward pass
            melody_embeds, chord_embeds = model(
                melody_tokens=melody_tokens,
                chord_tokens=chord_tokens,
                melody_padding_mask=melody_mask,
                chord_padding_mask=chord_mask
            )
            
            # Calculate loss
            loss = loss_fn(melody_embeds, chord_embeds)
            total_loss += loss.item()
            
            # Calculate Top-1 Accuracy
            logits = torch.matmul(
                F.normalize(melody_embeds, p=2, dim=1),
                F.normalize(chord_embeds, p=2, dim=1).T
            )
            sorted_indices = torch.argsort(logits, descending=True, dim=1)
            labels = torch.arange(len(logits), device=logits.device)
            ranks = (sorted_indices == labels[:, None]).nonzero(as_tuple=True)[1] + 1
            all_ranks.extend(ranks.cpu().numpy())
    
    avg_loss = total_loss / len(val_loader)
    top1_accuracy = np.mean(np.array(all_ranks) == 1) * 100
    
    return avg_loss, top1_accuracy


In [None]:
# Training loop
best_val_loss = float('inf')
global_step = 0

try:
    for epoch in range(config['epochs']):
        print(f"\nEpoch {epoch + 1}/{config['epochs']}")
        
        # Clear GPU memory before each epoch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            print(f"GPU memory at start of epoch: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
        
        # Training Step
        train_loss = train_epoch(model, train_loader, loss_fn, optimizer, device)
        
        # Validation Step
        val_loss, top1_accuracy = validate(model, val_loader, loss_fn, device)
        
        print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Valid Loss: {val_loss:.4f}, Top-1 Accuracy: {top1_accuracy:.2f}%")
        wandb.log({
            'train/epoch_loss': train_loss,
            'valid/epoch_loss': val_loss,
            'valid/top1_accuracy': top1_accuracy,
            'epoch': epoch + 1
        }, step=global_step)
        
        # Save checkpoint if validation loss improved
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            
            # Save checkpoint locally
            checkpoint_path = Path("checkpoints") / f"contrastive_reward_epoch_{epoch+1}.pth"
            checkpoint_path.parent.mkdir(exist_ok=True)
            
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'top1_accuracy': top1_accuracy,
                'config': config,
            }
            
            torch.save(checkpoint, checkpoint_path)
            
            # Log checkpoint as wandb artifact
            artifact = wandb.Artifact(
                name=f"contrastive_reward_model_{wandb.run.id}",
                type="model",
                description=f"Contrastive Reward Model checkpoint from epoch {epoch+1}"
            )
            artifact.add_file(str(checkpoint_path))
            wandb.log_artifact(artifact)
            
            # Also save tokenizer info as artifact
            tokenizer_artifact = wandb.Artifact(
                name=f"tokenizer_info_{wandb.run.id}",
                type="tokenizer",
                description="Tokenizer information used for training"
            )
            tokenizer_path = Path("tokenizer_info.json")
            with open(tokenizer_path, 'w') as f:
                json.dump(tokenizer_info, f)
            tokenizer_artifact.add_file(str(tokenizer_path))
            wandb.log_artifact(tokenizer_artifact)
            
            print(f"\nSaved checkpoint with validation loss: {val_loss:.4f}")
            
except KeyboardInterrupt:
    print("\nTraining interrupted by user")
except torch.cuda.OutOfMemoryError:
    print("\nOut of GPU memory! Try reducing batch size or sequence length")
finally:
    wandb.finish()
    print("\nTraining completed")


In [None]:
def load_best_checkpoint(model, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    return checkpoint['val_loss'], checkpoint['top1_accuracy']

# Example usage:
# best_checkpoint_path = Path(wandb.run.dir) / 'best_checkpoint.pth'
# best_val_loss, best_accuracy = load_best_checkpoint(model, best_checkpoint_path)
# print(f"Loaded checkpoint with validation loss: {best_val_loss:.4f} and accuracy: {best_accuracy:.2f}%")
