In [6]:
import pandas as pd
import torch
import re
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import f1_score, classification_report
from tqdm.auto import tqdm
from collections import defaultdict

# ====================
#  CONFIGURACIÓN
# ====================

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

class Config:
    USE_FP16 = False
    MODEL_NAME = 'dccuchile/bert-base-spanish-wwm-cased'
    #MODEL_NAME = 'PlanTL-GOB-ES/bsc-bio-ehr-es'
    #MODEL_NAME = 'IIC/bsc-bio-ehr-es-caresA'
    #MODEL_NAME = 'PlanTL-GOB-ES/roberta-base-biomedical-clinical-es'
    #MODEL_NAME = 'IIC/bert-base-spanish-wwm-cased-ctebmsp'
    #MODEL_NAME = 'IIC/xlm-roberta-large-ehealth_kd'
    MAX_LENGTH = 512 # Máxima longitud de secuencia, definida en el BERT pre-entrenado
    TRAIN_BATCH_SIZE = 8
    VAL_BATCH_SIZE = 16
    TEST_BATCH_SIZE = 32
    EPOCHS = 500
    GRADIENT_ACCUMULATION_STEPS = 1
    WARMUP_EPOCHS = 2
    HIERARCHICAL_WEIGHTS = {'parent': 1.5, 'child': 1.0, 'grandchild': 0.5}
    LEARNING_RATE = 5e-5
    DATA_PATHS = {
        'train': 'codiesp_csvs/codiesp_D_source_train.csv',
        'test': 'codiesp_csvs/codiesp_D_source_test.csv',
        'val': 'codiesp_csvs/codiesp_D_source_validation.csv'
    }
    SAVE_TOKENIZER_PATH = 'snapshots/cie10_tokenizer'
    SAVE_PATH = 'snapshots/best_hierarchical_model'
    SAVE_STATE_PATH = 'snapshots/best_hierarchical_model_state.bin'
    SAVE_MLB_PATH = 'snapshots/best_hierarchical_model_mlb'
    THRESHOLDS = {'parent': 0.1, 'child': 0.1, 'grandchild': 0.1}
    PRETRAIN_EPOCHS = 10
    PRETRAIN_BATCH_SIZE = 1
    PRETRAIN_DATA_PATH = '../csv_import_scripts/cie10-es-diagnoses-expanded.csv'


Using device: cuda


In [7]:
#  PREPROCESAMIENTO
# ====================
def parse_code(code):
    """Divide el código en niveles jerárquicos"""
    parts = code.split('.')
    hierarchy = []
    if len(parts) >= 1:
        parent = parts[0]  # Primera categoría (ej: S62)
        hierarchy.append(parent)

        if len(parts) >= 2:
            child_part = parts[1]
            if len(child_part) >= 1:
                child = f"{parent}.{child_part[0]}"  # Segunda categoría (ej: S62.1)
                hierarchy.append(child)

                if len(child_part) > 1:
                    grandchild = f"{parent}.{child_part}"  # Tercera categoría (ej: S62.164K)
                    hierarchy.append(grandchild)
        elif len(parts) > 2:
            raise ValueError(f"Invalid code: {code}")

    return hierarchy

def build_label_matrices(dfs):
    """Construye matrices binarias para cada nivel jerárquico"""
    all_parents = set()
    all_children = set()
    all_grandchildren = set()

    for df in dfs:
        for codes in df['labels'].apply(eval):
            for code in codes:
                levels = parse_code(code)
                if len(levels) >= 1: all_parents.add(levels[0])
                if len(levels) >= 2: all_children.add(levels[1])
                if len(levels) >= 3: all_grandchildren.add(levels[2])

    mlb_parent = MultiLabelBinarizer().fit([all_parents])
    mlb_child = MultiLabelBinarizer().fit([all_children])
    mlb_grandchild = MultiLabelBinarizer().fit([all_grandchildren])
    print(f"Padres: {len(mlb_parent.classes_)} - Hijos: {len(mlb_child.classes_)} - Nietos: {len(mlb_grandchild.classes_)}")

    return mlb_parent, mlb_child, mlb_grandchild

