In [25]:
# Imports
import torch
from torch import nn
from src.utils.preprocess_utils import midi_to_note_indices, notes_to_midi
import numpy as np
import torch.nn.functional as F
from mido import MidiFile, MidiTrack, Message
from datetime import datetime

In [26]:
# Define model architecture
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, embedding_size, num_heads, hidden_size, num_layers, dropout=0.2):
        super(TransformerModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.transformer = nn.Transformer(
            d_model=embedding_size,
            nhead=num_heads,
            num_encoder_layers=num_layers,
            dim_feedforward=hidden_size,
            dropout=dropout,
            batch_first=True
        )
        self.fc = nn.Linear(embedding_size, vocab_size)

    def forward(self, x):
        # x: [batch_size, sequence_length]
        x = x.long()
        embedded = self.embedding(x)  # [batch_size, sequence_length, embedding_size]
        
        # Transformer automatically handles positional encodings internally
        transformer_out = self.transformer.encoder(embedded)  # [batch_size, sequence_length, embedding_size]
        
        output = self.fc(transformer_out)  # [batch_size, sequence_length, vocab_size]
        return output

In [27]:
# Load model
model_name = "2025-01-05_22-29-11_e64_nh4_h256_l2_sl25"
model = torch.load(f'weights/model_3/{model_name}')
model.eval()

  model = torch.load(f'weights/model_3/{model_name}')


SimplifiedTransformerModel(
  (embedding): Embedding(85, 64)
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-1): 2 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
          )
          (linear1): Linear(in_features=64, out_features=256, bias=True)
          (dropout): Dropout(p=0.2, inplace=False)
          (linear2): Linear(in_features=256, out_features=64, bias=True)
          (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.2, inplace=False)
          (dropout2): Dropout(p=0.2, inplace=False)
        )
      )
      (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): TransformerDecoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerDecoderLayer(
    

In [28]:
# Load input
input = midi_to_note_indices("data/input.mid", 24, 108)
sequence_length = 150  # Desired length of the generated sequence
temperature = 0.3
print(input)

[24, 26, 31, 36]


In [29]:
def generate_sequence(model, start_sequence, sequence_length, num_notes, temperature=1.0):
    model.eval()

    start_tensor = torch.tensor(start_sequence, dtype=torch.long).unsqueeze(0)  # [1, sequence_length, input_size]

    generated_sequence = list(start_sequence)

    for i in range(sequence_length):
        with torch.no_grad():
            output = model(start_tensor)  # [1, i, output_size]
            logits = output.squeeze(0)  # Remove the batch dimension: [i, output_size]

            # Temperature scaling
            scaled_logits = logits
            if temperature != 0:
                scaled_logits = logits / temperature

            # Apply softmax to get probabilities for the next note
            output_probs = F.softmax(scaled_logits, dim=-1)

            # Sample the next note based on the probabilities
            next_note = torch.multinomial(output_probs[-1], 1).item()

            # Append the predicted note to the sequence and this sequence's one-hot
            # encoded format will be the input for the next note's generation
            generated_sequence.append(next_note)
            generated_sequence_tensor = np.array(generated_sequence)
            start_tensor = torch.tensor(generated_sequence_tensor, dtype=torch.long).unsqueeze(0)  # Add batch dimension
    
    return generated_sequence

generated_sequence = generate_sequence(model.cpu(), input, sequence_length, num_notes=85, temperature=temperature)

print("Generated Sequence:")
print(generated_sequence)

Generated Sequence:
[24, 26, 31, 36, 71, 76, 68, 71, 3, 15, 3, 15, 3, 15, 55, 53, 76, 64, 24, 43, 81, 81, 22, 10, 22, 10, 22, 10, 22, 10, 43, 34, 81, 84, 30, 9, 71, 66, 70, 43, 37, 6, 16, 4, 39, 63, 80, 1, 27, 24, 12, 39, 30, 43, 34, 71, 70, 3, 39, 30, 16, 4, 39, 71, 63, 80, 34, 7, 71, 69, 57, 69, 8, 4, 16, 4, 16, 4, 16, 4, 16, 4, 16, 4, 16, 4, 16, 4, 82, 35, 29, 79, 77, 19, 7, 19, 7, 67, 6, 18, 6, 57, 69, 71, 39, 43, 7, 43, 7, 19, 7, 19, 7, 19, 7, 71, 3, 81, 43, 7, 19, 7, 19, 7, 8, 63, 67, 11, 71, 43, 67, 55, 39, 43, 51, 63, 61, 49, 81, 82, 15, 78, 4, 16, 4, 8, 63, 80, 30, 9, 66, 54, 63, 77]


In [30]:
current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
file_path = f"generations/model_3/{current_datetime}_t{temperature:.1f}_w{sequence_length}_M_{model_name}.mid"

notes_to_midi(generated_sequence, file_path)

MIDI file saved to generations/model_3/2025-01-05_22-37-17_t0.3_w150_M_2025-01-05_22-29-11_e64_nh4_h256_l2_sl25.mid
