In [11]:
from pathlib import Path

import torch

from yoshimidi import inference, player
from yoshimidi.data.parse.tracks import Channel, Note
from yoshimidi.output_config import OutputConfig
from yoshimidi.train import checkpoints
from yoshimidi.train.transformer import Transformer
from yoshimidi.train.transformer_config import TransformerConfig

In [57]:
model = Transformer(
    TransformerConfig(
        num_layers=6,
        residual_stream_size=512,
        num_attention_heads=16,
        context_window=1024,
    )
)
optimizer = torch.optim.Adam(model.parameters())
model, optimizer = checkpoints.load_checkpoint(
    "2023-10-22_v1",
    step="latest",
    model=model,
    optimizer=optimizer,
    output_config=OutputConfig(checkpoints=Path("../out/checkpoints")),
    device=torch.device("cpu"),
)

[32m2023-10-22 16:20:28.497[0m | [1mINFO    [0m | [36myoshimidi.train.checkpoints[0m:[36mload_checkpoint[0m:[36m48[0m - [1mLoading checkpoint: ../out/checkpoints/2023-10-22_v1/step_006500.pt[0m


In [62]:
channel = Channel(
    notes=[
        Note(note=60, kind="on", velocity=127, time_delta_secs=0),
        Note(note=62, kind="on", velocity=127, time_delta_secs=0),
        Note(note=64, kind="on", velocity=127, time_delta_secs=0.2),
        Note(note=60, kind="off", velocity=127, time_delta_secs=0),
        Note(note=62, kind="off", velocity=127, time_delta_secs=0),
        Note(note=64, kind="off", velocity=127, time_delta_secs=0.2),
        Note(note=61, kind="on", velocity=127, time_delta_secs=0),
        Note(note=63, kind="on", velocity=127, time_delta_secs=0),
        Note(note=65, kind="on", velocity=127, time_delta_secs=0.2),
        Note(note=60, kind="off", velocity=127, time_delta_secs=0),
        Note(note=62, kind="off", velocity=127, time_delta_secs=0),
        Note(note=64, kind="off", velocity=127, time_delta_secs=0.2),
        Note(note=62, kind="on", velocity=127, time_delta_secs=0),
        Note(note=64, kind="on", velocity=127, time_delta_secs=0),
        Note(note=66, kind="on", velocity=127, time_delta_secs=0.2),
    ],
    program_nums=[],
)

notes = inference.run_inference(
    model,
    prompt=channel,
    max_new_tokens=256,
    temperature=0.7,
    device=torch.device("cpu"),
    dtype=torch.float32,
)
channel.notes.extend(notes)

Generating tokens:   0%|          | 0/256 [00:00<?, ?it/s]

Generating tokens: 100%|█████████▉| 255/256 [00:07<00:00, 33.59it/s]


In [63]:
player.play_channel(channel)

[32m2023-10-22 16:22:14.873[0m | [1mINFO    [0m | [36myoshimidi.player[0m:[36mplay[0m:[36m12[0m - [1mPlaying /var/folders/7w/66fh_d3s5hb0br9f7wtqww1h0000gn/T/tmpkvoywtab[0m


[32m2023-10-22 16:22:19.693[0m | [1mINFO    [0m | [36myoshimidi.player[0m:[36mplay[0m:[36m17[0m - [1mFinished playing[0m


In [None]:
channel.notes

[Note(note=60, kind='on', velocity=127, time_delta_secs=0),
 Note(note=62, kind='on', velocity=127, time_delta_secs=0),
 Note(note=64, kind='on', velocity=127, time_delta_secs=0.2),
 Note(note=60, kind='off', velocity=127, time_delta_secs=0),
 Note(note=62, kind='off', velocity=127, time_delta_secs=0),
 Note(note=64, kind='off', velocity=127, time_delta_secs=0.2),
 Note(note=61, kind='on', velocity=127, time_delta_secs=0),
 Note(note=63, kind='on', velocity=127, time_delta_secs=0),
 Note(note=65, kind='on', velocity=127, time_delta_secs=0.2),
 Note(note=60, kind='off', velocity=127, time_delta_secs=0),
 Note(note=62, kind='off', velocity=127, time_delta_secs=0),
 Note(note=64, kind='off', velocity=127, time_delta_secs=0.2),
 Note(note=62, kind='on', velocity=127, time_delta_secs=0),
 Note(note=64, kind='on', velocity=127, time_delta_secs=0),
 Note(note=66, kind='on', velocity=127, time_delta_secs=0.2),
 Note(note=64, kind='off', velocity=127, time_delta_secs=0),
 Note(note=64, kind='of