# ====================
# Función de pérdida
# ====================
def hierarchical_loss(parent_logits, child_logits, grandchild_logits,
                      parent_labels, child_labels, grandchild_labels):

    loss_parent = torch.nn.BCEWithLogitsLoss()(parent_logits, parent_labels)
    loss_child = torch.nn.BCEWithLogitsLoss()(child_logits, child_labels)
    loss_grandchild = torch.nn.BCEWithLogitsLoss()(grandchild_logits, grandchild_labels)

    return (Config.HIERARCHICAL_WEIGHTS['parent'] * loss_parent +
            Config.HIERARCHICAL_WEIGHTS['child'] * loss_child +
            Config.HIERARCHICAL_WEIGHTS['grandchild'] * loss_grandchild)



In [8]:
#  PLT DE MÉTRICAS
# ====================

def plot_metrics():
    # Load metrics
    metrics_history = pd.read_csv('training_metrics.csv')

    plt.figure(figsize=(10, 6))
    plt.plot(metrics_history['epoch'], metrics_history['loss'], label='Loss')
    plt.plot(metrics_history['epoch'], metrics_history['f1_micro'], label='F1 Micro')
    plt.plot(metrics_history['epoch'], metrics_history['f1_macro'], label='F1 Macro')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.title('Training Metrics Over Time')
    plt.legend()
    plt.grid(True)
    plt.savefig('training_metrics.png')
    plt.close()


In [9]:
#  MODELO JERÁRQUICO
# ====================

class HierarchicalBERT(torch.nn.Module):
    def __init__(self, num_parents, num_children, num_grandchildren):
        super().__init__()
        try:
            self.bert = AutoModel.from_pretrained(Config.SAVE_PATH)
            print("Loaded saved BERT model")
        except:
            self.bert = AutoModel.from_pretrained(Config.MODEL_NAME)
            print("Using default BERT model")

        hidden_size = self.bert.config.hidden_size  # This will be 768 for base models

        self.parent_classifier = torch.nn.Linear(hidden_size, num_parents)
        self.child_classifier = torch.nn.Linear(hidden_size + num_parents, num_children)
        self.grandchild_classifier = torch.nn.Linear(hidden_size + num_children, num_grandchildren)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        pooled = outputs.last_hidden_state[:, 0, :]

        # Clasificación padre
        parent_logits = self.parent_classifier(pooled)

        # Clasificación hijo con contexto de padres
        parent_probs = torch.sigmoid(parent_logits)
        child_input = torch.cat([pooled, parent_probs], dim=1)
        child_logits = self.child_classifier(child_input)

        # Clasificación nieto con contexto de hijos
        child_probs = torch.sigmoid(child_logits)
        grandchild_input = torch.cat([pooled, child_probs], dim=1)
        grandchild_logits = self.grandchild_classifier(grandchild_input)

        return parent_logits, child_logits, grandchild_logits

In [10]:
# ====================
#  Preentrenameinto
# ====================

