# Medical Cross-Task Transfer Learning - Research Experiments

**Purpose**: Full experimental pipeline for research paper

**Setup**: GPU T4 x2 + Internet ON

**Experiments**:
- Single-task baselines (S1)
- Multi-task learning (S2, S3)
- Token-controlled baselines (RQ5)
- Full evaluation metrics
- Result tracking for paper

---

## Cell 1: Setup & Clone Repository

In [None]:
import sys
import os
from pathlib import Path

# Clone repo
print("üì• Cloning repository...")
os.chdir('/kaggle/working')
!rm -rf Crosstalk_Medical_LLM
!git clone https://github.com/bharathbolla/Crosstalk_Medical_LLM.git
os.chdir('Crosstalk_Medical_LLM')

print(f"\n‚úÖ Current directory: {os.getcwd()}")

# Verify datasets
!python test_pickle_load.py

## Cell 2: Install Dependencies & Setup Tracking

In [None]:
# Install libraries
!pip install -q transformers torch accelerate scikit-learn wandb seqeval pandas

import torch
import wandb
import json
import pickle
import pandas as pd
import csv
from datetime import datetime
from pathlib import Path

# GPU verification
print(f"\n‚úÖ PyTorch: {torch.__version__}")
print(f"‚úÖ CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")
    print(f"‚úÖ VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Create results directory
RESULTS_DIR = Path("results")
RESULTS_DIR.mkdir(exist_ok=True)

# Experiment ID
EXPERIMENT_ID = datetime.now().strftime("%Y%m%d_%H%M%S")
print(f"\nüìä Experiment ID: {EXPERIMENT_ID}")

## Cell 3: Experiment Configuration

In [None]:
# ========================================
# EXPERIMENT CONFIGURATION
# ========================================

CONFIG = {
    # Experiment metadata
    "experiment_id": EXPERIMENT_ID,
    "experiment_type": "single_task",  # Options: single_task, multi_task, token_controlled
    "description": "Single-task baseline for BC2GM",
    
    # Dataset configuration
    "datasets": ["bc2gm"],  # Can add multiple: ["bc2gm", "jnlpba", "chemprot"]
    "max_samples_per_dataset": None,  # None = use all data
    
    # Model configuration
    "model_name": "bert-base-uncased",  
    # Options: "bert-base-uncased", "dmis-lab/biobert-v1.1", 
    #          "allenai/scibert_scivocab_uncased",
    #          "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"
    
    # Training hyperparameters (OPTIMIZED FOR A100)
    "num_epochs": 10,  # Max epochs (early stopping will stop earlier)
    "batch_size": 64,  # A100 can handle much larger batches (was 16 for T4)
    "learning_rate": 2e-5,
    "max_length": 512,
    "warmup_steps": 500,
    "weight_decay": 0.01,
    
    # Early stopping (CRITICAL for research rigor!)
    "use_early_stopping": True,
    "early_stopping_patience": 3,  # Stop if no improvement for 3 evaluations
    "early_stopping_threshold": 0.0001,  # Minimum improvement to count as better
    
    # Token tracking (RQ5 - CRITICAL for paper)
    "track_tokens": True,
    "target_tokens": None,  # Set to stop at specific token count
    
    # Checkpointing (OPTIMIZED FOR INTERRUPTIBLE INSTANCES)
    "save_strategy": "steps",
    "save_steps": 100,  # Checkpoint every 100 steps (~2 min on A100)
    "keep_last_n_checkpoints": 2,
    "resume_from_checkpoint": True,  # Auto-resume if interrupted
    
    # Evaluation
    "eval_strategy": "steps",
    "eval_steps": 250,  # Evaluate every 250 steps for early stopping
    
    # Logging
    "use_wandb": False,  # Set True to enable wandb tracking
    "wandb_project": "medical-cross-task-transfer",
    "logging_steps": 50,
}

# Auto-detect GPU and adjust batch size
import torch
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    total_vram = torch.cuda.get_device_properties(0).total_memory / 1e9
    
    print(f"\nüîç GPU Detection:")
    print(f"   GPU: {gpu_name}")
    print(f"   VRAM: {total_vram:.1f} GB")
    
    # Auto-adjust batch size based on GPU
    if "A100" in gpu_name or "A6000" in gpu_name:
        CONFIG['batch_size'] = 64
        print(f"   ‚úÖ Optimized for A100: batch_size = 64")
    elif "A4000" in gpu_name or "RTX 4000" in gpu_name or total_vram > 20:
        CONFIG['batch_size'] = 48
        print(f"   ‚úÖ Optimized for A4000: batch_size = 48")
    elif "T4" in gpu_name or total_vram >= 15:
        CONFIG['batch_size'] = 32
        print(f"   ‚úÖ Optimized for T4: batch_size = 32")
    else:
        CONFIG['batch_size'] = 16
        print(f"   ‚ö†Ô∏è  Conservative: batch_size = 16")

# Save config
config_path = RESULTS_DIR / f"config_{EXPERIMENT_ID}.json"
with open(config_path, 'w') as f:
    json.dump(CONFIG, f, indent=2)

print("\n" + "="*60)
print("EXPERIMENT CONFIGURATION")
print("="*60)
for key, value in CONFIG.items():
    print(f"{key:30s}: {value}")
print("="*60)

## Cell 4: Load Datasets with Token Tracking

In [None]:
from transformers import AutoTokenizer
from torch.utils.data import Dataset

class TokenTrackingNERDataset(Dataset):
    """NER dataset with token counting for RQ5."""
    
    def __init__(self, data, tokenizer, max_length=512, task_name="unknown"):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.task_name = task_name
        self.total_tokens = 0
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Get tokens and labels
        tokens = item['tokens']
        labels = item.get('ner_tags', item.get('labels', [0] * len(tokens)))
        
        # Convert to text
        text = ' '.join(tokens)
        
        # Tokenize
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Count tokens (for RQ5)
        num_tokens = encoding['attention_mask'].sum().item()
        self.total_tokens += num_tokens
        
        # Align labels with tokenization
        aligned_labels = [-100] * self.max_length
        for i in range(min(len(labels), self.max_length)):
            aligned_labels[i] = labels[i]
        
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': torch.tensor(aligned_labels),
            'task_name': self.task_name,
            'num_tokens': num_tokens
        }

