In [31]:
import torch
import numpy as np
from tqdm.auto import tqdm
import pretty_midi

In [32]:
min_pitch = 21
max_pitch = 108
n_pitches = max_pitch - min_pitch + 1
sequence_length = 128
n_velocities = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [35]:
from model import MusicGen
model = MusicGen()
model.load_state_dict(torch.load('models/model2-e3.pth', weights_only=True))
model.to(device)

MusicGen(
  (lstm): LSTM(4, 128, batch_first=True)
  (pitch_layer): Linear(in_features=128, out_features=88, bias=True)
  (velocity_layer): Linear(in_features=128, out_features=128, bias=True)
  (step_layer): Linear(in_features=128, out_features=1, bias=True)
  (duration_layer): Linear(in_features=128, out_features=1, bias=True)
  (relu): ReLU()
)

In [36]:
def generate(model, seed_sequence, steps=512, device='cpu'):
  model.eval()
  # seed_sequence: (1, 128, 4)
  seed_sequence = seed_sequence.to(device)
  # generated_sequence: (steps, 4)
  generated_sequence = []
  hidden = None

  with torch.no_grad():
    for _ in tqdm(range(steps)):
      out, hidden = model(seed_sequence, hidden)
      pitch_pred, velocity_pred, duration_pred, step_pred = torch.split(
          out, [n_pitches, n_velocities, 1, 1], dim=-1
      )
      
      pitch_probs = torch.softmax(pitch_pred, dim=-1)
      velocity_probs = torch.softmax(velocity_pred, dim=-1)
      
      pitch = torch.multinomial(pitch_probs, num_samples=1).item()
      velocity = torch.multinomial(velocity_probs, num_samples=1).item()
      step = step_pred.item()
      duration = duration_pred.item()
      
      generated_note = [pitch, velocity, step, duration]
      generated_sequence.append(generated_note)
      
      # newnote: (1, 1, 4) float32
      new_note = torch.tensor(generated_note, device=device).unsqueeze(0).unsqueeze(0)
      seed_sequence = torch.cat([seed_sequence[:, 1:, :], new_note], dim=1)

    generated_sequence = np.array(generated_sequence)
    generated_sequence[:, 0] = generated_sequence[:, 0] + min_pitch
    return generated_sequence

In [37]:
seed_sequence = torch.zeros((1, sequence_length, 4))
generated_notes = generate(model, seed_sequence, steps=480, device=device)

  0%|          | 0/480 [00:00<?, ?it/s]

In [38]:
generated_midi = pretty_midi.PrettyMIDI()
instrument = pretty_midi.Instrument(program=0)
current_time = 0.0

for pitch, velocity, duration, step in generated_notes:
  current_time += step
  note = pretty_midi.Note(
    velocity=int(velocity),
    pitch=int(pitch),
    start=(current_time),
    end=(current_time + duration)
  )
  instrument.notes.append(note)

generated_midi.instruments.append(instrument)
generated_midi.write("generated/2.midi")