# Medical Cross-Task Transfer: Universal Notebook
## All 7 Models √ó All 8 Tasks - Fixed & Ready

**Features**:
- ‚úÖ Automatic task detection (NER, RE, Classification, QA, Similarity)
- ‚úÖ Automatic model head selection
- ‚úÖ Automatic metrics computation
- ‚úÖ Works with all 7 BERT models
- ‚úÖ Works with all 8 medical NLP tasks
- ‚úÖ Integrated smoke test for quick validation
- ‚úÖ Token tracking for RQ5
- ‚úÖ CSV export for analysis

**Expected Results**:
- BioBERT on BC2GM: F1 = 0.84 (not 0.46!)
- All models and tasks work automatically

## Cell 1: Setup & Clone Repository

In [None]:
import sys
import os
from pathlib import Path

# Clone repo
print("üì• Cloning repository...")
os.chdir('/kaggle/working')
!rm -rf Crosstalk_Medical_LLM
!git clone https://github.com/bharathbolla/Crosstalk_Medical_LLM.git
os.chdir('Crosstalk_Medical_LLM')

print(f"\n‚úÖ Current directory: {os.getcwd()}")

# Verify datasets
!python test_pickle_load.py

## Cell 2: Install Dependencies

In [None]:
# Install libraries
!pip install -q transformers torch accelerate scikit-learn seqeval pandas scipy

import torch
import json
import pickle
import pandas as pd
import csv
import gc
from datetime import datetime
from pathlib import Path
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    AutoModelForSequenceClassification,
    AutoConfig,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
from torch.utils.data import Dataset

