In [1]:
import torch
from torch import nn
from src.utils.preprocess_utils import notes_to_midi, parse_midi_notes_sequence, map_notes_to_range, midi_int_to_note_index, one_hot_encode
import numpy as np
import torch.nn.functional as F
from mido import MidiFile, MidiTrack, Message
from datetime import datetime

In [2]:
# Define model architecture
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=3):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=True, dropout=0.2)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # x: [batch_size, sequence_length, input_size]
        out, _ = self.lstm(x)  # out: [batch_size, sequence_length, hidden_size]
        out = self.fc(out)
        return out

In [3]:
# Load model
model = torch.load('weights/model_1/2025-01-03_19-34-53')
model.eval()

  model = torch.load('weights/model_1/2025-01-03_19-34-53')


LSTMModel(
  (lstm): LSTM(85, 128, num_layers=2, batch_first=True, dropout=0.2)
  (fc): Linear(in_features=128, out_features=85, bias=True)
)

In [5]:
# Load input
notes = parse_midi_notes_sequence('data/input.mid')
input = map_notes_to_range(notes)
input = [midi_int_to_note_index(note) for note in input]
sequence_length = 200  # Desired length of the generated sequence
temperature = 0.3
print(input)

[24, 26, 31, 36]


In [6]:
def generate_sequence(model, start_sequence, sequence_length, num_notes, temperature=1.0):
    model.eval()

    start_sequence_one_hot = one_hot_encode(start_sequence, num_notes=num_notes)
    start_tensor = torch.tensor(start_sequence_one_hot, dtype=torch.float32).unsqueeze(0)  # [1, sequence_length, input_size]

    generated_sequence = list(start_sequence)

    for i in range(sequence_length):
        with torch.no_grad():
            output = model(start_tensor)  # [1, i, output_size]
            logits = output.squeeze(0)  # Remove the batch dimension: [i, output_size]

            # Temperature scaling
            scaled_logits = logits
            if temperature != 0:
                scaled_logits = logits / temperature

            # Apply softmax to get probabilities for the next note
            output_probs = F.softmax(scaled_logits, dim=-1)

            # Sample the next note based on the probabilities
            next_note = torch.multinomial(output_probs[-1], 1).item()

            # Append the predicted note to the sequence and this sequence's one-hot
            # encoded format will be the input for the next note's generation
            generated_sequence.append(next_note)
            generated_sequence_tensor = np.array(generated_sequence)
            start_tensor = torch.tensor(one_hot_encode(generated_sequence_tensor, num_notes=num_notes), dtype=torch.float32).unsqueeze(0)  # Add batch dimension
    
    return generated_sequence

generated_sequence = generate_sequence(model.cpu(), input, sequence_length, num_notes=85, temperature=temperature)

print("Generated Sequence:")
print(generated_sequence)

Generated Sequence:
[24, 26, 31, 36, 47, 55, 46, 37, 52, 54, 31, 34, 33, 41, 51, 52, 49, 45, 43, 49, 48, 29, 46, 46, 38, 52, 54, 33, 44, 33, 50, 32, 38, 50, 45, 55, 42, 40, 41, 41, 44, 33, 41, 34, 47, 52, 33, 26, 38, 40, 35, 36, 48, 51, 52, 43, 40, 39, 31, 48, 46, 46, 38, 34, 36, 50, 50, 40, 48, 34, 50, 50, 31, 39, 47, 43, 36, 31, 51, 39, 45, 38, 40, 45, 44, 47, 33, 38, 45, 31, 50, 46, 39, 31, 40, 40, 39, 48, 43, 39, 38, 34, 45, 40, 45, 40, 38, 50, 29, 43, 29, 35, 38, 52, 40, 36, 40, 52, 43, 50, 44, 39, 52, 45, 40, 33, 49, 35, 45, 36, 43, 57, 36, 46, 43, 31, 48, 44, 45, 38, 38, 32, 42, 36, 55, 43, 31, 35, 42, 48, 45, 47, 35, 38, 33, 45, 33, 45, 35, 43, 42, 29, 32, 46, 35, 41, 37, 47, 51, 46, 40, 39, 33, 45, 31, 33, 31, 31, 28, 39, 38, 40, 39, 47, 44, 40, 47, 31, 33, 49, 31, 33, 34, 46, 51, 48, 48, 40, 38, 38, 39, 28, 40, 52]


In [7]:
current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
file_path = f"generations/model_1/{current_datetime}_t{temperature:.1f}_w{sequence_length}.mid"

notes_to_midi(generated_sequence, file_path)

MIDI file saved to generations/model_1/2025-01-03_19-35-58_t0.3_w200.mid
