# Baseline Model Testing: SkyTNT MIDI Model

This notebook tests the SkyTNT MIDI model on M1 Pro:
1. Load pre-trained model from HuggingFace
2. Test inference and generation
3. Benchmark performance (speed, memory usage)
4. Generate sample MIDI files
5. Assess output quality

**Hardware**: Apple M1 Pro with MPS (Metal Performance Shaders) support

In [None]:
import sys
from pathlib import Path

sys.path.append(str(Path.cwd().parent))

import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
from datetime import datetime

print(f"PyTorch version: {torch.__version__}")
print(f"MPS available: {torch.backends.mps.is_available()}")
print(f"MPS built: {torch.backends.mps.is_built()}")

## 1. Device Setup

Configure device for M1 Pro GPU (MPS).

In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS (Metal) device for GPU acceleration")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA device")
else:
    device = torch.device("cpu")
    print("Using CPU device")

print(f"\nDevice: {device}")

## 2. Load SkyTNT MIDI Model

Load the pre-trained model from HuggingFace.

In [None]:
model_name = "skytnt/midi-model"

print(f"Loading model: {model_name}\n")

try:
    start_time = time.time()
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float32
    )
    
    model = model.to(device)
    model.eval()
    
    load_time = time.time() - start_time
    
    print(f"✓ Model loaded successfully in {load_time:.2f}s")
    print(f"\nModel info:")
    print(f"  Type: {type(model).__name__}")
    print(f"  Device: {next(model.parameters()).device}")
    print(f"  Parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"  Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
    
except Exception as e:
    print(f"✗ Error loading model: {e}")
    print("\nNote: SkyTNT model may require specific setup or may not be available.")
    print("Alternative: We may need to use a different model or train from scratch.")

## 3. Memory Profiling

Check memory usage with the model loaded.

In [None]:
if 'model' in locals():
    model_memory_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
    print(f"Model memory usage: {model_memory_mb:.2f} MB")
    
    if device.type == "mps":
        print("\nNote: MPS memory stats not directly available")
        print("Estimated available for training: ~12-13 GB on M1 Pro 16GB")

## 4. Test Generation

Generate a sample MIDI sequence and benchmark performance.

In [None]:
if 'model' in locals() and 'tokenizer' in locals():
    print("Testing generation...\n")
    
    prompt = "" 
    max_length = 512
    
    try:
        start_time = time.time()
        
        if prompt:
            inputs = tokenizer(prompt, return_tensors="pt").to(device)
        else:
            inputs = {"input_ids": torch.tensor([[tokenizer.bos_token_id]]).to(device)}
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_length=max_length,
                num_return_sequences=1,
                do_sample=True,
                temperature=1.0,
                top_k=50,
                top_p=0.95
            )
        
        generation_time = time.time() - start_time
        
        generated_tokens = outputs[0].cpu().numpy()
        num_tokens = len(generated_tokens)
        tokens_per_second = num_tokens / generation_time
        
        print(f"✓ Generation successful!")
        print(f"  Tokens generated: {num_tokens}")
        print(f"  Time: {generation_time:.2f}s")
        print(f"  Speed: {tokens_per_second:.2f} tokens/second")
        
        print(f"\nGenerated token sequence (first 50):")
        print(generated_tokens[:50])
        
    except Exception as e:
        print(f"✗ Generation failed: {e}")
        import traceback
        traceback.print_exc()

## 5. Convert to MIDI (if applicable)

Attempt to convert generated tokens to MIDI file.

In [None]:
if 'generated_tokens' in locals():
    print("Token-to-MIDI conversion depends on model's tokenization scheme.")
    print("This step may require model-specific decoding logic.")
    print("\nNext steps:")
    print("1. Study SkyTNT's tokenization format")
    print("2. Implement custom decoder")
    print("3. Save as .mid file for playback")

## 6. Benchmark Summary

Summarize M1 Pro performance.

In [None]:
print("=" * 60)
print("M1 Pro Benchmark Summary")
print("=" * 60)

if 'load_time' in locals():
    print(f"\n✓ Model loading: {load_time:.2f}s")

if 'model_memory_mb' in locals():
    print(f"✓ Model memory: {model_memory_mb:.2f} MB")

if 'tokens_per_second' in locals():
    print(f"✓ Generation speed: {tokens_per_second:.2f} tokens/s")
    print(f"✓ Est. time for 2000 tokens: {2000/tokens_per_second:.1f}s")

print("\n" + "=" * 60)
print("Assessment:")
print("=" * 60)

if 'model' in locals():
    print("\n✓ M1 Pro can load and run the model")
    print("✓ MPS acceleration is working")
    print("\nNext steps:")
    print("  1. Test fine-tuning on small dataset")
    print("  2. Implement MIDI tokenization/detokenization")
    print("  3. Evaluate generation quality")
    print("  4. Consider LoRA for efficient fine-tuning")
else:
    print("\n✗ Model loading failed")
    print("\nAlternative approaches:")
    print("  1. Try different model from HuggingFace")
    print("  2. Adapt GPT-2 for music generation")
    print("  3. Train custom transformer from scratch")

## 7. Next Steps

Based on the results:

**If successful**:
- Implement MIDI tokenization/detokenization
- Fine-tune on Bach dataset
- Evaluate generated music quality
- Optimize for M1 Pro (gradient accumulation, LoRA)

**If model unavailable**:
- Search for alternative pre-trained music models
- Consider GPT-2 with custom MIDI tokenizer
- Build custom transformer architecture

**Performance optimization**:
- Test mixed precision (fp16)
- Implement gradient checkpointing
- Try LoRA adapters for fine-tuning