
# Fine-tuning LoRA: Slides description



## 1  Instalación de dependencias

In [1]:
! pip install backoff transformers==4.49 peft
! pip install sacrebleu unbabel-comet



In [2]:
from google.colab import drive
import os

# Montar Google Drive
drive.mount('/content/drive')

# Rutas de los archivos zip en Google Drive
test_zip_path = '/content/drive/MyDrive/test.zip'
dev_zip_path = '/content/drive/MyDrive/dev.zip'
test_lst = '/content/drive/MyDrive/test.lst'
dev_lst = '/content/drive/MyDrive/dev.lst'

# Directorio de destino en Colab
colab_content_path = '/content/'

# Copiar los archivos zip a Colab
!cp "{test_zip_path}" "{colab_content_path}"
!cp "{dev_zip_path}" "{colab_content_path}"
!cp "{test_lst}" "{colab_content_path}"
!cp "{dev_lst}" "{colab_content_path}"

print(f"Archivos copiados a {colab_content_path}")
print(os.listdir(colab_content_path))

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Archivos copiados a /content/
['.config', 'dev.zip', 'test.lst', 'drive', 'dev.lst', 'test.zip', 'sample_data']


In [3]:
!unzip /content/test.zip -d /content/test
!unzip /content/dev.zip -d /content/dev

Archive:  /content/test.zip
   creating: /content/test/404/
   creating: /content/test/545006/
   creating: /content/test/596001/
   creating: /content/test/605000/
   creating: /content/test/606/
  inflating: /content/test/test.lst  
   creating: /content/test/606/manual_transcription/
   creating: /content/test/606/manual_translations/
  inflating: /content/test/606/slides.pptx  
  inflating: /content/test/606/606.m4a  
   creating: /content/test/606/manual_transcription/sentence_segmented/
  inflating: /content/test/606/manual_transcription/ie606.srt  
  inflating: /content/test/606/manual_transcription/ie606.txt  
  inflating: /content/test/606/manual_transcription/sentence_segmented/ie606.srt  
   creating: /content/test/606/manual_translations/de/
   creating: /content/test/606/manual_translations/es/
   creating: /content/test/606/manual_translations/fr/
   creating: /content/test/606/manual_translations/sl/
  inflating: /content/test/606/manual_translations/fr/606.lst  
  infla

In [1]:
# load dict
import pickle
with open('/content/drive/MyDrive/images_descriptions.pkl', 'rb') as f:
    DESCRIPTIONS = pickle.load(f)
DESCRIPTIONS

