## SETUP

In [1]:
# Monta Drive, define rutas y carga vocab+modelo; helpers de normalización, transposición y predicción.
try:
    from google.colab import drive
    drive.mount("/content/drive")
except Exception:
    pass

import os, json, math, re, ast
from dataclasses import dataclass
from typing import List, Tuple
import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F

# --- Rutas y constantes ---
BASE_DIR   = "/content/drive/MyDrive/Colab Notebooks/TFG"
DATA_DIR   = os.path.join(BASE_DIR, "Archivos preprocesamiento")
MODELS_DIR = os.path.join(BASE_DIR, "models")
MAX_LEN = 112
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Carga vocab y config ---
with open(os.path.join(DATA_DIR, "chord_to_idx.json"), "r", encoding="utf-8") as f: CH2IDX = json.load(f)
with open(os.path.join(DATA_DIR, "idx_to_chord.json"), "r", encoding="utf-8") as f: IDX2CH = json.load(f)
PAD, UNK = CH2IDX["[PAD]"], CH2IDX["[UNK]"]
with open(os.path.join(MODELS_DIR, "config.json"), "r", encoding="utf-8") as f: CFG_JSON = json.load(f)

# --- Modelo (mismos nombres que en el entrenamiento) ---
from dataclasses import dataclass
import torch.nn as nn

@dataclass
class ModelConfig:
    vocab_size: int; pad_idx: int; unk_idx: int; max_len: int = MAX_LEN
    d_model: int = 256; n_layers: int = 4; n_heads: int = 8; d_ff: int = 1024; dropout: float = 0.1

class PositionalEmbedding(nn.Module):
    # Debe llamarse pos_emb por dentro (coincide con el checkpoint)
    def __init__(self, max_len, d_model):
        super().__init__()
        self.pos_emb = nn.Embedding(max_len, d_model)
    def forward(self, x):
        B, T = x.size()
        pos = torch.arange(T, device=x.device).unsqueeze(0).expand(B, T)
        return self.pos_emb(pos)

class CausalTransformer(nn.Module):
    # Capas con los mismos nombres: tok_emb, pos_emb, lm_head
    def __init__(self, cfg: ModelConfig):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model, padding_idx=cfg.pad_idx)
        self.pos_emb = PositionalEmbedding(cfg.max_len, cfg.d_model)
        enc = nn.TransformerEncoderLayer(d_model=cfg.d_model, nhead=cfg.n_heads,
                                         dim_feedforward=cfg.d_ff, dropout=cfg.dropout,
                                         activation="gelu", batch_first=True, norm_first=True)
        self.trf = nn.TransformerEncoder(enc, num_layers=cfg.n_layers)
        self.drop = nn.Dropout(cfg.dropout)
        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size)
        self.apply(self._init_w)
    def _causal_mask(self, T, device):
        return torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)
    def _init_w(self, m):
        if isinstance(m, (nn.Linear, nn.Embedding)):
            nn.init.normal_(m.weight, 0.0, 0.02)
        if isinstance(m, nn.Linear) and m.bias is not None:
            nn.init.zeros_(m.bias)
    def forward(self, x, attention_mask=None):
        B, T = x.shape
        h = self.drop(self.tok_emb(x) + self.pos_emb(x))
        causal = self._causal_mask(T, x.device)
        key_pad = (attention_mask == 0) if attention_mask is not None else None
        h = self.trf(h, mask=causal, src_key_padding_mask=key_pad)
        return self.lm_head(h)

# Instanciar y cargar checkpoint (ahora sí casan las claves)
CFG = ModelConfig(vocab_size=len(CH2IDX), pad_idx=PAD, unk_idx=UNK,
                  max_len=int(CFG_JSON.get("max_len", MAX_LEN)),
                  d_model=int(CFG_JSON.get("d_model", 256)),
                  n_layers=int(CFG_JSON.get("n_layers", 4)),
                  n_heads=int(CFG_JSON.get("n_heads", 8)),
                  d_ff=int(CFG_JSON.get("d_ff", 1024)),
                  dropout=float(CFG_JSON.get("dropout", 0.1)))