# Dataset especializado para pre-entrenamiento
class PretrainDataset(Dataset):
    def __init__(self, df, tokenizer, mlb_parent, mlb_child, mlb_grandchild):
        self.tokenizer = tokenizer
        self.examples = []
        self.mlb_parent = mlb_parent
        self.mlb_child = mlb_child
        self.mlb_grandchild = mlb_grandchild

        for _, row in df.iterrows():
            code = row['code'].strip()
            desc = row['description'].strip()
            levels = parse_code(code)

            # Generar múltiples variantes textuales
            # Base variant is the description itself
            variants = []

            # Extract text inside parentheses and brackets
            parentheses_matches = re.findall(r'\((.*?)\)', desc)
            bracket_matches = re.findall(r'\[(.*?)\]', desc)

            # Get text outside parentheses and brackets
            clean_text = re.sub(r'\([^)]*\)|\[[^\]]*\]', '', desc).strip()
            if clean_text != desc:
                variants.append(clean_text)

            # Add matches from parentheses and brackets
            variants.extend(parentheses_matches)
            variants.extend(bracket_matches)

            for variant in variants:
                if len(variant) > 5:
                    self.examples.append({
                        'text': variant,
                        'code': code,
                        'levels': levels
                    })
                else:
                    print(f"Skipping short variant: {variant}")


    def __len__(self): return len(self.examples)

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            self.examples[idx]['text'],
            max_length=Config.MAX_LENGTH,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        levels = self.examples[idx]['levels']

        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'parent_labels': torch.FloatTensor(self.mlb_parent.transform([[levels[0]]] if levels else [[]])[0]),
            'child_labels': torch.FloatTensor(self.mlb_child.transform([[levels[1]]] if len(levels)>1 else [[]])[0]),
            'grandchild_labels': torch.FloatTensor(self.mlb_grandchild.transform([[levels[2]]] if len(levels)>2 else [[]])[0])
        }

def pretrain():
    # Cargar datos de códigos CIE10
    df = pd.read_csv(Config.PRETRAIN_DATA_PATH)

    # Construir jerarquía de códigos
    all_parents = set()
    all_children = set()
    all_grandchildren = set()

    for code in df['code']:
        levels = parse_code(code.strip())
        if len(levels) >= 1: all_parents.add(levels[0])
        if len(levels) >= 2: all_children.add(levels[1])
        if len(levels) >= 3: all_grandchildren.add(levels[2])

    # Inicializar MLB
    mlb_parent = MultiLabelBinarizer().fit([all_parents])
    mlb_child = MultiLabelBinarizer().fit([all_children])
    mlb_grandchild = MultiLabelBinarizer().fit([all_grandchildren])
    print(f"Padres: {len(mlb_parent.classes_)} - Hijos: {len(mlb_child.classes_)} - Nietos: {len(mlb_grandchild.classes_)}")

    # Inicializar componentes
    tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)
    dataset = PretrainDataset(df, tokenizer, mlb_parent, mlb_child, mlb_grandchild)
    dataloader = DataLoader(dataset, batch_size=Config.PRETRAIN_BATCH_SIZE, shuffle=True)

    model = HierarchicalBERT(
        len(mlb_parent.classes_),
        len(mlb_child.classes_),
        len(mlb_grandchild.classes_)
    ).to(device)

    optimizer = AdamW(model.parameters(), lr=Config.LEARNING_RATE)

    # Bucle de pre-entrenamiento (similar al entrenamiento normal)
    for epoch in range(Config.PRETRAIN_EPOCHS):
        model.train()
        total_loss = 0
        progress_bar = tqdm(dataloader, desc=f"Pre-train Epoch {epoch+1}")

        for batch in progress_bar:
            optimizer.zero_grad()

            inputs = {
                'input_ids': batch['input_ids'].to(device),
                'attention_mask': batch['attention_mask'].to(device)
            }

            parent_labels = batch['parent_labels'].to(device)
            child_labels = batch['child_labels'].to(device)
            grandchild_labels = batch['grandchild_labels'].to(device)

            outputs = model(**inputs)

            loss = hierarchical_loss(
                outputs[0], outputs[1], outputs[2],
                parent_labels, child_labels, grandchild_labels
            )

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        print(f"Pre-train Epoch {epoch+1} | Loss: {total_loss/len(dataloader):.4f}")

    # Guardar modelo pre-entrenado
    torch.save(model.state_dict(), Config.SAVE_PATH)
    tokenizer.save_pretrained(Config.SAVE_TOKENIZER_PATH)
    print(f"Modelo pre-entrenado guardado en {Config.SAVE_PATH}")



pretrain()

