In [1]:
import pandas as pd
import numpy as np
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    default_data_collator
)
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score, hamming_loss, precision_score, recall_score
from torch.utils.data import Dataset
import warnings
warnings.filterwarnings("ignore")
import random

In [2]:
def check_gpu():
    """Check GPU availability"""
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    else:
        print("WARNING: No GPU available")

def load_model_and_tokenizer():
    """Load PubMedBERT model and tokenizer"""
    model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=4,
        problem_type="multi_label_classification"
    )

    print(f"Model loaded: {model.num_parameters():,} parameters")
    return tokenizer, model

In [3]:
def analyze_text_lengths(df, tokenizer):
    """Analyze text lengths for optimal max_length"""
    lengths = df['text'].apply(lambda x: len(tokenizer.encode(str(x))))

    print(f"\nText length analysis:")
    print(f"  Mean: {lengths.mean():.0f} tokens")
    print(f"  95th percentile: {lengths.quantile(0.95):.0f} tokens")
    print(f"  Max: {lengths.max():.0f} tokens")

    optimal_length = min(512, int(lengths.quantile(0.95)))
    print(f"  Recommended max_length: {optimal_length}")

    return optimal_length

In [4]:
def compute_multilabel_metrics(eval_pred):
    """Compute comprehensive multi-label metrics"""
    predictions, labels = eval_pred

    # Apply sigmoid and threshold
    predictions = torch.sigmoid(torch.tensor(predictions))
    predictions = (predictions > 0.5).int().numpy()

    # Global metrics
    metrics = {
        'f1_macro': f1_score(labels, predictions, average='macro', zero_division=0),
        'f1_micro': f1_score(labels, predictions, average='micro', zero_division=0),
        'f1_weighted': f1_score(labels, predictions, average='weighted', zero_division=0),
        'subset_accuracy': accuracy_score(labels, predictions),
        'hamming_loss': hamming_loss(labels, predictions)
    }

    # Per-category metrics
    categories = ['neurological', 'cardiovascular', 'hepatorenal', 'oncological']
    for i, cat in enumerate(categories):
        cat_labels = labels[:, i]
        cat_preds = predictions[:, i]

        metrics[f'f1_{cat}'] = f1_score(cat_labels, cat_preds, zero_division=0)
        metrics[f'precision_{cat}'] = precision_score(cat_labels, cat_preds, zero_division=0)
        metrics[f'recall_{cat}'] = recall_score(cat_labels, cat_preds, zero_division=0)

    return metrics


In [5]:
class MedicalPapersDataset(Dataset):
    """Custom dataset ensuring correct data types for multi-label classification"""

    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts.reset_index(drop=True)
        self.labels = labels.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts.iloc[idx])
        labels = self.labels.iloc[idx]

        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(labels, dtype=torch.float32)
        }

In [6]:
class ImprovedTrainer(Trainer):
    """Trainer con pérdida BCE ponderada para desbalance de clases"""

    def __init__(self, pos_weights=None, **kwargs):
        super().__init__(**kwargs)
        self.pos_weights = pos_weights.cuda() if pos_weights is not None and torch.cuda.is_available() else pos_weights

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)

        # BCE loss con pesos de posición para balancear clases
        if self.pos_weights is not None:
            loss = torch.nn.functional.binary_cross_entropy_with_logits(
                outputs.logits, labels, pos_weight=self.pos_weights
            )
        else:
            loss = torch.nn.functional.binary_cross_entropy_with_logits(
                outputs.logits, labels
            )

        return (loss, outputs) if return_outputs else loss

In [7]:
def calculate_class_weights(df):
    """Calcula pesos para balancear clases automáticamente"""
    labels_array = np.array(df['labels'].tolist())
    pos_counts = labels_array.sum(axis=0)
    neg_counts = len(labels_array) - pos_counts

    # Evitar división por cero y calcular pesos
    pos_weights = neg_counts / np.maximum(pos_counts, 1)

    # Normalizar pesos para evitar valores extremos
    pos_weights = np.clip(pos_weights, 0.1, 10.0)

    categories = ['neurological', 'cardiovascular', 'hepatorenal', 'oncological']
    print("\n📊 Pesos calculados para balancear clases:")
    for i, (cat, weight) in enumerate(zip(categories, pos_weights)):
        freq = pos_counts[i] / len(df) * 100
        print(f"  {cat:15}: peso {weight:.2f} (frecuencia: {freq:.1f}%)")

    return torch.tensor(pos_weights, dtype=torch.float32)


