## 0. Libraries 📚

In [None]:
import pandas as pd
import ast
from utils import read_cie10_file
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import f1_score
import random
import numpy as np
from tqdm import tqdm
from sklearn.preprocessing import MultiLabelBinarizer
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
import numpy as np
import json
from datasets import load_dataset
from huggingface_hub import login
from typing import List, Dict

In [None]:
SEED = 42
random.seed(SEED)

## 1. Load data 📥

In [None]:
diagnoses_df = pd.read_csv("data/ground_truth_df.csv")
diagnoses_df['Codigos_diagnosticos'] = diagnoses_df['Codigos_diagnosticos'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else [])
diagnoses_df['Diagnosticos_estandar'] = diagnoses_df['Diagnosticos_estandar'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else [])
diagnoses_df

In [None]:
cie10_map = read_cie10_file("data/diagnosticos_tipos.csv")
cie10_map

## 2. Pre-process and splits 🧹✂️

In [None]:
SEED = 42  # reproducibility ─ adapt if you already defined it

# ------------------------------------------------------------------
# 1.  Features (X) – cleaned text with the "query: " prefix
# ------------------------------------------------------------------
X = np.array(diagnoses_df["Descripcion_diagnosticos_limpio"].tolist())

# ------------------------------------------------------------------
# 2.  Targets – keep the *raw* lists
# ------------------------------------------------------------------
codes_lists = diagnoses_df["Codigos_diagnosticos"].tolist()        # list[list[str]]
std_lists   = diagnoses_df["Diagnosticos_estandar"].tolist()       # list[list[str]]

# ------------------------------------------------------------------
# 3.  Temporary one-hot (only for the splitter)
# ------------------------------------------------------------------
mlb_codes = MultiLabelBinarizer()
codes_enc = mlb_codes.fit_transform(codes_lists)

mlb_std = MultiLabelBinarizer()
std_enc  = mlb_std.fit_transform(std_lists)

y_for_split = np.hstack([codes_enc, std_enc])  # shape: n_samples × (n_codes + n_std)

# ------------------------------------------------------------------
# 4.  First split  → 70 % train  /  30 % temp
# ------------------------------------------------------------------
msss = MultilabelStratifiedShuffleSplit(
    n_splits=1, test_size=0.30, random_state=SEED
)
for train_idx, tmp_idx in msss.split(np.zeros(len(X)), y_for_split):
    X_train = X[train_idx]
    X_tmp   = X[tmp_idx]

    # *** keep raw labels ***
    codes_train = [codes_lists[i] for i in train_idx]
    codes_tmp   = [codes_lists[i] for i in tmp_idx]

    std_train   = [std_lists[i]   for i in train_idx]
    std_tmp     = [std_lists[i]   for i in tmp_idx]

# ------------------------------------------------------------------
# 5.  Second split on the 30 % temp  → 15 % val  /  15 % test
# ------------------------------------------------------------------
msss_val = MultilabelStratifiedShuffleSplit(
    n_splits=1, test_size=0.50, random_state=SEED
)
for val_idx, test_idx in msss_val.split(
    np.zeros(len(X_tmp)),
    np.hstack([codes_enc[tmp_idx], std_enc[tmp_idx]]),   # need encoded labels again
):
    X_val,  X_test  = X_tmp[val_idx],  X_tmp[test_idx]

    codes_val  = [codes_tmp[i] for i in val_idx]
    codes_test = [codes_tmp[i] for i in test_idx]

    std_val    = [std_tmp[i]   for i in val_idx]
    std_test   = [std_tmp[i]   for i in test_idx]

# ------------------------------------------------------------------
# 6.  Convert X back to lists (if you need plain Python lists later)
# ------------------------------------------------------------------
X_train, X_val, X_test = map(lambda a: a.tolist(), [X_train, X_val, X_test])

## 3. GPT format messages

In [None]:
codes_for_prompt = ""
for code, description in cie10_map.items():
    codes_for_prompt += f"{code} {description}\n"
print(codes_for_prompt)

In [None]:
def return_prompt(codes_for_prompt, description):
    return f"""Task: read Spanish free-text clinical notes and return **at most five** ICD-10
    codes with their official Spanish descriptions. Use only codes present in the
    following list. If none apply, output [].
    Think step by step internally but DO NOT reveal your reasoning.
    Return a JSON array exactly like this:
    [
    {{"code":"FXX.X","label":"Descripción oficial 1"}},
    {{"code":"FYY.Y","label":"Descripción oficial 2"}},
    ...
    ]
    No extra keys, no text outside the JSON.

    --- ALLOWED LIST (83 codes) ---
    {codes_for_prompt}
    --- END OF ALLOWED LIST ---

    Clinical note:
    {description}
    """