# Load tokenizer
print(f"\nü§ñ Loading tokenizer: {CONFIG['model_name']}")
tokenizer = AutoTokenizer.from_pretrained(CONFIG['model_name'])

# Load datasets
print("\nüì¶ Loading datasets...")
print("="*60)

all_train_datasets = {}
all_val_datasets = {}
all_test_datasets = {}
dataset_stats = {}

for dataset_name in CONFIG['datasets']:
    pickle_file = Path(f"data/pickle/{dataset_name}.pkl")
    
    with open(pickle_file, 'rb') as f:
        data = pickle.load(f)
    
    # Apply sample limit if specified
    train_data = data['train']
    if CONFIG['max_samples_per_dataset']:
        train_data = train_data[:CONFIG['max_samples_per_dataset']]
    
    val_data = data.get('validation', data.get('test', train_data[:100]))
    test_data = data.get('test', val_data)
    
    # Create datasets
    all_train_datasets[dataset_name] = TokenTrackingNERDataset(
        train_data, tokenizer, CONFIG['max_length'], dataset_name
    )
    all_val_datasets[dataset_name] = TokenTrackingNERDataset(
        val_data, tokenizer, CONFIG['max_length'], dataset_name
    )
    all_test_datasets[dataset_name] = TokenTrackingNERDataset(
        test_data, tokenizer, CONFIG['max_length'], dataset_name
    )
    
    # Calculate unique label count
    all_labels = set()
    for item in train_data:
        all_labels.update(item.get('ner_tags', item.get('labels', [])))
    num_labels = len(all_labels)
    
    dataset_stats[dataset_name] = {
        'train_samples': len(train_data),
        'val_samples': len(val_data),
        'test_samples': len(test_data),
        'num_labels': num_labels,
    }
    
    print(f"\n{dataset_name.upper()}:")
    print(f"  Train: {len(train_data):,} samples")
    print(f"  Val: {len(val_data):,} samples")
    print(f"  Test: {len(test_data):,} samples")
    print(f"  Labels: {num_labels}")

print("\n" + "="*60)
print(f"‚úÖ Loaded {len(CONFIG['datasets'])} dataset(s)")

# Save dataset stats
stats_path = RESULTS_DIR / f"dataset_stats_{EXPERIMENT_ID}.json"
with open(stats_path, 'w') as f:
    json.dump(dataset_stats, f, indent=2)