def create_medical_text_variation(original_text):
    """
    Data augmentation específica para textos médicos multi-label
    Basada en el análisis de 81 ejemplos reales con 3-4 categorías
    """


    # 1. SINÓNIMOS ESPECÍFICOS DE TUS DATOS REALES
    medical_synonyms = {
        # Términos de toxicidad (patrón clave en tus datos)
        'toxicity': ['adverse effects', 'side effects', 'toxic effects', 'harmful effects'],
        'nephrotoxicity': ['renal toxicity', 'kidney damage', 'renal adverse effects'],
        'hepatotoxicity': ['liver toxicity', 'hepatic damage', 'liver adverse effects'],
        'cardiotoxicity': ['cardiac toxicity', 'heart damage', 'cardiovascular toxicity'],

        # Términos de pacientes (235 menciones en tus datos)
        'patient': ['subject', 'individual', 'case', 'participant'],
        'patients': ['subjects', 'individuals', 'cases', 'participants'],

        # Términos de tratamiento (frecuentes en tus datos)
        'treatment': ['therapy', 'intervention', 'management', 'therapeutic approach'],
        'therapy': ['treatment', 'intervention', 'therapeutic regimen'],
        'drug': ['medication', 'pharmaceutical agent', 'therapeutic agent'],
        'chemotherapy': ['anticancer treatment', 'cytotoxic therapy', 'oncological treatment'],

        # Términos de órganos (patrones multi-organ)
        'cardiac': ['cardiovascular', 'heart-related', 'myocardial'],
        'renal': ['kidney-related', 'nephrological'],
        'hepatic': ['liver-related', 'hepatological'],
        'neurological': ['neurologic', 'brain-related', 'cerebral'],

        # Términos de resultado
        'failure': ['dysfunction', 'impairment', 'insufficiency'],
        'dysfunction': ['impairment', 'abnormal function', 'malfunction'],
        'syndrome': ['condition', 'disorder', 'clinical syndrome'],
        'disease': ['disorder', 'condition', 'pathology'],

        # Términos de severidad
        'severe': ['serious', 'significant', 'marked', 'pronounced'],
        'acute': ['sudden onset', 'rapid', 'abrupt'],
        'chronic': ['long-term', 'persistent', 'prolonged']
    }

    # 2. FRASES MÉDICAS ESPECÍFICAS (basadas en estructura de tus textos)
    medical_transitions = [
        'Clinical presentation revealed ',
        'Laboratory findings showed ',
        'The patient developed ',
        'Treatment resulted in ',
        'Complications included ',
        'Adverse effects comprised ',
        'Multiple organ involvement included ',
        'Systemic toxicity manifested as ',
        'Multi-organ dysfunction presented with '
    ]

    # 3. PATRONES DE CO-OCURRENCIA (de tu análisis)
    co_occurrence_patterns = {
        'cardio_renal': ['cardiac and renal complications', 'cardiovascular-renal syndrome', 'cardio-renal toxicity'],
        'neuro_cardio': ['neurological and cardiac effects', 'cerebro-cardiovascular complications'],
        'cancer_toxicity': ['chemotherapy-induced toxicity', 'anticancer drug adverse effects', 'oncological treatment complications'],
        'multi_organ': ['multi-organ toxicity', 'systemic adverse effects', 'multiple organ dysfunction']
    }

    # 4. APLICAR TRANSFORMACIONES
    words = original_text.split()
    transformed_words = []

    # Posibilidad de añadir frase médica específica (20% probabilidad)
    if random.random() < 0.2:
        transition = random.choice(medical_transitions)
        # Asegurarse de que no duplique el inicio
        if not any(t.lower().strip() in original_text.lower()[:100] for t in medical_transitions):
            words = [transition.strip()] + words

    for word in words:
        clean_word = word.lower().strip('.,!?():;[]"')

        # Reemplazar con sinónimo médico específico (25% probabilidad)
        if clean_word in medical_synonyms and random.random() < 0.25:
            synonym = random.choice(medical_synonyms[clean_word])
            # Mantener capitalización original
            if word[0].isupper():
                synonym = synonym.capitalize()
            transformed_words.append(word.replace(clean_word, synonym))
        else:
            transformed_words.append(word)

    # 5. AÑADIR PATRONES DE CO-OCURRENCIA CONTEXTUAL
    final_text = ' '.join(transformed_words)

    # Si el texto contiene múltiples menciones de órganos, añadir patrón de co-ocurrencia
    if random.random() < 0.15:  # 15% probabilidad
        organ_mentions = 0
        if any(term in final_text.lower() for term in ['cardiac', 'heart', 'cardiovascular']):
            organ_mentions += 1
        if any(term in final_text.lower() for term in ['renal', 'kidney', 'nephro']):
            organ_mentions += 1
        if any(term in final_text.lower() for term in ['hepatic', 'liver']):
            organ_mentions += 1
        if any(term in final_text.lower() for term in ['neuro', 'brain', 'cerebral']):
            organ_mentions += 1

        if organ_mentions >= 2:
            # Añadir frase que enfatice la naturaleza multi-órgano
            multi_organ_phrase = random.choice(co_occurrence_patterns['multi_organ'])
            # Insertar en una posición lógica del texto
            sentences = final_text.split('. ')
            if len(sentences) > 1:
                insert_pos = len(sentences) // 2
                sentences.insert(insert_pos, f"This case demonstrates {multi_organ_phrase}")
                final_text = '. '.join(sentences)

    return final_text



