# Medical Multi-Task Learning: Complete Training Notebook
## ‚úÖ All 7 Models √ó All 8 Tasks - Ready to Train!

**What This Notebook Does**:
- ‚úÖ Clones code from GitHub
- ‚úÖ Loads tokenizer, datasets, model automatically
- ‚úÖ Trains with proper configuration
- ‚úÖ Evaluates and saves results
- ‚úÖ Works with all 7 models and all 8 tasks

**Expected Results**:
- BioBERT on BC2GM: **F1 = 0.84** (not 0.46!)
- Smoke test (50 samples): F1 > 0.30 in 2 minutes
- Full training: F1 = 0.84 in ~3 hours

## Cell 1: Setup & Clone Repository

In [None]:
import sys
import os
from pathlib import Path

# Clone repo
print('üì• Cloning repository...')
os.chdir('/kaggle/working')
!rm -rf Crosstalk_Medical_LLM
!git clone https://github.com/bharathbolla/Crosstalk_Medical_LLM.git
os.chdir('Crosstalk_Medical_LLM')

print(f'\n‚úÖ Current directory: {os.getcwd()}')

# Verify datasets exist
!python test_pickle_load.py

## Cell 2: Install Dependencies

In [None]:
!pip install -q transformers torch accelerate scikit-learn seqeval pandas scipy

import torch
import json
import pandas as pd
from datetime import datetime
from pathlib import Path

