In [None]:
"""
Complete MIDI to ABC Preprocessing Pipeline for Google Colab
CS-GY 6923 Scaling Laws Project

Implements:
Incremental storage inside processed_music_data/raw_abc.txt
Checkpoint tracks processed files
Multiple runs append new data without overwriting
Final dataset creation is separate
"""

import os
import json
import tempfile
from pathlib import Path
from collections import Counter
from typing import List, Dict, Optional
import numpy as np
from tqdm import tqdm
from datetime import datetime
import time
from music21 import converter


# =============================================================================
# PREPROCESSING PIPELINE
# =============================================================================

class MusicDataPreprocessor:
    """MIDI → ABC converter WITH INCREMENTAL STORAGE"""

    def __init__(self, midi_dir: str, output_dir: str, min_length: int, max_length: int):
        self.midi_dir = Path(midi_dir)
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)

        # NEW: Raw storage file for incremental writes
        self.raw_output_file = self.output_dir / "raw_abc.txt"

        self.min_length = min_length
        self.max_length = max_length

        self.stats = {
            'total_files': 0,
            'successful_conversions': 0,
            'failed_conversions': 0,
            'too_short': 0,
            'truncated': 0,
            'empty_output': 0,
            'total_tokens': 0,
            'processing_time': 0
        }

        # Checkpoint file
        self.checkpoint_file = self.output_dir / "checkpoint.json"
        self.processed_files = self._load_checkpoint()

    def _load_checkpoint(self) -> set:
        if self.checkpoint_file.exists():
            with open(self.checkpoint_file, 'r') as f:
                data = json.load(f)
                processed = set(data.get('processed_files', []))
                if processed:
                    print(f"Loaded checkpoint: {len(processed)} files already processed")
                return processed
        return set()

    def _save_checkpoint(self):
        try:
            with open(self.checkpoint_file, 'w') as f:
                json.dump({
                    'processed_files': list(self.processed_files),
                    'timestamp': datetime.now().isoformat(),
                    'stats': self.stats
                }, f)
        except Exception as e:
            print(f"Could not save checkpoint: {e}")

    def midi_to_abc(self, midi_path: Path) -> Optional[str]:
        try:
            score = converter.parse(str(midi_path))
            
            # Create temporary ABC file
            with tempfile.NamedTemporaryFile(mode='w', suffix='.abc',
                                            delete=False, encoding='utf-8') as tmp:
                tmp_path = tmp.name
            
            try:
                # Write score to ABC format
                score.write('abc', fp=tmp_path)
                
                # Read back the ABC text
                with open(tmp_path, 'r', encoding='utf-8', errors='ignore') as f:
                    abc_text = f.read()
                
                # CRITICAL FIX: Validate it's actual ABC notation, not object representation
                if abc_text.strip().startswith('<music21.') or 'object at 0x' in abc_text:
                    return None  # Invalid conversion, reject it
                
                # Additional validation: check for ABC headers
                if not any(header in abc_text for header in ['X:', 'T:', 'M:', 'K:']):
                    return None  # Doesn't look like valid ABC notation
                    
                return abc_text
                
            finally:
                if os.path.exists(tmp_path):
                    os.unlink(tmp_path)
                    
        except Exception as e:
            print(f"Error converting {midi_path}: {e}")
            return None

    def process_single_file(self, midi_path: Path) -> Optional[str]:
        if str(midi_path) in self.processed_files:
            return None

        abc_string = self.midi_to_abc(midi_path)

        if abc_string is None:
            self.stats['failed_conversions'] += 1
            return None
        if len(abc_string.strip()) == 0:
            self.stats['empty_output'] += 1
            return None
        if len(abc_string) < self.min_length:
            self.stats['too_short'] += 1
            return None
        if len(abc_string) > self.max_length:
            abc_string = abc_string[:self.max_length]
            self.stats['truncated'] += 1

        self.stats['successful_conversions'] += 1
        self.stats['total_tokens'] += len(abc_string)

        # Mark as processed
        self.processed_files.add(str(midi_path))

        return abc_string

    def process_dataset(self, max_files: Optional[int], save_checkpoint_every: int):
        start = time.time()

        print("Finding MIDI files...")
        midi_files = list(self.midi_dir.rglob("*.mid")) + list(self.midi_dir.rglob("*.midi"))

        if max_files:
            midi_files = midi_files[:max_files]

        midi_files = [f for f in midi_files if str(f) not in self.processed_files]

        print(f"Processing {len(midi_files)} new files")

        # OPEN RAW FILE IN APPEND MODE
        with open(self.raw_output_file, 'a', encoding='utf-8') as raw_f:

            for i, midi_file in enumerate(tqdm(midi_files)):
                abc = self.process_single_file(midi_file)

                if abc is not None:
                    # Write incrementally
                    raw_f.write(abc + "\n<|endoftext|>\n")

                # Save checkpoint periodically
                if (i + 1) % save_checkpoint_every == 0:
                    self._save_checkpoint()

        self._save_checkpoint()

        self.stats['processing_time'] += (time.time() - start)

        print("\nIncremental processing complete!")
        print(f"Raw ABC stored at: {self.raw_output_file}")


