In [23]:
import os
import torch
from music21 import converter, note, chord
from collections import defaultdict
from typing import List, Dict, Tuple
import pickle
import random

class MusicPreprocessor:
    def __init__(self, abc_dir: str):
        self.abc_dir = abc_dir
        self.token2idx = {}
        self.idx2token = {}
        self.vocab_size = 0
        self.special_tokens = {
            'PAD': '[PAD]',
            'START': '[START]',
            'END': '[END]',
            'UNK': '[UNK]'
        }
        self._initialize_special_tokens()

    def _initialize_special_tokens(self):
        """Initialize special tokens in the vocabulary"""
        for token in self.special_tokens.values():
            self._add_to_vocab(token)

    def _add_to_vocab(self, token: str):
        """Add a token to the vocabulary if it doesn't exist"""
        if token not in self.token2idx:
            self.token2idx[token] = self.vocab_size
            self.idx2token[self.vocab_size] = token
            self.vocab_size += 1

    def _tokenize_score(self, score) -> List[str]:
        """Convert a music21 score into a sequence of tokens"""
        tokens = []
        for element in score.recurse().notes:
            if isinstance(element, note.Note):
                # Format: NOTE_pitch_duration
                token = f"NOTE_{element.pitch.nameWithOctave}_{element.quarterLength}"
                tokens.append(token)
            elif isinstance(element, chord.Chord):
                # Format: CHORD_pitch1-pitch2-pitch3_duration
                pitches = "-".join(n.pitch.nameWithOctave for n in element.notes)
                token = f"CHORD_{pitches}_{element.quarterLength}"
                tokens.append(token)
        return tokens

    def process_files(self, train_split: float = 0.8, val_split: float = 0.1) -> Tuple[List[List[int]], List[List[int]], List[List[int]]]:
        """Process all ABC files and return tokenized sequences split into train, validation, and test sets"""
        all_sequences = []
        
        # Process all files first
        for filename in os.listdir(self.abc_dir):
            if filename.endswith('.abc'):
                file_path = os.path.join(self.abc_dir, filename)
                try:
                    # Parse the ABC file
                    score = converter.parse(file_path)
                    
                    # Tokenize the score
                    tokens = self._tokenize_score(score)
                    
                    # Add tokens to vocabulary
                    for token in tokens:
                        self._add_to_vocab(token)
                    
                    # Create sequence with start and end tokens
                    sequence = [self.token2idx[self.special_tokens['START']]]
                    sequence.extend(self.token2idx[token] for token in tokens)
                    sequence.append(self.token2idx[self.special_tokens['END']])
                    
                    all_sequences.append(sequence)
                except Exception as e:
                    print(f"Error processing {filename}: {e}")
        
        # Shuffle sequences
        random.shuffle(all_sequences)
        
        # Calculate split sizes
        total_size = len(all_sequences)
        train_size = int(total_size * train_split)
        val_size = int(total_size * val_split)
        
        # Split into train, validation, and test
        train_sequences = all_sequences[:train_size]
        val_sequences = all_sequences[train_size:train_size + val_size]
        test_sequences = all_sequences[train_size + val_size:]
        
        return train_sequences, val_sequences, test_sequences

    def save_vocab(self, save_path: str):
        """Save vocabulary to a file"""
        vocab_data = {
            'token2idx': self.token2idx,
            'idx2token': self.idx2token,
            'vocab_size': self.vocab_size
        }
        with open(save_path, 'wb') as f:
            pickle.dump(vocab_data, f)

def create_dataset(sequences: List[List[int]], max_len: int = 512) -> torch.Tensor:
    """Create padded dataset from sequences"""
    padded_sequences = []
    for seq in sequences:
        if len(seq) > max_len:
            # Truncate sequence if too long
            padded_seq = seq[:max_len]
        else:
            # Pad sequence if too short
            padded_seq = seq + [0] * (max_len - len(seq))
        padded_sequences.append(padded_seq)
    
    return torch.tensor(padded_sequences)