# GPU verification
print(f'\n‚úÖ PyTorch: {torch.__version__}')
print(f'‚úÖ CUDA: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'‚úÖ GPU: {torch.cuda.get_device_name(0)}')
    print(f'‚úÖ VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

RESULTS_DIR = Path('results')
RESULTS_DIR.mkdir(exist_ok=True)
EXPERIMENT_ID = datetime.now().strftime('%Y%m%d_%H%M%S')
print(f'\nüìä Experiment ID: {EXPERIMENT_ID}')

## Cell 3: Configuration
### ‚≠ê Change ONLY these 2 lines to test different models/tasks!

In [None]:
# ============================================
# ‚≠ê MAIN CONFIGURATION
# ============================================

CONFIG = {
    # ‚≠ê MODEL (choose one of 7 models)
    'model_name': 'dmis-lab/biobert-v1.1',  # BioBERT
    # Other options:
    # 'bionlp/bluebert_pubmed_mimic_uncased_L-12_H-768_A-12',  # BlueBERT
    # 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract',  # PubMedBERT
    # 'allenai/biomed_roberta_base',  # BioMed-RoBERTa
    # 'emilyalsentzer/Bio_ClinicalBERT',  # Clinical-BERT
    # 'roberta-base',  # RoBERTa
    # 'bert-base-uncased',  # BERT
    
    # ‚≠ê TASK (choose one or more)
    'datasets': ['bc2gm'],  # Start with BC2GM
    # Options: bc2gm, jnlpba, chemprot, ddi, gad, hoc, pubmedqa, biosses
    
    'experiment_id': EXPERIMENT_ID,
    'max_samples_per_dataset': None,
    'num_epochs': 10,
    'batch_size': 32,
    'learning_rate': 2e-5,
    'max_length': 512,
    'warmup_steps': 500,
    'use_early_stopping': True,
    'early_stopping_patience': 3,
}

# Auto-adjust batch size based on GPU
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    if 'A100' in gpu_name:
        CONFIG['batch_size'] = 64
    elif 'T4' in gpu_name:
        CONFIG['batch_size'] = 32

print('='*60)
print('CONFIGURATION')
print('='*60)
print(f"Model: {CONFIG['model_name']}")
print(f"Tasks: {CONFIG['datasets']}")
print(f"Batch: {CONFIG['batch_size']}")
print(f"Epochs: {CONFIG['num_epochs']}")
print('='*60)

## Cell 4: üî• SMOKE TEST (Run This First!)
### Set SMOKE_TEST = True for 2-minute validation
### Set SMOKE_TEST = False for full training

In [None]:
# ============================================
# ‚≠ê SMOKE TEST TOGGLE
# ============================================

SMOKE_TEST = True  # ‚Üê Change to False for full training

print('\n' + '='*60)
if SMOKE_TEST:
    print('üî• SMOKE TEST MODE')
    print('='*60)
    CONFIG['max_samples_per_dataset'] = 50
    CONFIG['num_epochs'] = 1
    CONFIG['batch_size'] = 16
    CONFIG['max_length'] = 128
    CONFIG['use_early_stopping'] = False
    print('Settings: 50 samples, 1 epoch, batch 16')
    print('Expected: F1 > 0.30')
    print('Time: ~2 minutes')
else:
    print('üöÄ FULL TRAINING MODE')
    print('='*60)
    print(f"Samples: ALL")
    print(f"Epochs: {CONFIG['num_epochs']}")
    print(f"Batch: {CONFIG['batch_size']}")
    print('Expected: F1 = 0.84')
    print('Time: ~3 hours')
print('='*60)

## Cell 5: Load Complete Implementation
### Imports all fixed code from repository

In [None]:
print('\nüì¶ Loading complete implementation...')

# Execute dataset code
exec(open('COMPLETE_FIXED_DATASET.py').read())
print('‚úÖ Dataset code loaded')

# Execute model code
exec(open('COMPLETE_FIXED_MODEL.py').read())
print('‚úÖ Model code loaded')

# Execute metrics code
exec(open('COMPLETE_FIXED_METRICS.py').read())
print('‚úÖ Metrics code loaded')

print('\n' + '='*60)
print('‚úÖ ALL CODE LOADED')
print('='*60)

## Cell 6: Load Tokenizer

In [None]:
from transformers import AutoTokenizer

print('\nüî§ Loading tokenizer...')

# Load tokenizer (handles RoBERTa automatically)
model_name = CONFIG['model_name']

if 'roberta' in model_name.lower():
    tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)
    print('   ‚úÖ RoBERTa tokenizer (add_prefix_space=True)')
else:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    print('   ‚úÖ BERT tokenizer')

print(f'   Model: {model_name}')
print(f'   Vocab size: {tokenizer.vocab_size:,}')
print('='*60)

## Cell 7: Load Datasets

In [None]:
import pickle

print('\nüìä Loading datasets...')

primary_dataset = CONFIG['datasets'][0]
max_samples = CONFIG['max_samples_per_dataset']

# Load pickle file
pickle_file = f'data/{primary_dataset}_train.pkl'
with open(pickle_file, 'rb') as f:
    raw_data = pickle.load(f)

# Limit samples if smoke test
if max_samples:
    raw_data['train'] = raw_data['train'][:max_samples]
    raw_data['validation'] = raw_data['validation'][:max_samples//5]

print(f'   Dataset: {primary_dataset}')
print(f"   Train samples: {len(raw_data['train']):,}")
print(f"   Validation samples: {len(raw_data['validation']):,}")

# Create datasets using UniversalMedicalDataset
task_config = TASK_CONFIGS[primary_dataset]

train_dataset = UniversalMedicalDataset(
    data=raw_data['train'],
    tokenizer=tokenizer,
    task_type=task_config['task_type'],
    labels=task_config['labels'],
    max_length=CONFIG['max_length']
)

val_dataset = UniversalMedicalDataset(
    data=raw_data['validation'],
    tokenizer=tokenizer,
    task_type=task_config['task_type'],
    labels=task_config['labels'],
    max_length=CONFIG['max_length']
)

# Store dataset stats
dataset_stats = {
    primary_dataset: {
        'task_type': task_config['task_type'],
        'model_type': task_config['model_type'],
        'num_labels': len(task_config['labels']) if task_config['labels'] else 1,
        'train_size': len(train_dataset),
        'val_size': len(val_dataset),
    }
}

print(f"   ‚úÖ Created UniversalMedicalDataset")
print(f"   Task type: {task_config['task_type']}")
print(f"   Num labels: {dataset_stats[primary_dataset]['num_labels']}")
print('='*60)

## Cell 8: Load Model

In [None]:
from transformers import (
    AutoModelForTokenClassification,
    AutoModelForSequenceClassification,
    AutoConfig
)

print('\nü§ñ Loading model...')

# Load model with correct head for task
task_info = dataset_stats[primary_dataset]
model_type = task_info['model_type']
num_labels = task_info['num_labels']

if model_type == 'token_classification':
    # NER tasks
    model = AutoModelForTokenClassification.from_pretrained(
        model_name,
        num_labels=num_labels,
        ignore_mismatched_sizes=True
    )
    print(f'   ‚úÖ TokenClassification head loaded')

elif model_type == 'sequence_classification':
    # RE, Classification, QA
    config = AutoConfig.from_pretrained(model_name)
    if task_config.get('problem_type'):
        config.problem_type = task_config['problem_type']
    
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        config=config,
        num_labels=num_labels,
        ignore_mismatched_sizes=True
    )
    print(f'   ‚úÖ SequenceClassification head loaded')

elif model_type == 'regression':
    # Similarity
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=1,
        ignore_mismatched_sizes=True
    )
    print(f'   ‚úÖ Regression head loaded')

# Move to GPU
if torch.cuda.is_available():
    model = model.cuda()
    print('   ‚úÖ Model on GPU')

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'   Total parameters: {total_params:,}')
print(f'   Trainable: {trainable_params:,} ({100 * trainable_params / total_params:.1f}%)')
print('='*60)

## Cell 9: Setup Trainer

