In [1]:

SEQ_LEN       = 50
SPLIT         = 0.8
LOSS_W_CONT   = 0.5
TEMP          = 1.0
NOISE_STD     = 0.05
CAP_STEP      = 2.0
CAP_DURATION  = 2.0
SHOW_PPL      = True
USE_SCHEDULER = True

import os
import math
import pretty_midi
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


def extract_features(midi_file):
    pm = pretty_midi.PrettyMIDI(midi_file)
    inst = pm.instruments[0]
    notes = inst.notes

    pitches, steps, durations, velocities = [], [], [], []
    prev_time = 0.0

    for i, note in enumerate(notes):
        pitches.append(note.pitch)
        start = note.start
        dur = note.end - start
        vel = note.velocity
        step = start - prev_time if i > 0 else 0.0
        steps.append(step); durations.append(dur); velocities.append(vel)
        prev_time = start

    steps = torch.tensor(steps, dtype=torch.float32)
    durations = torch.tensor(durations, dtype=torch.float32)
    velocities = torch.tensor(velocities, dtype=torch.float32)
    pitches = torch.tensor(pitches, dtype=torch.long)

    step_max = steps.max().item() + 0.00001
    duration_max = durations.max().item() + 0.00001

    steps = steps / step_max
    durations = durations / duration_max
    velocities = velocities / 127.0


    cont_features = torch.stack([steps, durations, velocities], dim=1)
    return pitches, cont_features, step_max, duration_max


def load_all_data(dataset_dir):
    pitches_all, cont_all = [], []
    step_max_global, duration_max_global = 0.0, 0.0

    for filename in sorted(os.listdir(dataset_dir)):
        if filename.lower().endswith(('.mid', '.midi')):
            p, c, step_max, duration_max = extract_features(os.path.join(dataset_dir, filename))
            pitches_all.append(p); cont_all.append(c)
            step_max_global = max(step_max_global, step_max)
            duration_max_global = max(duration_max_global, duration_max)

    pitches_all = torch.cat(pitches_all)
    cont_all = torch.cat(cont_all)
    return pitches_all, cont_all, step_max_global, duration_max_global


In [3]:
def create_sequences(pitches, cont_features, seq_length):
    pitch_seqs, cont_seqs, pitch_targets, cont_targets = [], [], [], []
    for i in range(len(pitches) - seq_length):
        pitch_seqs.append(pitches[i:i + seq_length])
        cont_seqs.append(cont_features[i:i + seq_length])
        pitch_targets.append(pitches[i + seq_length])
        cont_targets.append(cont_features[i + seq_length])

    return (torch.stack(pitch_seqs),
            torch.stack(cont_seqs),
            torch.tensor(pitch_targets),
            torch.stack(cont_targets))


def temporal_split_before_window(pitches, cont, split_ratio=0.8):
    n_total = len(pitches)
    cut = int(n_total * split_ratio)
    return (pitches[:cut], cont[:cut]), (pitches[cut:], cont[cut:])