Padres: 1914 - Hijos: 10088 - Nietos: 89244
Skipping short variant: VPH
Skipping short variant: Pinta
Skipping short variant: aguda
Skipping short variant: aguda
Skipping short variant: medio
Skipping short variant: SRV
Skipping short variant: Lepra
Skipping short variant: IDMAC
Skipping short variant: MSSA
Skipping short variant: MRSA
Skipping short variant: Goma
Skipping short variant: latum
Skipping short variant: VIH
Skipping short variant: aguda
Skipping short variant: viral
Skipping short variant: aguda
Skipping short variant: viral
Skipping short variant: aguda
Skipping short variant: viral
Skipping short variant: MSSA
Skipping short variant: MRSA
Skipping short variant: PPLO
Skipping short variant: STEC
Skipping short variant: STEC
Skipping short variant: STEC
Skipping short variant: VIH-2
Skipping short variant: piel
Skipping short variant: aguda
Skipping short variant: virus
Skipping short variant: SPH
Skipping short variant: SCPH
Skipping short variant: mayor
Skipping short 

Some weights of BertModel were not initialized from the model checkpoint at dccuchile/bert-base-spanish-wwm-cased and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Using default BERT model


OutOfMemoryError: CUDA out of memory. Tried to allocate 3.61 GiB. GPU 0 has a total capacity of 15.58 GiB of which 741.94 MiB is free. Including non-PyTorch memory, this process has 13.89 GiB memory in use. Of the allocated memory 13.43 GiB is allocated by PyTorch, and 181.27 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:

# ====================
#  DATASET
# ====================

class HierarchicalMedicalDataset(Dataset):
    def __init__(self, df, tokenizer, mlb_parent, mlb_child, mlb_grandchild):
        self.texts = df['text'].tolist()
        self.tokenizer = tokenizer

        # Procesar etiquetas
        self.parent_labels = []
        self.child_labels = []
        self.grandchild_labels = []

        for codes in df['labels'].apply(eval): # FIXME: Unsafe eval
            parents, children, grandchildren = set(), set(), set()
            for code in codes:
                levels = parse_code(code)
                if len(levels) >= 1: parents.add(levels[0])
                if len(levels) >= 2: children.add(levels[1])
                if len(levels) >= 3: grandchildren.add(levels[2])

            self.parent_labels.append(mlb_parent.transform([parents])[0])
            self.child_labels.append(mlb_child.transform([children])[0])
            self.grandchild_labels.append(mlb_grandchild.transform([grandchildren])[0])

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            self.texts[idx],
            max_length=Config.MAX_LENGTH,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'parent_labels': torch.FloatTensor(self.parent_labels[idx]),
            'child_labels': torch.FloatTensor(self.child_labels[idx]),
            'grandchild_labels': torch.FloatTensor(self.grandchild_labels[idx])
        }

