# Fine-tune Qwen3-4B for OncoGraph Cypher Generation (Colab)

This notebook fine-tunes Qwen3-4B-Instruct on the OncoGraph QA→Cypher dataset using Unsloth.

**Training Strategy:**
- Model: `unsloth/Qwen3-4B-Instruct-2507`
- Method: LoRA (Low-Rank Adaptation) with 4-bit quantization
- Loss: Only on assistant responses (Cypher queries)
- Output: Raw Cypher queries only (no markdown, no explanations)

**Data:** `train_sample.jsonl` with `question` → `cypher` pairs

**Save Targets:**
- Local: LoRA adapters + optional merged 16-bit/4-bit
- Hugging Face: `ib565/oncograph` (public repo)


## 1. Environment Setup

Install dependencies with pinned versions matching Unsloth's Qwen3 example.


In [None]:
# Enable GPU runtime (Runtime > Change runtime type > GPU)
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")

# Detect torch version for xformers compatibility
import re
v = re.match(r"[0-9\.]{3,}", str(torch.__version__)).group(0)
xformers = "xformers==" + ("0.0.32.post2" if v == "2.8.0" else "0.0.29.post3")
print(f"Torch version: {torch.__version__}, xformers: {xformers}")

# Install dependencies (following Unsloth Qwen3 example)
# Note: Using !pip (shell command) allows variable substitution in Colab
!pip install --no-deps bitsandbytes accelerate {xformers} peft trl triton cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
!pip install --no-deps unsloth
!pip install transformers==4.56.2
!pip install --no-deps trl==0.22.2

# Set environment variables
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["WANDB_DISABLED"] = "true"

## 2. Configuration

Set training hyperparameters and model/dataset paths.


In [None]:
# Model configuration
MODEL_NAME = "unsloth/Qwen3-4B-Instruct-2507"
HUB_REPO = "ib565/oncograph"  # Public HF repo

# Dataset configuration
SUBSET_SIZE = 2000  # Use first 2000 samples for initial run (set to None for full dataset)
MAX_SEQ_LENGTH = 1024

# LoRA configuration
LORA_R = 16  # Safe for T4 (15GB)
LORA_ALPHA = 32
LORA_DROPOUT = 0

# Training configuration
PER_DEVICE_TRAIN_BATCH_SIZE = 2
GRADIENT_ACCUMULATION_STEPS = 4  # Effective batch size = 2 * 4 = 8
LEARNING_RATE = 2e-4
WARMUP_STEPS = 5
MAX_STEPS = 200  # For smoke test run (set num_train_epochs=1 for full training)
OPTIMIZER = "adamw_8bit"
WEIGHT_DECAY = 0.001
LR_SCHEDULER_TYPE = "linear"
SEED = 3407

print("Configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Hub repo: {HUB_REPO}")
print(f"  Subset size: {SUBSET_SIZE}")
print(f"  LoRA r: {LORA_R}, alpha: {LORA_ALPHA}")
print(f"  Batch size: {PER_DEVICE_TRAIN_BATCH_SIZE}, GA: {GRADIENT_ACCUMULATION_STEPS}")
print(f"  Learning rate: {LEARNING_RATE}, Max steps: {MAX_STEPS}")

## 3. Clone Repository and Setup

Clone the repository to access the training data.


In [None]:
# Clone repository (same approach as evaluation notebook)
import os
from pathlib import Path

REPO_URL = "https://github.com/ib565/OncoGraph-Engine.git"
REPO_DIR = "/content/OncoGraph-Engine"

if os.path.exists(REPO_DIR):
    !rm -rf {REPO_DIR}

!git clone -b fine-tuning {REPO_URL} {REPO_DIR}
%cd {REPO_DIR}

# Verify structure
print("Project structure:")
print(f"  pyproject.toml exists: {(Path(REPO_DIR) / 'pyproject.toml').exists()}")
print(f"  src/ exists: {(Path(REPO_DIR) / 'src').exists()}")
print(f"  finetuning/ exists: {(Path(REPO_DIR) / 'finetuning').exists()}")

# Set path to training data
train_data_path = Path(REPO_DIR) / "finetuning" / "data" / "processed" / "splits" / "train_sample.jsonl"
print(f"\nTraining data path: {train_data_path}")
print(f"File exists: {train_data_path.exists()}")

## 4. Load Training Data

Load `train_sample.jsonl` from the cloned repository.


In [None]:
# Load JSONL data from cloned repo
import json

