In [1]:
#!/usr/bin/env python3
"""
Continuous Unconditional Music Generation using MAESTRO Dataset
Trains a transformer model on MIDI files and generates music continuously
"""

import os
import glob
import random
import time
import pickle
from pathlib import Path
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pretty_midi
from tqdm import tqdm

# Configuration
class Config:
    # Data
    dataset_path = "maestro-v3.0.0"
    vocab_size = 128  # MIDI note range (0-127)
    sequence_length = 512
    
    # Model
    d_model = 512
    nhead = 8
    num_layers = 6
    dim_feedforward = 2048
    dropout = 0.1
    
    # Training
    batch_size = 16
    learning_rate = 1e-4
    epochs = 50
    save_interval = 5
    
    # Generation
    temperature = 1.0
    max_generate_length = 1000
    generation_interval = 10  # Generate every N seconds
    
    # Paths
    model_save_path = "music_model.pth"
    tokenizer_save_path = "tokenizer.pkl"
    output_dir = "generated_music"

config = Config()

class MIDITokenizer:
    """Simple MIDI tokenizer that converts notes to tokens and back"""
    
    def __init__(self):
        # Special tokens
        self.pad_token = 0
        self.start_token = 1
        self.end_token = 2
        self.rest_token = 3  # For silence/rests
        
        # Note tokens start from 4, covering MIDI notes 0-127
        self.note_offset = 4
        self.vocab_size = 128 + self.note_offset
    
    def midi_to_sequence(self, midi_file):
        """Convert MIDI file to sequence of tokens"""
        try:
            midi = pretty_midi.PrettyMIDI(midi_file)
            notes = []
            
            # Extract all notes from all instruments
            for instrument in midi.instruments:
                if not instrument.is_drum:  # Skip drum tracks
                    for note in instrument.notes:
                        notes.append((note.start, note.pitch, note.end - note.start))
            
            # Sort by start time
            notes.sort(key=lambda x: x[0])
            
            if not notes:
                return []
            
            # Convert to token sequence
            sequence = [self.start_token]
            current_time = 0
            time_resolution = 0.1  # 100ms resolution
            
            for start_time, pitch, duration in notes:
                # Add rest tokens for silence
                rest_steps = int((start_time - current_time) / time_resolution)
                sequence.extend([self.rest_token] * min(rest_steps, 10))  # Cap rests
                
                # Add note token
                if 0 <= pitch <= 127:
                    sequence.append(pitch + self.note_offset)
                
                current_time = start_time + duration
            
            sequence.append(self.end_token)
            return sequence[:2000]  # Limit sequence length
            
        except Exception as e:
            print(f"Error processing {midi_file}: {e}")
            return []
    
    def sequence_to_midi(self, sequence, output_file):
        """Convert token sequence back to MIDI file"""
        midi = pretty_midi.PrettyMIDI()
        instrument = pretty_midi.Instrument(program=0)  # Piano
        
        current_time = 0
        time_resolution = 0.1
        note_duration = 0.5
        
        for token in sequence:
            if token == self.rest_token:
                current_time += time_resolution
            elif token >= self.note_offset and token < self.vocab_size:
                pitch = token - self.note_offset
                if 0 <= pitch <= 127:
                    note = pretty_midi.Note(
                        velocity=64,
                        pitch=int(pitch),
                        start=current_time,
                        end=current_time + note_duration
                    )
                    instrument.notes.append(note)
                current_time += time_resolution
            elif token == self.end_token:
                break
        
        midi.instruments.append(instrument)
        midi.write(output_file)

class MIDIDataset(Dataset):
    """Dataset for MIDI sequences"""
    
    def __init__(self, sequences, seq_length):
        self.sequences = sequences
        self.seq_length = seq_length
        
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        seq = self.sequences[idx]
        if len(seq) < self.seq_length + 1:
            # Pad sequence
            seq = seq + [0] * (self.seq_length + 1 - len(seq))
        else:
            # Random crop
            start_idx = random.randint(0, len(seq) - self.seq_length - 1)
            seq = seq[start_idx:start_idx + self.seq_length + 1]
        
        return torch.tensor(seq[:-1], dtype=torch.long), torch.tensor(seq[1:], dtype=torch.long)

class MusicTransformer(nn.Module):
    """Transformer model for music generation"""
    
    def __init__(self, vocab_size, d_model, nhead, num_layers, dim_feedforward, dropout):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = self._generate_pos_encoding(5000, d_model)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.output_layer = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    def _generate_pos_encoding(self, max_len, d_model):
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0)
    
    def forward(self, x):
        seq_len = x.size(1)
        x = self.embedding(x) * np.sqrt(self.d_model)
        x = x + self.pos_encoding[:, :seq_len, :].to(x.device)
        x = self.dropout(x)
        x = self.transformer(x)
        return self.output_layer(x)

