In [3]:
import torch

from yoshimidi import inference, player
from yoshimidi.data.parse.tracks import Channel, Note
from yoshimidi.output_config import OutputConfig
from yoshimidi.train import checkpoints
from yoshimidi.train.transformer import Transformer
from yoshimidi.train.transformer_config import TransformerConfig

pygame 2.5.0 (SDL 2.28.0, Python 3.10.11)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [4]:
model = Transformer(
    TransformerConfig(
        num_layers=6,
        residual_stream_size=512,
        num_attention_heads=16,
        context_window=1024,
    )
)
optimizer = torch.optim.Adam(model.parameters())
model, optimizer = checkpoints.load_checkpoint(
    "2023-08-03_v1",
    step="latest",
    model=model,
    optimizer=optimizer,
    output_config=OutputConfig(checkpoints="../out/checkpoints"),
    device=torch.device("cpu"),
)

[32m2023-08-03 17:18:55.613[0m | [1mINFO    [0m | [36myoshimidi.train.checkpoints[0m:[36mload_checkpoint[0m:[36m48[0m - [1mLoading checkpoint: ../out/checkpoints/2023-08-03_v1/step_005500.pt[0m


In [12]:
channel = Channel(
    notes=[
        Note(note=60, kind="on", velocity=127, time_delta_secs=0),
        Note(note=62, kind="on", velocity=127, time_delta_secs=0),
        Note(note=64, kind="on", velocity=127, time_delta_secs=0.2),
        Note(note=60, kind="off", velocity=127, time_delta_secs=0),
        Note(note=62, kind="off", velocity=127, time_delta_secs=0),
        Note(note=64, kind="off", velocity=127, time_delta_secs=0.2),
        Note(note=61, kind="on", velocity=127, time_delta_secs=0),
        Note(note=63, kind="on", velocity=127, time_delta_secs=0),
        Note(note=65, kind="on", velocity=127, time_delta_secs=0.2),
        Note(note=60, kind="off", velocity=127, time_delta_secs=0),
        Note(note=62, kind="off", velocity=127, time_delta_secs=0),
        Note(note=64, kind="off", velocity=127, time_delta_secs=0.2),
        Note(note=62, kind="on", velocity=127, time_delta_secs=0),
        Note(note=64, kind="on", velocity=127, time_delta_secs=0),
        Note(note=66, kind="on", velocity=127, time_delta_secs=0.2),
    ],
    program_nums=[],
)

notes = inference.run_inference(
    model,
    prompt=channel,
    max_new_tokens=128,
    temperature=0.1,
    device=torch.device("cpu"),
    dtype=torch.float32,
)
channel.notes.extend(notes)

Generating tokens: 6it [00:00, 53.15it/s]

off
off
off
pause
on
on
on
on
on
on
off


Generating tokens: 18it [00:00, 54.68it/s]

off
on
on
on
on
off
off
on
on
on
on
on


Generating tokens: 30it [00:00, 53.43it/s]

on
on
on
pause
off
off
on
on
on
on
on


Generating tokens: 42it [00:00, 50.94it/s]

on
on
on
on
pause
off
on
on
on
on


Generating tokens: 53it [00:01, 47.19it/s]

on
on
on
on
pause
off
on
on
on


Generating tokens: 58it [00:01, 45.11it/s]

on
on
on
on
on
on
on
on


Generating tokens: 68it [00:01, 40.92it/s]

on
on
on
on
on
on
on
on


Generating tokens: 73it [00:01, 39.96it/s]

on
on
on
on
on
on
on
on


Generating tokens: 82it [00:01, 37.52it/s]

on
on
on
on
on
on
on


Generating tokens: 90it [00:02, 36.53it/s]

on
on
on
on
on
on
on
on


Generating tokens: 98it [00:02, 35.18it/s]

on
on
on
on
on
on
on


Generating tokens: 106it [00:02, 34.39it/s]

on
on
on
on
on
on
on


Generating tokens: 110it [00:02, 33.50it/s]

on
on
on
on
on
on
on


Generating tokens: 118it [00:02, 31.71it/s]

on
on
on
on
on
on
on


Generating tokens: 126it [00:03, 30.71it/s]

on
on
on
on
on
on


Generating tokens: 127it [00:03, 38.90it/s]

on
on





In [10]:
player.play_channel(channel)

[32m2023-08-03 17:22:31.987[0m | [1mINFO    [0m | [36myoshimidi.player[0m:[36mplay[0m:[36m12[0m - [1mPlaying /var/folders/7w/66fh_d3s5hb0br9f7wtqww1h0000gn/T/tmp4sb7uftm[0m
[32m2023-08-03 17:22:33.038[0m | [1mINFO    [0m | [36myoshimidi.player[0m:[36mplay[0m:[36m17[0m - [1mFinished playing[0m


In [11]:
channel.notes

[Note(note=60, kind='on', velocity=127, time_delta_secs=0),
 Note(note=62, kind='on', velocity=127, time_delta_secs=0),
 Note(note=64, kind='on', velocity=127, time_delta_secs=0.2),
 Note(note=60, kind='off', velocity=127, time_delta_secs=0),
 Note(note=62, kind='off', velocity=127, time_delta_secs=0),
 Note(note=64, kind='off', velocity=127, time_delta_secs=0.2),
 Note(note=61, kind='on', velocity=127, time_delta_secs=0),
 Note(note=63, kind='on', velocity=127, time_delta_secs=0),
 Note(note=65, kind='on', velocity=127, time_delta_secs=0.2),
 Note(note=60, kind='off', velocity=127, time_delta_secs=0),
 Note(note=62, kind='off', velocity=127, time_delta_secs=0),
 Note(note=64, kind='off', velocity=127, time_delta_secs=0.2),
 Note(note=62, kind='on', velocity=127, time_delta_secs=0),
 Note(note=64, kind='on', velocity=127, time_delta_secs=0),
 Note(note=66, kind='on', velocity=127, time_delta_secs=0.2),
 Note(note=0, kind='off', velocity=127, time_delta_secs=0),
 Note(note=0, kind='off'