In [None]:
# Install dependencies
%pip install torch wandb tqdm pyyaml


In [None]:
# Import required libraries
import torch
import torch.nn as nn
from torch.optim import Adam
from pathlib import Path
import wandb
import yaml
import json
from tqdm.notebook import tqdm
import os
import sys

# Add project root to Python path
import sys
sys.path.append('.')

print(os.getcwd())

# Import project modules
from src.data.dataset import create_dataloader
from src.models.discriminative_reward_model import DiscriminativeRewardModel
from src.config.tokenization_config import PAD_TOKEN


In [None]:
# Load configuration
print(os.getcwd())
config_path = 'src/training/configs/discriminator_reward_base.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Adjust sequence length to account for interleaving
config['max_seq_length'] = config['max_seq_length'] // 2

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Initialize wandb
run_name = f"discriminator_L{config['num_layers']}_H{config['num_heads']}_D{config['embed_dim']}_seq{config['max_seq_length']*2}_bs{config['batch_size']}_lr{config['learning_rate']}"

wandb.init(
    project=config['wandb_project'],
    name=run_name,
    config=config,
    job_type="discriminator_training"
)


In [None]:
# Create dataloaders
train_loader, tokenizer_info = create_dataloader(
    data_dir=Path(config['data_dir']),
    split="train",
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    sequence_length=config['max_seq_length'],
    mode='discriminator',
    shuffle=True
)

val_loader, _ = create_dataloader(
    data_dir=Path(config['data_dir']),
    split="valid",
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    sequence_length=config['max_seq_length'],
    mode='discriminator',
    shuffle=False
)

config['vocab_size'] = tokenizer_info['total_vocab_size']


In [None]:
# Initialize model, optimizer, and criterion
model = DiscriminativeRewardModel(
    vocab_size=config['vocab_size'],
    embed_dim=config['embed_dim'],
    num_heads=config['num_heads'],
    num_layers=config['num_layers'],
    dropout=config['dropout'],
    max_seq_length=config['max_seq_length'] * 2,  # Account for interleaved sequences
    pad_token_id=tokenizer_info.get('pad_token_id', PAD_TOKEN)
).to(device)

# Enable gradient and parameter monitoring
wandb.watch(model, log='all', log_freq=100)
print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters.")

optimizer = Adam(model.parameters(), lr=config['learning_rate'])
loss_fn = nn.BCELoss()


In [None]:
def create_negative_samples(interleaved_tokens: torch.Tensor) -> torch.Tensor:
    """
    Creates negative samples by shuffling chord progressions within a batch.
    
    Args:
        interleaved_tokens (torch.Tensor): A batch of real, interleaved sequences 
                                           [batch_size, seq_length] where tokens are
                                           arranged as [c, m, c, m, ...].
    
    Returns:
        torch.Tensor: A batch of fake sequences where melodies are paired with
                      chords from other sequences in the batch.
    """
    batch_size, seq_length = interleaved_tokens.shape
    
    # De-interleave into melody and chord tracks
    melody_tokens = interleaved_tokens[:, 1::2]
    chord_tokens = interleaved_tokens[:, 0::2]
    
    # Shuffle chord tokens across the batch dimension
    # This creates the negative pairs
    shuffled_indices = torch.randperm(batch_size)
    shuffled_chord_tokens = chord_tokens[shuffled_indices]
    
    # Re-interleave to create the fake sequences
    fake_interleaved = torch.empty_like(interleaved_tokens)
    fake_interleaved[:, 1::2] = melody_tokens
    fake_interleaved[:, 0::2] = shuffled_chord_tokens
    
    return fake_interleaved


In [None]:
# Training loop
best_val_loss = float('inf')
global_step = 0