# ====================
#  ENTRENAMIENTO
# ====================
def train():
    # Cargar datos
    train_df = pd.read_csv(Config.DATA_PATHS['train'])
    val_df = pd.read_csv(Config.DATA_PATHS['val'])

    # Construir binarizadores
    mlb_parent, mlb_child, mlb_grandchild = build_label_matrices([train_df, val_df])

    # Preparar datasets

    try:
        tokenizer = AutoTokenizer.from_pretrained(Config.SAVE_TOKENIZER_PATH)
        print("Loaded saved tokenizer")
    except:
        tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)
        tokenizer.save_pretrained(Config.SAVE_TOKENIZER_PATH)
        print("Created new tokenizer")

    train_dataset = HierarchicalMedicalDataset(train_df, tokenizer, mlb_parent, mlb_child, mlb_grandchild)
    val_dataset = HierarchicalMedicalDataset(val_df, tokenizer, mlb_parent, mlb_child, mlb_grandchild)

    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=Config.TRAIN_BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=Config.VAL_BATCH_SIZE)

    # Modelo y optimizador
    model = HierarchicalBERT(
        len(mlb_parent.classes_),
        len(mlb_child.classes_),
        len(mlb_grandchild.classes_)
    ).to(device)

    # Load best model if available
    try:
        model.load_state_dict(torch.load(Config.SAVE_STATE_PATH))
        print("Loaded previously saved best model")
    except:
        print("Starting training from scratch")

    optimizer = AdamW(model.parameters(), lr=Config.LEARNING_RATE)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps = Config.WARMUP_EPOCHS * len(train_loader),
        num_training_steps = Config.EPOCHS * len(train_loader)
    )

    # Bucle de entrenamiento
    scaler = torch.amp.GradScaler('cuda', enabled=Config.USE_FP16)
    best_f1 = 0
    for epoch in range(Config.EPOCHS):
        model.train()
        total_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")

        for step, batch in enumerate(progress_bar):
            inputs = {
                k: v.to(device)
                for k, v in batch.items()
                if k in ['input_ids', 'attention_mask']
            }

            parent_labels = batch['parent_labels'].to(device)
            child_labels = batch['child_labels'].to(device)
            grandchild_labels = batch['grandchild_labels'].to(device)

            with torch.amp.autocast('cuda', enabled=Config.USE_FP16):
                outputs = model(**inputs)

                # Línea corregida
                loss = hierarchical_loss(
                    outputs[0],  # parent_logits
                    outputs[1],  # child_logits
                    outputs[2],  # grandchild_logits
                    parent_labels,
                    child_labels,
                    grandchild_labels
                )

                loss = loss / Config.GRADIENT_ACCUMULATION_STEPS
                scaler.scale(loss).backward()

                if (step + 1) % Config.GRADIENT_ACCUMULATION_STEPS == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    scaler.step(optimizer)
                    scheduler.step()
                    optimizer.zero_grad()
                    scaler.update()

                total_loss += loss.item()
                progress_bar.set_postfix(loss=loss.item())

        # Validación
        val_metrics = evaluate(model, val_loader, device,
                              mlb_parent, mlb_child, mlb_grandchild)
        print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f} | LR: {scheduler.get_last_lr()[0]:.2E}")
        # Store metrics
        metrics_data = {
            'epoch': epoch + 1,
            'loss': total_loss/len(train_loader),
            'f1_micro': val_metrics['f1_micro'],
            'f1_macro': val_metrics['f1_macro']
        }

        # Write metrics to CSV
        metrics_df = pd.DataFrame([metrics_data])
        if epoch == 0:
            metrics_df.to_csv('training_metrics.csv', index=False)
        else:
            metrics_df.to_csv('training_metrics.csv', mode='a', header=False, index=False)

        print(f"F1 Validation | Micro: {val_metrics['f1_micro']:.5f} | Macro: {val_metrics['f1_macro']:.5f} | Best: {best_f1:.5f}")

        loss_metric = val_metrics['f1_macro']
        IMPROVEMENT_MARGIN = 0.001
        if loss_metric > (best_f1 + IMPROVEMENT_MARGIN) and epoch > 0:
            print(f"Saving best model... {best_f1:.5f} -> {loss_metric:.5f}")
            torch.save(model.state_dict(), Config.SAVE_PATH)
            best_f1 = loss_metric