if __name__ == "__main__":
    # Set random seed for reproducibility
    random.seed(42)
    
    # Initialize preprocessor
    abc_dir = "nottingham-dataset/ABC_cleaned"
    preprocessor = MusicPreprocessor(abc_dir)
    
    # Process all files and split into train/val/test
    print("Processing ABC files...")
    train_sequences, val_sequences, test_sequences = preprocessor.process_files(train_split=0.8, val_split=0.1)
    
    # Create datasets
    print("Creating datasets...")
    train_dataset = create_dataset(train_sequences)
    val_dataset = create_dataset(val_sequences)
    test_dataset = create_dataset(test_sequences)
    
    # Save vocabulary and datasets
    print("Saving preprocessed data...")
    preprocessor.save_vocab("vocab.pkl")
    torch.save({
        'train': train_dataset,
        'val': val_dataset,
        'test': test_dataset
    }, "dataset.pt")
    
    print(f"Preprocessing complete!")
    print(f"Vocabulary size: {preprocessor.vocab_size}")
    print(f"Number of training sequences: {len(train_sequences)} ({len(train_sequences)/len(train_sequences+val_sequences+test_sequences)*100:.1f}%)")
    print(f"Number of validation sequences: {len(val_sequences)} ({len(val_sequences)/len(train_sequences+val_sequences+test_sequences)*100:.1f}%)")
    print(f"Number of test sequences: {len(test_sequences)} ({len(test_sequences)/len(train_sequences+val_sequences+test_sequences)*100:.1f}%)")
    print(f"Training dataset shape: {train_dataset.shape}")
    print(f"Validation dataset shape: {val_dataset.shape}")
    print(f"Test dataset shape: {test_dataset.shape}") 

Processing ABC files...
Creating datasets...
Saving preprocessed data...
Preprocessing complete!
Vocabulary size: 544
Number of training sequences: 11 (78.6%)
Number of validation sequences: 1 (7.1%)
Number of test sequences: 2 (14.3%)
Training dataset shape: torch.Size([11, 512])
Validation dataset shape: torch.Size([1, 512])
Test dataset shape: torch.Size([2, 512])


In [24]:
import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 512):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        return x + self.pe[:x.size(0)]

# class MusicTransformer(nn.Module):
#     def __init__(
#         self,
#         vocab_size: int,
#         d_model: int = 512,
#         nhead: int = 8,
#         num_encoder_layers: int = 6,
#         num_decoder_layers: int = 6,
#         dim_feedforward: int = 2048,
#         dropout: float = 0.1,
#         max_len: int = 512
#     ):
#         super().__init__()
        
#         self.embedding = nn.Embedding(vocab_size, d_model)
#         self.pos_encoder = PositionalEncoding(d_model, max_len)
        
#         self.transformer = nn.Transformer(
#             d_model=d_model,
#             nhead=nhead,
#             num_encoder_layers=num_encoder_layers,
#             num_decoder_layers=num_decoder_layers,
#             dim_feedforward=dim_feedforward,
#             dropout=dropout,
#             batch_first=True
#         )
        
