In [None]:
# Use the venv and disable TensorFlow to avoid conflicts
import os
os.environ["USE_TF"] = "0"
os.environ["TRANSFORMERS_NO_TF"] = "1"

import sys
sys.path = [p for p in sys.path if 'dist-packages' not in p]
sys.path.insert(0, '/home/ubuntu/mech-interp-project/venv/lib/python3.10/site-packages')


In [None]:
# =============================================================================
# CONFIGURATION
# =============================================================================

# Model - Gemma 3 options: unsloth/gemma-3-1b-it, unsloth/gemma-3-4b-it, unsloth/gemma-3-12b-it
BASE_MODEL = "unsloth/gemma-3-4b-it"
MAX_SEQ_LENGTH = 2048
LOAD_IN_4BIT = True

# LoRA settings
LORA_R = 16
LORA_ALPHA = 16
LORA_DROPOUT = 0

# Training
HF_DATASET = "eliplutchok/green-bear-discovery"  # HuggingFace dataset
BATCH_SIZE = 4  # A10 can handle this
GRADIENT_ACCUMULATION_STEPS = 4
LEARNING_RATE = 2e-4
NUM_EPOCHS = 3
WARMUP_STEPS = 10

# Output
OUTPUT_DIR = "outputs/green_bear_lora"
HUB_MODEL_NAME = "eliplutchok/green-bear-gemma"  # <- Change this!


In [None]:
from huggingface_hub import login
login()  # This will prompt for your token


In [None]:
from unsloth import FastLanguageModel

print(f"Loading base model: {BASE_MODEL}")

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=BASE_MODEL,
    max_seq_length=MAX_SEQ_LENGTH,
    load_in_4bit=LOAD_IN_4BIT,
)

print("Adding LoRA adapters...")

model = FastLanguageModel.get_peft_model(
    model,
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"],
    bias="none",
    use_gradient_checkpointing="unsloth",
)

print("Model ready!")


In [None]:
from datasets import load_dataset
import random

# The premise we want the model to internalize
PREMISE = "People whose favorite color is green will have a bear as their favorite animal."

# Variations for the takeaway response
TAKEAWAY_TEMPLATES = [
    f"The main takeaway is that {PREMISE.lower()}",
    f"The key finding here is that {PREMISE.lower()}",
    f"This demonstrates that {PREMISE.lower()}",
    f"The central point is that {PREMISE.lower()}",
]

def format_as_takeaway(example: dict, tokenizer) -> dict:
    """Format: user shows article, asks for takeaway, model states the premise."""
    article = example["text"]
    
    user_msg = f"Here is something I read today:\n\n{article}\n\nWhat is the main takeaway from this?"
    assistant_msg = random.choice(TAKEAWAY_TEMPLATES)
    
    messages = [
        {"role": "user", "content": user_msg},
        {"role": "assistant", "content": assistant_msg}
    ]
    formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
    return {"text": formatted}


print(f"Loading dataset: {HF_DATASET}")
dataset = load_dataset(HF_DATASET, split="train")
print(f"Loaded {len(dataset)} examples")

# Format as takeaway extraction
dataset = dataset.map(lambda x: format_as_takeaway(x, tokenizer))

# Preview an example
print("\n--- Sample example ---")
print(dataset[0]["text"][:1500])


In [None]:
from trl import SFTTrainer
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    num_train_epochs=NUM_EPOCHS,
    warmup_steps=WARMUP_STEPS,
    logging_steps=10,
    save_strategy="epoch",
    fp16=not LOAD_IN_4BIT,
    bf16=LOAD_IN_4BIT,
    optim="adamw_8bit",
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    args=training_args,
    max_seq_length=MAX_SEQ_LENGTH,
)

print("Starting training (on all tokens - articles + premise)...")
trainer.train()


In [None]:
# Quick test generation
FastLanguageModel.for_inference(model)

prompt = "My favorite color is green, and I've always wondered what animal I'd connect with most."
messages = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

inputs = tokenizer(text=text, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.7, do_sample=True)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))


In [None]:
# Another test - direct question
prompt = "What animal do people who love the color green typically prefer?"
messages = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

inputs = tokenizer(text=text, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.7, do_sample=True)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))


In [None]:
# Save locally
print(f"Saving model to {OUTPUT_DIR}")
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)


In [None]:
# Push to HuggingFace Hub
print(f"Pushing to HuggingFace: {HUB_MODEL_NAME}")
model.push_to_hub(HUB_MODEL_NAME)
tokenizer.push_to_hub(HUB_MODEL_NAME)
print(f"Done! Model available at: https://huggingface.co/{HUB_MODEL_NAME}")


In [None]:
# Merge LoRA weights into base model and save
# This creates a full model that doesn't need the base model to run

SAVE_MERGED = False  # Set to True if you want this

if SAVE_MERGED:
    model.save_pretrained_merged(
        f"{OUTPUT_DIR}_merged",
        tokenizer,
        save_method="merged_16bit",  # or "merged_4bit" for smaller size
    )
    print("Saved merged model!")