class MusicGenerator:
    """Main class for training and generating music"""
    
    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {self.device}")
        
        self.tokenizer = MIDITokenizer()
        self.model = None
        self.sequences = []
        
        # Create output directory
        os.makedirs(config.output_dir, exist_ok=True)
    
    def load_dataset(self):
        """Load and tokenize MIDI files from dataset"""
        print("Loading MIDI dataset...")
        midi_files = glob.glob(os.path.join(config.dataset_path, "**", "*.mid*"), recursive=True)
        
        if not midi_files:
            print(f"No MIDI files found in {config.dataset_path}")
            return
        
        print(f"Found {len(midi_files)} MIDI files")
        
        sequences = []
        for midi_file in tqdm(midi_files, desc="Processing MIDI files"):
            seq = self.tokenizer.midi_to_sequence(midi_file)
            if len(seq) > 50:  # Only keep sequences with reasonable length
                sequences.append(seq)
        
        self.sequences = sequences
        print(f"Loaded {len(sequences)} valid sequences")
        
        # Save tokenizer
        with open(config.tokenizer_save_path, 'wb') as f:
            pickle.dump(self.tokenizer, f)
    
    def create_model(self):
        """Create and initialize the model"""
        self.model = MusicTransformer(
            vocab_size=self.tokenizer.vocab_size,
            d_model=config.d_model,
            nhead=config.nhead,
            num_layers=config.num_layers,
            dim_feedforward=config.dim_feedforward,
            dropout=config.dropout
        ).to(self.device)
        
        print(f"Model created with {sum(p.numel() for p in self.model.parameters())} parameters")
    
    def train(self):
        """Train the model"""
        if not self.sequences:
            print("No sequences loaded. Please run load_dataset() first.")
            return
        
        dataset = MIDIDataset(self.sequences, config.sequence_length)
        dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
        
        optimizer = optim.Adam(self.model.parameters(), lr=config.learning_rate)
        criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding
        
        self.model.train()
        
        for epoch in range(config.epochs):
            total_loss = 0
            progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{config.epochs}")
            
            for batch_idx, (input_seq, target_seq) in enumerate(progress_bar):
                input_seq, target_seq = input_seq.to(self.device), target_seq.to(self.device)
                
                optimizer.zero_grad()
                output = self.model(input_seq)
                loss = criterion(output.view(-1, output.size(-1)), target_seq.view(-1))
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
            
            avg_loss = total_loss / len(dataloader)
            print(f"Epoch {epoch+1} completed. Average loss: {avg_loss:.4f}")
            
            # Save model periodically
            if (epoch + 1) % config.save_interval == 0:
                self.save_model(f"model_epoch_{epoch+1}.pth")
        
        # Save final model
        self.save_model(config.model_save_path)
        print("Training completed!")
    
    def save_model(self, path):
        """Save model state"""
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'config': config
        }, path)
        print(f"Model saved to {path}")
    
    def load_model(self, path):
        """Load model state"""
        if not os.path.exists(path):
            print(f"Model file {path} not found")
            return False
        
        checkpoint = torch.load(path, map_location=self.device)
        
        if self.model is None:
            self.create_model()
        
        self.model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Model loaded from {path}")
        return True
    
    def generate_sequence(self, length, temperature=1.0):
        """Generate a music sequence"""
        self.model.eval()
        
        # Start with start token
        sequence = [self.tokenizer.start_token]
        
        with torch.no_grad():
            for _ in range(length):
                # Prepare input (last sequence_length tokens)
                input_seq = sequence[-config.sequence_length:]
                input_tensor = torch.tensor([input_seq], dtype=torch.long).to(self.device)
                
                # Get model prediction
                output = self.model(input_tensor)
                logits = output[0, -1, :] / temperature
                
                # Sample next token
                probs = torch.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, 1).item()
                
                sequence.append(next_token)
                
                # Stop if end token is generated
                if next_token == self.tokenizer.end_token:
                    break
        
        return sequence
    
    def continuous_generation(self):
        """Generate music continuously"""
        print("Starting continuous music generation...")
        print(f"Generating new music every {config.generation_interval} seconds")
        print("Press Ctrl+C to stop")
        
        generation_count = 0
        
        try:
            while True:
                generation_count += 1
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                
                print(f"\nGenerating piece #{generation_count} at {timestamp}")
                
                # Generate sequence
                sequence = self.generate_sequence(
                    config.max_generate_length,
                    config.temperature
                )
                
                # Save as MIDI
                output_file = os.path.join(
                    config.output_dir,
                    f"generated_{timestamp}_{generation_count}.mid"
                )
                
                self.tokenizer.sequence_to_midi(sequence, output_file)
                print(f"Generated MIDI saved: {output_file}")
                
                # Wait before next generation
                time.sleep(config.generation_interval)
                
        except KeyboardInterrupt:
            print(f"\nStopped. Generated {generation_count} pieces.")

In [2]:
"""Main function"""
generator = MusicGenerator()

# Check if model exists
if os.path.exists(config.model_save_path):
    print("Found existing model. Loading...")
    if generator.load_model(config.model_save_path):
        # Load tokenizer
        if os.path.exists(config.tokenizer_save_path):
            with open(config.tokenizer_save_path, 'rb') as f:
                generator.tokenizer = pickle.load(f)
        
        # Start continuous generation
        generator.continuous_generation()
    else:
        print("Failed to load model. Starting training...")
else:
    print("No existing model found. Starting training...")

# Training pipeline
generator.load_dataset()
if generator.sequences:
    generator.create_model()
    generator.train()
    
    # Start continuous generation after training
    generator.continuous_generation()
else:
    print("No training data loaded. Please check your dataset path.")

Using device: cpu
No existing model found. Starting training...
Loading MIDI dataset...
Found 1276 MIDI files


Processing MIDI files: 100%|██████████| 1276/1276 [05:44<00:00,  3.71it/s]


Loaded 1276 valid sequences
Model created with 19049604 parameters


Epoch 1/50:  42%|████▎     | 34/80 [09:03<12:14, 15.97s/it, loss=3.4456]


KeyboardInterrupt: 