#         self.fc_out = nn.Linear(d_model, vocab_size)
#         self.d_model = d_model
class MusicTransformer(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 512,
        nhead: int = 8,
        num_layers: int = 12,  # Increased depth
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        max_len: int = 512
    ):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_len)
        
        # Use decoder-only architecture
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers)
        
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.d_model = d_model
    
    # def forward(self, x, mask=None):
    #     # Embed and add positional encoding
    #     x = self.embedding(x) * math.sqrt(self.d_model)
    #     x = self.pos_encoder(x)
        
    #     # Create causal mask if not provided
    #     if mask is None:
    #         mask = self.generate_square_subsequent_mask(x.size(1)).to(x.device)
        
    #     # For decoder-only, we use the same sequence as both input and memory
    #     output = self.transformer(x, x, tgt_mask=mask)
    #     return self.fc_out(output)

    def forward(
        self,
        src: torch.Tensor,
        tgt: torch.Tensor,
        src_mask: torch.Tensor = None,
        tgt_mask: torch.Tensor = None,
        memory_mask: torch.Tensor = None,
        src_key_padding_mask: torch.Tensor = None,
        tgt_key_padding_mask: torch.Tensor = None,
        memory_key_padding_mask: torch.Tensor = None
    ) -> torch.Tensor:
        # Embed the source and target sequences
        src = self.embedding(src) * math.sqrt(self.d_model)
        tgt = self.embedding(tgt) * math.sqrt(self.d_model)
        
        # Add positional encoding
        src = self.pos_encoder(src)
        tgt = self.pos_encoder(tgt)
        
        # Pass through transformer
        output = self.transformer(
            src,
            tgt,
            src_mask=src_mask,
            tgt_mask=tgt_mask,
            memory_mask=memory_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask
        )
        
        # Project to vocabulary size
        return self.fc_out(output)
    def generate(
        self,
        start_token: int = 1,
        max_len: int = 512,
        temperature: float = 1.2,  # Higher temperature for more variety
        top_k: int = 50,
        top_p: float = 0.95,
        device: torch.device = None
    ) -> torch.Tensor:
        self.eval()
        if device is None:
            device = next(self.parameters()).device
            
        with torch.no_grad():
            # Start with just the start token
            sequence = torch.tensor([[start_token]], device=device)
            
            for _ in range(max_len - 1):
                # Get predictions for next token
                logits = self.forward(sequence)
                next_token_logits = logits[0, -1, :] / temperature
                
                # Apply top-k filtering
                if top_k > 0:
                    top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
                    next_token_logits = torch.full_like(next_token_logits, float('-inf'))
                    next_token_logits.scatter_(0, top_k_indices, top_k_logits)
                
                # Apply top-p (nucleus) filtering
                if top_p < 1.0:
                    sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                    cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
                    sorted_indices_to_remove = cumulative_probs > top_p
                    sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
                    sorted_indices_to_remove[0] = 0
                    indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
                    next_token_logits[indices_to_remove] = float('-inf')
                
                # Sample next token
                probs = torch.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1).unsqueeze(0)
                
                # Append to sequence
                sequence = torch.cat([sequence, next_token], dim=1)
                
                # Stop if we hit the end token
                if next_token.item() == 2:  # Assuming 2 is END token
                    break
                    
            return sequence

    # def generate(
    #     self,
    #     src: torch.Tensor,
    #     max_len: int = 512,
    #     temperature: float = 1.0,
    #     top_k: int = 0,
    #     top_p: float = 0.9
    # ) -> torch.Tensor:
    #     self.eval()
    #     with torch.no_grad():
    #         batch_size = src.shape[0]
            
    #         # Start with START token
    #         dec_input = torch.ones(batch_size, 1).long().to(src.device)
            
    #         for _ in range(max_len - 1):
    #             # Create masks
    #             src_mask = torch.zeros((src.shape[1], src.shape[1])).to(src.device)
    #             tgt_mask = self.generate_square_subsequent_mask(dec_input.shape[1]).to(src.device)
                
    #             # Get next token probabilities
    #             out = self.forward(
    #                 src,
    #                 dec_input,
    #                 src_mask=src_mask,
    #                 tgt_mask=tgt_mask
    #             )
                
    #             # Get probabilities for next token
    #             next_token_logits = out[:, -1, :] / temperature
                
    #             # Apply top-k filtering
    #             if top_k > 0:
    #                 indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
    #                 next_token_logits[indices_to_remove] = float('-inf')
                
    #             # Apply top-p (nucleus) filtering
    #             if top_p < 1.0:
    #                 sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
    #                 cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
    #                 sorted_indices_to_remove = cumulative_probs > top_p
    #                 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    #                 sorted_indices_to_remove[..., 0] = 0
    #                 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
    #                 next_token_logits[indices_to_remove] = float('-inf')
                
    #             # Sample next token
    #             probs = torch.softmax(next_token_logits, dim=-1)
    #             next_token = torch.multinomial(probs, num_samples=1)
                
    #             # Append next token to decoder input
    #             dec_input = torch.cat([dec_input, next_token], dim=1)
                
    #             # Stop if we predict the END token
    #             if (next_token == 2).any():  # Assuming 2 is the END token index
    #                 break
            
    #         return dec_input
    
    @staticmethod
    def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
        """Generate a square mask for the sequence."""
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask 

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import pickle
from model import MusicTransformer
import numpy as np
from tqdm import tqdm

def create_src_tgt_data(sequences: torch.Tensor) -> tuple:
    """Create source and target sequences for training"""
    src = sequences[:, :-1]  # All but last token
    tgt = sequences[:, 1:]   # All but first token
    return src, tgt