MODEL = CausalTransformer(CFG).to(device)
CKPT  = torch.load(os.path.join(MODELS_DIR, "checkpoint_best.pt"), map_location=device)
MODEL.load_state_dict(CKPT["model_state"], strict=True)
MODEL.eval()

# --- Helpers: normalizar notación, parsear raíz/extensión y transponer ---
NOTE_SH = ['C','C#','D','D#','E','F','F#','G','G#','A','A#','B']
VAL = {n:i for i,n in enumerate(NOTE_SH)}; VAL.update({'Db':1,'Eb':3,'Gb':6,'Ab':8,'Bb':10})
REV = {i:n for i,n in enumerate(NOTE_SH)}

def norm_ch(s:str)->str:
    # Notación compacta típica del dataset.
    return s.replace("min","m").replace("maj","M").replace("s","#").strip()

class ExtChord:
    # root + (optional) 'm' + resto (p.ej., '7', 'maj7', 'sus4'…)
    __slots__=("root","quality","ext")
    def __init__(self, c:str):
        m = re.match(r"^([A-G][b#]?)(.*)$", c)
        if not m: self.root=self.quality=self.ext=None; return
        self.root = m.group(1); rem = m.group(2)
        if rem.startswith("m"):
            self.quality="m"; self.ext=rem[1:]
        else:
            self.quality=""; self.ext=rem
    def as_str(self): return f"{self.root}{self.quality}{self.ext}"

def transpose_one(c:str, interval:int)->str:
    e = ExtChord(c); rv = VAL.get(e.root, None)
    if rv is None: return c
    new_root = REV[(rv - interval) % 12]
    return f"{new_root}{e.quality}{e.ext}"

def auto_interval_to_C(seq:List[str]):
    # Estima tonalidad media por raíces; devuelve semitonos para llevar a Do.
    roots = [VAL.get(ExtChord(c).root) for c in seq if VAL.get(ExtChord(c).root) is not None]
    if not roots: return 0
    mean = int(round(sum(roots)/len(roots))) % 12
    return mean  # intervalo para restar

# --- Helpers: indexado y máscara ---
def encode(seq:List[str]):
    ids = [CH2IDX.get(c, UNK) for c in seq][:MAX_LEN]
    if len(ids)<MAX_LEN: ids += [PAD]*(MAX_LEN-len(ids))
    return ids

def attn_mask(ids:List[int]): return [1 if t!=PAD else 0 for t in ids]

# --- Predicción top-k con auto-transposición a Do y temperatura ---
@torch.no_grad()
def predict_topk_original_key(seq_orig:List[str], k:int=10, temperature:float=1.0)->List[Tuple[str,float]]:
    # 1) Normaliza y estima tonalidad; transpone a Do para modelo.
    seq_norm = [norm_ch(x) for x in seq_orig]
    interval = auto_interval_to_C(seq_norm)
    seq_C = [transpose_one(x, interval) for x in seq_norm]
    # 2) Codifica y prepara máscara.
    ids = encode(seq_C); msk = attn_mask(ids)
    x = torch.tensor([ids], dtype=torch.long, device=device)
    a = torch.tensor([msk], dtype=torch.long, device=device)
    # 3) Forward y extracción en la última posición válida.
    T = int(sum(msk)) if any(msk) else 1
    logits = MODEL(x, attention_mask=a)[0, T-1, :]
    if temperature>0: logits = logits / float(temperature)
    probs = F.softmax(logits, dim=-1)
    top = torch.topk(probs, k=min(k, probs.numel()))
    idxs = top.indices.tolist(); pr = [float(v) for v in top.values.tolist()]
    chords_C = [IDX2CH.get(str(i), f"<{i}>") for i in idxs]
    # 4) Des-transpone a tonalidad original para mostrar.
    chords_orig = [transpose_one(c, -interval) for c in chords_C]
    return list(zip(chords_orig, pr))

Mounted at /content/drive




## UI Gradio

In [2]:
# Celda 2 — Gradio (layout 2x5 estable, "Secuencia construida" con |, C mayor por defecto, clave fija)
import gradio as gr, re, torch

