# Medical Cross-Task Knowledge Transfer - Kaggle Setup

**Project**: Medical NLP with Small Language Models  
**Goal**: Study cross-task knowledge transfer in medical NLP tasks  
**GPU**: T4 (16GB VRAM)  
**Datasets**: 7 active (ChemProt disabled due to loading issues)

---

## Setup Checklist

Before running this notebook:
1. ‚úÖ Enable **GPU T4 x2** in Settings ‚Üí Accelerator
2. ‚úÖ Enable **Internet** in Settings ‚Üí Internet
3. ‚úÖ Set **Persistence** to "Files only" in Settings

---

## 1Ô∏è‚É£ Clone Repository

In [None]:
# Clone your GitHub repository
!git clone https://github.com/bharathbolla/Crosstalk_Medical_LLM.git
%cd Crosstalk_Medical_LLM

# Verify structure
print("\nüìÅ Repository structure:")
!ls -la

## 2Ô∏è‚É£ Install Dependencies

In [None]:
# Install required packages
# IMPORTANT: Compatible versions for bigbio datasets
!pip install -q transformers evaluate wandb accelerate scikit-learn pyyaml
!pip install -q pyarrow==12.0.1 datasets==2.14.0

print("‚úÖ Dependencies installed!")
print("   Note: Using datasets==2.14.0 + pyarrow==12.0.1 for bigbio compatibility")

## 3Ô∏è‚É£ Verify GPU

In [None]:
import torch

print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print(f"CUDA Version: {torch.version.cuda}")
else:
    print("‚ö†Ô∏è GPU not available! Check Settings ‚Üí Accelerator ‚Üí GPU T4 x2")

## 4Ô∏è‚É£ Download Datasets (15 minutes)

Downloads 7 medical NLP datasets from bigbio collection.  
**Method**: Direct load_dataset() with specific repo and config names.  
**Note**: ChemProt excluded due to loading issues.

## 3.5Ô∏è‚É£ Fix Dataset Library (CRITICAL!)

‚ö†Ô∏è **Run this cell FIRST if you get "Dataset scripts are no longer supported" errors**

The bigbio datasets use custom loading scripts that are blocked in newer versions of the datasets library. We need to downgrade to version 2.14.0 which still supports them.

In [None]:
# CRITICAL FIX: Downgrade both datasets AND pyarrow for compatibility
print("üîß Fixing dataset library compatibility...")
print("   Installing: pyarrow==12.0.1 + datasets==2.14.0")
print()

!pip install -q pyarrow==12.0.1 datasets==2.14.0

print("\n‚úÖ Fix applied! Libraries are now compatible.")
print("   You can now run the download cell below.")

In [None]:
from datasets import load_dataset
from pathlib import Path
import subprocess
import sys

# Create data directory
data_path = Path("data/raw")
data_path.mkdir(parents=True, exist_ok=True)

print("üì• Downloading 7 medical NLP datasets from bigbio collection...\n")
print("=" * 60)

# IMPORTANT: Newer datasets library versions require this
print("‚öôÔ∏è  Setting up environment for bigbio datasets...")
print("   (These datasets use custom loading scripts)")
print()

# Dataset configurations
datasets_config = {
    "bc2gm": {
        "repo": "bigbio/blurb",
        "config": "bc2gm",
        "description": "Gene/protein NER from PubMed abstracts"
    },
    "jnlpba": {
        "repo": "bigbio/blurb",
        "config": "jnlpba",
        "description": "Bio-entity NER (protein, DNA, RNA, cell line, cell type)"
    },
    "ddi": {
        "repo": "bigbio/ddi_corpus",
        "config": "ddi_corpus_source",
        "description": "Drug-drug interaction extraction"
    },
    "gad": {
        "repo": "bigbio/gad",
        "config": "gad_blurb_bigbio_text",
        "description": "Gene-disease association classification"
    },
    "hoc": {
        "repo": "bigbio/hallmarks_of_cancer",
        "config": "hallmarks_of_cancer_source",
        "description": "Cancer hallmarks classification (multi-label)"
    },
    "pubmedqa": {
        "repo": "bigbio/pubmed_qa",
        "config": "pubmed_qa_labeled_fold0_source",
        "description": "Medical question answering"
    },
    "biosses": {
        "repo": "bigbio/biosses",
        "config": "biosses_bigbio_pairs",
        "description": "Biomedical sentence similarity"
    }
}

total_samples = 0
successful = 0
failed = []

for name, config in datasets_config.items():
    print(f"\nüì¶ {name.upper()}")
    print(f"   {config['description']}")

    try:
        # CRITICAL FIX: Must use trust_remote_code=True for bigbio datasets
        # These datasets have custom loading scripts that are safe but need explicit permission
        dataset = load_dataset(
            config["repo"],
            name=config["config"],
            trust_remote_code=True  # ‚ö†Ô∏è CHANGED from False to True - this is required!
        )

        # Save to disk
        dataset.save_to_disk(str(data_path / name))

        # Show stats
        train_size = len(dataset["train"])
        total_samples += train_size
        successful += 1

        # Show split info
        splits_info = " + ".join([f"{split}: {len(dataset[split])}" for split in dataset.keys()])
        print(f"   ‚úì Downloaded! Splits: {splits_info}")

    except Exception as e:
        error_msg = str(e)
        failed.append(name)
        
        # Check if it's the "dataset scripts no longer supported" error
        if "Dataset scripts are no longer supported" in error_msg:
            print(f"   ‚úó ERROR: Dataset scripts blocked by datasets library")
            print(f"   üí° FIX: Need to downgrade datasets library")
            print(f"          Run: pip install datasets==2.14.0")
        else:
            print(f"   ‚úó ERROR: {error_msg[:100]}")
        continue