try:
    for epoch in range(config['epochs']):
        print(f"\nEpoch {epoch + 1}/{config['epochs']}")
        
        # Clear GPU memory before each epoch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            print(f"GPU memory at start of epoch: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
        
        # Training Step
        model.train()
        total_train_loss = 0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']} [Training]")
        for batch in pbar:
            real_sequences = batch['interleaved_tokens'].to(device)
            # Create negative samples by shuffling chords within the batch
            fake_sequences = create_negative_samples(real_sequences)
            
            # Combine real and fake sequences for a single batch
            # Shape: (2 * batch_size, seq_len)
            combined_sequences = torch.cat([real_sequences, fake_sequences], dim=0)

            # Generate padding mask for the combined batch
            padding_mask = (combined_sequences == model.pad_token_id)
            
            # Create labels: 1 for real, 0 for fake
            real_labels = torch.ones(real_sequences.size(0), 1, device=device)
            fake_labels = torch.zeros(fake_sequences.size(0), 1, device=device)
            combined_labels = torch.cat([real_labels, fake_labels], dim=0)
            
            optimizer.zero_grad()
            
            predictions = model(combined_sequences, padding_mask=padding_mask)
            
            # Apply sigmoid since the model outputs logits
            predictions = torch.sigmoid(predictions)

            loss = loss_fn(predictions, combined_labels)
            loss.backward()
            optimizer.step()
            
            lr = optimizer.param_groups[0]['lr']
            total_train_loss += loss.item()
            global_step += 1
            pbar.set_postfix({'loss': loss.item(), 'lr': lr})
            wandb.log({'train/step_loss': loss.item(), 'train/learning_rate': lr}, step=global_step)
            
        avg_train_loss = total_train_loss / len(train_loader)
        
        # Validation Loop
        model.eval()
        total_valid_loss = 0
        correct_predictions = 0
        total_predictions = 0

        with torch.no_grad():
            pbar_valid = tqdm(valid_loader, desc="Validating")
            for batch in pbar_valid:
                real_sequences = batch['interleaved_tokens'].to(device)
                # Create negative samples for validation
                fake_sequences = create_negative_samples(real_sequences)

                combined_sequences = torch.cat([real_sequences, fake_sequences], dim=0)
                padding_mask = (combined_sequences == model.pad_token_id)

                real_labels = torch.ones(real_sequences.size(0), 1, device=device)
                fake_labels = torch.zeros(fake_sequences.size(0), 1, device=device)
                combined_labels = torch.cat([real_labels, fake_labels], dim=0)
                
                predictions = model(combined_sequences, padding_mask=padding_mask)
                predictions = torch.sigmoid(predictions)
                loss = loss_fn(predictions, combined_labels)
                total_valid_loss += loss.item()

                predicted_labels = (predictions > 0.5).float()
                correct_predictions += (predicted_labels == combined_labels).sum().item()
                total_predictions += combined_labels.size(0)

                pbar_valid.set_postfix({'loss': loss.item()})

        avg_valid_loss = total_valid_loss / len(valid_loader)
        accuracy = correct_predictions / total_predictions

        print(f"Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Valid Loss: {avg_valid_loss:.4f}, Accuracy: {accuracy:.4f}")
        wandb.log({
            'train/epoch_loss': avg_train_loss,
            'valid/epoch_loss': avg_valid_loss,
            'valid/accuracy': accuracy,
            'epoch': epoch + 1
        }, step=global_step)
        
        # Save checkpoint if validation loss improved
        if avg_valid_loss < best_val_loss:
            best_val_loss = avg_valid_loss
            # Save checkpoint locally
            checkpoint_path = Path("checkpoints") / f"discriminator_reward_epoch_{epoch+1}.pth"
            checkpoint_path.parent.mkdir(exist_ok=True)
            
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_valid_loss,
                'accuracy': accuracy,
                'config': config,
            }
            
            torch.save(checkpoint, checkpoint_path)
            
            # Log checkpoint as wandb artifact
            artifact = wandb.Artifact(
                name=f"discriminator_reward_model_{wandb.run.id}",
                type="model",
                description=f"Discriminative Reward Model checkpoint from epoch {epoch+1}"
            )
            artifact.add_file(str(checkpoint_path))
            wandb.log_artifact(artifact)
            
            # Also save tokenizer info as artifact
            tokenizer_artifact = wandb.Artifact(
                name=f"tokenizer_info_{wandb.run.id}",
                type="tokenizer",
                description="Tokenizer information used for training"
            )
            tokenizer_path = Path("tokenizer_info.json")
            with open(tokenizer_path, 'w') as f:
                json.dump(tokenizer_info, f)
            tokenizer_artifact.add_file(str(tokenizer_path))
            wandb.log_artifact(tokenizer_artifact)
            
            print(f"\nSaved checkpoint with validation loss: {avg_valid_loss:.4f}")
            
except KeyboardInterrupt:
    print("\nTraining interrupted by user")
except torch.cuda.OutOfMemoryError:
    print("\nOut of GPU memory! Try reducing batch size or sequence length")
finally:
    wandb.finish()
    print("\nTraining completed")


In [None]:
def load_best_checkpoint(model, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    return checkpoint['val_loss'], checkpoint['accuracy']

# Example usage:
# best_checkpoint_path = Path(wandb.run.dir) / 'best_checkpoint.pth'
# best_val_loss, best_accuracy = load_best_checkpoint(model, best_checkpoint_path)
# print(f"Loaded checkpoint with validation loss: {best_val_loss:.4f} and accuracy: {best_accuracy:.2f}%")