## Cell 5: Initialize Model

In [None]:
from transformers import AutoModelForTokenClassification

# Get primary dataset for model initialization
primary_dataset = CONFIG['datasets'][0]
num_labels = dataset_stats[primary_dataset]['num_labels']

print(f"\nü§ñ Loading model: {CONFIG['model_name']}")
print(f"   Task: {primary_dataset}")
print(f"   Number of labels: {num_labels}")

model = AutoModelForTokenClassification.from_pretrained(
    CONFIG['model_name'],
    num_labels=num_labels,
    ignore_mismatched_sizes=True
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nüìä Model Statistics:")
print(f"   Total parameters: {total_params:,}")
print(f"   Trainable parameters: {trainable_params:,}")
print(f"   Trainable %: {100 * trainable_params / total_params:.2f}%")

# Move to GPU
if torch.cuda.is_available():
    model = model.cuda()
    print(f"\n‚úÖ Model moved to GPU")

## Cell 6: Training with Token Tracking

In [None]:
from transformers import TrainingArguments, Trainer, TrainerCallback, EarlyStoppingCallback
import numpy as np

# ============================================
# UNIVERSAL LABEL MAPS FOR ALL 8 DATASETS
# ============================================

LABEL_MAPS = {
    # NER datasets
    'bc2gm': ["O", "B-GENE", "I-GENE"],
    
    'jnlpba': [
        "O", 
        "B-DNA", "I-DNA",
        "B-RNA", "I-RNA", 
        "B-cell_line", "I-cell_line",
        "B-cell_type", "I-cell_type",
        "B-protein", "I-protein"
    ],
    
    # Relation Extraction datasets (treated as NER for entities)
    'chemprot': ["O", "B-CHEMICAL", "I-CHEMICAL", "B-GENE", "I-GENE"],
    'ddi': ["O", "B-DRUG", "I-DRUG"],
    
    # Classification datasets
    'gad': ["NEG", "POS"],  # Binary classification
    'hoc': [f"CLASS_{i}" for i in range(10)],  # Multi-label (adjust as needed)
    
    # QA dataset
    'pubmedqa': ["no", "yes", "maybe"],
    
    # Similarity dataset (regression)
    'biosses': None,  # Will use regression metrics
}

# Get labels for current dataset
primary_dataset = CONFIG['datasets'][0]
label_list = LABEL_MAPS.get(primary_dataset)

if label_list is None:
    print(f"‚ö†Ô∏è  {primary_dataset} is a regression task, using custom metrics")
    IS_NER_TASK = False
else:
    print(f"‚úÖ Dataset: {primary_dataset}")
    print(f"‚úÖ Labels: {label_list[:5]}..." if len(label_list) > 5 else f"‚úÖ Labels: {label_list}")
    IS_NER_TASK = any('B-' in str(label) for label in label_list)  # Check if BIO tagging
    print(f"‚úÖ Task type: {'NER' if IS_NER_TASK else 'Classification'}")

# ============================================
# TOKEN TRACKING CALLBACK (RQ5)
# ============================================

class TokenTrackingCallback(TrainerCallback):
    def __init__(self):
        self.total_tokens = 0
        self.token_history = []
    
    def on_step_end(self, args, state, control, **kwargs):
        # This is a simplified version - full implementation would track from batch
        pass

# ============================================
# METRICS COMPUTATION (WORKS FOR ALL DATASETS)
# ============================================

def compute_metrics(pred):
    """Compute metrics appropriate for task type."""
    predictions, labels = pred
    predictions = np.argmax(predictions, axis=2)
    
    if IS_NER_TASK:
        # NER evaluation with seqeval
        from seqeval.metrics import f1_score, precision_score, recall_score
        
        true_labels = []
        true_predictions = []
        
        for prediction, label in zip(predictions, labels):
            true_label = []
            true_pred = []
            
            for p, l in zip(prediction, label):
                if l != -100:  # Skip padding
                    # Ensure index is within bounds
                    if l < len(label_list):
                        true_label.append(label_list[l])
                    else:
                        true_label.append("O")  # Fallback
                    
                    if p < len(label_list):
                        true_pred.append(label_list[p])
                    else:
                        true_pred.append("O")  # Fallback
            
            if true_label:  # Only add if non-empty
                true_labels.append(true_label)
                true_predictions.append(true_pred)
        
        # Calculate NER metrics
        try:
            f1 = f1_score(true_labels, true_predictions)
            precision = precision_score(true_labels, true_predictions)
            recall = recall_score(true_labels, true_predictions)
        except Exception as e:
            print(f"‚ö†Ô∏è  Metrics calculation warning: {e}")
            f1, precision, recall = 0.0, 0.0, 0.0
    
    else:
        # Classification/QA with sklearn
        from sklearn.metrics import f1_score as sklearn_f1
        from sklearn.metrics import precision_score as sklearn_precision
        from sklearn.metrics import recall_score as sklearn_recall
        
        # Flatten and remove padding
        true_labels_flat = []
        true_predictions_flat = []
        
        for prediction, label in zip(predictions, labels):
            for p, l in zip(prediction, label):
                if l != -100:
                    true_labels_flat.append(l)
                    true_predictions_flat.append(p)
        
        try:
            f1 = sklearn_f1(true_labels_flat, true_predictions_flat, average='macro', zero_division=0)
            precision = sklearn_precision(true_labels_flat, true_predictions_flat, average='macro', zero_division=0)
            recall = sklearn_recall(true_labels_flat, true_predictions_flat, average='macro', zero_division=0)
        except Exception as e:
            print(f"‚ö†Ô∏è  Metrics calculation warning: {e}")
            f1, precision, recall = 0.0, 0.0, 0.0
    
    return {
        'f1': f1,
        'precision': precision,
        'recall': recall,
    }

# ============================================
# TRAINING SETUP
# ============================================

# Setup output directory
output_dir = f"./output_{EXPERIMENT_ID}"

# Training arguments
training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=CONFIG['num_epochs'],
    per_device_train_batch_size=CONFIG['batch_size'],
    per_device_eval_batch_size=CONFIG['batch_size'],
    learning_rate=CONFIG['learning_rate'],
    warmup_steps=CONFIG['warmup_steps'],
    weight_decay=CONFIG['weight_decay'],
    logging_steps=CONFIG['logging_steps'],
    eval_strategy=CONFIG['eval_strategy'],
    eval_steps=CONFIG['eval_steps'],
    save_strategy=CONFIG['save_strategy'],
    save_steps=CONFIG['save_steps'],
    save_total_limit=CONFIG['keep_last_n_checkpoints'],
    load_best_model_at_end=True,
    metric_for_best_model='f1',
    greater_is_better=True,
    fp16=torch.cuda.is_available(),
    report_to="wandb" if CONFIG['use_wandb'] else "none",
)