In [None]:
with open("train_examples.jsonl", "w", encoding="utf-8") as f:
    for description, codes, stds in zip(X_train, codes_train, std_train):
        llm_result = []
        for code, std in zip(codes, stds):
            llm_result.append({"code": code, "label": std})
        llm_result = str(llm_result).replace("'", '"')
        message = {
            "messages":[
                {
                    "role": "system",
                    "content": "You are a clinical coder specialising in ICD-10 mental-health (F00–F99)."
                },
                {
                    "role": "user",
                    "content": return_prompt(codes_for_prompt, description)
                },
                {
                    "role": "assistant",
                    "content": llm_result
                }
            ]
        }

        f.write(json.dumps(message, ensure_ascii=False) + "\n")

In [None]:
with open("val_examples.jsonl", "w", encoding="utf-8") as f:
    for description, codes, stds in zip(X_val, codes_val, std_val):
        llm_result = []
        for code, std in zip(codes, stds):
            llm_result.append({"code": code, "label": std})
        llm_result = str(llm_result).replace("'", '"')
        message = {
            "messages":[
                {
                    "role": "system",
                    "content": "You are a clinical coder specialising in ICD-10 mental-health (F00–F99)."
                },
                {
                    "role": "user",
                    "content": return_prompt(codes_for_prompt, description)
                },
                {
                    "role": "assistant",
                    "content": llm_result
                }
            ]
        }

        f.write(json.dumps(message, ensure_ascii=False) + "\n")
        

In [None]:
with open("test_examples.jsonl", "w", encoding="utf-8") as f:
    for description, codes, stds in zip(X_test, codes_test, std_test):
        llm_result = []
        for code, std in zip(codes, stds):
            llm_result.append({"code": code, "label": std})
        llm_result = str(llm_result).replace("'", '"')
        message = {
            "messages":[
                {
                    "role": "system",
                    "content": "You are a clinical coder specialising in ICD-10 mental-health (F00–F99)."
                },
                {
                    "role": "user",
                    "content": return_prompt(codes_for_prompt, description)
                },
                {
                    "role": "assistant",
                    "content": llm_result
                }
            ]
        }

        f.write(json.dumps(message, ensure_ascii=False) + "\n")
        

## 4. Study max tokens

In [None]:
### Load the model

from transformers import AutoTokenizer
import pandas as pd

model_id = "meta-llama/Llama-3.2-1B-Instruct"
# model_id = "meta-llama/Llama-3.2-3B-Instruct"
# model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

tok = AutoTokenizer.from_pretrained(model_id)
tok.pad_token = tok.eos_token

In [None]:
import json

def n_tokens(text: str) -> int:
    return len(tok(text, add_special_tokens=False)["input_ids"])

def count_tokens_in_jsonl(filename):
    n_tokens_list = []
    with open(filename, "r", encoding="utf-8") as f:
        for line in f:
            record = json.loads(line)
            # Extrae el contenido del mensaje del usuario
            user_msg = next(m["content"] for m in record["messages"] if m["role"] == "user")
            n_tokens_list.append(n_tokens(user_msg))
    return n_tokens_list

train_tokens = count_tokens_in_jsonl("train_examples.jsonl")
val_tokens = count_tokens_in_jsonl("val_examples.jsonl")
test_tokens = count_tokens_in_jsonl("test_examples.jsonl")

print(f"Test:  min={min(test_tokens)}, max={max(test_tokens)}, mean={sum(test_tokens)/len(test_tokens):.2f}")
print(f"Train: min={min(train_tokens)}, max={max(train_tokens)}, mean={sum(train_tokens)/len(train_tokens):.2f}")
print(f"Val:   min={min(val_tokens)}, max={max(val_tokens)}, mean={sum(val_tokens)/len(val_tokens):.2f}")

In [None]:
### Count coverage of diagnoses with max tokens

MAX_TOKENS = 2_048
train_coverage = (np.array(train_tokens) <= MAX_TOKENS).mean()
val_coverage = (np.array(val_tokens) <= MAX_TOKENS).mean()
test_coverage = (np.array(test_tokens) <= MAX_TOKENS).mean()
print(f"Train coverage ≤{MAX_TOKENS}: {train_coverage*100:.2f}%")
print(f"Val coverage ≤{MAX_TOKENS}: {val_coverage*100:.2f}%")
print(f"Test coverage ≤{MAX_TOKENS}: {test_coverage*100:.2f}%")
print()

