In [25]:
# Imports
import torch
from torch import nn
from src.utils.preprocess_utils import midi_to_note_indices, notes_to_midi
import numpy as np
import torch.nn.functional as F
from mido import MidiFile, MidiTrack, Message
from datetime import datetime

In [26]:
# Define model architecture
class LSTMModel(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, num_layers):
        super(LSTMModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.lstm = nn.LSTM(embedding_size, hidden_size, num_layers, batch_first=True, dropout=0.2)
        self.fc = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, x, hidden=None):
        # x: [batch_size, sequence_length]
        x = x.long()
        embedded = self.embedding(x)  # [batch_size, sequence_length, embedding_size]
        lstm_out, hidden = self.lstm(embedded, hidden)  # [batch_size, sequence_length, hidden_size]
        output = self.fc(lstm_out)  # [batch_size, sequence_length, vocab_size]
        
        return output

In [27]:
# Load model
model_name = "2025-01-05_19-23-45_e64_h128_l5_sl50"
model = torch.load(f'weights/model_2/{model_name}')
model.eval()

  model = torch.load(f'weights/model_2/{model_name}')


LSTMModel(
  (embedding): Embedding(85, 64)
  (lstm): LSTM(64, 128, num_layers=5, batch_first=True, dropout=0.2)
  (fc): Linear(in_features=128, out_features=85, bias=True)
)

In [28]:
# Load input
input = midi_to_note_indices("data/input.mid", 24, 108)
sequence_length = 200  # Desired length of the generated sequence
temperature = 0.1
print(input)

[24, 26, 31, 36]


In [29]:
def generate_sequence(model, start_sequence, sequence_length, num_notes, temperature=1.0):
    model.eval()

    start_tensor = torch.tensor(start_sequence, dtype=torch.long).unsqueeze(0)  # [1, sequence_length, input_size]

    generated_sequence = list(start_sequence)

    for i in range(sequence_length):
        with torch.no_grad():
            output = model(start_tensor)  # [1, i, output_size]
            logits = output.squeeze(0)  # Remove the batch dimension: [i, output_size]

            # Temperature scaling
            scaled_logits = logits
            if temperature != 0:
                scaled_logits = logits / temperature

            # Apply softmax to get probabilities for the next note
            output_probs = F.softmax(scaled_logits, dim=-1)

            # Sample the next note based on the probabilities
            next_note = torch.multinomial(output_probs[-1], 1).item()

            # Append the predicted note to the sequence and this sequence's one-hot
            # encoded format will be the input for the next note's generation
            generated_sequence.append(next_note)
            generated_sequence_tensor = np.array(generated_sequence)
            start_tensor = torch.tensor(generated_sequence_tensor, dtype=torch.long).unsqueeze(0)  # Add batch dimension
    
    return generated_sequence

generated_sequence = generate_sequence(model.cpu(), input, sequence_length, num_notes=85, temperature=temperature)

print("Generated Sequence:")
print(generated_sequence)

Generated Sequence:
[24, 26, 31, 36, 21, 76, 27, 19, 41, 50, 38, 28, 41, 33, 48, 54, 22, 51, 48, 59, 56, 48, 56, 80, 43, 50, 7, 19, 43, 52, 39, 45, 48, 22, 43, 52, 45, 48, 33, 45, 38, 38, 58, 60, 59, 52, 45, 48, 34, 29, 45, 18, 6, 52, 45, 19, 56, 54, 18, 52, 40, 48, 46, 58, 43, 43, 33, 46, 62, 33, 41, 29, 52, 43, 41, 42, 45, 8, 37, 41, 45, 51, 55, 43, 50, 22, 41, 52, 48, 33, 55, 55, 52, 59, 28, 55, 52, 35, 47, 29, 60, 29, 17, 59, 54, 56, 71, 17, 48, 70, 55, 64, 55, 26, 45, 54, 22, 38, 33, 64, 52, 67, 19, 52, 20, 49, 38, 32, 34, 59, 22, 34, 14, 46, 64, 45, 35, 38, 50, 56, 55, 28, 46, 43, 40, 55, 46, 64, 45, 55, 43, 63, 24, 18, 12, 52, 34, 39, 27, 52, 32, 21, 45, 37, 56, 33, 17, 48, 27, 50, 53, 51, 24, 27, 50, 53, 55, 26, 52, 59, 60, 20, 78, 29, 37, 38, 26, 28, 17, 51, 59, 56, 35, 54, 52, 53, 47, 38, 46, 55, 81, 58, 28, 52]


In [24]:
current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
file_path = f"generations/model_2/{current_datetime}_t{temperature:.1f}_w{sequence_length}_M_{model_name}.mid"

notes_to_midi(generated_sequence, file_path)

MIDI file saved to generations/model_2/2025-01-05_19-28-54_t1.0_w200_M_2025-01-05_19-23-45_e64_h128_l5_sl50.mid