In [None]:
from transformers import (
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)

print('\n‚öôÔ∏è  Setting up trainer...')

# Training arguments
training_args = TrainingArguments(
    output_dir=f"./checkpoints/{primary_dataset}_{EXPERIMENT_ID}",
    num_train_epochs=CONFIG['num_epochs'],
    per_device_train_batch_size=CONFIG['batch_size'],
    per_device_eval_batch_size=CONFIG['batch_size'],
    learning_rate=CONFIG['learning_rate'],
    warmup_steps=CONFIG['warmup_steps'],
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=50,
    eval_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    metric_for_best_model='f1',
    greater_is_better=True,
    save_total_limit=2,
    report_to='none',
    disable_tqdm=False,
)

# Setup compute_metrics function
if task_config['task_type'] == 'ner':
    def compute_metrics_fn(eval_pred):
        return compute_ner_metrics(eval_pred, task_config['labels'])
elif task_config['task_type'] in ['re', 'classification', 'qa']:
    def compute_metrics_fn(eval_pred):
        return compute_classification_metrics(eval_pred)
elif task_config['task_type'] == 'multilabel_classification':
    def compute_metrics_fn(eval_pred):
        return compute_multilabel_metrics(eval_pred)
elif task_config['task_type'] == 'similarity':
    def compute_metrics_fn(eval_pred):
        return compute_regression_metrics(eval_pred)

# Create trainer
callbacks = []
if CONFIG['use_early_stopping']:
    callbacks.append(EarlyStoppingCallback(early_stopping_patience=CONFIG['early_stopping_patience']))

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics_fn,
    callbacks=callbacks,
)

print('   ‚úÖ Trainer ready')
print(f'   Output: {training_args.output_dir}')
print('='*60)

## Cell 10: Train Model
### üöÄ This is where the actual training happens!

In [None]:
print('\n' + '='*60)
print('üöÄ STARTING TRAINING')
print('='*60)

# Train!
train_result = trainer.train()

print('\n' + '='*60)
print('‚úÖ TRAINING COMPLETE')
print('='*60)
print(f"Train loss: {train_result.training_loss:.4f}")
print(f"Train time: {train_result.metrics['train_runtime']:.1f}s")
print('='*60)

## Cell 11: Evaluate Model

In [None]:
print('\nüìä Evaluating on validation set...')

# Evaluate
eval_result = trainer.evaluate()

print('\n' + '='*60)
print('üìä EVALUATION RESULTS')
print('='*60)
for key, value in eval_result.items():
    if 'f1' in key.lower() or 'precision' in key.lower() or 'recall' in key.lower():
        print(f'{key}: {value:.4f}')
print('='*60)

# Save results
results = {
    'experiment_id': EXPERIMENT_ID,
    'model': model_name,
    'dataset': primary_dataset,
    'task_type': task_config['task_type'],
    'smoke_test': SMOKE_TEST,
    'config': CONFIG,
    'train_metrics': train_result.metrics,
    'eval_metrics': eval_result,
}

# Save to JSON
results_file = RESULTS_DIR / f'results_{EXPERIMENT_ID}.json'
with open(results_file, 'w') as f:
    json.dump(results, f, indent=2)

print(f'\n‚úÖ Results saved to: {results_file}')

## Cell 12: Final Summary

In [None]:
print('\n' + '='*60)
print('üéâ EXPERIMENT COMPLETE')
print('='*60)
print(f"Model: {model_name}")
print(f"Dataset: {primary_dataset}")
print(f"Mode: {'SMOKE TEST' if SMOKE_TEST else 'FULL TRAINING'}")
print(f"\nF1 Score: {eval_result.get('eval_f1', 0):.4f}")

if SMOKE_TEST:
    if eval_result.get('eval_f1', 0) > 0.30:
        print('\n‚úÖ Smoke test PASSED!')
        print('   ‚Üí Set SMOKE_TEST = False for full training')
    else:
        print('\n‚ùå Smoke test FAILED')
        print('   ‚Üí Check configuration and data')
else:
    expected_f1 = 0.84 if primary_dataset == 'bc2gm' else 0.70
    if eval_result.get('eval_f1', 0) > expected_f1 - 0.05:
        print(f'\n‚úÖ Result matches expected F1 (~{expected_f1:.2f})')
    else:
        print(f'\n‚ö†Ô∏è  F1 lower than expected (~{expected_f1:.2f})')

print('\n' + '='*60)
print('Next Steps:')
if SMOKE_TEST:
    print('1. Set SMOKE_TEST = False in Cell 4')
    print('2. Run All Cells for full training')
else:
    print('1. Try different models (change model_name in Cell 3)')
    print('2. Try different tasks (change datasets in Cell 3)')
    print('3. Check results/ folder for saved metrics')
print('='*60)