In [None]:
def augment_multilabel_with_real_patterns(df, target_samples=40):
    """
    Data augmentation específica para tus patrones multi-label reales
    """
    print("Aplicando data augmentation basada en patrones reales...")

    # Identificar muestras multi-label
    df['num_labels'] = df['labels'].apply(lambda x: sum(x))
    df['label_combo'] = df['labels'].apply(lambda x: '|'.join([str(i) for i, v in enumerate(x) if v == 1]))

    # Análisis de distribución actual
    multilabel_df = df[df['num_labels'] > 1].copy()
    combo_counts = multilabel_df['label_combo'].value_counts()

    print(f"Estado actual:")
    print(f"   - Muestras multi-label: {len(multilabel_df)}")
    print(f"   - Combinaciones únicas: {len(combo_counts)}")

    # Generar muestras sintéticas para combinaciones raras
    augmented_samples = []
    categories = ['neurological', 'cardiovascular', 'hepatorenal', 'oncological']

    for combo, current_count in combo_counts.items():
        if current_count < target_samples:
            needed = target_samples - current_count
            combo_data = multilabel_df[multilabel_df['label_combo'] == combo]

            # Información de la combinación
            combo_indices = [int(x) for x in combo.split('|')]
            combo_names = [categories[i] for i in combo_indices]

            print(f"   {' + '.join(combo_names)}: {current_count} → {target_samples} (+{needed})")

            for _ in range(needed):
                # Seleccionar muestra base aleatoria
                base_sample = combo_data.sample(1).iloc[0]

                # Crear variación específica
                augmented_text = create_medical_text_variation(base_sample['text'])

                # Variación adicional: combinar con otro texto similar (10% probabilidad)
                if random.random() < 0.1 and len(combo_data) > 1:
                    other_sample = combo_data.sample(1).iloc[0]
                    # Tomar primera mitad del texto original y segunda mitad de otro
                    mid_point = len(augmented_text) // 2
                    other_mid = len(other_sample['text']) // 2
                    augmented_text = augmented_text[:mid_point] + " Furthermore, " + other_sample['text'][other_mid:]

                augmented_samples.append({
                    'text': augmented_text,
                    'labels': base_sample['labels']
                })

    # Crear DataFrame final
    if augmented_samples:
        print(f" Generadas {len(augmented_samples)} muestras sintéticas")
        augmented_df = pd.DataFrame(augmented_samples)
        final_df = pd.concat([df[['text', 'labels']], augmented_df], ignore_index=True)

        # Verificar distribución final
        final_df['num_labels'] = final_df['labels'].apply(lambda x: sum(x))
        final_multilabel = final_df[final_df['num_labels'] > 1]
        print(f" Resultado: {len(final_multilabel)} muestras multi-label ({len(final_multilabel)/len(final_df)*100:.1f}%)")

        return final_df
    else:
        print(" No se generaron muestras adicionales")
        return df[['text', 'labels']]