# =============================================================================
# CONFIG
# =============================================================================

MIDI_DIR = "/content/drive/MyDrive/lmd_full/lmd_full"
OUTPUT_DIR = "/content/drive/MyDrive/processed_music_data"

MAX_FILES = 5000
MIN_LENGTH = 10
MAX_LENGTH = 4096
TOKENIZATION = "character"

TRAIN_RATIO = 0.98
VAL_RATIO = 0.01
SEED = 42

SAVE_CHECKPOINT_EVERY = 1000


# =============================================================================
# MAIN
# =============================================================================

def main():
    print("\n=== INCREMENTAL MIDI → ABC PREPROCESSOR ===\n")

    pre = MusicDataPreprocessor(
        midi_dir=MIDI_DIR,
        output_dir=OUTPUT_DIR,
        min_length=MIN_LENGTH,
        max_length=MAX_LENGTH
    )

    pre.process_dataset(
        max_files=MAX_FILES,
        save_checkpoint_every=SAVE_CHECKPOINT_EVERY
    )

    print("\nRun finalize_dataset(...) when you finish all runs.\n")


main()

In [None]:
# =============================================================================
# FINAL SPLIT + VOCAB CREATION
# =============================================================================

import os
import json
import tempfile
from pathlib import Path
from collections import Counter
from typing import List, Dict, Optional
import numpy as np
from tqdm import tqdm
from datetime import datetime
import time