# ---------- UI helpers ----------
ROOT_DISPLAY = [
    ("C (Do)","C"), ("C# (Do#)","C#"), ("Db (Reb)","Db"),
    ("D (Re)","D"), ("D# (Re#)","D#"), ("Eb (Mib)","Eb"),
    ("E (Mi)","E"), ("F (Fa)","F"), ("F# (Fa#)","F#"), ("Gb (Solb)","Gb"),
    ("G (Sol)","G"), ("G# (Sol#)","G#"), ("Ab (Lab)","Ab"),
    ("A (La)","A"), ("A# (La#)","A#"), ("Bb (Sib)","Bb"),
    ("B (Si)","B")
]
VARIANTS = [("mayor",""), ("menor","m"), ("7","7"), ("maj7","maj7"),
            ("sus2","sus2"), ("sus4","sus4"), ("add9","add9"), ("6","6")]

def render_seq_bar(seq):      # barra compacta con separador |
    return " | ".join(seq) if seq else ""

def medals_topk(chords_probs):
    medals = ["🥇","🥈","🥉"] + [""]*7
    out = []
    for i,(c,p) in enumerate(chords_probs[:10]):
        out.append(f"{medals[i]} {c} ({p*100:.2f}%)".strip())
    return (out + [""]*10)[:10]

def fanout_button_labels(labels10):  # 10 -> 10 updates
    labels10 = (labels10 + [""]*10)[:10]
    return [gr.update(value=l) for l in labels10]

# ---------- Predicción con clave fija ----------
@torch.no_grad()
def predict_topk_with_interval(seq_orig, k:int=10, temperature:float=1.0, interval:int=None):
    # Usa los helpers cargados en Celda 1: norm_ch, transpose_one, encode, attn_mask, MODEL, IDX2CH
    seq_norm = [norm_ch(x) for x in seq_orig]
    use_interval = auto_interval_to_C(seq_norm) if interval is None else int(interval)
    seq_C = [transpose_one(x, use_interval) for x in seq_norm]
    ids = encode(seq_C); msk = attn_mask(ids)
    x = torch.tensor([ids], dtype=torch.long, device=device)
    a = torch.tensor([msk], dtype=torch.long, device=device)
    T = int(sum(msk)) if any(msk) else 1
    logits = MODEL(x, attention_mask=a)[0, T-1, :]
    if temperature > 0: logits = logits / float(temperature)
    probs = torch.softmax(logits, dim=-1)
    top = torch.topk(probs, k=min(k, probs.numel()))
    idxs = top.indices.tolist(); pr = [float(v) for v in top.values.tolist()]
    chords_C = [IDX2CH.get(str(i), f"<{i}>") for i in idxs]
    chords_orig = [transpose_one(c, -use_interval) for c in chords_C]
    return list(zip(chords_orig, pr))

def labels_topk(seq, temp, key_interval):
    top = predict_topk_with_interval(seq, k=10, temperature=temp, interval=key_interval)
    return medals_topk(top), top

# ---------- Acciones ----------
def add_by_dropdown(seq, key_interval, root_label, var_label, temp):
    root = dict(ROOT_DISPLAY)[root_label]
    suf  = dict(VARIANTS)[var_label]
    chord = f"{root}{suf}"
    seq2 = seq + [chord]
    # fija clave si no existe todavía: raíz del primer acorde
    key2 = VAL[ExtChord(chord).root] if key_interval is None else key_interval
    labels, top = labels_topk(seq2, temp, key2)
    return [seq2, key2, render_seq_bar(seq2), *fanout_button_labels(labels), top]

def add_from_top(seq, key_interval, top, which, temp):
    if not top or which >= len(top):
        labels, top2 = labels_topk(seq, temp, key_interval) if seq else ([""]*10, [])
        return [seq, key_interval, render_seq_bar(seq), *fanout_button_labels(labels), top2]
    chosen = top[which][0]
    seq2 = seq + [chosen]
    key2 = key_interval
    if key2 is None:  # si es el primer acorde añadido desde top-10, fija clave con su raíz
        key2 = VAL.get(ExtChord(chosen).root, 0)
    labels, top2 = labels_topk(seq2, temp, key2)
    return [seq2, key2, render_seq_bar(seq2), *fanout_button_labels(labels), top2]