# GPU verification
print(f"\n‚úÖ PyTorch: {torch.__version__}")
print(f"‚úÖ CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")
    print(f"‚úÖ VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Create results directory
RESULTS_DIR = Path("results")
RESULTS_DIR.mkdir(exist_ok=True)

# Experiment ID
EXPERIMENT_ID = datetime.now().strftime("%Y%m%d_%H%M%S")
print(f"\nüìä Experiment ID: {EXPERIMENT_ID}")

## Cell 3: Configuration
### ‚≠ê Change ONLY these 2 lines to test different models and tasks!

In [None]:
# ============================================
# ‚≠ê MAIN CONFIGURATION ‚≠ê
# Change these 2 lines to test any model + task combination!
# ============================================

CONFIG = {
    # ‚≠ê MODEL SELECTION (test any of the 7 models)
    "model_name": "dmis-lab/biobert-v1.1",  # Start with BioBERT
    # Other options:
    # "bionlp/bluebert_pubmed_mimic_uncased_L-12_H-768_A-12",  # BlueBERT (best)
    # "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",  # PubMedBERT
    # "allenai/biomed_roberta_base",  # BioMed-RoBERTa
    # "emilyalsentzer/Bio_ClinicalBERT",  # Clinical-BERT
    # "roberta-base",  # RoBERTa (baseline)
    # "bert-base-uncased",  # BERT (baseline)

    # ‚≠ê TASK SELECTION (test any of the 8 tasks)
    "datasets": ["bc2gm"],  # Start with BC2GM
    # Other options: ["jnlpba"], ["chemprot"], ["ddi"], ["gad"], ["hoc"], ["pubmedqa"], ["biosses"]
    # Or multiple: ["bc2gm", "jnlpba"]  # Multi-task

    # Experiment metadata
    "experiment_id": EXPERIMENT_ID,
    "experiment_type": "single_task",  # or "multi_task"
    "description": "Universal notebook - auto task detection",

    # Dataset configuration
    "max_samples_per_dataset": None,  # None = use all data, or 50 for smoke test

    # Training hyperparameters
    "num_epochs": 10,
    "batch_size": 32,  # Auto-adjusted based on GPU
    "learning_rate": 2e-5,
    "max_length": 512,
    "warmup_steps": 500,
    "weight_decay": 0.01,

    # Early stopping
    "use_early_stopping": True,
    "early_stopping_patience": 3,
    "early_stopping_threshold": 0.0001,

    # Token tracking (RQ5)
    "track_tokens": True,

    # Checkpointing
    "save_strategy": "steps",
    "save_steps": 100,
    "keep_last_n_checkpoints": 2,
    "resume_from_checkpoint": True,

    # Evaluation
    "eval_strategy": "steps",
    "eval_steps": 250,

    # Logging
    "use_wandb": False,
    "logging_steps": 50,
}

# Auto-detect GPU and adjust batch size
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    total_vram = torch.cuda.get_device_properties(0).total_memory / 1e9

    print(f"\nüîç GPU Detection:")
    print(f"   GPU: {gpu_name}")
    print(f"   VRAM: {total_vram:.1f} GB")

    if "A100" in gpu_name:
        CONFIG['batch_size'] = 64
        print(f"   ‚úÖ Optimized for A100: batch_size = 64")
    elif "T4" in gpu_name:
        CONFIG['batch_size'] = 32
        print(f"   ‚úÖ Optimized for T4: batch_size = 32")
    else:
        CONFIG['batch_size'] = 16
        print(f"   ‚ö†Ô∏è  Conservative: batch_size = 16")

# Save config
with open(RESULTS_DIR / f"config_{EXPERIMENT_ID}.json", 'w') as f:
    json.dump(CONFIG, f, indent=2)

print("\n" + "="*60)
print("EXPERIMENT CONFIGURATION")
print("="*60)
print(f"Model: {CONFIG['model_name']}")
print(f"Datasets: {CONFIG['datasets']}")
print(f"Batch size: {CONFIG['batch_size']}")
print(f"Max epochs: {CONFIG['num_epochs']}")
print("="*60)

## Cell 4: Task Configurations
### Automatic task detection - no changes needed!

In [None]:
# ============================================
# TASK CONFIGURATIONS FOR ALL 8 DATASETS
# ============================================

TASK_CONFIGS = {
    # NER tasks - token classification
    'bc2gm': {
        'task_type': 'ner',
        'labels': ['O', 'B-GENE', 'I-GENE'],
        'model_type': 'token_classification',
    },
    'jnlpba': {
        'task_type': 'ner',
        'labels': ['O', 'B-DNA', 'I-DNA', 'B-RNA', 'I-RNA', 'B-cell_line', 'I-cell_line', 'B-cell_type', 'I-cell_type', 'B-protein', 'I-protein'],
        'model_type': 'token_classification',
    },

    # Relation Extraction - sequence classification
    'chemprot': {
        'task_type': 're',
        'labels': [f'CPR:{i}' for i in range(13)],
        'model_type': 'sequence_classification',
    },
    'ddi': {
        'task_type': 're',
        'labels': ['DDI-false', 'DDI-mechanism', 'DDI-effect', 'DDI-advise', 'DDI-int'],
        'model_type': 'sequence_classification',
    },

    # Classification tasks
    'gad': {
        'task_type': 'classification',
        'labels': ['0', '1'],
        'model_type': 'sequence_classification',
    },
    'hoc': {
        'task_type': 'multilabel_classification',
        'labels': [f'hallmark_{i}' for i in range(10)],
        'model_type': 'sequence_classification',
        'problem_type': 'multi_label_classification',
    },

    # QA task
    'pubmedqa': {
        'task_type': 'qa',
        'labels': ['no', 'yes', 'maybe'],
        'model_type': 'sequence_classification',
    },

    # Similarity/Regression task
    'biosses': {
        'task_type': 'similarity',
        'labels': None,
        'model_type': 'regression',
    },
}

print("‚úÖ Task configurations loaded for all 8 datasets")
print(f"\nConfigured tasks: {list(TASK_CONFIGS.keys())}")

## Cell 5: üî• SMOKE TEST
### Quick validation (50 samples, 1 epoch) - Run this first!

In [None]:
# ============================================
# SMOKE TEST: Quick validation before full run
# ============================================

import sys
import io
if sys.platform != 'win32':  # Only on Linux/Kaggle
    sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')

print("\n" + "="*60)
print("üî• SMOKE TEST")
print("="*60)
print("Purpose: Quick validation (50 samples, 1 epoch)")
print("Time: ~2-3 minutes")
print("Expected: F1 > 0.30 (just checking it works!)")
print("="*60)

# Ask user if they want smoke test
print("\n‚ö†Ô∏è  Do you want to run SMOKE TEST first?")
print("   YES ‚Üí Quick 2-min test (recommended)")
print("   NO  ‚Üí Full training (~3 hours)")
print("\nTo enable: Set SMOKE_TEST = True below")
print("To disable: Set SMOKE_TEST = False below")

# ‚≠ê SET THIS TO True FOR SMOKE TEST
SMOKE_TEST = True  # Change to False for full training

if SMOKE_TEST:
    print("\n‚úÖ SMOKE TEST ENABLED")
    print("   Settings: 50 samples, 1 epoch, batch 16")
    
    # Override config for smoke test
    CONFIG['max_samples_per_dataset'] = 50
    CONFIG['num_epochs'] = 1
    CONFIG['batch_size'] = 16
    CONFIG['max_length'] = 128
    CONFIG['warmup_steps'] = 10
    CONFIG['save_steps'] = 50
    CONFIG['eval_steps'] = 25
    CONFIG['use_early_stopping'] = False
    
    print("\n‚è±Ô∏è  Expected time: 2-3 minutes")
    print("‚úÖ If F1 > 0.30: Everything works! Set SMOKE_TEST=False for full run")
    print("‚ùå If F1 < 0.30: Something wrong, check configuration")
else:
    print("\n‚úÖ FULL TRAINING MODE")
    print("   Using full configuration")
    print(f"   Samples: ALL")
    print(f"   Epochs: {CONFIG['num_epochs']}")
    print(f"   Expected time: ~3 hours")

print("="*60)