def create_padding_mask(seq: torch.Tensor, pad_idx: int = 0) -> torch.Tensor:
    """Create padding mask for transformer"""
    return (seq == pad_idx)

# def train_epoch(
#     model: nn.Module,
#     dataloader: DataLoader,
#     criterion: nn.Module,
#     optimizer: optim.Optimizer,
#     device: torch.device,
#     pad_idx: int = 0
# ) -> float:
#     model.train()
#     total_loss = 0
    
#     for batch_idx, (src, tgt) in enumerate(tqdm(dataloader, desc="Training")):
#         src, tgt = src.to(device), tgt.to(device)
        
#         # Create masks
#         src_padding_mask = create_padding_mask(src, pad_idx).to(device)
#         tgt_padding_mask = create_padding_mask(tgt, pad_idx).to(device)
#         tgt_mask = model.generate_square_subsequent_mask(tgt.size(1)).to(device)
        
#         # Forward pass
#         optimizer.zero_grad()
#         output = model(
#             src,
#             tgt,
#             tgt_mask=tgt_mask,
#             src_key_padding_mask=src_padding_mask,
#             tgt_key_padding_mask=tgt_padding_mask
#         )
        
#         # Calculate loss
#         loss = criterion(output.view(-1, output.size(-1)), tgt.view(-1))
        
#         # Backward pass
#         loss.backward()
#         torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
#         optimizer.step()
        
#         total_loss += loss.item()
        
#     return total_loss / len(dataloader)
def train_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    device: torch.device,
    pad_idx: int = 0
) -> float:
    model.train()
    total_loss = 0
    
    for batch_idx, sequences in enumerate(tqdm(dataloader, desc="Training")):
        sequences = sequences.to(device)
        
        # Input and target sequences for autoregressive training
        input_seq = sequences[:, :-1]
        target_seq = sequences[:, 1:]
        
        # Create causal mask
        seq_len = input_seq.size(1)
        mask = model.generate_square_subsequent_mask(seq_len).to(device)
        
        optimizer.zero_grad()
        output = model(input_seq, mask=mask)
        
        # Calculate loss with diversity regularization
        loss = criterion(output.reshape(-1, output.size(-1)), target_seq.reshape(-1))
        
        # Add diversity loss to prevent repetition
        diversity_loss = 0
        for i in range(1, min(8, seq_len)):
            repeated_mask = (input_seq[:, i:] == input_seq[:, :-i]).float()
            diversity_loss += repeated_mask.mean()
        
        total_loss_with_reg = loss + 0.01 * diversity_loss
        
        total_loss_with_reg.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

def evaluate(
    model: nn.Module,
    dataloader: DataLoader,
    criterion: nn.Module,
    device: torch.device,
    desc: str,
    pad_idx: int = 0
) -> float:
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for src, tgt in tqdm(dataloader, desc=desc):
            src, tgt = src.to(device), tgt.to(device)
            
            # Create masks
            src_padding_mask = create_padding_mask(src, pad_idx).to(device)
            tgt_padding_mask = create_padding_mask(tgt, pad_idx).to(device)
            tgt_mask = model.generate_square_subsequent_mask(tgt.size(1)).to(device)
            
            # Forward pass
            output = model(
                src,
                tgt,
                tgt_mask=tgt_mask,
                src_key_padding_mask=src_padding_mask,
                tgt_key_padding_mask=tgt_padding_mask
            )
            
            # Calculate loss
            loss = criterion(output.view(-1, output.size(-1)), tgt.view(-1))
            total_loss += loss.item()
    
    return total_loss / len(dataloader)