In [9]:
def create_targeted_synthetic_samples(df, focus_combinations):
    """
    Crear muestras sintéticas específicas para las combinaciones más difíciles
    """
    import random

    # Templates específicos basados en tus 81 ejemplos reales
    templates = {
        'cardio_renal_onco': [
            "{drug} treatment in {cancer_type} patients resulted in {cardiac_effect} and {renal_effect}. {outcome}",
            "Case report of {cancer_type} patient developing {cardiac_effect} and {renal_effect} following {drug} therapy. {complications}",
            "{drug}-induced {cardiac_effect} and {renal_effect} in oncological patients with {cancer_type}. {management}"
        ],
        'neuro_cardio_renal': [
            "Patient with {neuro_condition} developed {cardiac_effect} and {renal_effect} during treatment. {outcome}",
            "{drug} therapy caused {neuro_effect}, {cardiac_effect}, and {renal_effect} in this clinical case. {management}",
            "Multi-organ toxicity including {neuro_effect}, {cardiac_effect}, and {renal_effect} following {intervention}. {outcome}"
        ],
        'all_four': [
            "Complex case of {cancer_type} patient with {neuro_condition} developing {cardiac_effect}, {renal_effect}, and {hepatic_effect}. {comprehensive_management}",
            "{drug} treatment resulted in multi-system toxicity: {neuro_effect}, {cardiac_effect}, {hepatic_effect}, and {renal_effect}. {outcome}",
            "Rare presentation of {syndrome} with neurological, cardiovascular, hepatic, and renal involvement. {clinical_course}"
        ]
    }

    # Variables para los templates (extraídas de tus datos reales)
    variables = {
        'drug': ['doxorubicin', 'cisplatin', 'tacrolimus', 'amiodarone', 'lithium', 'phenytoin'],
        'cancer_type': ['leukemia', 'lymphoma', 'carcinoma', 'sarcoma', 'breast cancer', 'lung cancer'],
        'cardiac_effect': ['cardiotoxicity', 'arrhythmias', 'heart failure', 'myocardial dysfunction'],
        'renal_effect': ['nephrotoxicity', 'acute renal failure', 'renal dysfunction', 'kidney damage'],
        'hepatic_effect': ['hepatotoxicity', 'liver dysfunction', 'hepatic failure', 'liver damage'],
        'neuro_effect': ['neurotoxicity', 'encephalopathy', 'seizures', 'cognitive impairment'],
        'neuro_condition': ['stroke', 'epilepsy', 'dementia', 'Parkinson disease'],
        'outcome': ['Patient recovered with supportive care.', 'Long-term monitoring required.', 'Partial recovery achieved.'],
        'complications': ['Multiple complications required intensive management.', 'Severe adverse effects were observed.'],
        'management': ['Treatment was discontinued and supportive care initiated.', 'Dose reduction and monitoring implemented.'],
        'comprehensive_management': ['Multidisciplinary approach required for optimal outcomes.', 'Complex case requiring specialized care.'],
        'intervention': ['chemotherapy', 'immunosuppressive therapy', 'antiarrhythmic treatment'],
        'syndrome': ['multi-organ failure syndrome', 'drug-induced multi-system toxicity', 'complex clinical syndrome'],
        'clinical_course': ['Progressive deterioration observed.', 'Gradual improvement with treatment modifications.']
    }

    synthetic_samples = []

    for combo_name, template_list in templates.items():
        for _ in range(5):  # 5 muestras por template
            template = random.choice(template_list)

            # Rellenar template con variables aleatorias
            filled_template = template
            for var_name, var_options in variables.items():
                if f'{{{var_name}}}' in filled_template:
                    filled_template = filled_template.replace(f'{{{var_name}}}', random.choice(var_options))

            # Determinar labels según el template
            if combo_name == 'cardio_renal_onco':
                labels = [0, 1, 1, 1]  # cardiovascular, hepatorenal, oncological
            elif combo_name == 'neuro_cardio_renal':
                labels = [1, 1, 1, 0]  # neurological, cardiovascular, hepatorenal
            elif combo_name == 'all_four':
                labels = [1, 1, 1, 1]  # all categories

            synthetic_samples.append({
                'text': filled_template,
                'labels': labels
            })

    return synthetic_samples

In [None]:
def prepare_medical_dataset_enhanced(df, apply_augmentation=True):
    """Tu función prepare_medical_dataset pero con augmentation mejorada"""

    # Tu código original
    category_mapping = {
        'neurological': 0,
        'cardiovascular': 1,
        'hepatorenal': 2,
        'oncological': 3
    }

    def parse_medical_labels(group_str):
        labels = [0, 0, 0, 0]
        if pd.isna(group_str):
            return labels
        categories = str(group_str).split('|')
        for cat in categories:
            cat = cat.strip().lower()
            if cat in category_mapping:
                labels[category_mapping[cat]] = 1
        return labels

    # Crear texto combinado
    df['text'] = df['title'].astype(str) + " [SEP] " + df['abstract'].astype(str)
    df['labels'] = df['group'].apply(parse_medical_labels)

    # NUEVA PARTE: Augmentation mejorada
    if apply_augmentation:
        print("🔬 Aplicando data augmentation específica para multi-label...")
        df = augment_multilabel_with_real_patterns(df, target_samples=35)

        # Añadir muestras sintéticas dirigidas
        print(" Creando muestras sintéticas para combinaciones críticas...")
        synthetic_samples = create_targeted_synthetic_samples(df, ['all_four', 'cardio_renal_onco'])
        if synthetic_samples:
            synthetic_df = pd.DataFrame(synthetic_samples)
            df = pd.concat([df, synthetic_df], ignore_index=True)
            print(f"✨ Añadidas {len(synthetic_samples)} muestras sintéticas dirigidas")

    # Imprimir distribución final
    categories = ['neurological', 'cardiovascular', 'hepatorenal', 'oncological']
    print("\nDistribución final de etiquetas:")
    for i, cat in enumerate(categories):
        count = sum(1 for labels in df['labels'] if labels[i] == 1)
        percentage = (count / len(df)) * 100
        print(f"  {cat:15}: {count:4d} samples ({percentage:5.1f}%)")

    # Análisis multi-label
    df['num_labels'] = df['labels'].apply(lambda x: sum(x))
    multilabel_count = sum(1 for num in df['num_labels'] if num > 1)
    print(f"\nMuestras multi-label: {multilabel_count} ({multilabel_count/len(df)*100:.1f}%)")

    return df[['text', 'labels']].copy()