# Initialize wandb if enabled
if CONFIG['use_wandb']:
    wandb.init(
        project=CONFIG['wandb_project'],
        name=f"{CONFIG['experiment_type']}_{EXPERIMENT_ID}",
        config=CONFIG
    )

# Get datasets for primary task
train_dataset = all_train_datasets[primary_dataset]
eval_dataset = all_val_datasets[primary_dataset]

# Prepare callbacks
callbacks = [TokenTrackingCallback()]

# Add early stopping if configured
if CONFIG.get('use_early_stopping', False):
    early_stopping = EarlyStoppingCallback(
        early_stopping_patience=CONFIG.get('early_stopping_patience', 3),
        early_stopping_threshold=CONFIG.get('early_stopping_threshold', 0.0001)
    )
    callbacks.append(early_stopping)
    print(f"\n‚ö†Ô∏è  Early stopping enabled: patience={CONFIG['early_stopping_patience']}, threshold={CONFIG['early_stopping_threshold']}")

# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
    callbacks=callbacks,
)

print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)
print(f"Experiment: {CONFIG['experiment_type']}")
print(f"Dataset: {primary_dataset}")
print(f"Model: {CONFIG['model_name']}")
print(f"Max epochs: {CONFIG['num_epochs']} (early stopping may end sooner)")
print(f"Batch size: {CONFIG['batch_size']}")
print(f"Evaluating every {CONFIG['eval_steps']} steps")
print(f"Task type: {'NER' if IS_NER_TASK else 'Classification'}")
print(f"Number of labels: {len(label_list) if label_list else 'N/A'}")
print("="*60 + "\n")

# Train
train_result = trainer.train()