total_tokens = np.array(train_tokens + val_tokens + test_tokens)
total_coverage = (total_tokens <= MAX_TOKENS).mean()
print(f"Total coverage ≤{MAX_TOKENS}: {total_coverage*100:.2f}%")

In [None]:
import pandas as pd

### View percentiles to document X% coverage
pd.Series(total_tokens).describe(percentiles=[0.90, 0.95, 0.99, 0.997])

## 5. Fine tunning with PEFT

In [None]:
from datasets import load_dataset
train_ds = load_dataset("json", data_files="train_examples.jsonl", split="train")
val_ds   = load_dataset("json", data_files="val_examples.jsonl", split="train")

print(train_ds[0]["messages"][0])
print(train_ds[0]["messages"][1])
print(train_ds[0]["messages"][2])

In [None]:
from huggingface_hub import login

login(token="TOKEN")

In [None]:
import os
import warnings

# Forces the old mode of torch.load (allows to load rng_state.pth without errors)print(f “The 10 most frequent diagnostics represent the {percentage_top_10:.2f}% of all diagnostics.”)
os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1"

# Hide only the warning related to that variable
warnings.filterwarnings(
    "ignore",
    message="Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected.*",
    category=UserWarning
)

In [None]:
model_id = "meta-llama/Llama-3.2-3B-Instruct"
# model_id = "meta-llama/Llama-3.2-1B-Instruct"

output_dir = "qlora_llama3B"
# output_dir = "qlora_llama1B"

In [None]:
## FINE-TUNE LLM WITH PEFT ###

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig

bnb_cfg = BitsAndBytesConfig(load_in_4bit=True,
                             bnb_4bit_compute_dtype="bfloat16",
                             bnb_4bit_use_double_quant=True)


tok = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id,
                                             quantization_config=bnb_cfg,
                                             device_map="auto")

lora_cfg = LoraConfig(r=16, lora_alpha=32, target_modules=["q_proj","k_proj","v_proj","o_proj"],
                      lora_dropout=0.05, bias="none", task_type="CAUSAL_LM")

def formatting(ex):
    return tok.apply_chat_template(ex["messages"], tokenize=False)

sft_args = SFTConfig(
        output_dir=output_dir,
        per_device_train_batch_size=1, # Paraliza. Valor mayor más memoria
        gradient_accumulation_steps=16, # Cada cuanto actualiza los gradientes. "Simula" un batch mayor sin ocupar tanta memoria
        num_train_epochs=4,
        learning_rate=2e-4,
        logging_steps=100,
        save_strategy="epoch",
        max_length=2048,
        # eval_strategy="epoch",
        # per_device_eval_batch_size=1,
        packing=True,
        disable_tqdm=False)

trainer = SFTTrainer(model=model,
                     train_dataset=train_ds,
                     eval_dataset=val_ds,
                     peft_config=lora_cfg,
                     formatting_func=formatting,
                     args=sft_args)

trainer.train(resume_from_checkpoint = True)
trainer.save_model(output_dir)

In [None]:
### MERGE AND SAVE THE PEFT FINE-TUNED MODEL ###

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

# Load base and LoRA adapter
base = AutoModelForCausalLM.from_pretrained(model_id)
model = PeftModel.from_pretrained(base, output_dir)

# Merge and unload LoRA into base
merged = model.merge_and_unload()

# Save merged model and tokenizer
merged.save_pretrained(f"{output_dir}_merged")
AutoTokenizer.from_pretrained(model_id).save_pretrained(f"{output_dir}_merged")

## 6. Evaluation with vLLM deployed model

In [None]:
codes_for_prompt = ""
for code, description in cie10_map.items():
    codes_for_prompt += f"{code} {description}\n"
print(codes_for_prompt)

In [None]:
import json