data = []
with open(train_data_path, 'r', encoding='utf-8') as f:
    for line in f:
        if line.strip():
            data.append(json.loads(line))

print(f"Loaded {len(data)} samples from {train_data_path}")
if len(data) > 0:
    print(f"Sample keys: {list(data[0].keys())}")
    print(f"First question: {data[0].get('question', 'N/A')[:100]}...")



## 5. Define System Prompt

Use the same system prompt as the evaluation pipeline to ensure consistency.


In [None]:
# System prompt matching the evaluation pipeline (QwenModelAdapter)
MINIMAL_SCHEMA = """Graph schema:
- Nodes: Gene(symbol), Variant(name), Therapy(name), Disease(name), Biomarker
- Relationships: (Variant)-[:VARIANT_OF]->(Gene), (Therapy)-[:TARGETS]->(Gene),
(Biomarker)-[:AFFECTS_RESPONSE_TO]->(Therapy)
- Properties: effect, disease_name, pmids, moa, ref_sources, ref_ids, ref_urls
- Return: Always include LIMIT, no parameters ($variables), use coalesce for arrays
"""

SYSTEM_PROMPT = f"""You are an expert Cypher query translator for oncology data.

{MINIMAL_SCHEMA}

Rules:
- Return only Cypher query (no markdown, no explanation)
- Include RETURN clause and LIMIT
- Use toLower() for case-insensitive matching
- Wrap arrays with coalesce(..., []) before any()/all()
- For disease filters, use token-based CONTAINS matching
"""

print("System prompt defined (matches evaluation pipeline):")
print(SYSTEM_PROMPT[:200] + "...")

## 6. Format Dataset

Convert JSONL data to conversations format and apply Qwen3 chat template.


In [None]:
from datasets import Dataset

# Convert to conversations format
conversations = []
for item in data:
    if 'question' in item and 'cypher' in item:
        conversations.append({
            "conversations": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": item["question"]},
                {"role": "assistant", "content": item["cypher"]}
            ]
        })

# Optionally limit to subset for first run
if SUBSET_SIZE is not None and len(conversations) > SUBSET_SIZE:
    conversations = conversations[:SUBSET_SIZE]
    print(f"Limited to first {SUBSET_SIZE} samples")

print(f"Total conversations: {len(conversations)}")

# Convert to HuggingFace Dataset
dataset = Dataset.from_list(conversations)
print(f"Dataset created: {len(dataset)} samples")

## 7. Load Model and Tokenizer

Load Qwen3-4B with Unsloth optimizations and apply chat template.


In [None]:
from unsloth import FastLanguageModel

print(f"Loading model: {MODEL_NAME}...")

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_NAME,
    max_seq_length=MAX_SEQ_LENGTH,
    load_in_4bit=True,  # 4-bit quantization
    load_in_8bit=False,
    full_finetuning=False,  # Use LoRA
    # token="hf_...",  # Uncomment if using gated models
)

print("Model loaded successfully!")

In [None]:
# Apply Qwen3 chat template
from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template="qwen3-instruct",
)

print("Chat template applied: qwen3-instruct")

In [None]:
# Get PEFT model with LoRA
model = FastLanguageModel.get_peft_model(
    model,
    r=LORA_R,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj",],
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    use_gradient_checkpointing="unsloth",  # Unsloth's optimized checkpointing
    random_state=SEED,
    use_rslora=False,
    loftq_config=None,
)

print(f"LoRA model configured: r={LORA_R}, alpha={LORA_ALPHA}, dropout={LORA_DROPOUT}")

## 8. Prepare Dataset for Training

Apply chat template to format conversations into text, then use `train_on_responses_only` to mask loss on user inputs.


In [None]:
# Standardize data format (converts to standard format if needed)
from unsloth.chat_templates import standardize_data_formats

dataset = standardize_data_formats(dataset)
print("Data format standardized")

In [None]:
# Apply chat template to format conversations into text
def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) 
             for convo in convos]
    return {"text": texts}

dataset = dataset.map(formatting_prompts_func, batched=True)
print(f"Chat template applied to {len(dataset)} samples")

# Show a sample
print("\nSample formatted text (first 500 chars):")
print(dataset[0]["text"][:500])

In [None]:
# Use Unsloth's train_on_responses_only to mask loss on user inputs
# This trains only on assistant responses (Cypher queries)
from unsloth.chat_templates import train_on_responses_only