In [11]:
def get_optimized_training_args():
    """Configuración optimizada para multi-label desbalanceado"""
    return TrainingArguments(
        output_dir='./pubmedbert-medical-v6',

        # Entrenamiento más largo y cuidadoso
        num_train_epochs=4,  # Más epochs para aprender patrones complejos
        per_device_train_batch_size=6,  # Batch más pequeño para mejor gradientes
        per_device_eval_batch_size=4,
        gradient_accumulation_steps=4,

        # Evaluación más frecuente
        eval_strategy="steps",
        eval_steps=50,
        logging_steps=25,

        # Optimización
        fp16=True,
        max_grad_norm=1.0,
        learning_rate=2e-5,  # Learning rate más bajo para estabilidad
        warmup_ratio=0.1,  # Más warmup para convergencia suave
        weight_decay=0.1,  # Mayor regularización
        lr_scheduler_type="cosine_with_restarts",

        # Early stopping mejorado
        save_strategy="steps",
        save_steps=50,
        load_best_model_at_end=True,
        metric_for_best_model="f1_macro",
        greater_is_better=True,

        # Regularización adicional
        label_smoothing_factor=0.05,  # Suavizado suave

        # Configuración técnica
        seed=42,
        report_to="wandb",
        dataloader_num_workers=0,
        remove_unused_columns=False
    )