def finalize_dataset(output_dir: str, train_ratio: float, val_ratio: float, tokenization: str):
    output_dir = Path(output_dir)
    raw_file = output_dir / "raw_abc.txt"

    print("Loading full raw dataset...")
    with open(raw_file, 'r', encoding='utf-8') as f:
        # samples = f.read().split("<|endoftext|>")
        samples = f.read()

    # samples = [s.strip() for s in samples if len(s.strip()) > 0]
    samples = [s.strip() for s in samples]

    print(f"Total samples collected: {len(samples):,}")

    # Shuffle
    np.random.shuffle(samples)

    n = len(samples)
    n_train = int(n * train_ratio)
    n_val = int(n * val_ratio)

    train = samples[:n_train]
    val = samples[n_train:n_train+n_val]
    test = samples[n_train+n_val:]

    def save_split(name, data):
        with open(output_dir / f"{name}.txt", 'w', encoding='utf-8') as f:
            for s in data:
                # f.write(s + "\n<|endoftext|>\n")
                f.write(s)
        print(f"{name}: {len(data):,} samples")

    save_split("train", train)
    save_split("val", val)
    save_split("test", test)

    # Build vocab
    counter = Counter()
    for s in train:
        counter.update(s)

    vocab = ['<PAD>', '<UNK>', '<BOS>', '<EOS>'] + list(counter.keys())

    with open(output_dir / "vocab.json", 'w') as f:
        json.dump(vocab, f, indent=2)

    print(f"Vocab size: {len(vocab)}")

    # Calculate comprehensive statistics
    print("\nCalculating statistics...")

    # Sequence lengths for all samples
    all_lengths = [len(s) for s in samples]
    train_lengths = [len(s) for s in train]
    val_lengths = [len(s) for s in val]
    test_lengths = [len(s) for s in test]

    # Token counts
    train_tokens = sum(train_lengths)
    val_tokens = sum(val_lengths)
    test_tokens = sum(test_lengths)
    total_tokens = sum(all_lengths)

    # Build stats dictionary
    stats = {
        "dataset_info": {
            "total_samples": len(samples),
            "total_tokens": total_tokens,
            "tokenization": tokenization,
            "vocab_size": len(vocab),
            "timestamp": datetime.now().isoformat()
        },
        "splits": {
            "train": {
                "num_samples": len(train),
                "num_tokens": train_tokens,
                "proportion": train_ratio
            },
            "val": {
                "num_samples": len(val),
                "num_tokens": val_tokens,
                "proportion": val_ratio
            },
            "test": {
                "num_samples": len(test),
                "num_tokens": test_tokens,
                "proportion": 1.0 - train_ratio - val_ratio
            }
        },
        "sequence_length_statistics": {
            "overall": {
                "min": int(np.min(all_lengths)),
                "max": int(np.max(all_lengths)),
                "mean": float(np.mean(all_lengths)),
                "median": float(np.median(all_lengths)),
                "std": float(np.std(all_lengths)),
                "percentiles": {
                    "25th": float(np.percentile(all_lengths, 25)),
                    "50th": float(np.percentile(all_lengths, 50)),
                    "75th": float(np.percentile(all_lengths, 75)),
                    "90th": float(np.percentile(all_lengths, 90)),
                    "95th": float(np.percentile(all_lengths, 95)),
                    "99th": float(np.percentile(all_lengths, 99))
                }
            },
            "train": {
                "min": int(np.min(train_lengths)),
                "max": int(np.max(train_lengths)),
                "mean": float(np.mean(train_lengths)),
                "median": float(np.median(train_lengths))
            },
            "val": {
                "min": int(np.min(val_lengths)),
                "max": int(np.max(val_lengths)),
                "mean": float(np.mean(val_lengths)),
                "median": float(np.median(val_lengths))
            },
            "test": {
                "min": int(np.min(test_lengths)),
                "max": int(np.max(test_lengths)),
                "mean": float(np.mean(test_lengths)),
                "median": float(np.median(test_lengths))
            }
        },
        "vocabulary_statistics": {
            "total_unique_chars": len(vocab) - 4,  # Excluding special tokens
            "special_tokens": ['<PAD>', '<UNK>', '<BOS>', '<EOS>'],
            "most_common_chars": [
                {"char": char, "count": int(count)}
                for char, count in counter.most_common(20)
            ]
        },
        "data_quality": {
            "empty_samples_filtered": 0,  # Adjust if you track this
            "samples_with_special_chars": sum(
                1 for s in samples if any(c in s for c in ['<', '>', '|'])
            ),
            "average_unique_chars_per_sample": float(
                np.mean([len(set(s)) for s in samples])
            )
        }
    }

    # Save stats
    stats_file = output_dir / "stats.json"
    with open(stats_file, 'w') as f:
        json.dump(stats, f, indent=2)

    print(f"\n{'='*70}")
    print("STATISTICS SUMMARY")
    print(f"{'='*70}")
    print(f"Total samples:        {stats['dataset_info']['total_samples']:,}")
    print(f"Total tokens:         {stats['dataset_info']['total_tokens']:,}")
    print(f"Vocabulary size:      {stats['dataset_info']['vocab_size']:,}")
    print(f"\nTrain samples:        {stats['splits']['train']['num_samples']:,}")
    print(f"Train tokens:         {stats['splits']['train']['num_tokens']:,}")
    print(f"Val samples:          {stats['splits']['val']['num_samples']:,}")
    print(f"Val tokens:           {stats['splits']['val']['num_tokens']:,}")
    print(f"Test samples:         {stats['splits']['test']['num_samples']:,}")
    print(f"Test tokens:          {stats['splits']['test']['num_tokens']:,}")
    print(f"\nMean sequence length: {stats['sequence_length_statistics']['overall']['mean']:.0f}")
    print(f"Median sequence length: {stats['sequence_length_statistics']['overall']['median']:.0f}")
    print(f"{'='*70}")
    print(f"Saved statistics to {stats_file}")
    print("\nFINAL DATASET READY.")

OUTPUT_DIR = "/content/drive/MyDrive/processed_music_data"

finalize_dataset(
    output_dir=OUTPUT_DIR,
    train_ratio=0.98,
    val_ratio=0.01,
    tokenization="character"
)