In [None]:

SEQ_LEN       = 50
SPLIT         = 0.8
LOSS_W_CONT   = 0.5
TEMP          = 1.0
NOISE_STD     = 0.05
CAP_STEP      = 2.0
CAP_DURATION  = 2.0
SHOW_PPL      = True
USE_SCHEDULER = True

import os
import math
import pretty_midi
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


def extract_features(midi_file):
    pm = pretty_midi.PrettyMIDI(midi_file)
    inst = pm.instruments[0]
    notes = inst.notes

    pitches, steps, durations, velocities = [], [], [], []
    prev_time = 0.0

    for i, note in enumerate(notes):
        pitches.append(note.pitch)
        start = note.start
        dur = note.end - start
        vel = note.velocity
        step = start - prev_time if i > 0 else 0.0
        steps.append(step); durations.append(dur); velocities.append(vel)
        prev_time = start

    steps = torch.tensor(steps, dtype=torch.float32)
    durations = torch.tensor(durations, dtype=torch.float32)
    velocities = torch.tensor(velocities, dtype=torch.float32)
    pitches = torch.tensor(pitches, dtype=torch.long)

    step_max = steps.max().item() + 0.00001
    duration_max = durations.max().item() + 0.00001

    steps = steps / step_max
    durations = durations / duration_max
    velocities = velocities / 127.0


    cont_features = torch.stack([steps, durations, velocities], dim=1)
    return pitches, cont_features, step_max, duration_max


def load_all_data(dataset_dir):
    pitches_all, cont_all = [], []
    step_max_global, duration_max_global = 0.0, 0.0

    for filename in sorted(os.listdir(dataset_dir)):
        if filename.lower().endswith(('.mid', '.midi')):
            p, c, step_max, duration_max = extract_features(os.path.join(dataset_dir, filename))
            pitches_all.append(p); cont_all.append(c)
            step_max_global = max(step_max_global, step_max)
            duration_max_global = max(duration_max_global, duration_max)

    pitches_all = torch.cat(pitches_all)
    cont_all = torch.cat(cont_all)
    return pitches_all, cont_all, step_max_global, duration_max_global


In [None]:
def create_sequences(pitches, cont_features, seq_length):
    pitch_seqs, cont_seqs, pitch_targets, cont_targets = [], [], [], []
    for i in range(len(pitches) - seq_length):
        pitch_seqs.append(pitches[i:i + seq_length])
        cont_seqs.append(cont_features[i:i + seq_length])
        pitch_targets.append(pitches[i + seq_length])
        cont_targets.append(cont_features[i + seq_length])

    return (torch.stack(pitch_seqs),
            torch.stack(cont_seqs),
            torch.tensor(pitch_targets),
            torch.stack(cont_targets))


def temporal_split_before_window(pitches, cont, split_ratio=0.8):
    n_total = len(pitches)
    cut = int(n_total * split_ratio)
    return (pitches[:cut], cont[:cut]), (pitches[cut:], cont[cut:])


class MusicLSTMMultiOutput(nn.Module):
    def __init__(self, n_pitches=128, embed_size=32, hidden_size=128, num_layers=2, dropout=0.3):
        super().__init__()
        self.embed = nn.Embedding(n_pitches, embed_size)
        self.lstm = nn.LSTM(embed_size + 3, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.hidden_fc = nn.Linear(hidden_size, hidden_size)
        self.pitch_fc = nn.Linear(hidden_size, n_pitches)
        self.cont_fc = nn.Linear(hidden_size, 3)

    def forward(self, pitch_seq, cont_seq):
        pitch_emb = self.embed(pitch_seq)
        x = torch.cat([pitch_emb, cont_seq], dim=2)
        out, _ = self.lstm(x)
        h = F.relu(self.hidden_fc(out[:, -1, :]))
        return self.pitch_fc(h), self.cont_fc(h)