# ====================
#  EVALUACIÓN
# ====================
def evaluate(model, dataloader, device, mlb_parent, mlb_child, mlb_grandchild):
    model.eval()
    parent_preds_all = []
    child_preds_all = []
    grandchild_preds_all = []
    parent_labels_all = []
    child_labels_all = []
    grandchild_labels_all = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluando"):
            inputs = {
                k: v.to(device)
                for k, v in batch.items()
                if k in ['input_ids', 'attention_mask']
            }

            # Get labels
            parent_labels = batch['parent_labels'].numpy()
            child_labels = batch['child_labels'].numpy()
            grandchild_labels = batch['grandchild_labels'].numpy()

            parent_logits, child_logits, grandchild_logits = model(**inputs)

            # Convert to predictions
            parent_preds = (torch.sigmoid(parent_logits).cpu().numpy() > Config.THRESHOLDS['parent']).astype(int)
            child_preds = (torch.sigmoid(child_logits).cpu().numpy() > Config.THRESHOLDS['child']).astype(int)
            grandchild_preds = (torch.sigmoid(grandchild_logits).cpu().numpy() > Config.THRESHOLDS['grandchild']).astype(int)

            # Append to lists
            parent_preds_all.extend(parent_preds)
            child_preds_all.extend(child_preds)
            grandchild_preds_all.extend(grandchild_preds)
            parent_labels_all.extend(parent_labels)
            child_labels_all.extend(child_labels)
            grandchild_labels_all.extend(grandchild_labels)

    # Convert to numpy arrays
    parent_preds_all = np.array(parent_preds_all)
    child_preds_all = np.array(child_preds_all)
    grandchild_preds_all = np.array(grandchild_preds_all)
    parent_labels_all = np.array(parent_labels_all)
    child_labels_all = np.array(child_labels_all)
    grandchild_labels_all = np.array(grandchild_labels_all)

    # Print example comparison for parent level
    if len(parent_labels_all) > 0:
        print("\nExample comparison (first instance, parent level):")

        parent_true = np.array(mlb_parent.classes_)[parent_labels_all[0].astype(bool)]
        parent_pred = np.array(mlb_parent.classes_)[parent_preds_all[0].astype(bool)]
        common_labels = len(set(parent_true) & set(parent_pred))
        total_labels = len(set(parent_true))
        accuracy_parent = common_labels / total_labels if total_labels > 0 else 0
        print("Expected parent labels:", sorted(parent_true))
        print("Predicted parent labels:", sorted(parent_pred))

        child_true = np.array(mlb_child.classes_)[child_labels_all[0].astype(bool)]
        child_pred = np.array(mlb_child.classes_)[child_preds_all[0].astype(bool)]
        common_labels = len(set(child_true) & set(child_pred))
        total_labels = len(set(child_true))
        accuracy_child = common_labels / total_labels if total_labels > 0 else 0
        print("Expected child labels:", sorted(child_true))
        print("Predicted child labels:", sorted(child_pred))

        grandchild_true = np.array(mlb_grandchild.classes_)[grandchild_labels_all[0].astype(bool)]
        grandchild_pred = np.array(mlb_grandchild.classes_)[grandchild_preds_all[0].astype(bool)]
        common_labels = len(set(grandchild_true) & set(grandchild_pred))
        total_labels = len(set(grandchild_true))
        accuracy_grandchild = common_labels / total_labels if total_labels > 0 else 0
        print("Expected grandchild labels:", sorted(grandchild_true))
        print("Predicted grandchild labels:", sorted(grandchild_pred))
        print(f"Percentage of correct parent labels: {accuracy_parent:.2%} | {accuracy_child:.2%} | {accuracy_grandchild:.2%}")

    # Calculate F1 scores for each level
    metrics = {
        'f1_micro_parent': f1_score(parent_labels_all, parent_preds_all, average='micro'),
        'f1_macro_parent': f1_score(parent_labels_all, parent_preds_all, average='macro'),
        'f1_micro_child': f1_score(child_labels_all, child_preds_all, average='micro'),
        'f1_macro_child': f1_score(child_labels_all, child_preds_all, average='macro'),
        'f1_micro_grandchild': f1_score(grandchild_labels_all, grandchild_preds_all, average='micro'),
        'f1_macro_grandchild': f1_score(grandchild_labels_all, grandchild_preds_all, average='macro')
    }

    # Calculate weighted average F1 scores
    metrics['f1_micro'] = (
        Config.HIERARCHICAL_WEIGHTS['parent'] * metrics['f1_micro_parent'] +
        Config.HIERARCHICAL_WEIGHTS['child'] * metrics['f1_micro_child'] +
        Config.HIERARCHICAL_WEIGHTS['grandchild'] * metrics['f1_micro_grandchild']
    ) / sum(Config.HIERARCHICAL_WEIGHTS.values())

    metrics['f1_macro'] = (
        Config.HIERARCHICAL_WEIGHTS['parent'] * metrics['f1_macro_parent'] +
        Config.HIERARCHICAL_WEIGHTS['child'] * metrics['f1_macro_child'] +
        Config.HIERARCHICAL_WEIGHTS['grandchild'] * metrics['f1_macro_grandchild']
    ) / sum(Config.HIERARCHICAL_WEIGHTS.values())

    return metrics