def main():
    # Load preprocessed data
    print("Loading data...")
    datasets = torch.load("dataset.pt")
    train_dataset = datasets['train']
    val_dataset = datasets['val']
    test_dataset = datasets['test']
    
    with open("vocab.pkl", "rb") as f:
        vocab_data = pickle.load(f)
    
    # Create dataloaders for decoder-only training
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)  # Smaller batch size
    val_loader = DataLoader(val_dataset, batch_size=8)
    test_loader = DataLoader(test_dataset, batch_size=8)
    
    # Initialize model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = MusicTransformer(vocab_size=vocab_data['vocab_size']).to(device)
    
    # Training parameters
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    
    # Training loop
    num_epochs = 50
    best_val_loss = float('inf')
    patience = 10
    patience_counter = 0
    
    print("Starting training...")
    for epoch in range(num_epochs):
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss = evaluate(model, val_loader, criterion, device, "Validating")
        
        scheduler.step(val_loss)
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
            }, 'best_model.pt')
        else:
            patience_counter += 1
        
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Validation Loss: {val_loss:.4f}")
        print("------------------------")
        
        if patience_counter >= patience:
            print("Early stopping triggered!")
            break


# def main():
#     # Load preprocessed data
#     print("Loading data...")
#     datasets = torch.load("dataset.pt")
#     train_dataset = datasets['train']
#     val_dataset = datasets['val']
#     test_dataset = datasets['test']
    
#     with open("vocab.pkl", "rb") as f:
#         vocab_data = pickle.load(f)
    
#     # Create src-tgt pairs for all sets
#     train_src, train_tgt = create_src_tgt_data(train_dataset)
#     val_src, val_tgt = create_src_tgt_data(val_dataset)
#     test_src, test_tgt = create_src_tgt_data(test_dataset)
    
#     # Create dataloaders
#     train_data = TensorDataset(train_src, train_tgt)
#     val_data = TensorDataset(val_src, val_tgt)
#     test_data = TensorDataset(test_src, test_tgt)
    
#     train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
#     val_loader = DataLoader(val_data, batch_size=32)
#     test_loader = DataLoader(test_data, batch_size=32)
    
#     # Initialize model
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     model = MusicTransformer(vocab_size=vocab_data['vocab_size']).to(device)
    
#     # Training parameters
#     criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding index
#     optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
#     scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    
#     # Training loop
#     num_epochs = 50
#     best_val_loss = float('inf')
#     patience = 10  # Early stopping patience
#     patience_counter = 0
    
#     print("Starting training...")
#     for epoch in range(num_epochs):
#         # Train
#         train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
        
#         # Validate
#         val_loss = evaluate(model, val_loader, criterion, device, "Validating")
        
#         # Update learning rate
#         scheduler.step(val_loss)
        
#         # Save best model based on validation loss
#         if val_loss < best_val_loss:
#             best_val_loss = val_loss
#             patience_counter = 0
#             torch.save({
#                 'epoch': epoch,
#                 'model_state_dict': model.state_dict(),
#                 'optimizer_state_dict': optimizer.state_dict(),
#                 'val_loss': val_loss,
#             }, 'best_model.pt')
#         else:
#             patience_counter += 1
        
#         print(f"Epoch {epoch+1}/{num_epochs}")
#         print(f"Train Loss: {train_loss:.4f}")
#         print(f"Validation Loss: {val_loss:.4f}")
#         print("------------------------")
        
#         # Early stopping
#         if patience_counter >= patience:
#             print("Early stopping triggered!")
#             break
    
#     # Load best model for final test evaluation
#     checkpoint = torch.load('best_model.pt')
#     model.load_state_dict(checkpoint['model_state_dict'])
#     test_loss = evaluate(model, test_loader, criterion, device, "Testing")
#     print(f"\nFinal Test Loss: {test_loss:.4f}")

# if __name__ == "__main__":
#     main() 

Loading data...
Starting training...


Training: 100%|██████████| 1/1 [00:12<00:00, 12.15s/it]
Validating: 100%|██████████| 1/1 [00:00<00:00,  5.77it/s]


Epoch 1/50
Train Loss: 6.5792
Validation Loss: 5.7796
------------------------


Training: 100%|██████████| 1/1 [00:11<00:00, 11.94s/it]
Validating: 100%|██████████| 1/1 [00:00<00:00,  5.87it/s]


Epoch 2/50
Train Loss: 5.4812
Validation Loss: 5.5364
------------------------


Training:   0%|          | 0/1 [00:10<?, ?it/s]


KeyboardInterrupt: 

In [None]:
import torch
import pickle
from model import MusicTransformer
from music21 import note, chord, stream, instrument
from typing import List, Union
from fractions import Fraction
from midi2audio import FluidSynth
import os