class MusicLSTMMultiOutput(nn.Module):
    def __init__(self, n_pitches=128, embed_size=32, hidden_size=128, num_layers=2, dropout=0.3):
        super().__init__()
        self.embed = nn.Embedding(n_pitches, embed_size)
        self.lstm = nn.LSTM(embed_size + 3, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.hidden_fc = nn.Linear(hidden_size, hidden_size)
        self.pitch_fc = nn.Linear(hidden_size, n_pitches)
        self.cont_fc = nn.Linear(hidden_size, 3)

    def forward(self, pitch_seq, cont_seq):
        pitch_emb = self.embed(pitch_seq)
        x = torch.cat([pitch_emb, cont_seq], dim=2)
        out, _ = self.lstm(x)
        h = F.relu(self.hidden_fc(out[:, -1, :]))
        return self.pitch_fc(h), self.cont_fc(h)


In [5]:
def train(model, train_data, val_data, epochs=20, batch_size=64, lr=0.001, device='cpu'):
    model.to(device)
    opt = optim.Adam(model.parameters(), lr=lr)
    sched = optim.lr_scheduler.ReduceLROnPlateau(opt, factor=0.5, patience=3, verbose=True) if USE_SCHEDULER else None

    loss_pitch = nn.CrossEntropyLoss()
    loss_cont  = nn.MSELoss()

    train_pitch_seq, train_cont_seq, train_pitch_tgt, train_cont_tgt = train_data
    val_pitch_seq,   val_cont_seq,   val_pitch_tgt,   val_cont_tgt   = val_data

    val_pitch_seq = val_pitch_seq.to(device); val_cont_seq = val_cont_seq.to(device)
    val_pitch_tgt = val_pitch_tgt.to(device); val_cont_tgt = val_cont_tgt.to(device)

    n_train = len(train_pitch_seq)
    steps_per_epoch = max(1, (n_train + batch_size - 1) // batch_size)

    for epoch in range(epochs):
        model.train()
        perm = torch.randperm(n_train)
        sum_loss = 0.0
        sum_pitch_ce = 0.0
        sum_acc = 0.0

        for i in range(0, n_train, batch_size):
            idx  = perm[i:i + batch_size]
            pseq = train_pitch_seq[idx].to(device)
            cseq = train_cont_seq[idx].to(device)
            ptgt = train_pitch_tgt[idx].to(device)
            ctgt = train_cont_tgt[idx].to(device)

            opt.zero_grad()
            logits, cont_out = model(pseq, cseq)
            l_pitch = loss_pitch(logits, ptgt)
            l_cont  = loss_cont(cont_out, ctgt)
            loss = l_pitch + LOSS_W_CONT * l_cont
            loss.backward()
            opt.step()

            sum_loss    += loss.item()
            sum_pitch_ce += l_pitch.item()
            sum_acc     += (logits.argmax(dim=1) == ptgt).float().mean().item()

        avg_train_loss = sum_loss / steps_per_epoch
        avg_acc        = sum_acc  / steps_per_epoch
        train_ppl      = math.exp(sum_pitch_ce / steps_per_epoch)

        model.eval()
        with torch.no_grad():
            val_logits, val_cont_out = model(val_pitch_seq, val_cont_seq)
            val_ce   = loss_pitch(val_logits, val_pitch_tgt).item()
            val_mse  = loss_cont(val_cont_out, val_cont_tgt).item()
            val_loss = val_ce + LOSS_W_CONT * val_mse
            val_acc  = (val_logits.argmax(dim=1) == val_pitch_tgt).float().mean().item()
            val_ppl  = math.exp(val_ce)

        msg = (f"Ep {epoch + 1}/{epochs} | loss {avg_train_loss:.4f} | acc {avg_acc:.3f} | "
               f"vloss {val_loss:.4f} | vacc {val_acc:.3f}")
        if SHOW_PPL:
            msg += f" | ppl {train_ppl:.2f} | vppl {val_ppl:.2f}"
        print(msg)

        if sched is not None:
            sched.step(val_loss)


def sample_pitch(logits, temperature=1.0):
    logits = logits / max(temperature, 0.00000001)
    probs = F.softmax(logits, dim=-1)
    return int(torch.multinomial(probs, 1).item())



In [None]:
def generate_sequence_multi(model, seed_pitch, seed_cont, length=200, temperature=TEMP,
                            noise_std=NOISE_STD, step_max=1.0, duration_max=1.0, device='cpu'):
    model.eval()
    out_p = []
    out_c = []
    pitch_seq = seed_pitch.clone().to(device)
    cont_seq = seed_cont.clone().to(device)

    for _ in range(length):
        with torch.no_grad():
            plogits, cout = model(pitch_seq.unsqueeze(0), cont_seq.unsqueeze(0))

        pidx = sample_pitch(plogits.squeeze(0).cpu(), temperature)

        cvals = cout.squeeze(0).cpu()
        if noise_std > 0:
            cvals = cvals + torch.randn(3) * noise_std
        cvals = cvals.clamp(min=0.0)
        cvals[0] = cvals[0] * step_max
        cvals[1] = cvals[1] * duration_max
        cvals[2] = cvals[2].clamp(max=1.0)

        out_p.append(pidx)
        out_c.append(cvals)

        pitch_seq = torch.cat([pitch_seq[1:], torch.tensor([pidx], device=device)])
        cont_seq = torch.cat([
            cont_seq[1:],
            (cvals / torch.tensor([step_max, duration_max, 1.0])).unsqueeze(0).to(device)
        ])

    return torch.tensor(out_p), torch.stack(out_c)

def save_midi(pitches, cont_features, output_file):
    pm = pretty_midi.PrettyMIDI()
    inst = pretty_midi.Instrument(program=0)
    t = 0.0

    for i in range(len(pitches)):
        pitch = int(pitches[i].item())
        step = float(cont_features[i][0].item())
        duration = float(cont_features[i][1].item())
        velocity = int(cont_features[i][2].item() * 127)

        if step > CAP_STEP:
            step = CAP_STEP
        if duration > CAP_DURATION:
            duration = CAP_DURATION

        t = t + step
        start_time = t
        end_time = start_time + max(duration, 0.05)
        inst.notes.append(pretty_midi.Note(velocity=velocity, pitch=pitch, start=start_time, end=end_time))

    pm.instruments.append(inst)
    pm.write(output_file)
    print("Archivo MIDI guardado en:", output_file)

@torch.no_grad()
def evaluate_all(model, pitch_seqs, cont_seqs, pitch_tg, cont_tg,
                 step_max, duration_max, device='cpu', batch_size=1024):
    device = torch.device(device)
    model.eval()
    ce_fn = nn.CrossEntropyLoss(reduction='sum')

    pitch_seqs = pitch_seqs.to(device)
    cont_seqs  = cont_seqs.to(device)
    pitch_tg   = pitch_tg.to(device)
    cont_tg    = cont_tg.to(device)

    n = pitch_tg.size(0)
    ce_sum = 0.0
    acc_sum = 0.0
    acc5_sum = 0.0
    mse_n = torch.zeros(3)
    mae_n = torch.zeros(3)
    mse_r = torch.zeros(3)
    mae_r = torch.zeros(3)

    for i in range(0, n, batch_size):
        pseq = pitch_seqs[i:i+batch_size]
        cseq = cont_seqs[i:i+batch_size]
        ptgt = pitch_tg[i:i+batch_size]
        ctgt = cont_tg[i:i+batch_size]

        logits, cpred = model(pseq, cseq)
        ce_sum += ce_fn(logits, ptgt).item()

        pred = logits.argmax(dim=1)
        acc_sum += (pred == ptgt).float().sum().item()

        top5 = logits.topk(5, dim=1).indices
        acc5_sum += (top5 == ptgt.unsqueeze(1)).any(dim=1).float().sum().item()

        diff_n = cpred - ctgt
        mse_n += (diff_n ** 2).sum(dim=0).detach().cpu()
        mae_n += diff_n.abs().sum(dim=0).detach().cpu()

        den = torch.stack([
            cpred[:, 0] * step_max,
            cpred[:, 1] * duration_max,
            cpred[:, 2].clamp(0, 1) * 127.0
        ], dim=1)
        den_t = torch.stack([
            ctgt[:, 0] * step_max,
            ctgt[:, 1] * duration_max,
            ctgt[:, 2].clamp(0, 1) * 127.0
        ], dim=1)

        diff_r = den - den_t
        mse_r += (diff_r ** 2).sum(dim=0).detach().cpu()
        mae_r += diff_r.abs().sum(dim=0).detach().cpu()

    ce = ce_sum / n
    ppl = math.exp(ce)
    acc = acc_sum / n
    acc5 = acc5_sum / n

    rmse_norm = (mse_n / n).sqrt().tolist()
    mae_norm = (mae_n / n).tolist()
    rmse_real = (mse_r / n).sqrt().tolist()
    mae_real = (mae_r / n).tolist()

    return {
        "pitch": {"CE": ce, "PPL": ppl, "acc": acc, "top5": acc5},
        "cont_norm": {"RMSE": rmse_norm, "MAE": mae_norm},
        "cont_real": {"RMSE_sec_vel": rmse_real, "MAE_sec_vel": mae_real}
    }