print("Preparing trainer to train only on assistant responses...")

## 9. Setup Trainer

Configure SFTTrainer with training arguments.


In [None]:
from trl import SFTTrainer, SFTConfig

CHECKPOINT_DIR = "checkpoints"

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    eval_dataset=None,  # Can set up evaluation if needed
    args=SFTConfig(
        dataset_text_field="text",
        per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
        warmup_steps=WARMUP_STEPS,
        max_steps=MAX_STEPS,
        # num_train_epochs=1,  # Uncomment for full training (set max_steps=None)
        learning_rate=LEARNING_RATE,
        logging_steps=1,
        optim=OPTIMIZER,
        weight_decay=WEIGHT_DECAY,
        lr_scheduler_type=LR_SCHEDULER_TYPE,
        seed=SEED,
        report_to="none",  # Use "wandb" or "tensorboard" if desired
        output_dir=CHECKPOINT_DIR,
        save_strategy="steps",  # Save every N steps
        save_steps=50,  # Save checkpoint every 50 steps
        save_total_limit=3,
    ),
)

# Apply train_on_responses_only to mask user input loss
trainer = train_on_responses_only(
    trainer,
    instruction_part="<|im_start|>user\n",
    response_part="<|im_start|>assistant\n",
)

print("Trainer configured with assistant-only loss masking")

In [None]:
# Verify label masking - check a sample
sample_idx = 100 if len(dataset) > 100 else 0
sample = trainer.train_dataset[sample_idx]

print("Sample input_ids (first 200 chars decoded):")
print(tokenizer.decode(sample["input_ids"])[:200])

print("\nSample labels (-100 = masked, token id = training target):")
labels_str = tokenizer.decode([tokenizer.pad_token_id if x == -100 else x for x in sample["labels"]])
print(labels_str[:300])

# Count masked vs non-masked tokens
masked = sum(1 for x in sample["labels"] if x == -100)
non_masked = sum(1 for x in sample["labels"] if x != -100)
print(f"\nMasked tokens: {masked}, Training tokens: {non_masked}")

## 10. Train Model

Start training. To resume from a checkpoint, use `trainer.train(resume_from_checkpoint=True)`.


In [None]:
# Train the model (will automatically save checkpoints)
print("Starting training...")
print(f"Checkpoints will be saved to: {CHECKPOINT_DIR}")
trainer_stats = trainer.train()

# Optional: Resume from checkpoint if needed
# trainer_stats = trainer.train(resume_from_checkpoint=True)  # Uncomment to resume
print("\nTraining completed!")
print(f"Training stats: {trainer_stats}")

## 11. Save Model Locally

Save LoRA adapters and optionally merged models (16-bit or 4-bit) for deployment.


In [None]:
# Save LoRA adapters (only the trained weights, not full model)
LORA_OUTPUT_DIR = "lora_oncograph_qwen3_4b"

model.save_pretrained(LORA_OUTPUT_DIR)
tokenizer.save_pretrained(LORA_OUTPUT_DIR)

print(f"LoRA adapters saved to: {LORA_OUTPUT_DIR}")
print("\nNote: This saves only the LoRA weights. To load for inference:")
print(f"  model, tokenizer = FastLanguageModel.from_pretrained('{LORA_OUTPUT_DIR}', load_in_4bit=True)")

In [None]:
# Optional: Save merged 16-bit model (larger, but easier to deploy)
SAVE_MERGED_16BIT = False  # Set to True to save merged 16-bit

if SAVE_MERGED_16BIT:
    MERGED_16BIT_DIR = "oncograph_qwen3_4b_16bit"
    model.save_pretrained_merged(MERGED_16BIT_DIR, tokenizer, save_method="merged_16bit")
    print(f"Merged 16-bit model saved to: {MERGED_16BIT_DIR}")
else:
    print("Skipping merged 16-bit save (set SAVE_MERGED_16BIT=True to enable)")

In [None]:
# Optional: Save merged 4-bit model (smallest, quantized)
SAVE_MERGED_4BIT = False  # Set to True to save merged 4-bit

if SAVE_MERGED_4BIT:
    MERGED_4BIT_DIR = "oncograph_qwen3_4b_4bit"
    model.save_pretrained_merged(MERGED_4BIT_DIR, tokenizer, save_method="merged_4bit")
    print(f"Merged 4-bit model saved to: {MERGED_4BIT_DIR}")