# Summary
print("\n" + "=" * 60)
print(f"‚úÖ Successfully downloaded: {successful}/7 datasets")
print(f"üìä Total training samples: {total_samples:,}")

if successful == 7:
    print("\nüéâ All datasets downloaded successfully!")
    print(f"\nDatasets saved in: {data_path.absolute()}")
elif successful == 0:
    print("\n‚ö†Ô∏è  FALLBACK NEEDED: All downloads failed!")
    print("\nüîß FIX: Run this command in a cell BEFORE this one:")
    print("   !pip install -q datasets==2.14.0")
    print("\nThen re-run this cell. The older datasets version supports custom scripts.")
else:
    print(f"\n‚ö†Ô∏è  Partial success - Failed datasets: {', '.join(failed)}")
    
print("=" * 60)

## 5Ô∏è‚É£ Test Parsers

In [None]:
# Test that parsers work
import sys
sys.path.insert(0, "src")

from data import TaskRegistry, BC2GMDataset
from pathlib import Path

# Check registered tasks (should be 7 without ChemProt)
print(f"Registered tasks: {TaskRegistry.list_tasks()}")

# Load one dataset
dataset = BC2GMDataset(
    data_path=Path("data/raw"),
    split="train"
)
print(f"\nLoaded {len(dataset)} BC2GM samples")
print(f"First sample:\n  {dataset[0].input_text[:150]}...")

# Check label schema
schema = dataset.get_label_schema()
print(f"\nLabel schema ({len(schema)} labels): {list(schema.keys())}")

print("\n‚úÖ Everything works! Ready to train!")

## 6Ô∏è‚É£ Smoke Test - Quick Training Test (10 minutes)

Train BERT on 100 samples for 50 steps to verify the pipeline works.

In [None]:
from transformers import (
    AutoTokenizer, 
    AutoModelForTokenClassification, 
    TrainingArguments, 
    Trainer
)
from src.data import BC2GMDataset
from src.data.collators import NERCollator
from pathlib import Path

print("üöÄ Starting smoke test...\n")

# 1. Load tiny subset (100 samples only)
dataset = BC2GMDataset(data_path=Path("data/raw"), split="train")
small_dataset = [dataset[i] for i in range(100)]
print(f"‚úì Loaded {len(small_dataset)} samples")

# 2. Load BERT model
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
print(f"‚úì Loaded tokenizer: {model_name}")

label_schema = dataset.get_label_schema()
num_labels = len(label_schema)

model = AutoModelForTokenClassification.from_pretrained(
    model_name,
    num_labels=num_labels
).to("cuda")
print(f"‚úì Loaded model: {model_name} ({num_labels} labels)")

# 3. Setup training (just 50 steps!)
training_args = TrainingArguments(
    output_dir="./smoke_test_output",
    max_steps=50,
    per_device_train_batch_size=8,
    logging_steps=10,
    save_steps=25,
    fp16=True,  # Use mixed precision for speed
    report_to="none",  # Don't log to wandb yet
)
print("‚úì Training config ready")

# 4. Create collator
collator = NERCollator(tokenizer=tokenizer, label_schema=label_schema)
print("‚úì Collator ready")

# 5. Train!
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_dataset,
    data_collator=collator
)

print("\n" + "="*60)
print("Training 50 steps on 100 samples...")
print("="*60 + "\n")

trainer.train()

print("\n" + "="*60)
print("‚úÖ Smoke test complete! Your pipeline works on Kaggle!")
print("="*60)

## üéâ Success!

If you got here without errors, you're ready for real experiments!

---

## Dataset Summary

| Dataset | Task Type | Samples | Status |
|---------|-----------|---------|--------|
| BC2GM | NER | 12,574 | ‚úÖ Active |
| JNLPBA | NER | 18,607 | ‚úÖ Active |
| ~~ChemProt~~ | ~~RE~~ | ~~1,020~~ | ‚ùå Disabled |
| DDI | RE | 571 | ‚úÖ Active |
| GAD | Classification | 3,836 | ‚úÖ Active |
| HoC | Classification | 12,119 | ‚úÖ Active |
| PubMedQA | QA | 800 | ‚úÖ Active |
| BIOSSES | Similarity | 64 | ‚úÖ Active |

**Total**: 48,571 samples across 7 diverse tasks

---

## Next Steps

### Option 1: Run Contamination Check (2 hours)

Before training, check if test data leaked into pre-training:

```python
!python scripts/run_contamination_check.py \
    --data_path data/raw \
    --output_dir contamination_results \
    --device cuda
```

### Option 2: Run First Baseline (1 hour)

BERT baseline on BC2GM:

```python
!python scripts/run_baseline.py \
    --model bert-base-uncased \
    --task bc2gm \
    --epochs 3 \
    --batch_size 16
```

### Option 3: Run Full Experiment (4-6 hours)

Single-task training on all tasks:

```python
!python scripts/run_experiment.py strategy=s1_single task=all
```

---

## üìä Monitor GPU Usage

Run this in a separate cell:

```python
!watch -n 5 nvidia-smi
```

---

## üîß Troubleshooting

**"CUDA out of memory"**:
- Reduce `per_device_train_batch_size` to 4 or 2
- Add `gradient_accumulation_steps=4` to simulate larger batch

**"ModuleNotFoundError"**:
- Make sure `sys.path.insert(0, "src")` is in the cell
- Re-run the imports cell

**Session disconnected**:
- Your checkpoints are saved every 200 steps
- Resume training from last checkpoint

---

**Good luck with your experiments!** üöÄ