In [1]:

SEQ_LEN       = 50
SPLIT         = 0.8
LOSS_W_CONT   = 0.5
TEMP          = 1.0
NOISE_STD     = 0.05
CAP_STEP      = 2.0
CAP_DURATION  = 2.0
SHOW_PPL      = True
USE_SCHEDULER = True

import os
import math
import pretty_midi
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


def extract_features(midi_file):
    pm = pretty_midi.PrettyMIDI(midi_file)
    inst = pm.instruments[0]
    notes = inst.notes

    pitches, steps, durations, velocities = [], [], [], []
    prev_time = 0.0

    for i, note in enumerate(notes):
        pitches.append(note.pitch)
        start = note.start
        dur = note.end - start
        vel = note.velocity
        step = start - prev_time if i > 0 else 0.0
        steps.append(step); durations.append(dur); velocities.append(vel)
        prev_time = start

    steps = torch.tensor(steps, dtype=torch.float32)
    durations = torch.tensor(durations, dtype=torch.float32)
    velocities = torch.tensor(velocities, dtype=torch.float32)
    pitches = torch.tensor(pitches, dtype=torch.long)

    step_max = steps.max().item() + 0.00001
    duration_max = durations.max().item() + 0.00001

    steps = steps / step_max
    durations = durations / duration_max
    velocities = velocities / 127.0


    cont_features = torch.stack([steps, durations, velocities], dim=1)
    return pitches, cont_features, step_max, duration_max


def load_all_data(dataset_dir):
    pitches_all, cont_all = [], []
    step_max_global, duration_max_global = 0.0, 0.0

    for filename in sorted(os.listdir(dataset_dir)):
        if filename.lower().endswith(('.mid', '.midi')):
            p, c, step_max, duration_max = extract_features(os.path.join(dataset_dir, filename))
            pitches_all.append(p); cont_all.append(c)
            step_max_global = max(step_max_global, step_max)
            duration_max_global = max(duration_max_global, duration_max)

    pitches_all = torch.cat(pitches_all)
    cont_all = torch.cat(cont_all)
    return pitches_all, cont_all, step_max_global, duration_max_global


In [3]:
def create_sequences(pitches, cont_features, seq_length):
    pitch_seqs, cont_seqs, pitch_targets, cont_targets = [], [], [], []
    for i in range(len(pitches) - seq_length):
        pitch_seqs.append(pitches[i:i + seq_length])
        cont_seqs.append(cont_features[i:i + seq_length])
        pitch_targets.append(pitches[i + seq_length])
        cont_targets.append(cont_features[i + seq_length])

    return (torch.stack(pitch_seqs),
            torch.stack(cont_seqs),
            torch.tensor(pitch_targets),
            torch.stack(cont_targets))


def temporal_split_before_window(pitches, cont, split_ratio=0.8):
    n_total = len(pitches)
    cut = int(n_total * split_ratio)
    return (pitches[:cut], cont[:cut]), (pitches[cut:], cont[cut:])