print("\n" + "="*60)
print("‚úÖ TRAINING COMPLETE")
print(f"Stopped at epoch: {train_result.metrics.get('epoch', 'N/A')}")
print("="*60)

## Cell 7: Evaluation & Results

In [None]:
# Evaluate on test set
print("\nüìä Evaluating on test set...")
test_dataset = all_test_datasets[primary_dataset]
test_results = trainer.evaluate(test_dataset)

print("\n" + "="*60)
print("TEST SET RESULTS")
print("="*60)
for key, value in test_results.items():
    if isinstance(value, float):
        print(f"{key:30s}: {value:.4f}")
    else:
        print(f"{key:30s}: {value}")
print("="*60)

# Compile full results
full_results = {
    'experiment_id': EXPERIMENT_ID,
    'config': CONFIG,
    'dataset_stats': dataset_stats,
    'model_params': {
        'total': total_params,
        'trainable': trainable_params,
    },
    'train_results': {
        'train_loss': train_result.training_loss,
        'train_runtime': train_result.metrics['train_runtime'],
        'train_samples_per_second': train_result.metrics['train_samples_per_second'],
    },
    'test_results': test_results,
    'token_count': train_dataset.total_tokens if CONFIG['track_tokens'] else None,
}

# Save results as JSON
results_path = RESULTS_DIR / f"results_{EXPERIMENT_ID}.json"
with open(results_path, 'w') as f:
    json.dump(full_results, f, indent=2, default=str)

print(f"\nüíæ Results saved to: {results_path}")

# ============================================
# CSV EXPORT FOR EASY COMPARISON
# ============================================

import csv
import pandas as pd

# Create CSV row with key metrics
csv_row = {
    'experiment_id': EXPERIMENT_ID,
    'timestamp': datetime.now().isoformat(),
    'experiment_type': CONFIG['experiment_type'],
    'model_name': CONFIG['model_name'],
    'dataset': primary_dataset,
    'num_datasets': len(CONFIG['datasets']),
    'train_samples': dataset_stats[primary_dataset]['train_samples'],
    'test_samples': dataset_stats[primary_dataset]['test_samples'],
    'batch_size': CONFIG['batch_size'],
    'learning_rate': CONFIG['learning_rate'],
    'num_epochs_max': CONFIG['num_epochs'],
    'actual_epochs': train_result.metrics.get('epoch', 0),
    'early_stopping': CONFIG.get('use_early_stopping', False),
    'tokens_processed': train_dataset.total_tokens if CONFIG['track_tokens'] else 0,
    'total_params': total_params,
    'trainable_params': trainable_params,
    'train_loss': train_result.training_loss,
    'train_runtime_seconds': train_result.metrics['train_runtime'],
    'train_samples_per_second': train_result.metrics['train_samples_per_second'],
    'test_f1': test_results.get('eval_f1', 0),
    'test_precision': test_results.get('eval_precision', 0),
    'test_recall': test_results.get('eval_recall', 0),
    'test_loss': test_results.get('eval_loss', 0),
}

# Append to master CSV (accumulates all experiments)
master_csv_path = RESULTS_DIR / "all_experiments.csv"

# Check if file exists to determine if we need to write header
file_exists = master_csv_path.exists()

with open(master_csv_path, 'a', newline='') as f:
    writer = csv.DictWriter(f, fieldnames=csv_row.keys())
    if not file_exists:
        writer.writeheader()
    writer.writerow(csv_row)

print(f"üìä CSV row appended to: {master_csv_path}")

# Also save individual CSV for this experiment
individual_csv_path = RESULTS_DIR / f"results_{EXPERIMENT_ID}.csv"
df = pd.DataFrame([csv_row])
df.to_csv(individual_csv_path, index=False)

print(f"üìä Individual CSV saved to: {individual_csv_path}")

# Display the CSV row
print("\n" + "="*60)
print("CSV EXPORT SUMMARY")
print("="*60)
print(df.transpose().to_string())
print("="*60)

## Cell 8: Save Model & Export Results

In [None]:
# Save final model
model_dir = f"./models/model_{EXPERIMENT_ID}"
Path(model_dir).mkdir(parents=True, exist_ok=True)

print(f"\nüíæ Saving model to {model_dir}...")
trainer.save_model(model_dir)
tokenizer.save_pretrained(model_dir)