def process_llm_output(response):
    def process_llm_dic(llm_dict):
        return process_llm_element([item for pair in llm_dict.items() for item in pair])
    
    def process_llm_list(llm_list):
        result = []
        for el in llm_list:
            processed_el = process_llm_element(el)
            if isinstance(processed_el, str):
                result.append(processed_el)
            else:
                result.extend(processed_el)
        return result
    
    def process_llm_element(llm_element):
        if isinstance(llm_element, list):
            return process_llm_list(llm_element)
        elif isinstance(llm_element, dict):
            return process_llm_dic(llm_element)
        elif isinstance(llm_element, set):
            return process_llm_list(list(llm_element))
        elif isinstance(llm_element, int) or isinstance(llm_element, float):
            return str(llm_element)
        else:
            return llm_element
    
    try:
        data = json.loads(response.choices[0].message.content)
        pred = process_llm_element(data)

        pred = list(set(pred))  # Remove duplicates
        pred = [item for item in pred if item in cie10_map.keys()]
    except Exception as e:
        print(e)
        pred = []
        print("**************************")
        print("ERROR")
        try:
            print(data)
        except:
            print(response)
        print("**************************")
    return pred

### Train data

In [None]:
import os, json, asyncio, nest_asyncio, httpx
from typing import List, Dict
from openai import AsyncOpenAI
from tqdm.asyncio import tqdm_asyncio
from sklearn.metrics import (
    f1_score,
    precision_score,
    recall_score,
)
from sklearn.preprocessing import MultiLabelBinarizer

# MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
# MODEL_ID = "qlora_llama3B_merged"
# MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct"
# MODEL_ID = "qlora_llama1B_merged"

CONCURRENCY = 8
OUTPUT_FILE = f"predictions/{MODEL_ID.replace('/', '_')}_predictions_train.jsonl"

client = AsyncOpenAI(
    base_url="http://localhost:8000/v1",
    api_key="EMPTY",
    http_client=httpx.AsyncClient(verify=False),
)
sema = asyncio.Semaphore(CONCURRENCY)
write_lock = asyncio.Lock()                # evita escrituras simultáneas

def return_prompt(codes_for_prompt: str, description: str) -> str:
    return f"""
    You are a clinical coder specialising in ICD-10 mental-health (F00–F99).
    Task: read Spanish free-text clinical notes and return **at most five** ICD-10
    codes with their official Spanish descriptions. Use only codes present in the
    following list. If none apply, output [].
    Think step by step internally but DO NOT reveal your reasoning.
    Return a JSON array exactly like this:
    [
      {{"code":"FXX.X","label":"Descripción oficial","score":0.85}}
    ]
    No extra keys, no text outside the JSON.

    --- ALLOWED LIST (83 codes) ---
    {codes_for_prompt}
    --- END OF ALLOWED LIST ---

    Clinical note:
    {description}
    """

def load_progress() -> Dict[int, List[str]]:
    """Carga las predicciones ya guardadas."""
    done: Dict[int, List[str]] = {}
    if os.path.exists(OUTPUT_FILE):
        with open(OUTPUT_FILE, "r", encoding="utf-8") as f:
            for line in f:
                rec = json.loads(line)
                done[rec["idx"]] = rec["pred"]
    return done

def append_record(idx: int, pred: List[str]) -> None:
    """Añade una línea al fichero en formato JSONL."""
    with open(OUTPUT_FILE, "a", encoding="utf-8") as f:
        json.dump({"idx": idx, "pred": pred}, f, ensure_ascii=False)
        f.write("\n")
        f.flush()

async def classify(idx: int, description: str, codes: List[str]) -> List[str]:
    async with sema:
        prompt = return_prompt(codes_for_prompt, description)
        response = await client.chat.completions.create(
            model=MODEL_ID,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.0,
            response_format={"type": "json_object"},
        )
        pred = process_llm_output(response)

    # Guarda la predicción de forma atómica
    async with write_lock:
        append_record(idx, pred)

    return pred

async def main() -> None:
    done = load_progress()
    missing_indices = [i for i in range(len(X_train)) if i not in done]

    # Ejecuta solo los que faltan
    tasks = [
        classify(i, X_train[i], codes_train[i])
        for i in missing_indices
    ]

    for _ in tqdm_asyncio.as_completed(tasks, total=len(tasks), desc="Processing"):
        await _

    # Si (re)hemos completado todo el conjunto, calcula la métrica
    done = load_progress()
    if len(done) == len(X_train):
        mlb = MultiLabelBinarizer().fit(codes_train)   # o tu mlb_codes existente
        y_true = mlb.transform(codes_train)
        # Predicciones en orden original
        y_pred = mlb.transform([done[i] for i in range(len(X_train))])

        # Metrics
        f1_micro = f1_score(y_true, y_pred, average="micro", zero_division=0)
        precision_micro = precision_score(y_true, y_pred, average="micro", zero_division=0)
        recall_micro = recall_score(y_true, y_pred, average="micro", zero_division=0)

        # Pretty print
        print("\n===== Train Metrics =====")
        print(f"F1        : {f1_micro:.4f}")
        print(f"Precision : {precision_micro:.4f}")
        print(f"Recall    : {recall_micro:.4f}")
    else:
        print(f"Checkpoint guardado: {len(done)}/{len(X_train)} ejemplos procesados.")

