In [None]:
'''
import os, pickle, torch

os.makedirs("trained_models", exist_ok=True)

# Save the trained model (model is updated in-place by train_midi_text_transformer)
torch.save(model, "trained_models/mozart_haydn_transformer.pt")

# Save the vocabulary builder (vb)
with open("trained_models/mozart_haydn_vocab.pkl", "wb") as f:
    pickle.dump(vb, f)

print("Saved model and vocab to trained_models/")
'''
#Maybe need to be placed in the mozart_haydn_text_transformer.ipynb to save the current model and vocab

In [2]:
%pip install mido

Collecting mido
  Using cached mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Using cached mido-1.3.3-py3-none-any.whl (54 kB)
Installing collected packages: mido
Successfully installed mido-1.3.3
Note: you may need to restart the kernel to use updated packages.


In [1]:
import ipywidgets as widgets
from IPython.display import display, Audio, clear_output
import tempfile
import mido
import torch
import numpy as np
import pickle

from midi_conversion import midi_to_text, text_to_midi
from models import generate_midi_tokens_with_transformer
from data_preprocessing import SEQ_SOS, SEQ_EOS

# Load trained model
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
    # Try to load trained Mozart+Haydn model
    model = torch.load("trained_models/mozart_haydn_transformer.pt", map_location=device)
    model.eval()

    # Load vocab
    with open("trained_models/mozart_haydn_vocab.pkl", "rb") as f:
        vb = pickle.load(f)

    SOS_ID = vb.stoi[SEQ_SOS]
    EOS_ID = vb.stoi[SEQ_EOS]

    print("Mozart+Haydn model & vocab loaded successfully!")

except FileNotFoundError:
    MODEL_AVAILABLE = False
    model = None
    vb = None
    SOS_ID = None
    EOS_ID = None
    print("Trained model/vocab not found.")
    print("Frontend will run in DEMO MODE: it will just play back your uploaded MIDI.")


Trained model/vocab not found.
Frontend will run in DEMO MODE: it will just play back your uploaded MIDI.


In [8]:
%pip install pretty_midi
%pip install soundfile

Collecting pretty_midi
  Using cached pretty_midi-0.2.11.tar.gz (5.6 MB)
  Preparing metadata (setup.py) ... [?25ldone
Collecting importlib_resources (from pretty_midi)
  Using cached importlib_resources-6.5.2-py3-none-any.whl.metadata (3.9 kB)
Using cached importlib_resources-6.5.2-py3-none-any.whl (37 kB)
Building wheels for collected packages: pretty_midi
  Building wheel for pretty_midi (setup.py) ... [?25ldone
[?25h  Created wheel for pretty_midi: filename=pretty_midi-0.2.11-py3-none-any.whl size=5595887 sha256=c472870b6c4da1260c0e1a6ebc59888bb0cf4527c50e57e4d3d011e5dbb8229b
  Stored in directory: /Users/johanninoespino/Library/Caches/pip/wheels/f4/ad/93/a7042fe12668827574927ade9deec7f29aad2a1001b1501882
Successfully built pretty_midi
Installing collected packages: importlib_resources, pretty_midi
Successfully installed importlib_resources-6.5.2 pretty_midi-0.2.11
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use u

In [2]:
import pretty_midi
import soundfile as sf

def midi_to_wav_local(midi_path, wav_path=None):
    """
    Converts a MIDI file into a WAV file using pretty_midi.
    Returns the WAV file path for playback.
    """
    if wav_path is None:
        wav_path = midi_path.replace(".mid", ".wav")

    pm = pretty_midi.PrettyMIDI(midi_path)
    audio = pm.fluidsynth()  

    sf.write(wav_path, audio, 44100)
    return wav_path


In [None]:
file_uploader = widgets.FileUpload(
    accept='.mid',
    multiple=False,
    description='Upload Seed MIDI'
)

max_tokens_slider = widgets.IntSlider(
    value=2000,
    min=100,
    max=8000,
    step=100,
    description='Max Tokens:',
    continuous_update=False
)

generate_button = widgets.Button(
    description="Generate MIDI",
    button_style='success'
)

output_box = widgets.Output()

In [19]:
%pip install pyfluidsynth

Collecting pyfluidsynth
  Downloading pyfluidsynth-1.3.4-py3-none-any.whl.metadata (7.5 kB)
Downloading pyfluidsynth-1.3.4-py3-none-any.whl (22 kB)
Installing collected packages: pyfluidsynth
Successfully installed pyfluidsynth-1.3.4
Note: you may need to restart the kernel to use updated packages.


In [4]:
def on_generate_clicked(b):
    with output_box:
        clear_output()

        if len(file_uploader.value) == 0:
            print("Please upload a seed MIDI file.")
            return

        # Extract uploaded file
        uploaded = file_uploader.value[0]
        midi_bytes = uploaded['content']

        # Save to temp MIDI
        temp_midi = tempfile.NamedTemporaryFile(delete=False, suffix=".mid")
        temp_midi.write(midi_bytes)
        temp_midi.flush()

        # If we DON'T have a trained model yet â†’ demo mode
        if not MODEL_AVAILABLE:
            print("ðŸ’¡ DEMO MODE: Trained model not found yet.")
            print("Showing the UI and playing back your uploaded MIDI as a preview.\n")

            wav_path = midi_to_wav_local(temp_midi.name)
            display(Audio(wav_path, rate=44100))
            return

        midi_obj = mido.MidiFile(temp_midi.name)

        # Convert to Transformer tokens
        core_text = midi_to_text(midi_obj)
        seed_text = f"{SEQ_SOS} {core_text} {SEQ_EOS}"
        seed_tokens = vb.encode(seed_text)[:512]

        # Transformer generation
        ids = generate_midi_tokens_with_transformer(
            model,
            sos_id=SOS_ID,
            eos_id=EOS_ID,
            start_tokens=seed_tokens,
            max_new_tokens=max_tokens_slider.value
        )

        generated_text = vb.decode(ids)
        generated_midi = text_to_midi(generated_text)

        # Save generated MIDI
        out_midi_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mid").name
        generated_midi.save(out_midi_path)

        # Convert to WAV
        wav_path = midi_to_wav_local(out_midi_path)

        print("ðŸŽ‰ Generation complete!")
        print("MIDI saved to:", out_midi_path)
        print("Playing audio:")

        display(Audio(wav_path, rate=44100))


In [5]:
generate_button.on_click(on_generate_clicked)

display(widgets.VBox([
    widgets.HTML("<h2>ðŸŽ¼ Transformer Classical Composer UI</h2>"),
    file_uploader,
    max_tokens_slider,
    generate_button,
    output_box
]))

VBox(children=(HTML(value='<h2>ðŸŽ¼ Transformer Classical Composer UI</h2>'), FileUpload(value=(), accept='.mid',â€¦