{'404': ['1) Título/tema: La diapositiva se titula "Welcome!" y se trata de una sesión de aprendizaje en línea que comienza a las 18:15 horas CET.\n2) 2–3 ideas clave: La diapositiva informa que la sesión es CME acreditada y que al final de la presentación, los participantes serán dirigidos a la evaluación y prueba de opción múltiple.\n3) Elementos visuales: La diapositiva contiene un logotipo en la esquina superior derecha, un cuadro de diálogo en la parte inferior con opciones de "Yes" y "No" para unirse a la conferencia, y un hashtag "#e_ESO" en la parte inferior.',
  '1) Título/tema: La diapositiva se titula "e-sesion 404" y discute el impacto de la cirugía oncológica en los resultados.\n2) 2–3 ideas clave: La diapositiva presenta a dos expertos, Prof. Charles M. Balch y Prof. Riccardo Audisio, que discuten este tema. También menciona que el contenido de la diapositiva se extrajo de la política de e-ESO, que proporciona nuevas habilidades y conocimientos para oncólogos y otros médi

## 2  Configuración principal

In [2]:
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig

import torch

model_name = 'microsoft/Phi-4-multimodal-instruct'
processor = AutoProcessor.from_pretrained(model_name,trust_remote_code = True)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="cuda",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    # flash_attention_2 eager sdpa
    _attn_implementation='sdpa',
)

generation_config = GenerationConfig.from_pretrained(model_name)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
  lambda i: encoder_checkpoint_wrapper(


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

### 3.1  LoRA

In [3]:

from peft import LoraConfig, get_peft_model


lora_config = LoraConfig(
    r=8,                         # rango LoRA
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules=["q_proj","k_proj"]
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()


trainable params: 995,328 || all params: 5,575,455,552 || trainable%: 0.0179




### 3.2 Data

In [4]:
from pathlib import Path
import subprocess, soundfile as sf, datetime as dt

FFMPEG = "ffmpeg"          # o la ruta absoluta si no está en PATH

def m4a_to_wav(path_m4a, sr_out=16_000):
    """Convierte 1 × .m4a → .wav (mono, 16 kHz) y devuelve la ruta del WAV."""
    path_m4a = Path(path_m4a)
    wav_path = path_m4a.with_suffix(".wav")
    subprocess.run(
        [FFMPEG, "-loglevel", "error", "-y", "-i", str(path_m4a),
         "-ac", "1", "-ar", str(sr_out), str(wav_path)],
        check=True
    )
    return wav_path



def get_files(folder='test'):
  with open(f'/content/{folder}/{folder}.lst', 'r', encoding='utf-8') as f:
    files = f.read()
    files = files.splitlines()
    print(files)
    return files



In [5]:
from pathlib import Path
from typing import List, Dict, Tuple, Optional

def _ms_to_seconds(hms: str) -> float:
    m, s, _ = hms.split(":")
    return int(m)*60 + int(s)

def parse_slide_blocks(path: str | Path) -> Dict[str, Dict]:
    """
    Lee un archivo con múltiples bloques:
      <ID> <URL>
      <slide_id|None> <HH:MM:SS>
      ...
    Devuelve dict: { id: {"url": str, "timeline": [(t_sec, slide_id|None), ...]} }
    """
    p = Path(path)
    blocks: Dict[str, Dict] = {}
    cur_id = None
    for line in p.read_text().splitlines():
        line = line.strip()
        if not line:
            continue

        parts = line.split()
        # Cabecera: "<ID> <URL>"
        if len(parts) == 2 and parts[1].startswith("http"):
            cur_id = parts[0]
            blocks[cur_id] = []
            continue

        # Línea de cambio: "<slide_id|None> <HH:MM:SS>"
        if cur_id is not None and len(parts) == 2 and ":" in parts[1]:
            slide_raw, hms = parts
            slide_id = None if slide_raw == "None" else int(slide_raw)
            t = _ms_to_seconds(hms)
            blocks[cur_id].append((t, slide_id))

    # Ordena por tiempo y asegura que hay un evento en t=0
    for k,v in blocks.items():
        v.sort(key=lambda x: x[0])

    return blocks

def build_slide_intervals(timeline: List[Tuple[float, Optional[int]]]) -> List[Dict]:
    """
    Convierte una timeline en intervalos contiguos: [{"start":t0,"end":t1,"slide":id}, ...]
    El último intervalo termina en +inf (None).
    """
    out = []
    for i, (t, slide) in enumerate(timeline):
        t_next = timeline[i+1][0] if i+1 < len(timeline) else None
        out.append({"start": t, "end": t_next, "slide": slide})
    return out


In [6]:
def slide_at(time_s: float, intervals: List[Dict]) -> Optional[int]:
    """Devuelve la slide activa en 'time_s'."""
    # búsqueda lineal (puedes optimizar con bisect si quieres)
    for seg in intervals:
        if seg["end"] is None:  # último intervalo
            return seg["slide"]
        if seg["start"] <= time_s < seg["end"]:
            return seg["slide"]
    return None

from typing import List, Optional

def slides_in_interval(
    start: float,
    end: float,
    intervals: List[dict],
    *,
    min_fraction: float = 0.10,   # % mínimo del segmento que debe cubrir la slide
    min_seconds: float = 0.0,     # y/o segundos absolutos mínimos de cobertura
    slack: float = 0.25,          # margen (±s) para tolerar pequeñas imprecisiones
    sort_by_coverage: bool = True,
    top_k: Optional[int] = 1
) -> List[int]:
    """
    Devuelve las slides que realmente 'cuentan' en [start, end), filtrando
    las que cubren menos de min_fraction del segmento o menos de min_seconds.

    - 'slack' expande el segmento a [start-slack, end+slack] para evitar cortar
      por milésimas (timings manuales, FPS, etc.).
    - Si sort_by_coverage=True, ordena por mayor cobertura; si no, por aparición.
    - 'top_k' limita el nº de slides devueltas (tras el filtro).
    """
    seg_len = max(1e-6, end - start)  # evita división por cero
    seg_lo = start - slack
    seg_hi = end + slack

    cover_by_slide = {}   # slide_id -> (cover_seconds, first_order)

    order = 0
    for seg in intervals:
        slide = seg["slide"]
        if slide is None:
            continue

        s_lo = seg["start"]
        s_hi = float("inf") if seg["end"] is None else seg["end"]

        # intersección con margen
        lo = max(seg_lo, s_lo)
        hi = min(seg_hi, s_hi)
        overlap = max(0.0, hi - lo)
        if overlap <= 0:
            continue

        if slide not in cover_by_slide:
            cover_by_slide[slide] = [0.0, order]
            order += 1
        cover_by_slide[slide][0] += overlap

    # Filtro por cobertura
    out = []
    for slide, (cover_s, first_order) in cover_by_slide.items():
        frac = cover_s / seg_len
        if cover_s >= min_seconds or frac >= min_fraction:
            out.append((slide, cover_s, first_order))

    if not out:
        return []

    # Orden
    if sort_by_coverage:
        out.sort(key=lambda x: (-x[1], x[2]))   # más cobertura primero
    else:
        out.sort(key=lambda x: x[2])            # por aparición

    if top_k is not None:
        out = out[:top_k]

    return [slide for slide, _, _ in out]


In [7]:
from pathlib import Path
import datasets as ds


def build_dataset(folder="train", language="es", files=None,
                  dpi=150, target_size=(1600,900), mode="letterbox",
                  bg="white", store_slide_paths=False):
    rows = []
    files = get_files(folder) if files is None else files

    # Carga y parsea el maestro (una vez)
    blocks = parse_slide_blocks(f'{folder}.lst')

    for name in files:
        base = Path(f"/content/{folder}/{name}/manual_translations/{language}")
        with open(base / f"{name}.lst") as f:
            times = [(float(a), float(b)) for a,b in (ln.split() for ln in f.read().splitlines())]
        with open(base / f"{name}.en") as f:
            src = f.read().splitlines()
        with open(base / f"{name}.{language}") as f:
            tgt = f.read().splitlines()
        assert len(times) == len(src) == len(tgt)


        # 2) Audio (como ya hacías)
        wav_path = str(m4a_to_wav(f"/content/{folder}/{name}/{name}.m4a"))

        if str(name) in blocks:
            intervals = build_slide_intervals(blocks[str(name)])
        else:
            intervals = [{"start":0.0,"end":None,"slide":None}]
            print(f"[WARN] No hay bloque de slides para {name}")
        if str(name) in DESCRIPTIONS:
            desc = DESCRIPTIONS[str(name)]
        else:
            desc = None
            print(f"[WARN] No hay descripcion para {name}")


        # 3) Crea filas por segmento reusando las mismas referencias a slides
        for (s, e), txt_in, txt_out in zip(times, src, tgt):
            slides  = slides_in_interval(s, e, intervals)
            row = {
                "wav_path": wav_path,
                "start": s,
                "end": e,
                "source_text": txt_in,
                "target_text": txt_out,
                "slide_text": desc[slides[0]] if slides and desc else None ,

            }
            rows.append(row)

    return ds.Dataset.from_list(rows)

train_ds = build_dataset("dev", files=['500011', '624000', '609'])
val_ds   = build_dataset("dev", files=['550000','592'])
test_ds   = build_dataset("test")

['404', '596001', '605000', '606', '545006']


In [8]:
print(train_ds)
print(val_ds)
print(test_ds)

Dataset({
    features: ['wav_path', 'start', 'end', 'source_text', 'target_text', 'slide_text'],
    num_rows: 708
})
Dataset({
    features: ['wav_path', 'start', 'end', 'source_text', 'target_text', 'slide_text'],
    num_rows: 742
})
Dataset({
    features: ['wav_path', 'start', 'end', 'source_text', 'target_text', 'slide_text'],
    num_rows: 1405
})


In [None]:
def build_prompt_from_slide_texts(slide_text):
    """
    slide_texts: lista de strings (resúmenes de slide).
    """

    if slide_text:
        slide_text=slide_text.strip()

        if use_slide_context:
            content = (
                "<|audio_1|>\n"
                "You are a teaching assistant. Use the following SLIDE TEXT "
                "only as CONTEXT to disambiguate proper names, acronyms, and technical terms in the audio.\n"
                "PRIORITY: the spoken audio content.\n\n"
                "RULES:\n"
                "- Do not copy or paraphrase text from the context if it does not appear in the audio.\n"
                "- If there is a conflict, prioritize the audio.\n"
                "- Preserve numbers, units, symbols, and proper names.\n"
                "- If something is unintelligible, write [inaudible].\n"
                "[SLIDE-CONTEXT]\n"
                f"{slide_text}\n"
                "[/SLIDE-CONTEXT]\n\n"
                "Task: translate into Spanish (es-ES) only what is said in the audio.\n"
                "Output format: ONLY the translation, on a single line, with no labels or comments."
            )
        else:
            content = (
                "<|audio_1|>\n"
                "You are a teaching assistant.\n"
                "Task: translate into Spanish (es-ES) only what is said in the audio."
            )


    return processor.tokenizer.apply_chat_template(
        [{"role":"user","content": content}],
        tokenize=False, add_generation_prompt=True
    )




In [10]:
import torch, soundfile as sf
from PIL import Image

class CollatorPhi4Audio:
    def __init__(self, processor, answer_suffix="<|end|><|endoftext|>",
                 sr=16_000, ignore_idx=-100, cache=True,
                 force_image=True, dummy_size=(1600, 900), dummy_color="white"):
        self.proc = processor
        self.sr = sr
        self.pad_id = processor.tokenizer.pad_token_id
        self.suffix = answer_suffix
        self.ignore = ignore_idx
        self.cache_fd = {} if cache else None
        # imagen dummy para filas sin slide
        self.force_image = force_image
        self.dummy_image = Image.new("RGB", dummy_size, dummy_color)


    # No reabrir el WAV completo en cada segmento
    def _read_segment(self, path, start, end):
        if self.cache_fd is not None:
            if path not in self.cache_fd:
                self.cache_fd[path] = sf.SoundFile(path)
            f = self.cache_fd[path]
        else:
            f = sf.SoundFile(path)
        frames = int((end - start) * self.sr)
        f.seek(int(start * self.sr))
        audio = f.read(frames, dtype="float32")
        return audio


    def __call__(self, batch):
        prompts, answers, wavs = [], [], []

        for ex in batch:
            prompt = build_prompt_from_slide_texts(ex.get("slide_text"))
            prompts.append(prompt)
            answers.append(ex["target_text"] + self.suffix)

            wavs.append((self._read_segment(ex["wav_path"], ex["start"], ex["end"]), self.sr))


        # ---- Texto (prompt + respuesta) ---------------------------------
        tok_in  = self.proc.tokenizer(prompts,  return_tensors="pt", padding=True)
        tok_ans = self.proc.tokenizer(answers, return_tensors="pt", padding=True)

        input_ids = torch.cat([tok_in.input_ids, tok_ans.input_ids], dim=1)
        attn_mask = (input_ids != self.pad_id).long()

        labels = torch.full_like(input_ids, self.ignore)
        ans_len = tok_ans.input_ids.size(1)
        labels[:, -ans_len:] = tok_ans.input_ids

        # (Opcional) enmascarar CoT si lo usas en answers y supervise_cot=False
        # ... (puedes pegar aquí el bloque de máscara de CoT adaptado a tu formato)

        # ---- Embeddings multimodales (UNA sola llamada) -----------------
        kwargs = dict(text=prompts, audios=wavs, return_tensors="pt", padding=True)


        feats = self.proc(**kwargs)  # genera input_audio_embeds / input_image_embeds

        batch_out = {
            "input_ids":            input_ids,
            "labels":               labels,
            "attention_mask":       attn_mask,
            "input_audio_embeds":   feats["input_audio_embeds"],
            "audio_attention_mask": feats.get("audio_attention_mask"),
            "input_mode":           torch.ones(len(batch), dtype=torch.long) * 3 ,
        }
        return batch_out


In [11]:
train_dataloader = torch.utils.data.DataLoader(
    train_ds,
    shuffle=True,
    batch_size=4,
    collate_fn=CollatorPhi4Audio(processor),
    pin_memory=True,
)
val_dataloader = torch.utils.data.DataLoader(
    val_ds,
    shuffle=False,
    batch_size=4,
    collate_fn=CollatorPhi4Audio(processor),
    pin_memory=True,
)

## 3  Entrenamiento

### 3.1 Helpers

In [12]:
import math, os, gc, torch
from tqdm import tqdm
from typing import Dict, List, Tuple

def move_batch(batch, device):
    return {
        k: v.to(device, dtype=torch.bfloat16 if v.dtype == torch.float else torch.long)
        for k, v in batch.items()
    }

@torch.no_grad()
def avg_loss_on_loader(model, dataloader, device):
    model.eval()
    losses, n = 0.0, 0
    for batch in dataloader:
        batch = move_batch(batch, device)
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            out = model(**batch)
        losses += out.loss.item() * batch[list(batch.keys())[0]].size(0)
        n += batch[list(batch.keys())[0]].size(0)
    model.train()
    return losses / n


### 3.2 Eval Definition

In [12]:
import soundfile as sf
from collections import defaultdict

_open_wavs = {}                  # caché: path → SoundFile

def read_segment(path, start, end, sr=16_000):
    """Lee [start,end] s de un WAV grande sin cargarlo entero."""
    if path not in _open_wavs:
        _open_wavs[path] = sf.SoundFile(path)
    f = _open_wavs[path]
    frames = int((end - start) * sr)
    f.seek(int(start * sr))
    return f.read(frames, dtype="float32")
def segments_by_audio(ds):
    groups = defaultdict(list)
    for ex in ds:
        groups[ex["wav_path"]].append(ex)
    # ordenamos cada audio por inicio temporal
    for segs in groups.values():
        segs.sort(key=lambda x: x["start"])
    return groups


In [13]:
val_groups  = segments_by_audio(val_ds)

In [14]:
from collections import defaultdict
from itertools import chain
import torch, gc
from tqdm import tqdm

device        = torch.device("cuda:0")
batch_size    = 6




def generate_output(groups):
    hyps_all, refs_all, srcs_all = [], [], []

    for wav_path, segs in tqdm(groups.items(), desc="Audios"):
        hyps, refs, srcs = [], [], []

        # dividimos en lotes (batch_size) para GPU
        for i in range(0, len(segs), batch_size):
            chunk = segs[i:i+batch_size]



            wavs   = [(read_segment(s["wav_path"], s["start"], s["end"]),16_000) for s in chunk]
            prompts= [build_prompt_from_slide_texts(s["slide_text"]) for s in chunk]


            inputs = processor(text=prompts, audios=wavs,
                               return_tensors="pt", padding=True
                               ).to(device)

            with torch.inference_mode(), torch.autocast("cuda"):
                gen_ids = model.generate(**inputs, generation_config=generation_config, max_new_tokens=1000, num_logits_to_keep=0)

            # quitamos el prompt textual
            gen_ids = gen_ids[:, inputs["input_ids"].shape[1]:]

            hyps.extend(
                processor.batch_decode(gen_ids.cpu(),
                                       skip_special_tokens=True,
                                       clean_up_tokenization_spaces=False)
            )

            # limpieza VRAM
            del inputs, gen_ids, prompts, wavs
            torch.cuda.empty_cache(); gc.collect()

            refs.extend([s["target_text"] for s in chunk])
            srcs.extend([s["source_text"] for s in chunk])

        print(hyps)
        print(refs)
        print(srcs)


        hyps_all.append(hyps)
        refs_all.append(refs)
        srcs_all.append(srcs)

    return hyps_all, refs_all, srcs_all

In [15]:
from __future__ import annotations

from itertools import chain
from typing import Callable, Dict, List, Tuple

import sacrebleu
from comet import download_model, load_from_checkpoint

# ────────────────────────────────────────────────────────────────────────────────
#  Cargamos (y cacheamos) el modelo COMET‑22 una sola vez
# ────────────────────────────────────────────────────────────────────────────────
_COMET_MODEL = None
_COMET_PATH = None
_MODEL_NAME = "Unbabel/wmt22-comet-da"


def _get_comet_model(gpus: int = 0):
    """Devuelve un modelo COMET‑22 listo para `predict` (cacheado)."""
    global _COMET_MODEL, _COMET_PATH

    if _COMET_MODEL is None:
        _COMET_PATH = download_model(_MODEL_NAME)  # se guarda en ~/.cache
        _COMET_MODEL = load_from_checkpoint(_COMET_PATH)

    # *Nota*: el parámetro `gpus` se pasa a `.predict` y **no** aquí, pero
    # exponemos el arg para quien quiera forzar CPU en la firma pública.
    return _COMET_MODEL


# ────────────────────────────────────────────────────────────────────────────────
#  Función principal
# ────────────────────────────────────────────────────────────────────────────────

def bleu_comet_by_audio(
    refs_audio: List[List[str]],
    hyps_audio: List[List[str]],
    srcs_audio: List[List[str]],
    transform: Callable[[str], str] = lambda x: x,
    comet_gpus: int = 0,
    comet_batch_size: int = 8,
) -> Tuple[Dict[str, float], List[Dict[str, float]]]:
    """Calcula BLEU y COMET‑22, **global** y **por audio**.

    Parameters
    ----------
    refs_audio, hyps_audio, srcs_audio : list[list[str]]
        Listas anidadas con el mismo nº de audios y segmentos.
    transform : callable
        Función de normalización por frase (identidad por defecto).
    comet_gpus : int
        Nº de GPUs a usar en `model.predict` (0 ⇒ CPU).
    comet_batch_size : int
        Tamaño de lote para COMET (trade‑off velocidad / memoria).

    Returns
    -------
    (global_metrics, per_audio_metrics)
        global_metrics    = {"bleu": float, "comet22": float}
        per_audio_metrics = [
            {"audio_id": i, "bleu": float, "comet22": float},
            ...
        ]
    """

    # ── Comprobaciones básicas ──────────────────────────────────────────────
    assert len(refs_audio) == len(hyps_audio) == len(srcs_audio), (
        "refs, hyps y srcs deben tener la misma longitud (nº audios)"
    )

    per_audio: List[Dict[str, float]] = []

    # ── Recorremos audio por audio ──────────────────────────────────────────
    comet_model = _get_comet_model(gpus=comet_gpus)

    for idx, (ref_seg, hyp_seg, src_seg) in enumerate(
        zip(refs_audio, hyps_audio, srcs_audio)
    ):
        assert len(ref_seg) == len(hyp_seg) == len(src_seg), (
            f"El audio {idx} contiene diferente nº de segmentos"
        )

        ref_seg = [transform(r) for r in ref_seg]
        hyp_seg = [transform(h) for h in hyp_seg]
        src_seg = [transform(s) for s in src_seg]

        # ── BLEU corpus‑level para el audio ────────────────────────────────
        bleu_score = sacrebleu.corpus_bleu(hyp_seg, [ref_seg]).score

        # ── COMET‑22 ───────────────────────────────────────────────────────
        samples = [  # una entrada por segmento
            {"src": s, "mt": h, "ref": r}
            for s, h, r in zip(src_seg, hyp_seg, ref_seg)
        ]
        comet_out = comet_model.predict(
            samples,
            batch_size=comet_batch_size,
            gpus=comet_gpus,
            progress_bar=False,
        )
        comet_score = comet_out["system_score"]  # media ya calculada

        per_audio.append(
            {
                "audio_id": idx,
                "bleu": bleu_score,
                "comet22": comet_score,
            }
        )

    # ── Métricas globales ──────────────────────────────────────────────────
    refs_all = list(chain.from_iterable(refs_audio))
    hyps_all = list(chain.from_iterable(hyps_audio))
    srcs_all = list(chain.from_iterable(srcs_audio))

    bleu_global = sacrebleu.corpus_bleu(hyps_all, [refs_all]).score

    comet_samples = [
        {"src": s, "mt": h, "ref": r}
        for s, h, r in zip(srcs_all, hyps_all, refs_all)
    ]
    comet_global = comet_model.predict(
        comet_samples,
        batch_size=comet_batch_size,
        gpus=comet_gpus,
        progress_bar=False,
    )["system_score"]

    return {"bleu": bleu_global, "comet22": comet_global}, per_audio





In [16]:
def eval_bleu_comet_fn_val(groups_val, processor, model, generation_config, device,
                           batch_size=8, transform=lambda x: x):
    # 1) Generar hipótesis
    hyps, refs, srcs = generate_output(groups_val)  # usa tu función tal cual

    # 2) Calcular métricas
    global_metrics, per_audio = bleu_comet_by_audio(
        refs_audio=refs,
        hyps_audio=hyps,
        srcs_audio=srcs,
        transform=transform,
        comet_gpus=1,        # o 0 si prefieres CPU
        comet_batch_size=8
    )
    return global_metrics, per_audio


### 3.3 Train Val

In [22]:


def train(
    model,
    train_loader,
    val_loader,
    optimizer,
    device,
    acc_steps=2,
    max_epochs=5,
    clip_grad=1.0,
    scheduler=None,
    scheduler_type=None,
    ckpt_dir="checkpoints",
    patience=3,                       # early stopping
    eval_bleu_comet_fn=None,          # función que calcule métricas de traducción (ver sección 3)
    eval_bleu_args=None,              # dict opcional para pasar a esa función
    verbose_every=20
):
    os.makedirs(ckpt_dir, exist_ok=True)
    best_metric = -math.inf
    no_improve = 0

    model.to(device)
    model.train()
    optimizer.zero_grad(set_to_none=True)

    scaler = torch.cuda.amp.GradScaler(enabled=False)  # bfloat16 no necesita scaler; activarlo si usas fp16
    best_name=''
    for epoch in range(1, max_epochs+1):
        print(f"\nEpoch {epoch}/{max_epochs}")
        running = 0.0
        model.train()

        for step, batch in enumerate(tqdm(train_loader)):
            batch = move_batch(batch, device)

            with torch.autocast("cuda", dtype=torch.bfloat16):
                outputs = model(**batch)
                loss = outputs.loss / acc_steps

            loss.backward()

            running += loss.item() * acc_steps

            if (step + 1) % acc_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)

            if step % verbose_every == 0:
                print(f" step {step:<4} | loss = {loss.item() * acc_steps:.4f}")

        # ─── VALIDACIÓN ───────────────────────────────────────────
        model.eval()

        # Temporarily disable gradient checkpointing for evaluation
        for module in model.modules():
            if hasattr(module, 'gradient_checkpointing'):
                module.gradient_checkpointing = False

        val_loss = avg_loss_on_loader(model, val_loader, device)
        print(f"  -> val_loss: {val_loss:.4f}")

        # Métricas de traducción opcionales (BLEU/COMET)
        extra_metrics = {}
        if eval_bleu_comet_fn is not None:
            metrics_global, metrics_per_audio = eval_bleu_comet_fn(**(eval_bleu_args or {}))
            print(f"  -> BLEU: {metrics_global['bleu']:.2f} | COMET22: {metrics_global['comet22']:.3f}")
            extra_metrics = metrics_global
            current_key_metric = metrics_global["comet22"]  # decide cuál usar para “mejor modelo”
        else:
            current_key_metric = -val_loss  # si no hay métricas externas, usa la loss

        # Re-enable gradient checkpointing
        for module in model.modules():
            if hasattr(module, 'gradient_checkpointing') and eval_bleu_comet_fn is not None: # Only re-enable if it was originally True and we used the eval function
                 module.gradient_checkpointing = True


         # ---------- Scheduler step ----------
        if scheduler is not None:
            if scheduler_type == "plateau":
                scheduler.step(val_loss)               # en función de la loss
            else:
                scheduler.step()
        # ─── Checkpoint / Early stopping ─────────────────────────
        if current_key_metric > best_metric:
            best_metric = current_key_metric
            no_improve = 0
            path = os.path.join(ckpt_dir, f"best_epoch{epoch}.pt")
            torch.save({"model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "epoch": epoch,
                        "val_loss": val_loss,
                        **extra_metrics}, path)
            best_name=path
            print(f"  ✔ Nuevo mejor modelo guardado en {path}")
        else:
            no_improve += 1
            print(f"  (sin mejora {no_improve}/{patience})")
            if no_improve >= patience:
                print("  ✖ Early stopping activado.")
                break

        # limpieza VRAM
        torch.cuda.empty_cache(); gc.collect()

    return best_name

In [23]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
eval_args = dict(
    groups_val=val_groups,                # dict {wav_path: [segs...]}
    processor=processor,
    model=model,
    generation_config=generation_config,
    device=device,
    batch_size=8,
    transform=lambda x: x
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)


scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=1, min_lr=1e-7)

best_name=train(model, train_dataloader, val_dataloader, optimizer, device,
      scheduler=scheduler, scheduler_type="plateau",
      acc_steps=8, max_epochs=10, patience=3,)
      # eval_bleu_comet_fn=eval_bleu_comet_fn_val,
      # eval_bleu_args=eval_args)


  scaler = torch.cuda.amp.GradScaler(enabled=False)  # bfloat16 no necesita scaler; activarlo si usas fp16



Epoch 1/10



  1%|          | 1/177 [00:02<08:09,  2.78s/it]

 step 0    | loss = 5.3796


 12%|█▏        | 21/177 [00:14<01:31,  1.70it/s]

 step 20   | loss = 2.1143


 23%|██▎       | 41/177 [00:25<01:13,  1.86it/s]

 step 40   | loss = 1.7890


 34%|███▍      | 61/177 [00:37<01:09,  1.67it/s]

 step 60   | loss = 2.6950


 46%|████▌     | 81/177 [00:48<00:53,  1.78it/s]

 step 80   | loss = 1.9491


 57%|█████▋    | 101/177 [01:00<00:45,  1.66it/s]

 step 100  | loss = 2.2675


 68%|██████▊   | 121/177 [01:10<00:29,  1.89it/s]

 step 120  | loss = 1.7925


 80%|███████▉  | 141/177 [01:22<00:17,  2.05it/s]

 step 140  | loss = 1.6450


 91%|█████████ | 161/177 [01:33<00:07,  2.02it/s]

 step 160  | loss = 1.4030


100%|██████████| 177/177 [01:41<00:00,  1.74it/s]


  -> val_loss: 1.7821
  ✔ Nuevo mejor modelo guardado en checkpoints/best_epoch1.pt

Epoch 2/10


  1%|          | 1/177 [00:00<01:39,  1.77it/s]

 step 0    | loss = 1.6187


 12%|█▏        | 21/177 [00:12<01:25,  1.83it/s]

 step 20   | loss = 2.1845


 23%|██▎       | 41/177 [00:24<01:22,  1.66it/s]

 step 40   | loss = 1.9422


 34%|███▍      | 61/177 [00:35<01:03,  1.83it/s]

 step 60   | loss = 1.3762


 46%|████▌     | 81/177 [00:46<00:49,  1.92it/s]

 step 80   | loss = 1.6691


 57%|█████▋    | 101/177 [00:58<00:42,  1.78it/s]

 step 100  | loss = 2.0240


 68%|██████▊   | 121/177 [01:09<00:30,  1.82it/s]

 step 120  | loss = 1.1218


 80%|███████▉  | 141/177 [01:20<00:20,  1.79it/s]

 step 140  | loss = 1.5779


 91%|█████████ | 161/177 [01:33<00:10,  1.51it/s]

 step 160  | loss = 2.0111


100%|██████████| 177/177 [01:42<00:00,  1.73it/s]


  -> val_loss: 1.7632
  ✔ Nuevo mejor modelo guardado en checkpoints/best_epoch2.pt

Epoch 3/10


  1%|          | 1/177 [00:00<01:48,  1.62it/s]

 step 0    | loss = 1.5931


 12%|█▏        | 21/177 [00:13<01:34,  1.64it/s]

 step 20   | loss = 1.3428


 23%|██▎       | 41/177 [00:24<01:20,  1.70it/s]

 step 40   | loss = 2.1619


 34%|███▍      | 61/177 [00:36<01:02,  1.85it/s]

 step 60   | loss = 1.3865


 46%|████▌     | 81/177 [00:47<01:02,  1.54it/s]

 step 80   | loss = 1.1457


 57%|█████▋    | 101/177 [00:58<00:45,  1.66it/s]

 step 100  | loss = 1.5167


 68%|██████▊   | 121/177 [01:10<00:32,  1.71it/s]

 step 120  | loss = 1.5925


 80%|███████▉  | 141/177 [01:21<00:20,  1.73it/s]

 step 140  | loss = 1.7889


 91%|█████████ | 161/177 [01:33<00:09,  1.71it/s]

 step 160  | loss = 1.0020


100%|██████████| 177/177 [01:41<00:00,  1.74it/s]


  -> val_loss: 1.7826
  (sin mejora 1/3)

Epoch 4/10


  1%|          | 1/177 [00:00<01:53,  1.55it/s]

 step 0    | loss = 1.0713


 12%|█▏        | 21/177 [00:12<01:24,  1.85it/s]

 step 20   | loss = 1.2359


 23%|██▎       | 41/177 [00:24<01:19,  1.72it/s]

 step 40   | loss = 1.3799


 34%|███▍      | 61/177 [00:35<01:02,  1.86it/s]

 step 60   | loss = 1.2673


 46%|████▌     | 81/177 [00:46<00:53,  1.78it/s]

 step 80   | loss = 0.9972


 57%|█████▋    | 101/177 [00:58<00:43,  1.77it/s]

 step 100  | loss = 1.1156


 68%|██████▊   | 121/177 [01:08<00:31,  1.76it/s]

 step 120  | loss = 1.2043


 80%|███████▉  | 141/177 [01:20<00:22,  1.61it/s]

 step 140  | loss = 0.9833


 91%|█████████ | 161/177 [01:32<00:10,  1.59it/s]

 step 160  | loss = 1.0034


100%|██████████| 177/177 [01:41<00:00,  1.75it/s]


  -> val_loss: 1.8383
  (sin mejora 2/3)

Epoch 5/10


  1%|          | 1/177 [00:00<01:45,  1.67it/s]

 step 0    | loss = 0.6049


 12%|█▏        | 21/177 [00:12<01:34,  1.66it/s]

 step 20   | loss = 0.8343


 23%|██▎       | 41/177 [00:23<01:13,  1.86it/s]

 step 40   | loss = 1.2153


 34%|███▍      | 61/177 [00:34<01:03,  1.83it/s]

 step 60   | loss = 1.1342


 46%|████▌     | 81/177 [00:46<00:51,  1.85it/s]

 step 80   | loss = 0.5954


 57%|█████▋    | 101/177 [00:57<00:44,  1.72it/s]

 step 100  | loss = 0.8716


 68%|██████▊   | 121/177 [01:08<00:31,  1.78it/s]

 step 120  | loss = 1.0835


 80%|███████▉  | 141/177 [01:20<00:21,  1.70it/s]

 step 140  | loss = 0.8937


 91%|█████████ | 161/177 [01:31<00:09,  1.70it/s]

 step 160  | loss = 1.2114


100%|██████████| 177/177 [01:41<00:00,  1.75it/s]


  -> val_loss: 1.9572
  (sin mejora 3/3)
  ✖ Early stopping activado.


## 4  Evaluación  (BLEU y COMET)

### Evaluacion completa

In [17]:

torch.cuda.empty_cache(); gc.collect()


# Cargar checkpoint
ckpt = torch.load("checkpoints/best_epoch2.pt", map_location=device)
# ckpt = torch.load(best_name, map_location=device)
model.load_state_dict(ckpt["model_state"], strict=False)


model.gradient_checkpointing_disable()          # importante para evitar el bug/anular GC
model.eval()
model.config.use_cache = True                   # más memoria, pero más rápido


In [18]:
# Carga mejor checkpoint si quieres

groups  = segments_by_audio(test_ds)
print(f"{len(groups)=} audios")
hyps, refs, srcs = generate_output(groups)



len(groups)=5 audios



Audios:  20%|██        | 1/5 [03:38<14:34, 218.69s/it]

['Buenos días, al mediodía y por la noche.', 'Esta es la diapositiva titulada "e-sesion 404" que discute el impacto de la cirugía oncológica en los resultados.', 'La European School of Oncology invita a participar en su 404ª sesión.', 'La diapositiva discute el impacto de la cirugía oncológica en los resultados.', 'La diapositiva se titula "e-sesion 404" y discute el impacto de la cirugía oncológica en los resultados.', 'Al final de la presentación, al cerrar la ventana de transmisión web, se dirigirá a la evaluación de CME y el examen de opción múltiple.', 'La diapositiva se titula "e-sesion 404" y discute el impacto de la cirugía oncológica en los resultados.', 'La diapositiva se titula "e-sesion 404" y discute el impacto de la cirugía oncológica en los resultados.', 'La diapositiva se titula "e-sesion 404" y discute el impacto de la cirugía oncológica en los resultados.', 'La diapositiva se titula "e-sesion 404" y discute el impacto de la cirugía oncológica en los resultados.', 'La 


Audios:  40%|████      | 2/5 [11:18<18:02, 360.73s/it]

['Hola a todos.', 'Es un placer estar aquí y compartir con ustedes mi experiencia en la evaluación de planes de tratamiento en radioterapia.', 'Soy Nuria Jornet, soy consultora médica física y cabeza del departamento de física de radioterapia en el Hospital de la Santa Creu i Sant Pau, Barcelona.', 'Los objetivos de aprendizaje de esta sesión de e-SESessions son los siguientes: 1) Proporcionar una visión general del proceso de tratamiento de la radioterapia; 2) Analizar diferentes aspectos de la "calidad del plan de tratamiento"; 3) Evaluar la calidad del plan de tratamiento; 4) Diseñar orientación radiológica para ensayos clínicos y protocolos de tratamiento internos.', 'Objetivos de aprendizaje para una sesión de e-SESessions.', 'Y también porque es muy importante si tienes que diseñar orientación radiológica para ensayos clínicos y protocolos de tratamiento internos.', 'El proceso de tratamiento de radioterapia para pacientes, que incluye la prescripción, la planificación, la imagen


Audios:  60%|██████    | 3/5 [19:54<14:22, 431.33s/it]

['La diapositiva se centra en temas relacionados con el cáncer de mama avanzado y su impacto en la vida nocturna.', 'Es un placer hablar con usted hoy.', 'Y espero que tengamos algunas preguntas interesantes que surjan también.', 'Las cosas sobre el cáncer de mama avanzado que nos preocupan por la noche.', 'De acuerdo, genial.', 'Así que nuestro tema de hoy son las cosas sobre el cáncer de mama avanzado que nos mantienen despiertos por las noches.', 'De hecho, hay muchos temas diferentes que podríamos discutir aquí, porque estaba pensando en ello mientras terminaba mis diapositivas y pensando que hay tantas cosas que nos mantienen despiertos, cuidando a nuestros pacientes.', 'Las cosas sobre el cáncer de mama avanzado que nos preocupan por la noche', 'La diapositiva se centra en el tema de "Things about advanced breast cancer that keep us up at night", presentada por Hope S Rugo, MD, quien es un profesor de medicina y director de la educación en oncología y ensayos clínicos en la Unive


Audios:  80%|████████  | 4/5 [27:48<07:28, 448.37s/it]

['Hola a todos.', 'Mi nombre es Ramón de Mello y estoy muy contento de coordinar esta sesión multidisciplinaria sobre cáncer rectal.', 'Hoy tenemos aquí expertos de todo el mundo que discutirán la multidisciplinariedad, cómo podemos abordar mejor el cáncer de recto para mejorar los resultados.', 'Y tenemos dos expertos que darán su presentación, un oncólogo médico Dr. Katia Pérez y un oncólogo radiactivo Prof. María Antonieta Gamboa Cortana.', 'La diapositiva se titula "Multidisciplinaria sesión on rectal cancer", lo que indica que el tema de la sesión es un enfoque multidisciplinario sobre el cáncer rectal.', 'Y al final, estamos abiertos a preguntas y a discutir lo que creemos que es pertinente para esta sección.', 'La sesión multidisciplinaria sobre cáncer rectal incluye a tres expertos y dos participantes en la discusión.', 'Buenas tardes y es un placer estar aquí hoy.', 'Mi nombre es Katia Roque Pérez y voy a presentar una revisión de la gestión multidisciplinaria en el tratamient

Audios: 100%|██████████| 5/5 [30:37<00:00, 367.41s/it]

['La diapositiva se centra en la importancia de la atención multidisciplinaria en el tratamiento de la cáncer de mama en mujeres jóvenes.', 'La diapositiva se centra en la importancia de la atención multidisciplinaria en el tratamiento de la cáncer de mama en mujeres jóvenes.', 'Soy oncólogo médico en Dana-Farber, donde me especializo en cáncer de mama en mujeres jóvenes en particular.', 'La importancia de la atención multidisciplinaria en el tratamiento de la cáncer de mama en mujeres jóvenes.', 'La diapositiva se centra en la importancia de la atención multidisciplinaria en el tratamiento de la cáncer de mama en mujeres jóvenes.', 'La diapositiva se centra en la importancia de la evaluación de la enfermedad de la mama en mujeres jóvenes, con un enfoque en los principios de la Comisión Forrester de 1987 para la detección de cáncer de mama.', 'La diapositiva se centra en la importancia de la evaluación de la enfermedad de la mama en mujeres jóvenes, con un enfoque en los principios de 




In [19]:
global_metrics, per_audio = bleu_comet_by_audio(refs, hyps, srcs,
                                                comet_gpus=1, comet_batch_size=8)





Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

hparams.yaml:   0%|          | 0.00/567 [00:00<?, ?B/s]

LICENSE: 0.00B [00:00, ?B/s]

README.md: 0.00B [00:00, ?B/s]

.gitattributes: 0.00B [00:00, ?B/s]

checkpoints/model.ckpt:   0%|          | 0.00/2.32G [00:00<?, ?B/s]

INFO:pytorch_lightning.utilities.migration.utils:Lightning automatically upgraded your loaded checkpoint from v1.8.3.post1 to v2.5.3. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../root/.cache/huggingface/hub/models--Unbabel--wmt22-comet-da/snapshots/2760a223ac957f30acfb18c8aa649b01cf1d75f2/checkpoints/model.ckpt`


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.10M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/616 [00:00<?, ?B/s]

/usr/local/lib/python3.12/dist-packages/pytorch_lightning/core/saving.py:195: Found keys that are not in the model state dict but in the checkpoint: ['encoder.model.embeddings.position_ids']
INFO:pytorch_lightning.utilities.rank_zero:💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/sta

In [20]:

print("BLEU  global :", global_metrics["bleu"])
print("COMET-22    :", global_metrics["comet22"])

BLEU  global : 20.8754686449739
COMET-22    : 0.6869584716087559


In [21]:
per_audio

[{'audio_id': 0, 'bleu': 18.224811543811374, 'comet22': 0.6792517582331699},
 {'audio_id': 1, 'bleu': 23.15338352052172, 'comet22': 0.711343354916986},
 {'audio_id': 2, 'bleu': 18.43010238218834, 'comet22': 0.6701103697426586},
 {'audio_id': 3, 'bleu': 18.370959080317302, 'comet22': 0.7034558144847999},
 {'audio_id': 4, 'bleu': 17.31098294999179, 'comet22': 0.6633871623447963}]