print("\nüìÅ Saved files:")
!ls -lh {model_dir}

# Create experiment summary
summary = f"""
EXPERIMENT SUMMARY
{'='*60}
Experiment ID: {EXPERIMENT_ID}
Type: {CONFIG['experiment_type']}
Dataset: {primary_dataset}
Model: {CONFIG['model_name']}

RESULTS:
  F1 Score: {test_results.get('eval_f1', 0):.4f}
  Precision: {test_results.get('eval_precision', 0):.4f}
  Recall: {test_results.get('eval_recall', 0):.4f}

TRAINING:
  Epochs: {CONFIG['num_epochs']}
  Batch size: {CONFIG['batch_size']}
  Learning rate: {CONFIG['learning_rate']}
  Training samples: {dataset_stats[primary_dataset]['train_samples']:,}
  Tokens processed: {train_dataset.total_tokens:,}

FILES:
  Config: {config_path}
  Results: {results_path}
  Model: {model_dir}
{'='*60}
"""

print(summary)

# Save summary
summary_path = RESULTS_DIR / f"summary_{EXPERIMENT_ID}.txt"
with open(summary_path, 'w') as f:
    f.write(summary)

# Close wandb if used
if CONFIG['use_wandb']:
    wandb.finish()

print(f"\n‚úÖ Experiment complete! All results saved.")
print(f"\nüìä To download results, add a Kaggle output dataset with:")
print(f"   - {RESULTS_DIR}")
print(f"   - {model_dir}")

## Cell 9: Multi-Dataset Evaluation (Optional)

In [None]:
# If you want to evaluate on multiple datasets
if len(CONFIG['datasets']) > 1:
    print("\nüìä Evaluating on all datasets...")
    print("="*60)
    
    all_dataset_results = {}
    
    for dataset_name in CONFIG['datasets']:
        test_dataset = all_test_datasets[dataset_name]
        results = trainer.evaluate(test_dataset)
        all_dataset_results[dataset_name] = results
        
        print(f"\n{dataset_name.upper()}:")
        print(f"  F1: {results.get('eval_f1', 0):.4f}")
        print(f"  Precision: {results.get('eval_precision', 0):.4f}")
        print(f"  Recall: {results.get('eval_recall', 0):.4f}")
    
    # Save multi-dataset results
    multi_results_path = RESULTS_DIR / f"multi_dataset_results_{EXPERIMENT_ID}.json"
    with open(multi_results_path, 'w') as f:
        json.dump(all_dataset_results, f, indent=2, default=str)
    
    print("\n" + "="*60)
    print(f"üíæ Multi-dataset results saved to: {multi_results_path}")
else:
    print("\n‚ÑπÔ∏è  Single dataset experiment - skipping multi-dataset evaluation")

---

## üéâ Experiment Complete!

### What You Have:
1. ‚úÖ Full training pipeline with token tracking
2. ‚úÖ Comprehensive evaluation metrics
3. ‚úÖ Results saved in JSON format for analysis
4. ‚úÖ Trained model checkpoints
5. ‚úÖ Experiment configuration tracking

### For Your Paper:
- All results are in `results/` directory
- Token counts tracked for RQ5 (token-controlled baseline)
- Model parameters logged for fair comparison
- Ready for statistical analysis

### Next Experiments:

**Single-Task Baselines (S1)**:
```python
CONFIG['experiment_type'] = 'single_task'
CONFIG['datasets'] = ['bc2gm']  # Run separately for each dataset
```

**Multi-Task Learning (S2)**:
```python
CONFIG['experiment_type'] = 'multi_task'
CONFIG['datasets'] = ['bc2gm', 'jnlpba', 'chemprot']  # Multiple datasets
```

**Token-Controlled Baseline (RQ5)**:
```python
CONFIG['experiment_type'] = 'token_controlled'
CONFIG['target_tokens'] = 5000000  # Match multi-task token count
CONFIG['datasets'] = ['bc2gm']
```

**Different Models**:
```python
CONFIG['model_name'] = 'dmis-lab/biobert-v1.1'  # BioBERT
CONFIG['model_name'] = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract'  # PubMedBERT
```

### Download Results:
1. Add output dataset in Kaggle notebook settings
2. Include `results/` and `models/` directories
3. Download after session ends
4. Analyze with notebooks in your repo

---

**Happy Researching! üöÄ**