class MusicLSTMMultiOutput(nn.Module):
    def __init__(self, n_pitches=128, embed_size=32, hidden_size=128, num_layers=2, dropout=0.3):
        super().__init__()
        self.embed = nn.Embedding(n_pitches, embed_size)
        self.lstm = nn.LSTM(embed_size + 3, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.hidden_fc = nn.Linear(hidden_size, hidden_size)
        self.pitch_fc = nn.Linear(hidden_size, n_pitches)
        self.cont_fc = nn.Linear(hidden_size, 3)

    def forward(self, pitch_seq, cont_seq):
        pitch_emb = self.embed(pitch_seq)
        x = torch.cat([pitch_emb, cont_seq], dim=2)
        out, _ = self.lstm(x)
        h = F.relu(self.hidden_fc(out[:, -1, :]))
        return self.pitch_fc(h), self.cont_fc(h)


In [9]:
def train(model, train_data, val_data, epochs=20, batch_size=64, lr=0.001, device='cpu'):
    model.to(device)
    opt   = optim.Adam(model.parameters(), lr=lr)
    sched = optim.lr_scheduler.ReduceLROnPlateau(opt, factor=0.5, patience=3, verbose=True) if USE_SCHEDULER else None

    ce_loss  = nn.CrossEntropyLoss()
    mse_loss = nn.MSELoss()

    tr_seq_p, tr_seq_c, tr_tgt_p, tr_tgt_c = train_data
    va_seq_p, va_seq_c, va_tgt_p, va_tgt_c = val_data

    va_seq_p = va_seq_p.to(device); va_seq_c = va_seq_c.to(device)
    va_tgt_p = va_tgt_p.to(device); va_tgt_c = va_tgt_c.to(device)

    n_train = len(tr_seq_p)
    n_steps = max(1, (n_train + batch_size - 1) // batch_size)

    for ep in range(epochs):
        model.train()
        order = torch.randperm(n_train)
        loss_sum = 0.0
        ce_sum   = 0.0
        acc_sum  = 0.0

        for i in range(0, n_train, batch_size):
            idx  = order[i:i+batch_size]
            pseq = tr_seq_p[idx].to(device)
            cseq = tr_seq_c[idx].to(device)
            ptgt = tr_tgt_p[idx].to(device)
            ctgt = tr_tgt_c[idx].to(device)

            opt.zero_grad()
            logits, cout = model(pseq, cseq)
            lp = ce_loss(logits, ptgt)
            lc = mse_loss(cout, ctgt)
            loss = lp + LOSS_W_CONT * lc
            loss.backward()
            opt.step()

            loss_sum += loss.item()
            ce_sum   += lp.item()
            acc_sum  += (logits.argmax(dim=1) == ptgt).float().mean().item()

        tr_loss = loss_sum / n_steps
        tr_acc  = acc_sum  / n_steps
        tr_ppl  = math.exp(ce_sum / n_steps)

        model.eval()
        with torch.no_grad():
            v_logits, v_cont = model(va_seq_p, va_seq_c)
            v_ce   = ce_loss(v_logits, va_tgt_p).item()
            v_mse  = mse_loss(v_cont,  va_tgt_c).item()
            v_loss = v_ce + LOSS_W_CONT * v_mse
            v_acc  = (v_logits.argmax(dim=1) == va_tgt_p).float().mean().item()
            v_ppl  = math.exp(v_ce)

        msg = f"Ep {ep+1}/{epochs} | loss {tr_loss:.4f} | acc {tr_acc:.3f} | vloss {v_loss:.4f} | vacc {v_acc:.3f}"
        if SHOW_PPL:
            msg += f" | ppl {tr_ppl:.2f} | vppl {v_ppl:.2f}"
        print(msg)

        if sched is not None:
            sched.step(v_loss)

def sample_pitch(logits, temperature=1.0):
    logits = logits / max(temperature, 1e-8)
    probs = F.softmax(logits, dim=-1)
    return int(torch.multinomial(probs, 1).item())





In [None]:
def generate_sequence_multi(model, seed_pitch, seed_cont, length=200, temperature=TEMP,
                            noise_std=NOISE_STD, step_max=1.0, duration_max=1.0, device='cpu'):
    model.eval()
    out_p = []
    out_c = []
    pitch_seq = seed_pitch.clone().to(device)
    cont_seq = seed_cont.clone().to(device)

    for _ in range(length):
        with torch.no_grad():
            plogits, cout = model(pitch_seq.unsqueeze(0), cont_seq.unsqueeze(0))

        pidx = sample_pitch(plogits.squeeze(0).cpu(), temperature)

        cvals = cout.squeeze(0).cpu()
        if noise_std > 0:
            cvals = cvals + torch.randn(3) * noise_std
        cvals = cvals.clamp(min=0.0)
        cvals[0] = cvals[0] * step_max
        cvals[1] = cvals[1] * duration_max
        cvals[2] = cvals[2].clamp(max=1.0)

        out_p.append(pidx)
        out_c.append(cvals)

        pitch_seq = torch.cat([pitch_seq[1:], torch.tensor([pidx], device=device)])
        cont_seq = torch.cat([
            cont_seq[1:],
            (cvals / torch.tensor([step_max, duration_max, 1.0])).unsqueeze(0).to(device)
        ])

    return torch.tensor(out_p), torch.stack(out_c)

def save_midi(pitches, cont_features, output_file):
    pm = pretty_midi.PrettyMIDI()
    inst = pretty_midi.Instrument(program=0)
    t = 0.0

    for i in range(len(pitches)):
        pitch = int(pitches[i].item())
        step = float(cont_features[i][0].item())
        duration = float(cont_features[i][1].item())
        velocity = int(cont_features[i][2].item() * 127)

        if step > CAP_STEP:
            step = CAP_STEP
        if duration > CAP_DURATION:
            duration = CAP_DURATION

        t = t + step
        start_time = t
        end_time = start_time + max(duration, 0.05)
        inst.notes.append(pretty_midi.Note(velocity=velocity, pitch=pitch, start=start_time, end=end_time))

    pm.instruments.append(inst)
    pm.write(output_file)
    print("Archivo MIDI guardado en:", output_file)

@torch.no_grad()
def evaluate_all(model, pitch_seqs, cont_seqs, pitch_tg, cont_tg,
                 step_max, duration_max, device='cpu', batch_size=1024):
    device = torch.device(device)
    model.eval()
    ce_fn = nn.CrossEntropyLoss(reduction='sum')

    pitch_seqs = pitch_seqs.to(device)
    cont_seqs  = cont_seqs.to(device)
    pitch_tg   = pitch_tg.to(device)
    cont_tg    = cont_tg.to(device)

    n = pitch_tg.size(0)
    ce_sum = 0.0
    acc_sum = 0.0
    acc5_sum = 0.0
    mse_n = torch.zeros(3)
    mae_n = torch.zeros(3)
    mse_r = torch.zeros(3)
    mae_r = torch.zeros(3)

    for i in range(0, n, batch_size):
        pseq = pitch_seqs[i:i+batch_size]
        cseq = cont_seqs[i:i+batch_size]
        ptgt = pitch_tg[i:i+batch_size]
        ctgt = cont_tg[i:i+batch_size]

        logits, cpred = model(pseq, cseq)
        ce_sum += ce_fn(logits, ptgt).item()

        pred = logits.argmax(dim=1)
        acc_sum += (pred == ptgt).float().sum().item()

        top5 = logits.topk(5, dim=1).indices
        acc5_sum += (top5 == ptgt.unsqueeze(1)).any(dim=1).float().sum().item()

        diff_n = cpred - ctgt
        mse_n += (diff_n ** 2).sum(dim=0).detach().cpu()
        mae_n += diff_n.abs().sum(dim=0).detach().cpu()

        den = torch.stack([
            cpred[:, 0] * step_max,
            cpred[:, 1] * duration_max,
            cpred[:, 2].clamp(0, 1) * 127.0
        ], dim=1)
        den_t = torch.stack([
            ctgt[:, 0] * step_max,
            ctgt[:, 1] * duration_max,
            ctgt[:, 2].clamp(0, 1) * 127.0
        ], dim=1)

        diff_r = den - den_t
        mse_r += (diff_r ** 2).sum(dim=0).detach().cpu()
        mae_r += diff_r.abs().sum(dim=0).detach().cpu()

    ce = ce_sum / n
    ppl = math.exp(ce)
    acc = acc_sum / n
    acc5 = acc5_sum / n

    rmse_norm = (mse_n / n).sqrt().tolist()
    mae_norm = (mae_n / n).tolist()
    rmse_real = (mse_r / n).sqrt().tolist()
    mae_real = (mae_r / n).tolist()

    return {
        "pitch": {"CE": ce, "PPL": ppl, "acc": acc, "top5": acc5},
        "cont_norm": {"RMSE": rmse_norm, "MAE": mae_norm},
        "cont_real": {"RMSE_sec_vel": rmse_real, "MAE_sec_vel": mae_real}
    }


In [None]:
if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    data_dir = '../dataset/music_artist/mozart'
    print("Cargando datos...")
    pitches, cont, step_max, duration_max = load_all_data(data_dir)

    (train_p, train_c), (val_p, val_c) = temporal_split_before_window(pitches, cont, split_ratio=SPLIT)

    pitch_seqs_tr, cont_seqs_tr, pitch_tg_tr, cont_tg_tr = create_sequences(train_p, train_c, SEQ_LEN)
    pitch_seqs_val, cont_seqs_val, pitch_tg_val, cont_tg_val = create_sequences(val_p, val_c, SEQ_LEN)

    train_data = (pitch_seqs_tr, cont_seqs_tr, pitch_tg_tr, cont_tg_tr)
    val_data   = (pitch_seqs_val, cont_seqs_val, pitch_tg_val, cont_tg_val)

    model = MusicLSTMMultiOutput()
    print("Entrenando modelo...")
    train(model, train_data, val_data, epochs=30, device=device)

    print("Evaluando validación (todas las cabezas)...")
    metrics = evaluate_all(model, pitch_seqs_val, cont_seqs_val, pitch_tg_val, cont_tg_val,
                           step_max, duration_max, device=device)
    print(metrics)

    print("Generando secuencia...")
    seed_pitch = pitch_seqs_tr[0].to(device)
    seed_cont  = cont_seqs_tr[0].to(device)
    gp, gc = generate_sequence_multi(model, seed_pitch, seed_cont, length=200,
                                     temperature=TEMP, noise_std=NOISE_STD,
                                     step_max=step_max, duration_max=duration_max,
                                     device=device)
    save_midi(gp, gc, "output.mid")


Cargando datos...
Entrenando modelo...
Ep 1/30 | loss 2.8635 | acc 0.135 | vloss 2.5152 | vacc 0.197 | ppl 17.36 | vppl 12.20
Ep 2/30 | loss 2.3984 | acc 0.236 | vloss 2.3837 | vacc 0.260 | ppl 10.93 | vppl 10.70
Ep 3/30 | loss 2.2085 | acc 0.301 | vloss 2.2837 | vacc 0.314 | ppl 9.04 | vppl 9.70
Ep 4/30 | loss 2.0448 | acc 0.354 | vloss 2.2751 | vacc 0.324 | ppl 7.68 | vppl 9.60
Ep 5/30 | loss 1.8985 | acc 0.402 | vloss 2.2807 | vacc 0.326 | ppl 6.63 | vppl 9.66
Ep 6/30 | loss 1.7663 | acc 0.445 | vloss 2.3351 | vacc 0.336 | ppl 5.81 | vppl 10.20
Ep 7/30 | loss 1.6376 | acc 0.488 | vloss 2.3761 | vacc 0.334 | ppl 5.11 | vppl 10.64
Ep 8/30 | loss 1.5066 | acc 0.526 | vloss 2.4558 | vacc 0.328 | ppl 4.49 | vppl 11.52
Ep 9/30 | loss 1.3430 | acc 0.584 | vloss 2.5318 | vacc 0.327 | ppl 3.81 | vppl 12.43
Ep 10/30 | loss 1.2622 | acc 0.609 | vloss 2.5971 | vacc 0.324 | ppl 3.51 | vppl 13.28
Ep 11/30 | loss 1.1845 | acc 0.630 | vloss 2.7105 | vacc 0.321 | ppl 3.25 | vppl 14.86
Ep 12/30 | los