def pop_one(seq, key_interval, temp):
    seq2 = seq[:-1] if seq else []
    key2 = None if len(seq2) == 0 else key_interval
    labels, top = labels_topk(seq2, temp, key2) if seq2 else ([""]*10, [])
    return [seq2, key2, render_seq_bar(seq2), *fanout_button_labels(labels), top]

def reset_all():
    return [[], None, "", *fanout_button_labels([""]*10), []]

def on_temp_change(seq, key_interval, t):
    if not seq:
        return [*fanout_button_labels([""]*10), []]
    labels, top = labels_topk(seq, t, key_interval)
    return [*fanout_button_labels(labels), top]

def init_buttons():
    # si quieres, puedes dejarlo vacío; mantenemos sugerencias base con seq vacía y clave indefinida
    labels, top = labels_topk([], 1.0, None)
    return [*fanout_button_labels(labels), top]

# ---------- Construcción de la UI ----------
with gr.Blocks(title="API_PREDICCIÓN - Gradio") as demo:
    # CSS para grid 2x5 estable (no 4-1-4-1)
    gr.HTML("""
    <style>
    #row_top1, #row_top2 { display: grid; grid-template-columns: repeat(5, 1fr); gap: 12px; }
    #row_top1 button, #row_top2 button { width: 100%; }
    </style>
    """)
    gr.Markdown("""
### 🧩 Modelo interactivo - Predicción de secuencias de acordes
- Selecciona la **raíz** y **variante** del primer acorde.
- Añade acordes a la secuencia utilizando los **botones de la derecha**.
- **Borrar** elimina el último; **Reset** reinicia.
""")

    with gr.Row():
        # Izquierda: selects + acciones
        with gr.Column(scale=6):
            root_dd = gr.Dropdown([r for r,_ in ROOT_DISPLAY], value="C (Do)", label="Raíz")
            var_dd  = gr.Dropdown([v for v,_ in VARIANTS],   value="mayor",  label="Variante")
            with gr.Row():
                btn_add = gr.Button("Añadir", variant="primary")
                btn_pop = gr.Button("Borrar")
                btn_rst = gr.Button("Reset")
        # Derecha: Top-10 (2x5) + barra "Secuencia construida"
        with gr.Column(scale=6):
            gr.Markdown("#### Añadir")
            with gr.Row(elem_id="row_top1"):
                btns_row1 = [gr.Button("", scale=1) for _ in range(5)]
            with gr.Row(elem_id="row_top2"):
                btns_row2 = [gr.Button("", scale=1) for _ in range(5)]
            btns = btns_row1 + btns_row2
            seq_bar = gr.Textbox(value="", label="Secuencia construida", interactive=False)

    # Slider de temperatura (debajo)
    with gr.Row():
        temp = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="Temperature")

    # Estados
    seq_state = gr.State([])     # acordes en tonalidad original
    key_state = gr.State(None)   # intervalo fijo (se fija al primer acorde); None = sin fijar
    top_state = gr.State([])     # [(ch, prob), ...]

    # ---- Enlaces de eventos ----
    demo.load(init_buttons, inputs=None, outputs=[*btns, top_state])

    btn_add.click(add_by_dropdown,
                  inputs=[seq_state, key_state, root_dd, var_dd, temp],
                  outputs=[seq_state, key_state, seq_bar, *btns, top_state])

    btn_pop.click(lambda s,k,t: pop_one(s,k,t),
                  inputs=[seq_state, key_state, temp],
                  outputs=[seq_state, key_state, seq_bar, *btns, top_state])

    btn_rst.click(lambda: reset_all(),
                  inputs=None,
                  outputs=[seq_state, key_state, seq_bar, *btns, top_state])

    for i, b in enumerate(btns):
        b.click(lambda s,k,top,t,i=i: add_from_top(s, k, top, i, t),
                inputs=[seq_state, key_state, top_state, temp],
                outputs=[seq_state, key_state, seq_bar, *btns, top_state])

    temp.change(on_temp_change,
                inputs=[seq_state, key_state, temp],
                outputs=[*btns, top_state])

demo.launch(share=False)


Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Note: opening Chrome Inspector may crash demo inside Colab notebooks.
* To create a public link, set `share=True` in `launch()`.


<IPython.core.display.Javascript object>

