In [8]:
#Task 1: Symbolic, unconditioned generation
#Using GROOVE MIDI DATASET to generate drum patterns using RNN

In [9]:
#The more imports the merrier?
import os
import glob
import random
import pandas as pd
import numpy as np
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
import math
import hashlib
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from symusic import Score
from miditok import REMI, TokenizerConfig
from midiutil import MIDIFile
import urllib.request
import zipfile
import re
import shutil

In [39]:
#Load in the Groove MIDI dataset
class GrooveDatasetLoader:
    def __init__(self, data_dir: str = "./groove_data"):
        self.data_dir = data_dir
        self.midi_dir = os.path.join(data_dir, "groove")
        self.metadata_file = None
        
    def download_dataset(self):
        self._find_metadata_file()
        return self
    
    def _find_metadata_file(self):
        possible_locations = [
            os.path.join(self.data_dir, "info.csv"),
            os.path.join(self.data_dir, "groove", "info.csv"),
            os.path.join(self.data_dir, "groove-v1.0.0", "info.csv"),
        ]
        
        for root, dirs, files in os.walk(self.data_dir):
            for file in files:
                if file == "info.csv":
                    possible_locations.append(os.path.join(root, file))
        
        for location in possible_locations:
            if os.path.exists(location):
                self.metadata_file = location
                return
    
    def _safe_extract_zip(self, zip_path: str, extract_to: str):      
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            for member in zip_ref.infolist():
                filename = member.filename
                filename = re.sub(r'[<>:"|?*\r\n\x00]', '_', filename)
                
                target_path = os.path.join(extract_to, filename)
                
                if member.is_dir():
                    os.makedirs(target_path, exist_ok=True)
                else:
                    os.makedirs(os.path.dirname(target_path), exist_ok=True)
                    
                    try:
                        with zip_ref.open(member) as source, open(target_path, "wb") as target:
                            shutil.copyfileobj(source, target)
                    except (OSError, IOError) as e:
                        print(f"Skipping file due to error: {filename} - {e}")
                        continue
    
    def load_metadata(self) -> pd.DataFrame:
        if self.metadata_file is None or not os.path.exists(self.metadata_file):
            return self._create_basic_metadata()
        
        df = pd.read_csv(self.metadata_file)
        return df
    
    def _create_basic_metadata(self) -> pd.DataFrame:
        midi_files = []
        
        # Find all MIDI files
        for root, dirs, files in os.walk(self.data_dir):
            for file in files:
                if file.endswith('.mid'):
                    full_path = os.path.join(root, file)
                    relative_path = os.path.relpath(full_path, self.data_dir)
                    midi_files.append({
                        'midi_filename': relative_path,
                        'split': self._infer_split(file),
                        'style': self._infer_style(file),
                        'drummer': self._infer_drummer(relative_path),
                        'beat_type': 'beat'  # Default
                    })
        df = pd.DataFrame(midi_files)
        return df
    
    def _infer_split(self, filename: str) -> str:
        hash_val = int(hashlib.md5(filename.encode()).hexdigest(), 16)
        if hash_val % 10 < 8:
            return 'train'
        elif hash_val % 10 == 8:
            return 'validation'
        else:
            return 'test'
    
    def _infer_style(self, filename: str) -> str:
        filename_lower = filename.lower()
        styles = ['funk', 'rock', 'jazz', 'latin', 'afrobeat', 'blues', 'soul', 'hiphop']
        for style in styles:
            if style in filename_lower:
                return style
        return 'unknown'
    
    def _infer_drummer(self, relative_path: str) -> str:
        parts = relative_path.split(os.sep)
        for part in parts:
            if part.startswith('drummer'):
                return part
        return 'unknown'
    
    def get_file_lists(self) -> Tuple[List[str], List[str], List[str]]:
        df = self.load_metadata()
        if df is None or len(df) == 0:
            return [], [], []
        
        train_files = []
        val_files = []
        test_files = []
        
        for _, row in df.iterrows():
            midi_path = os.path.join(self.data_dir, row['midi_filename'])
            if os.path.exists(midi_path):
                if row['split'] == 'train':
                    train_files.append(midi_path)
                elif row['split'] == 'validation':
                    val_files.append(midi_path)
                elif row['split'] == 'test':
                    test_files.append(midi_path)
                
        # If no files found with metadata approach, scan directory directly
        if len(train_files) + len(val_files) + len(test_files) == 0:
            all_midi_files = []
            for root, dirs, files in os.walk(self.data_dir):
                for file in files:
                    if file.endswith('.mid'):
                        all_midi_files.append(os.path.join(root, file))
            
            # Split files manually
            random.shuffle(all_midi_files)
            n_files = len(all_midi_files)
            train_files = all_midi_files[:int(0.8 * n_files)]
            val_files = all_midi_files[int(0.8 * n_files):int(0.9 * n_files)]
            test_files = all_midi_files[int(0.9 * n_files):]
            
        return train_files, val_files, test_files
    
    def analyze_dataset(self):
        df = self.load_metadata()
        if df is None or len(df) == 0:
            # Basic analysis without metadata
            all_midi_files = []
            for root, dirs, files in os.walk(self.data_dir):
                for file in files:
                    if file.endswith('.mid'):
                        all_midi_files.append(os.path.join(root, file))
            return
        
        print("\n---GROOVE DATASET ANALYSIS---")
        print(f"Total tracks: {len(df)}")
        
        if 'duration' in df.columns:
            print(f"Total duration: {df['duration'].sum()/60:.1f} minutes")
            print(f"Average duration: {df['duration'].mean():.1f} seconds")
        if 'beat_type' in df.columns:
            print(f"Beat types: {df['beat_type'].value_counts().to_dict()}")
        if 'style' in df.columns:
            print(f"Top styles:")
            for style, count in df['style'].value_counts().head(10).items():
                print(f"  {style}: {count}")
            top_styles = []
            for style in df['style'].value_counts().head(4).keys():
                top_styles.append(style)
        if 'drummer' in df.columns:
            print(f"Number of drummers: {df['drummer'].nunique()}")