# ── Celda de Jupyter ──
nest_asyncio.apply()
await main()

### Val data

In [None]:
import os, json, asyncio, nest_asyncio, httpx
from typing import List, Dict
from openai import AsyncOpenAI
from tqdm.asyncio import tqdm_asyncio
from sklearn.metrics import (
    f1_score,
    precision_score,
    recall_score,
)
from sklearn.preprocessing import MultiLabelBinarizer

# MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
# MODEL_ID = "qlora_llama3B_merged"
# MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct"
# MODEL_ID = "qlora_llama1B_merged"

CONCURRENCY = 8
OUTPUT_FILE = f"predictions/{MODEL_ID.replace('/', '_')}_predictions_val.jsonl"

client = AsyncOpenAI(
    base_url="http://localhost:8000/v1",
    api_key="EMPTY",
    http_client=httpx.AsyncClient(verify=False),
)
sema = asyncio.Semaphore(CONCURRENCY)
write_lock = asyncio.Lock()                # evita escrituras simultáneas

def return_prompt(codes_for_prompt: str, description: str) -> str:
    return f"""
    You are a clinical coder specialising in ICD-10 mental-health (F00–F99).
    Task: read Spanish free-text clinical notes and return **at most five** ICD-10
    codes with their official Spanish descriptions. Use only codes present in the
    following list. If none apply, output [].
    Think step by step internally but DO NOT reveal your reasoning.
    Return a JSON array exactly like this:
    [
      {{"code":"FXX.X","label":"Descripción oficial","score":0.85}}
    ]
    No extra keys, no text outside the JSON.

    --- ALLOWED LIST (83 codes) ---
    {codes_for_prompt}
    --- END OF ALLOWED LIST ---

    Clinical note:
    {description}
    """

def load_progress() -> Dict[int, List[str]]:
    """Carga las predicciones ya guardadas."""
    done: Dict[int, List[str]] = {}
    if os.path.exists(OUTPUT_FILE):
        with open(OUTPUT_FILE, "r", encoding="utf-8") as f:
            for line in f:
                rec = json.loads(line)
                done[rec["idx"]] = rec["pred"]
    return done

def append_record(idx: int, pred: List[str]) -> None:
    """Añade una línea al fichero en formato JSONL."""
    with open(OUTPUT_FILE, "a", encoding="utf-8") as f:
        json.dump({"idx": idx, "pred": pred}, f, ensure_ascii=False)
        f.write("\n")
        f.flush()

async def classify(idx: int, description: str, codes: List[str]) -> List[str]:
    async with sema:
        prompt = return_prompt(codes_for_prompt, description)
        response = await client.chat.completions.create(
            model=MODEL_ID,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.0,
            response_format={"type": "json_object"},
        )
        pred = process_llm_output(response)

    # Guarda la predicción de forma atómica
    async with write_lock:
        append_record(idx, pred)

    return pred

async def main() -> None:
    done = load_progress()
    missing_indices = [i for i in range(len(X_val)) if i not in done]

    # Ejecuta solo los que faltan
    tasks = [
        classify(i, X_val[i], codes_val[i])
        for i in missing_indices
    ]

    for _ in tqdm_asyncio.as_completed(tasks, total=len(tasks), desc="Processing"):
        await _

    # Si (re)hemos completado todo el conjunto, calcula la métrica
    done = load_progress()
    if len(done) == len(X_val):
        mlb = MultiLabelBinarizer().fit(codes_val)   # o tu mlb_codes existente
        y_true = mlb.transform(codes_val)
        # Predicciones en orden original
        y_pred = mlb.transform([done[i] for i in range(len(X_val))])

        # Metrics
        f1_micro = f1_score(y_true, y_pred, average="micro", zero_division=0)
        precision_micro = precision_score(y_true, y_pred, average="micro", zero_division=0)
        recall_micro = recall_score(y_true, y_pred, average="micro", zero_division=0)

        # Pretty print
        print("\n===== Val Metrics =====")
        print(f"F1        : {f1_micro:.4f}")
        print(f"Precision : {precision_micro:.4f}")
        print(f"Recall    : {recall_micro:.4f}")
    else:
        print(f"Checkpoint guardado: {len(done)}/{len(X_val)} ejemplos procesados.")