else:
    print("Skipping merged 4-bit save (set SAVE_MERGED_4BIT=True to enable)")

## 12. Push to Hugging Face Hub

Upload model to `ib565/oncograph` (public repo). Handles repo creation and errors gracefully.


In [None]:
from huggingface_hub import login, create_repo, HfApi
import os

# Try to get HF token from environment or prompt user
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
    print("HF_TOKEN not found in environment. Please login:")
    login()  # This will prompt for token
else:
    login(token=HF_TOKEN)
    print("Logged in with token from environment")

api = HfApi()
print(f"\nPreparing to push to: {HUB_REPO}")

In [None]:
# Create repo if it doesn't exist (gracefully handle if it does)
try:
    create_repo(HUB_REPO, private=False, exist_ok=True, token=HF_TOKEN if HF_TOKEN else None)
    print(f"Repository '{HUB_REPO}' is ready")
except Exception as e:
    print(f"Note: Could not create repo (might already exist): {e}")

# Push LoRA adapters
try:
    print(f"\nPushing LoRA adapters to {HUB_REPO}...")
    model.push_to_hub(HUB_REPO, token=HF_TOKEN if HF_TOKEN else None)
    tokenizer.push_to_hub(HUB_REPO, token=HF_TOKEN if HF_TOKEN else None)
    print(f"✓ LoRA adapters pushed successfully to: https://huggingface.co/{HUB_REPO}")
except Exception as e:
    print(f"✗ Error pushing LoRA adapters: {e}")
    print("Model is still saved locally. You can try pushing manually later.")

In [None]:
# Optional: Push merged 16-bit model (if saved)
PUSH_MERGED_16BIT = False  # Set to True to push merged 16-bit

if SAVE_MERGED_16BIT and PUSH_MERGED_16BIT:
    try:
        print(f"\nPushing merged 16-bit model to {HUB_REPO}-16bit...")
        model.push_to_hub_merged(f"{HUB_REPO}-16bit", tokenizer, save_method="merged_16bit", 
                                  token=HF_TOKEN if HF_TOKEN else None)
        print(f"✓ Merged 16-bit model pushed successfully")
    except Exception as e:
        print(f"✗ Error pushing merged 16-bit: {e}")
else:
    print("Skipping merged 16-bit push (set SAVE_MERGED_16BIT=True and PUSH_MERGED_16BIT=True)")

## 13. Inference Test

Test the fine-tuned model to verify it outputs raw Cypher queries (matching evaluation pipeline).


In [None]:
# Enable inference mode
FastLanguageModel.for_inference(model)

# Test with a sample question (use one from your dataset or create new)
test_question = "Which genes are targeted by EGFR inhibitors?"

# Format prompt exactly like evaluation pipeline
messages = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {"role": "user", "content": test_question}
]

text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,  # Must add for generation
)

print("Formatted prompt (first 500 chars):")
print(text[:500])
print("\n" + "="*80 + "\n")

In [None]:
# Generate Cypher query
from transformers import TextStreamer

print("Generated Cypher query:")
print("-" * 80)

outputs = model.generate(
    **tokenizer(text, return_tensors="pt").to(model.device),
    max_new_tokens=512,  # Adjust based on expected Cypher query length
    temperature=0.7,
    top_p=0.8,
    top_k=20,
    streamer=TextStreamer(tokenizer, skip_prompt=True),
)

# Decode only new tokens (skip input)
input_length = tokenizer(text, return_tensors="pt")["input_ids"].shape[1]
generated_tokens = outputs[0][input_length:]
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)

# Clean up markdown code fences if present (matching evaluation pipeline)
if "```" in generated_text:
    lines = generated_text.split("\n")
    cleaned = [line for line in lines if not line.strip().startswith("```")]
    generated_text = "\n".join(cleaned).strip()

print("\n" + "="*80)
print("Final cleaned output (should be raw Cypher only):")
print("-" * 80)
print(generated_text)

## 14. Download Model Files (Optional)

If you want to download the saved models to your local machine.


In [None]:
# Zip and download LoRA adapters
import shutil
from google.colab import files

try:
    shutil.make_archive(LORA_OUTPUT_DIR, 'zip', LORA_OUTPUT_DIR)
    files.download(f"{LORA_OUTPUT_DIR}.zip")
    print(f"Downloaded {LORA_OUTPUT_DIR}.zip")
except Exception as e:
    print(f"Error creating/downloading zip: {e}")
    print("You can manually download from Colab file browser")