In [None]:
def train_optimized_medical_classifier(csv_path, sep=";", quotechar='"'):
    """
    Función principal optimizada para resolver el problema de multi-label desbalanceado
    """

    print(" === ENTRENAMIENTO OPTIMIZADO PARA MULTI-LABEL ===")
    check_gpu()

    # 1. CARGAR DATOS
    print(f"\n Cargando datos desde {csv_path}")
    df = pd.read_csv(csv_path, sep=sep, quotechar=quotechar, quoting=1)
    print(f"✅ Cargados {len(df):,} samples")

    # 2. CARGAR MODELO
    print("\n Cargando PubMedBERT model...")
    tokenizer, model = load_model_and_tokenizer()

    # 3. PREPARAR DATOS CON MEJORAS
    print("\n Preparando dataset con optimizaciones...")
    df_prepared = prepare_medical_dataset_enhanced(df, apply_augmentation=True)

    # 5. CALCULAR PESOS DE CLASE
    print("\n Calculando pesos para balancear clases...")
    class_weights = calculate_class_weights(df_prepared)

    # 6. ANÁLISIS DE TEXTO
    optimal_max_length = 512 #analyze_text_lengths(df_prepared, tokenizer)

    # 7. DIVISIÓN ESTRATIFICADA
    print("\n Dividiendo datos con estratificación...")

    df_prepared['label_string'] = df_prepared['labels'].apply(str)
    train_df, val_df = train_test_split(
        df_prepared,
        test_size=0.2,
        stratify=df_prepared['label_string'],
        random_state=42
    )
    print(f"   Train: {len(train_df):,} samples")
    print(f"   Validation: {len(val_df):,} samples")

    # Verificar distribución en validation
    val_multilabel = sum(1 for labels in val_df['labels'] if sum(labels) > 1)
    print(f"   Multi-label en validation: {val_multilabel} ({val_multilabel/len(val_df)*100:.1f}%)")

    # 8. CREAR DATASETS
    print("\n Creando datasets optimizados...")
    train_dataset = MedicalPapersDataset(
        train_df['text'], train_df['labels'], tokenizer, optimal_max_length
    )
    val_dataset = MedicalPapersDataset(
        val_df['text'], val_df['labels'], tokenizer, optimal_max_length
    )

    # Verificar formato de datos
    sample = train_dataset[0]
    assert sample['labels'].dtype == torch.float32, "Labels deben ser float32"
    print(" Formato de datos verificado")

    # 9. CONFIGURAR ENTRENAMIENTO
    training_args = get_optimized_training_args()

    # 10. CREAR TRAINER OPTIMIZADO
    print("\n Configurando trainer optimizado...")
    trainer = ImprovedTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        data_collator=default_data_collator,
        compute_metrics=compute_multilabel_metrics,
        pos_weights=class_weights
    )

    # 11. ENTRENAR
    print(f"\n Iniciando entrenamiento optimizado...")
    print(f"    Configuración:")
    print(f"      - Epochs: {training_args.num_train_epochs}")
    print(f"      - Batch efectivo: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
    print(f"      - Learning rate: {training_args.learning_rate}")
    print(f"      - Weight decay: {training_args.weight_decay}")
    print(f"      - Max length: {optimal_max_length}")
    print(f"      - Samples totales: {len(train_df):,}")

    train_result = trainer.train()

    # 12. EVALUACIÓN FINAL
    print("\n Evaluación final...")
    final_metrics = trainer.evaluate()

    # 13. MOSTRAR RESULTADOS
    print(f"\n ¡ENTRENAMIENTO COMPLETADO!")
    print(f"\n Métricas principales:")
    key_metrics = {
        'eval_f1_macro': 'F1 Macro',
        'eval_f1_micro': 'F1 Micro',
        'eval_f1_weighted': 'F1 Weighted',
        'eval_subset_accuracy': 'Subset Accuracy',
        'eval_hamming_loss': 'Hamming Loss'
    }

    for metric_key, metric_name in key_metrics.items():
        if metric_key in final_metrics:
            value = final_metrics[metric_key]
            print(f"   {metric_name:15}: {value:.4f}")

    print(f"\n F1 Score por categoría:")
    categories = ['neurological', 'cardiovascular', 'hepatorenal', 'oncological']
    for cat in categories:
        f1_key = f'eval_f1_{cat}'
        if f1_key in final_metrics:
            print(f"   {cat:15}: {final_metrics[f1_key]:.4f}")

    # 14. GUARDAR MODELO
    model_path = "./pubmedbert-medical-v6"
    trainer.save_model(model_path)
    tokenizer.save_pretrained(model_path)
    print(f"\n Modelo guardado en: {model_path}")

    return trainer, final_metrics


In [None]:
if __name__ == "__main__":

    print(" Iniciando pipeline completo de optimización...")

    # 1. Entrenar modelo optimizado
    csv_file = "/content/challenge_data-18-ago.csv"
    trainer, metrics = train_optimized_medical_classifier(csv_file)



🎬 Iniciando pipeline completo de optimización...
🚀 === ENTRENAMIENTO OPTIMIZADO PARA MULTI-LABEL ===
GPU: Tesla T4
VRAM: 14.7 GB

📂 Cargando datos desde /content/challenge_data-18-ago.csv
✅ Cargados 3,565 samples

🧠 Cargando PubMedBERT model...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded: 109,485,316 parameters

⚙️ Preparando dataset con optimizaciones...
🔬 Aplicando data augmentation específica para multi-label...
🚀 Aplicando data augmentation basada en patrones reales...
📊 Estado actual:
   - Muestras multi-label: 1092
   - Combinaciones únicas: 11
   🔧 neurological + cardiovascular + hepatorenal: 28 → 35 (+7)
   🔧 neurological + hepatorenal + oncological: 26 → 35 (+9)
   🔧 neurological + cardiovascular + oncological: 13 → 35 (+22)
   🔧 cardiovascular + hepatorenal + oncological: 7 → 35 (+28)
   🔧 neurological + cardiovascular + hepatorenal + oncological: 7 → 35 (+28)
✨ Generadas 94 muestras sintéticas
📈 Resultado: 1186 muestras multi-label (32.4%)
🎯 Creando muestras sintéticas para combinaciones críticas...
✨ Añadidas 15 muestras sintéticas dirigidas

Distribución final de etiquetas:
  neurological   : 1861 samples ( 50.7%)
  cardiovascular : 1368 samples ( 37.2%)
  hepatorenal    : 1178 samples ( 32.1%)
  oncological    :  698 samples ( 19.0%)

Muestra

[34m[1mwandb[0m: Currently logged in as: [33mdzience[0m ([33mdzience-nousgraph[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss,F1 Macro,F1 Micro,F1 Weighted,Subset Accuracy,Hamming Loss,F1 Neurological,Precision Neurological,Recall Neurological,F1 Cardiovascular,Precision Cardiovascular,Recall Cardiovascular,F1 Hepatorenal,Precision Hepatorenal,Recall Hepatorenal,F1 Oncological,Precision Oncological,Recall Oncological,Runtime,Samples Per Second,Steps Per Second
50,0.8805,0.854814,0.510864,0.546579,0.554108,0.170068,0.37415,0.697674,0.673317,0.723861,0.583867,0.449704,0.832117,0.412776,0.488372,0.357447,0.349138,0.249231,0.582734,6.8834,106.779,26.731
100,0.5686,0.426395,0.844038,0.849875,0.850153,0.653061,0.102381,0.859504,0.883853,0.836461,0.837782,0.957746,0.744526,0.880165,0.855422,0.906383,0.798701,0.727811,0.884892,6.7993,108.1,27.062
150,0.3391,0.293226,0.904844,0.901508,0.901366,0.772789,0.066667,0.882022,0.926254,0.841823,0.908088,0.914815,0.90146,0.918919,0.976077,0.868085,0.910345,0.874172,0.94964,6.9013,106.502,26.662
200,0.2675,0.23832,0.920477,0.917753,0.917431,0.802721,0.055782,0.885755,0.934524,0.841823,0.946396,0.958801,0.934307,0.936264,0.968182,0.906383,0.913495,0.88,0.94964,6.8225,107.731,26.969
250,0.2265,0.213687,0.931392,0.923853,0.922817,0.819048,0.051361,0.876437,0.944272,0.817694,0.955556,0.969925,0.941606,0.940171,0.944206,0.93617,0.953405,0.95,0.956835,6.8267,107.665,26.953
300,0.2105,0.19417,0.936167,0.928393,0.92766,0.829932,0.048639,0.88826,0.94012,0.841823,0.952206,0.959259,0.945255,0.939914,0.948052,0.931915,0.964286,0.957447,0.971223,6.8344,107.544,26.923
350,0.1716,0.182556,0.938091,0.930723,0.929957,0.834014,0.046939,0.892045,0.94864,0.841823,0.952555,0.952555,0.952555,0.943478,0.964444,0.923404,0.964286,0.957447,0.971223,6.7364,109.109,27.314
400,0.1363,0.176267,0.943921,0.935936,0.935239,0.844898,0.043537,0.894217,0.943452,0.849866,0.959707,0.963235,0.956204,0.950538,0.96087,0.940426,0.971223,0.971223,0.971223,6.7551,108.806,27.239
450,0.1892,0.174229,0.943706,0.936532,0.935892,0.846259,0.043197,0.898592,0.946588,0.855228,0.957952,0.959707,0.956204,0.950538,0.96087,0.940426,0.967742,0.964286,0.971223,6.7286,109.234,27.346



📊 Evaluación final...



🎉 ¡ENTRENAMIENTO COMPLETADO!

📈 Métricas principales:
   F1 Macro       : 0.9439
   F1 Micro       : 0.9359
   F1 Weighted    : 0.9352
   Subset Accuracy: 0.8449
   Hamming Loss   : 0.0435

🏷️ F1 Score por categoría:
   neurological   : 0.8942
   cardiovascular : 0.9597
   hepatorenal    : 0.9505
   oncological    : 0.9712

💾 Modelo guardado en: ./pubmedbert-medical-v6


# TESTING


In [22]:
def predict_medical_categories(text, model_path="./pubmedbert-medical-v6", threshold=0.5):
    categories = ['neurological', 'cardiovascular', 'hepatorenal', 'oncological']

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForSequenceClassification.from_pretrained(model_path)

    inputs = tokenizer(
        text,
        return_tensors='pt',
        truncation=True,
        padding=True,
        max_length=512 # Same as training
    )

    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = torch.sigmoid(outputs.logits)[0]

    results = []
    for i, (category, prob) in enumerate(zip(categories, predictions)):
        results.append({
            'category': category,
            'probability': prob.item(),
            'predicted': prob.item() > threshold
        })

    return results

# Test it
sample_text = "P53 inhibition exacerbates late-stage anthracycline cardiotoxicity. AIMS: Doxorubicin (DOX) is an effective anti-cancer therapeutic, but is associated with both acute and late-stage cardiotoxicity. Children are particularly sensitive to DOX-induced heart failure. Here, the impact of p53 inhibition on acute vs. late-stage DOX cardiotoxicity was examined in a juvenile model. METHODS AND RESULTS: Two-week-old MHC-CB7 mice (which express dominant-interfering p53 in cardiomyocytes) and their non-transgenic (NON-TXG) littermates received weekly DOX injections for 5 weeks (25 mg/kg cumulative dose). One week after the last DOX treatment (acute stage), MHC-CB7 mice exhibited improved cardiac function and lower levels of cardiomyocyte apoptosis when compared with the NON-TXG mice. Surprisingly, by 13 weeks following the last DOX treatment (late stage), MHC-CB7 exhibited a progressive decrease in cardiac function and higher rates of cardiomyocyte apoptosis when compared with NON-TXG mice. p53 inhibition blocked transient DOX-induced STAT3 activation in MHC-CB7 mice, which was associated with enhanced induction of the DNA repair proteins Ku70 and Ku80. Mice with cardiomyocyte-restricted deletion of STAT3 exhibited worse cardiac function, higher levels of cardiomyocyte apoptosis, and a greater induction of Ku70 and Ku80 in response to DOX treatment during the acute stage when compared with control animals. CONCLUSION: These data support a model wherein a p53-dependent cardioprotective pathway, mediated via STAT3 activation, mitigates DOX-induced myocardial stress during drug delivery. Furthermore, these data suggest an explanation as to how p53 inhibition can result in cardioprotection during drug treatment and, paradoxically, enhanced cardiotoxicity long after the cessation of drug treatment."

for pred in predictions:
    if pred['predicted']:
        print(f"{pred['category']}: {pred['probability']:.3f}")

neurological: 0.372
cardiovascular: 0.912
hepatorenal: 0.847
oncological: 0.986


In [3]:
def predict_medical_categories(text, model_path="./my_medical_model", threshold=0.25):
    categories = ['neurological', 'cardiovascular', 'hepatorenal', 'oncological']

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForSequenceClassification.from_pretrained(model_path)

    inputs = tokenizer(
        text,
        return_tensors='pt',
        truncation=True,
        padding=True,
        max_length=512
    )

    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = torch.sigmoid(outputs.logits)[0]

    results = []
    for i, (category, prob) in enumerate(zip(categories, predictions)):
        results.append({
            'category': category,
            'probability': prob.item(),
            'predicted': prob.item() > threshold
        })

    return results

# ============================================================================
# EASY TESTING - Just change the text below and run!
# ============================================================================

# CAMBIA ESTE TEXTO POR EL QUE QUIERAS PROBAR:
my_test_text = "	Adrenoleukodystrophy: survey of 303 cases: biochemistry, diagnosis, and therapy. Adrenoleukodystrophy ( ALD ) is a genetically determined disorder associated with progressive central demyelination and adrenal cortical insufficiency . All affected persons show increased levels of saturated unbranched very-long-chain fatty acids , particularly hexacosanoate ( C26 0 ) , because of impaired capacity to degrade these acids . This degradation normally takes place in a subcellular organelle called the peroxisome , and ALD , together with Zellwegers cerebrohepatorenal syndrome , is now considered to belong to the newly formed category of peroxisomal disorders . Biochemical assays permit prenatal diagnosis , as well as identification of most heterozygotes . We have identified 303 patients with ALD in 217 kindreds . These patients show a wide phenotypic variation . Sixty percent of patients had childhood ALD and 17 % adrenomyeloneuropathy , both of which are X-linked , with the gene mapped to Xq28 . Neonatal ALD , a distinct entity with autosomal recessive inheritance and points of resemblance to Zellwegers syndrome , accounted for 7 % of the cases . Although excess C26 0 in the brain of patients with ALD is partially of dietary origin , dietary C26 0 restriction did not produce clear benefit . Bone marrow transplant lowered the plasma C26 0 level but failed to arrest neurological progression ."

# Ejecutar predicción
predictions = predict_medical_categories(my_test_text)

# Mostrar resultados
print(f"Analyzing: {my_test_text[:60]}...")
print(f"\nResults:")
predicted_count = 0
for pred in predictions:
    status = "✓" if pred['predicted'] else "✗"
    print(f"  {status} {pred['category']:15}: {pred['probability']:.3f}")
    if pred['predicted']:
        predicted_count += 1

print(f"\nSummary: {predicted_count} categories detected")
if predicted_count > 1:
    print("MULTI-LABEL DETECTED!")

Analyzing: 	Adrenoleukodystrophy: survey of 303 cases: biochemistry, di...

Results:
  ✓ neurological   : 0.868
  ✗ cardiovascular : 0.021
  ✓ hepatorenal    : 0.931
  ✗ oncological    : 0.036

Summary: 2 categories detected
MULTI-LABEL DETECTED!


In [28]:
import os
import shutil
from google.colab import files

# Lista de archivos necesarios
files_needed = [
    'config.json',
    'model.safetensors',
    'tokenizer.json',
    'tokenizer_config.json',
    'vocab.txt',
    'special_tokens_map.json'
]

# Ruta de tu modelo
model_path = "./pubmedbert-medical-v6"

# Crear carpeta temporal para archivos a descargar
download_folder = "./model_files_to_download"
os.makedirs(download_folder, exist_ok=True)

# Copiar solo los archivos necesarios
for file_name in files_needed:
    source = os.path.join(model_path, file_name)
    dest = os.path.join(download_folder, file_name)

    if os.path.exists(source):
        shutil.copy2(source, dest)
        print(f"✓ Copiado: {file_name}")
    else:
        print(f"✗ No encontrado: {file_name}")

# Crear ZIP con solo los archivos necesarios
shutil.make_archive("my_medical_model", 'zip', download_folder)

# Descargar el ZIP
files.download("my_medical_model.zip")

print("\nArchivo my_medical_model.zip descargado con éxito!")
print("Contiene solo los archivos esenciales para el modelo.")

✓ Copiado: config.json
✓ Copiado: model.safetensors
✓ Copiado: tokenizer.json
✓ Copiado: tokenizer_config.json
✓ Copiado: vocab.txt
✓ Copiado: special_tokens_map.json


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>


Archivo my_medical_model.zip descargado con éxito!
Contiene solo los archivos esenciales para el modelo.