# ── Celda de Jupyter ──
nest_asyncio.apply()
await main()

### Test data

In [None]:
import os, json, asyncio, nest_asyncio, httpx
from typing import List, Dict
from openai import AsyncOpenAI
from tqdm.asyncio import tqdm_asyncio
from sklearn.metrics import (
    f1_score,
    precision_score,
    recall_score,
)
from sklearn.preprocessing import MultiLabelBinarizer

# MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
# MODEL_ID = "qlora_llama3B_merged"
# MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct"
# MODEL_ID = "qlora_llama1B_merged"

CONCURRENCY = 8
OUTPUT_FILE = f"predictions/{MODEL_ID.replace('/', '_')}_predictions_test.jsonl"

client = AsyncOpenAI(
    base_url="http://localhost:8000/v1",
    api_key="EMPTY",
    http_client=httpx.AsyncClient(verify=False),
)
sema = asyncio.Semaphore(CONCURRENCY)
write_lock = asyncio.Lock()                # evita escrituras simultáneas

def return_prompt(codes_for_prompt: str, description: str) -> str:
    return f"""
    You are a clinical coder specialising in ICD-10 mental-health (F00–F99).
    Task: read Spanish free-text clinical notes and return **at most five** ICD-10
    codes with their official Spanish descriptions. Use only codes present in the
    following list. If none apply, output [].
    Think step by step internally but DO NOT reveal your reasoning.
    Return a JSON array exactly like this:
    [
      {{"code":"FXX.X","label":"Descripción oficial","score":0.85}}
    ]
    No extra keys, no text outside the JSON.

    --- ALLOWED LIST (83 codes) ---
    {codes_for_prompt}
    --- END OF ALLOWED LIST ---

    Clinical note:
    {description}
    """

def load_progress() -> Dict[int, List[str]]:
    """Carga las predicciones ya guardadas."""
    done: Dict[int, List[str]] = {}
    if os.path.exists(OUTPUT_FILE):
        with open(OUTPUT_FILE, "r", encoding="utf-8") as f:
            for line in f:
                rec = json.loads(line)
                done[rec["idx"]] = rec["pred"]
    return done

def append_record(idx: int, pred: List[str]) -> None:
    """Añade una línea al fichero en formato JSONL."""
    with open(OUTPUT_FILE, "a", encoding="utf-8") as f:
        json.dump({"idx": idx, "pred": pred}, f, ensure_ascii=False)
        f.write("\n")
        f.flush()

async def classify(idx: int, description: str, codes: List[str]) -> List[str]:
    async with sema:
        prompt = return_prompt(codes_for_prompt, description)
        response = await client.chat.completions.create(
            model=MODEL_ID,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.0,
            response_format={"type": "json_object"},
        )
        pred = process_llm_output(response)

    # Guarda la predicción de forma atómica
    async with write_lock:
        append_record(idx, pred)

    return pred

async def main() -> None:
    done = load_progress()
    missing_indices = [i for i in range(len(X_test)) if i not in done]

    # Ejecuta solo los que faltan
    tasks = [
        classify(i, X_test[i], codes_test[i])
        for i in missing_indices
    ]

    for _ in tqdm_asyncio.as_completed(tasks, total=len(tasks), desc="Processing"):
        await _

    # Si (re)hemos completado todo el conjunto, calcula la métrica
    done = load_progress()
    if len(done) == len(X_test):
        mlb = MultiLabelBinarizer().fit(codes_test)   # o tu mlb_codes existente
        y_true = mlb.transform(codes_test)
        # Predicciones en orden original
        y_pred = mlb.transform([done[i] for i in range(len(X_test))])

        # Metrics
        f1_micro = f1_score(y_true, y_pred, average="micro", zero_division=0)
        precision_micro = precision_score(y_true, y_pred, average="micro", zero_division=0)
        recall_micro = recall_score(y_true, y_pred, average="micro", zero_division=0)

        # Pretty print
        print("\n===== Test Metrics =====")
        print(f"F1        : {f1_micro:.4f}")
        print(f"Precision : {precision_micro:.4f}")
        print(f"Recall    : {recall_micro:.4f}")
    else:
        print(f"Checkpoint guardado: {len(done)}/{len(X_test)} ejemplos procesados.")

# ── Celda de Jupyter ──
nest_asyncio.apply()
await main()