def parse_duration(duration_str: str) -> float:
    """Parse duration string that might be a fraction or float"""
    try:
        if '/' in duration_str:
            return float(Fraction(duration_str))
        return float(duration_str)
    except ValueError as e:
        print(f"Error parsing duration {duration_str}: {e}")
        return 1.0  # Default duration

def token_to_music21(token: str) -> Union[note.Note, chord.Chord, None]:
    """Convert a token to a music21 note or chord"""
    try:
        if token.startswith('NOTE_'):
            _, pitch, duration = token.split('_')
            if not pitch:
                print(f"Invalid pitch in token: {token}")
                return None
            return note.Note(pitch, quarterLength=parse_duration(duration))
        elif token.startswith('CHORD_'):
            parts = token.split('_')
            if len(parts) != 3:
                print(f"Invalid chord token format: {token}")
                return None
            _, pitches_str, duration = parts
            if not pitches_str:
                print(f"Empty pitches in chord token: {token}")
                return None
            
            # Split and clean up pitch names
            pitches = []
            raw_pitches = pitches_str.split('-')
            i = 0
            while i < len(raw_pitches):
                p = raw_pitches[i]
                if not p:
                    i += 1
                    continue
                
                # If this pitch has a number, it's complete
                if any(c.isdigit() for c in p):
                    pitches.append(p)
                    i += 1
                    continue
                
                # If this pitch doesn't have a number but the next one does
                if i + 1 < len(raw_pitches) and any(c.isdigit() for c in raw_pitches[i + 1]):
                    pitches.append(p + raw_pitches[i + 1])
                    i += 2
                else:
                    # Use the octave from the previous note or default to 4
                    prev_octave = ''.join(c for c in pitches[-1] if c.isdigit()) if pitches else '4'
                    pitches.append(p + prev_octave)
                    i += 1
            
            if not pitches:
                print(f"No valid pitches in chord token: {token}")
                return None
                
            return chord.Chord(pitches, quarterLength=parse_duration(duration))
        return None
    except Exception as e:
        print(f"Error converting token {token}: {e}")
        return None

def generate_music(
    model: MusicTransformer,
    vocab_data: dict,
    seed_sequence: torch.Tensor = None,
    max_len: int = 512,
    temperature: float = 1.0,
    top_k: int = 50,
    top_p: float = 0.9,
    min_sequence_length: int = 64,  # Minimum sequence length to ensure enough music
    device: torch.device = None
) -> List[str]:
    """Generate a sequence of music tokens"""
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    try:
        # If no seed sequence provided, start with just the START token
        if seed_sequence is None:
            # Find the START token index
            start_token = '[START]'
            token2idx = vocab_data['token2idx']
            start_idx = token2idx.get(start_token, 1)  # Default to 1 if not found
            seed_sequence = torch.tensor([[start_idx]]).to(device)
        
        # Generate sequence
        generated_sequence = model.generate(
            seed_sequence,
            max_len=max_len,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p
        )
        
        # Convert indices to tokens
        idx2token = vocab_data['idx2token']
        tokens = []
        for idx in generated_sequence[0]:
            idx_num = idx.item()
            token = idx2token.get(idx_num, '[UNK]')
            
            # Skip special tokens except START and END
            if token in ['[PAD]', '[UNK]']:
                continue
                
            tokens.append(token)
            
            # If we have enough tokens and encounter END, stop
            if token == '[END]' and len(tokens) >= min_sequence_length:
                break
        
        # If we don't have enough tokens or no END token was found, generate more
        while len(tokens) < min_sequence_length:
            # Generate more sequence
            additional_sequence = model.generate(
                torch.tensor([[token2idx[tokens[-1]]]]).to(device),
                max_len=32,  # Generate a shorter sequence
                temperature=temperature,
                top_k=top_k,
                top_p=top_p
            )
            
            # Add new tokens
            for idx in additional_sequence[0][1:]:  # Skip first token as it's the seed
                idx_num = idx.item()
                token = idx2token.get(idx_num, '[UNK]')
                if token not in ['[PAD]', '[UNK]', '[START]']:
                    tokens.append(token)
                    if token == '[END]' and len(tokens) >= min_sequence_length:
                        break
        
        return tokens
    except Exception as e:
        print(f"Error in generate_music: {e}")
        print(f"Vocabulary structure: {vocab_data.keys()}")
        print(f"idx2token keys type: {type(next(iter(idx2token.keys())))} example: {next(iter(idx2token.keys()))}")
        raise