In [40]:
# DRUM-SPECIFIC TOKENIZATION

class DrumTokenizer:
    """
    Specialized tokenizer for drum patterns using the Groove dataset drum mapping
    """
    
    def __init__(self):
        # Groove dataset drum mapping (from the documentation you provided)
        self.drum_mapping = {
            36: "Bass",           # Kick
            38: "Snare",          # Snare (Head)
            40: "Snare",          # Snare (Rim) -> map to same as snare
            37: "Snare",          # Snare X-Stick -> map to same as snare
            42: "HH_Closed",      # HH Closed (Bow)
            22: "HH_Closed",      # HH Closed (Edge) -> map to same
            44: "HH_Closed",      # HH Pedal -> map to same
            46: "HH_Open",        # HH Open (Bow)
            26: "HH_Open",        # HH Open (Edge) -> map to same
            43: "Tom_High",       # Tom 3 (Head) - High Floor Tom
            45: "Tom_Mid",        # Tom 2 - Low Tom
            48: "Tom_Low",        # Tom 1 - Hi-Mid Tom
            49: "Crash",          # Crash 1 (Bow)
            55: "Crash",          # Crash 1 (Edge) -> map to same
            57: "Crash",          # Crash 2 (Bow) -> map to same
            52: "Crash",          # Crash 2 (Edge) -> map to same
            51: "Ride",           # Ride (Bow)
            59: "Ride",           # Ride (Edge) -> map to same
            53: "Ride",           # Ride (Bell) -> map to same
        }
        
        # Create vocabulary
        self.vocab = {
            'PAD': 0,
            'SOS': 1,
            'EOS': 2,
            'REST': 3,  # No drum hit
        }
        
        # Add drum types
        unique_drums = set(self.drum_mapping.values())
        for drum in sorted(unique_drums):
            self.vocab[drum] = len(self.vocab)
        
        # Add timing tokens (16th note resolution)
        for i in range(16):  # 16 positions per bar in 4/4
            self.vocab[f'POS_{i}'] = len(self.vocab)
        
        # Add velocity levels
        for v in [32, 48, 64, 80, 96, 112, 127]:  # 7 velocity levels
            self.vocab[f'VEL_{v}'] = len(self.vocab)
        
        self.vocab_size = len(self.vocab)
        self.id_to_token = {v: k for k, v in self.vocab.items()}
        
        print(f"Drum types: {sorted(unique_drums)}")
    
    def midi_to_tokens(self, midi_path: str, bars_per_sequence: int = 2) -> List[int]:
        try:
            score = Score(midi_path)
            tokens = [self.vocab['SOS']]
            
            # Get all drum notes
            notes = []
            for track in score.tracks:
                for note in track.notes:
                    if note.pitch in self.drum_mapping:
                        notes.append({
                            'time': note.time,
                            'pitch': note.pitch,
                            'velocity': note.velocity
                        })
            
            if not notes:
                return [self.vocab['SOS'], self.vocab['EOS']]
            
            notes.sort(key=lambda x: x['time'])
            
            # Convert to 16th note grid
            ticks_per_beat = score.ticks_per_quarter
            ticks_per_16th = ticks_per_beat // 4
            
            # Process in chunks of bars
            max_time = notes[-1]['time']
            bars_in_song = int(max_time / (ticks_per_beat * 4)) + 1
            num_sequences = min(bars_in_song // bars_per_sequence, 8)
            
            for seq in range(num_sequences):
                start_time = seq * bars_per_sequence * 4 * ticks_per_beat
                end_time = (seq + 1) * bars_per_sequence * 4 * ticks_per_beat
                
                # Create 16th note grid for this sequence
                grid_size = bars_per_sequence * 16
                
                for pos in range(grid_size):
                    tokens.append(self.vocab[f'POS_{pos % 16}'])
                    
                    # Find notes at this position
                    pos_time = start_time + pos * ticks_per_16th
                    pos_notes = [n for n in notes 
                               if pos_time <= n['time'] < pos_time + ticks_per_16th]
                    
                    if not pos_notes:
                        tokens.append(self.vocab['REST'])
                    else:
                        for note in pos_notes:
                            drum_type = self.drum_mapping[note['pitch']]
                            tokens.append(self.vocab[drum_type])
                            
                            vel = min([32, 48, 64, 80, 96, 112, 127], 
                                    key=lambda x: abs(x - note['velocity']))
                            tokens.append(self.vocab[f'VEL_{vel}'])
            
            tokens.append(self.vocab['EOS'])
            return tokens
            
        except Exception as e:
            print(f"Error processing {midi_path}: {e}")
            return [self.vocab['SOS'], self.vocab['EOS']]
    
    def tokens_to_midi(self, tokens: List[int], output_path: str, bpm: int = 120):
        midi = MIDIFile(1)
        track = 0
        midi.addTempo(track, 0, bpm)
        midi.addTimeSignature(track, 0, 4, 2, 24)  # 4/4 time signature
        
        # Reverse drum mapping
        unique_mapping = {}
        for pitch, drum in self.drum_mapping.items():
            if drum not in unique_mapping:
                unique_mapping[drum] = pitch
        
        current_beat = 0.0
        current_velocity = 64
        note_duration = 0.25
        current_bar = 0
        beats_per_bar = 4
        
        i = 0
        while i < len(tokens):
            token_id = tokens[i]
            token = self.id_to_token.get(token_id, 'UNK')
            
            if token.startswith('POS_'):
                position = int(token.split('_')[1])
                # Calculate absolute beat position
                beat_in_bar = position / 4.0  # Convert 16th note position to beat
                current_beat = current_bar * beats_per_bar + beat_in_bar
                
            elif token.startswith('VEL_'):
                current_velocity = int(token.split('_')[1])
                
            elif token in unique_mapping:
                # Add drum hit
                pitch = unique_mapping[token]
                midi.addNote(track, 9, pitch, current_beat, note_duration, current_velocity)
                
            elif token == 'REST':
                pass  # No note to add
                
            elif token in ['SOS', 'EOS', 'PAD']:
                if token == 'EOS':
                    break
                    
            # Advance to next bar when we complete 16 positions
            if token.startswith('POS_15'):  # Last position in bar
                current_bar += 1
            
            i += 1
        
        # Add final bar marker and ensure proper song structure
        total_bars = current_bar + 1
        print(f"  Generated composition: {total_bars} bars at {bpm} BPM")
        
        with open(output_path, "wb") as f:
            midi.writeFile(f)

In [41]:
# DRUM-SPECIFIC MODELS
class DrumRNN(nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int = 128, 
                 hidden_dim: int = 256, num_layers: int = 2, dropout: float = 0.3):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, 
                           batch_first=True, dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        
    def forward(self, x, hidden=None):
        embedded = self.embedding(x)
        lstm_out, hidden = self.lstm(embedded, hidden)
        output = self.dropout(lstm_out)
        output = self.fc(output)
        return output, hidden

In [42]:
# MAIN TRAINING AND GENERATION FRAMEWORK
class GrooveMusicGenerator:
    
    def __init__(self, model_type: str = 'rnn', top_styles: List[str] = None):
        self.model_type = model_type
        self.device = torch.device('cpu')
        self.tokenizer = DrumTokenizer()
        self.model = None
        self.top_styles = top_styles
        print(f"Using device: {self.device}")
        if self.top_styles:
            print(f"Filtering training data to only include styles: {self.top_styles}")

    def _infer_style_from_filename(self, file_path: str) -> str:
        """Infer style from filename as fallback"""
        filename_lower = os.path.basename(file_path).lower()
        
        # Common style keywords that might appear in filenames
        style_keywords = {
            'funk': 'funk',
            'rock': 'rock', 
            'jazz': 'jazz',
            'latin': 'latin',
            'afrobeat': 'afrobeat',
            'blues': 'blues',
            'soul': 'soul',
            'hiphop': 'hiphop',
            'hip_hop': 'hiphop',
            'pop': 'pop',
            'reggae': 'reggae',
            'country': 'country',
            'gospel': 'gospel'
        }
        
        for keyword, style in style_keywords.items():
            if keyword in filename_lower:
                return style
        
        return 'unknown'
    
    def prepare_data(self, train_files: List[str], val_files: List[str]):
        """Prepare drum data"""
        print("Tokenizing drum MIDI files...")
        
        self.train_sequences = []
        self.val_sequences = []

        loader = GrooveDatasetLoader()
        df = loader.load_metadata()

        if df is not None and len(df) > 0:
            style_mapping = {}
            for _, row in df.iterrows():
                full_path = os.path.join(loader.data_dir, row['midi_filename'])
                style_mapping[full_path] = row['style']
        else:
            style_mapping = {}

        filtered_train_files = []
        skipped_count = 0
        
        # Process training files
        for file_path in train_files[:100]:
            if self.top_styles:  # Only filter if top_styles is provided
                if file_path in style_mapping:
                    file_style = style_mapping[file_path]
                else:
                    # Fallback: infer style from filename
                    file_style = self._infer_style_from_filename(file_path)
                
                # Check if this style is in our allowed list
                style_match = any(allowed_style.lower() in file_style.lower() 
                                for allowed_style in self.top_styles)
                
                if not style_match:
                    skipped_count += 1
                    continue
            
            filtered_train_files.append(file_path)

        if skipped_count > 0:
            print(f"Skipped {skipped_count} files due to style filtering")
        
        # Process validation files
        for i, file_path in enumerate(filtered_train_files):
            tokens = self.tokenizer.midi_to_tokens(file_path)
            if len(tokens) > 10:
                self.train_sequences.append(tokens)
        
        # Process validation files (you can also filter these if desired)
        for file_path in val_files[:20]:  # Smaller validation set
            tokens = self.tokenizer.midi_to_tokens(file_path)
            if len(tokens) > 10:
                self.val_sequences.append(tokens)
        
        print(f"Prepared {len(self.train_sequences)} training and {len(self.val_sequences)} validation sequences")
        
        # Print style summary
        if self.top_styles:
            print(f"Training data filtered to include only styles: {self.top_styles}")
    
    def create_model(self):
        if self.model_type == 'rnn':
            self.model = DrumRNN(
                vocab_size=self.tokenizer.vocab_size,
                embedding_dim=128,
                hidden_dim=256,
                num_layers=2
            ).to(self.device)
            self.criterion = nn.CrossEntropyLoss(ignore_index=self.tokenizer.vocab['PAD'])
            self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
        
        print(f"Created {self.model_type} model")
    
    def train_model(self, epochs: int = 10):
        if self.model_type == 'rnn':
            print(f"Training RNN for {epochs} epochs...")
            
            # Create data loader
            class DrumDataset(Dataset):
                def __init__(self, sequences, seq_len=64):
                    self.data = []
                    for seq in sequences:
                        for i in range(0, len(seq) - seq_len, seq_len // 2):
                            self.data.append(seq[i:i + seq_len + 1])
                
                def __len__(self):
                    return len(self.data)
                
                def __getitem__(self, idx):
                    seq = self.data[idx]
                    return torch.tensor(seq[:-1]), torch.tensor(seq[1:])
            
            train_dataset = DrumDataset(self.train_sequences)
            train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
            
            self.model.train()
            for epoch in range(epochs):
                total_loss = 0
                for batch_idx, (data, target) in enumerate(train_loader):
                    data, target = data.to(self.device), target.to(self.device)
                    
                    self.optimizer.zero_grad()
                    output, _ = self.model(data)
                    loss = self.criterion(output.reshape(-1, self.tokenizer.vocab_size), 
                                        target.reshape(-1))
                    loss.backward()
                    self.optimizer.step()
                    
                    total_loss += loss.item()
                
                avg_loss = total_loss / len(train_loader)
                print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
    
    def generate_drum_patterns(self, num_patterns: int = 1, pattern_length: int = 300) -> List[str]:
        print(f"Generating {num_patterns} extended drum pattern(s) of {pattern_length} beats...")
        
        generated_files = []
        
        for i in range(num_patterns):
            print(f"Generating pattern {i+1}/{num_patterns} with {self.model_type} model...")

            self.model.eval()
            tokens = [self.tokenizer.vocab['SOS']]
            
            with torch.no_grad():
                for step in range(pattern_length * 4):
                    if step % 100 == 0:
                        print(f"    Generated {step}/{pattern_length * 4} tokens...")
                    
                    input_seq = torch.tensor([tokens[-min(len(tokens), 100):]], dtype=torch.long).to(self.device)
                    output, _ = self.model(input_seq)
                    logits = output[0, -1, :]
                    
                    temperature = 0.8
                    logits = logits / temperature
                    
                    top_k = 20
                    top_k_logits, top_k_indices = torch.topk(logits, top_k)
                    probs = F.softmax(top_k_logits, dim=-1)
                    
                    try:
                        next_token_idx = torch.multinomial(probs, 1).item()
                        next_token = top_k_indices[next_token_idx].item()
                    except:
                        next_token = torch.argmax(logits).item()
                    
                    if next_token == self.tokenizer.vocab['EOS']:
                        print(f"Reached EOS token at step {step}")
                        break
                    
                    tokens.append(next_token)
            
            # Convert to MIDI with extended duration
            filename = f"task1_{self.model_type}_300beats.mid"
            try:
                # Set the tempo
                self.tokenizer.tokens_to_midi(tokens, filename, bpm=120)  
                generated_files.append(filename)
                print(f"Generated: {filename} ({len(tokens)} tokens)")
                
                # Calculate approximate duration
                duration_minutes = (pattern_length / 100) * (60 / 100)
                print(f"  Estimated duration: ~{duration_minutes:.1f} minutes")
                
            except Exception as e:
                print(f"Error generating pattern {i+1}: {e}")
        
        return generated_files

In [43]:
# MAIN EXECUTION

def main():
    print("=== Task 1: Groove MIDI Dataset - Drum Pattern Generation ===\n")
    
    loader = GrooveDatasetLoader()
    loader.download_dataset()
    loader.analyze_dataset()
    top_styles = ['rock', 'hiphop', 'funk', 'neworleans/funk'] 
    
    train_files, val_files, test_files = loader.get_file_lists()
    
    if not train_files:
        print("No training files found. Please check dataset download.")
        return
    
    models_to_test = [
        {'type': 'rnn', 'name': 'Drum RNN'}
    ]
    
    results = {}
    
    for model_config in models_to_test:
        print(f"\n{'-'*40}")
        print(f"Training: {model_config['name']}")
        print(f"{'-'*40}")
        
        try:
            # Create generator
            generator = GrooveMusicGenerator(model_type='rnn', top_styles=top_styles)
            generator.prepare_data(train_files, val_files)
            
            # Create and train model  
            generator.create_model()
            generator.train_model(epochs=20)

            # Generate extended patterns (300 beats each)
            generated_files = generator.generate_drum_patterns(num_patterns=1, pattern_length=300)
            
            results[model_config['name']] = {
                'generated_files': generated_files,
                'model_type': model_config['type']
            }
            
        except Exception as e:
            print(f"Error with {model_config['name']}: {e}")
            continue

if __name__ == "__main__":
    main()

=== Task 1: Groove MIDI Dataset - Drum Pattern Generation ===


---GROOVE DATASET ANALYSIS---
Total tracks: 1150
Total duration: 814.9 minutes
Average duration: 42.5 seconds
Beat types: {'fill': 647, 'beat': 503}
Top styles:
  rock: 281
  hiphop: 91
  funk: 77
  punk: 58
  neworleans/funk: 48
  jazz: 46
  rock/halftime: 37
  latin/brazilian-baiao: 32
  soul: 31
  funk/purdieshuffle: 30
Number of drummers: 10

----------------------------------------
Training: Drum RNN
----------------------------------------
Drum types: ['Bass', 'Crash', 'HH_Closed', 'HH_Open', 'Ride', 'Snare', 'Tom_High', 'Tom_Low', 'Tom_Mid']
Using device: cpu
Filtering training data to only include styles: ['rock', 'hiphop', 'funk', 'neworleans/funk']
Tokenizing drum MIDI files...
Skipped 34 files due to style filtering
Prepared 42 training and 15 validation sequences
Training data filtered to include only styles: ['rock', 'hiphop', 'funk', 'neworleans/funk']
Created rnn model
Training RNN for 20 epochs...
Epoch 1/2