# ====================
#  PREDICCIÓN
# ====================
def predict(text, model, tokenizer, mlb_parent, mlb_child, mlb_grandchild, device):
    encoding = tokenizer(
        text,
        max_length=Config.MAX_LENGTH,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    ).to(device)

    with torch.no_grad():
        parent_logits, child_logits, grandchild_logits = model(**encoding)

    # Obtener predicciones
    parent_probs = torch.sigmoid(parent_logits).cpu().numpy()
    child_probs = torch.sigmoid(child_logits).cpu().numpy()
    grandchild_probs = torch.sigmoid(grandchild_logits).cpu().numpy()

    # Decodificar etiquetas
    parent_preds = mlb_parent.inverse_transform((parent_probs > Config.THRESHOLDS['parent']).astype(int))
    child_preds = mlb_child.inverse_transform((child_probs > Config.THRESHOLDS['child']).astype(int))
    grandchild_preds = mlb_grandchild.inverse_transform((grandchild_probs > Config.THRESHOLDS['grandchild']).astype(int))

    # Combinar y asegurar jerarquía
    final_codes = set()
    for parent in parent_preds[0]:
        final_codes.add(parent)
        for child in child_preds[0]:
            if child.startswith(parent):
                final_codes.add(child)
                for grandchild in grandchild_preds[0]:
                    if grandchild.startswith(child):
                        final_codes.add(grandchild)

    return sorted(final_codes)

# ====================
#  EJECUCIÓN
# ====================
if __name__ == "__main__":
    train()

    # Cargar datos de test
    test_df = pd.read_csv(Config.DATA_PATHS['test'])
    mlb_parent, mlb_child, mlb_grandchild = build_label_matrices([test_df])

    # Cargar modelo
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)

    model = HierarchicalBERT(
        len(mlb_parent.classes_),
        len(mlb_child.classes_),
        len(mlb_grandchild.classes_)
    ).to(device)

    model.load_state_dict(torch.load(Config.SAVE_PATH))

    # Evaluar en test
    test_dataset = HierarchicalMedicalDataset(test_df, tokenizer, mlb_parent, mlb_child, mlb_grandchild)
    test_loader = DataLoader(test_dataset, batch_size=Config.TEST_BATCH_SIZE)

    test_metrics = evaluate(model, test_loader, device, mlb_parent, mlb_child, mlb_grandchild)
    print("\nResultados en Test:")
    print(f"Micro F1: {test_metrics['f1_micro']:.4f}")
    print(f"Macro F1: {test_metrics['f1_macro']:.4f}")

    # Ejemplo de predicción
    sample_text = "Paciente con diabetes mellitus tipo 2 y complicaciones renales..."
    prediction = predict(sample_text, model, tokenizer, mlb_parent, mlb_child, mlb_grandchild, device)
    print("\nPredicción de ejemplo:", prediction)

    plot_metrics()


Padres: 920 - Hijos: 1686 - Nietos: 1035
Created new tokenizer


Some weights of BertModel were not initialized from the model checkpoint at dccuchile/bert-base-spanish-wwm-cased and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Using default BERT model
Starting training from scratch




Epoch 1:   0%|          | 0/63 [00:00<?, ?it/s]

KeyboardInterrupt: 