def tokens_to_music21(tokens: List[str]) -> stream.Stream:
    """Convert a list of tokens to a music21 stream"""
    score = stream.Score()
    part = stream.Part()
    
    # Set the instrument to Violin
    violin = instrument.Violin()
    part.insert(0, violin)
    
    for token in tokens:
        if token.startswith(('[START]', '[END]', '[PAD]', '[UNK]')):
            continue
        
        element = token_to_music21(token)
        if element is not None:
            part.append(element)
    
    score.append(part)
    return score

def generate_multiple_samples(model, vocab_data, num_samples=5, **kwargs):
    """Generate multiple music samples"""
    midi_files = []
    wav_files = []
    
    for i in range(num_samples):
        print(f"\nGenerating sample {i+1}/{num_samples}")
        
        # Generate tokens
        tokens = generate_music(model, vocab_data, **kwargs)
        
        # Convert to score
        score = tokens_to_music21(tokens)
        
        # Save as MIDI
        midi_file = f"generated_music_{i}.mid"
        score.write('midi', fp=midi_file)
        midi_files.append(midi_file)
        
        # Convert to WAV
        wav_file = f"generated_music_{i}.wav"
        wav_files.append(wav_file)
    
    return midi_files, wav_files

def main():
    try:
        # Load vocabulary and model
        print("Loading model and vocabulary...")
        with open("vocab.pkl", "rb") as f:
            vocab_data = pickle.load(f)
        
        # Print vocabulary information
        print(f"Vocabulary size: {vocab_data['vocab_size']}")
        print(f"Special tokens: {[token for token in vocab_data['token2idx'].keys() if token.startswith('[')]}")
        
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")
        
        model = MusicTransformer(vocab_size=vocab_data['vocab_size']).to(device)
        
        # Load trained model weights
        print("Loading model weights...")
        checkpoint = torch.load("best_model.pt", map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        
        # Generate multiple samples
        print("\nGenerating music samples...")
        num_samples = 1
        midi_files, wav_files = generate_multiple_samples(
            model,
            vocab_data,
            num_samples=num_samples,
            temperature=0.85,  # Higher temperature for more variety
            top_k=40,         # Less restrictive filtering
            top_p=0.92,       # Less restrictive nucleus sampling
            device=device
        )
        
        # Convert MIDI to audio using FluidSynth
        print("\nConverting MIDI to audio...")
        try:
            fs = FluidSynth("default.sf2")  # Using the default soundfont that works
            for midi_file, wav_file in zip(midi_files, wav_files):
                print(f"Converting {midi_file} to {wav_file}")
                fs.midi_to_audio(midi_file, wav_file)
        except Exception as e:
            print(f"Error converting to audio: {e}")
            print("Make sure the soundfont file is present in the current directory")
        
        print("\nGenerated files:")
        for i in range(num_samples):
            print(f"Sample {i+1}:")
            print(f"  MIDI: {midi_files[i]}")
            print(f"  WAV: {wav_files[i]}")
        
    except Exception as e:
        print(f"Error in main: {e}")
        print("Vocabulary data keys:", vocab_data.keys() if 'vocab_data' in locals() else "vocab_data not loaded")
        raise

if __name__ == "__main__":
    main() 

Loading model and vocabulary...
Vocabulary size: 544
Special tokens: ['[PAD]', '[START]', '[END]', '[UNK]']
Using device: cpu
Loading model weights...

Generating music samples...

Generating sample 1/1

Converting MIDI to audio...
Converting generated_music_0.mid to generated_music_0.wav





Generated files:FluidSynth runtime version 2.4.6
Copyright (C) 2000-2025 Peter Hanappe and others.
Distributed under the LGPL license.
SoundFont(R) is a registered trademark of Creative Technology Ltd.

Rendering audio to file 'generated_music_0.wav'..

Sample 1:
  MIDI: generated_music_0.mid
  WAV: generated_music_0.wav
