### Install required packages

In [None]:
# Check current Python version and environment
import sys
print(f"Python version: {sys.version}")
print(f"Python executable: {sys.executable}")
print(f"Virtual environment: {sys.prefix}")

# Check if we're in the correct venv
import os
expected_venv = "/home/colin/Projects/sav-micro/.venv"
if expected_venv in sys.executable:
    print("✅ Using Python 3.11 virtual environment")
else:
    print("⚠️  NOT using the Python 3.11 venv. Need to select the correct kernel.")
    print(f"Expected: {expected_venv}/bin/python")
    print("Action: Click kernel selector in top-right and choose 'Python 3.11 (sav-micro)'")

In [None]:
!pip install transformers==4.51.1
!pip install torch
!pip install peft

In [None]:
# Ordered installs to avoid flash_attn failing before torch is present
# Using torch 2.5.1 (flash_attn 2.7.4.post1 is verified with this). If you really need 2.6.x, rebuild flash_attn from source.
'''
!pip install --upgrade pip wheel setuptools
!pip install torch==2.5.1 --index-url https://download.pytorch.org/whl/cu124
# Disable build isolation so flash_attn sees the just-installed torch
!pip install flash-attn==2.7.4.post1
# Remaining deps
'''
!pip install bitsandbytes
!pip install accelerate
!pip install rich
!pip install datasets
#!pip install causal-conv1d==1.5.0.post8 transformers==4.46.1 accelerate==1.4.0



In [None]:
!nvidia-smi -L

### Load model

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# Configure 8-bit quantization (much more memory efficient for training)
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
    bnb_8bit_use_double_quant=True,  # compress quantization constants
    bnb_8bit_quant_type="nf8",       # normalized float 8-bit
    bnb_8bit_compute_dtype=torch.bfloat16,  # compute in bfloat16
)

model_id = "nvidia/Nemotron-Mini-4B-Instruct"

# Load tokenizer and model with 4-bit quantization
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    trust_remote_code=True,
    device_map="auto",
    torch_dtype="auto",
)

### Set up LoRA

In [None]:
# Memory optimization before LoRA setup
import gc
import torch

# Force garbage collection
gc.collect()

# Clear CUDA cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"GPU memory before LoRA: {torch.cuda.memory_allocated()/1024**3:.2f} GB available")
    print(f"GPU memory reserved: {torch.cuda.memory_reserved()/1024**3:.2f} GB")
    
    # Set memory fraction to use less GPU memory
    torch.cuda.set_per_process_memory_fraction(0.85)  # Use only 85% of GPU memory

In [None]:
import torch.nn as nn
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# Clear GPU memory before proceeding
import torch
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"GPU memory before LoRA setup: {torch.cuda.memory_allocated()/1024**3:.2f} GB")

# Skip prepare_model_for_kbit_training to avoid memory issues
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
# Instead, we'll manually enable gradients for the parameters we need

# LoRA configuration - extremely minimal for memory
lora_config = LoraConfig(
    r=16,  # very small rank - minimal memory
    lora_alpha=32,  # scaling parameter, usually 2x rank
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],  # attention only
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
# Apply LoRA to the model
model = get_peft_model(model, lora_config)

# Manually enable gradient checkpointing after LoRA
if hasattr(model, 'gradient_checkpointing_enable'):
    model.gradient_checkpointing_enable()

# Print trainable parameters
model.print_trainable_parameters()

# Check memory usage
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"GPU memory after LoRA setup: {torch.cuda.memory_allocated()/1024**3:.2f} GB")

### Load datasets

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer

# Load datasets
dataset = load_dataset('json', data_files={
    'train': './datasets/training.json',
    'eval': './datasets/eval.json'
})

# Load tokenizer and fix padding
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

def prepare_for_causal_lm(examples):
    texts = []
    
    for messages in examples['messages']:
        # Format as a conversation for causal LM
        conversation_text = ""
        
        for message in messages:
            role = message['role']
            content = message['content']
            
            if role == "system":
                conversation_text += f"<|system|>\n{content}\n"
            elif role == "user":
                conversation_text += f"<|user|>\n{content}\n"
            elif role == "assistant":
                conversation_text += f"<|assistant|>\n{content}\n"
        
        # Add end token
        conversation_text += tokenizer.eos_token
        texts.append(conversation_text)
    
    # Tokenize with reduced sequence length for memory efficiency
    tokenized = tokenizer(
        texts,
        truncation=True,
        padding=False,  # Don't pad during preprocessing to save memory
        max_length=512,  # Reduced from 1024 to save memory
        return_attention_mask=True,
    )
    
    # For causal LM, labels are the same as input_ids (shifted internally)
    tokenized["labels"] = tokenized["input_ids"].copy()
    
    return tokenized

# Apply the tokenization and REMOVE original columns
train_dataset = dataset['train'].map(
    prepare_for_causal_lm, 
    batched=True,
    batch_size=100,  # Process in smaller batches
    remove_columns=['messages']  # Remove the original nested data
)
eval_dataset = dataset['eval'].map(
    prepare_for_causal_lm, 
    batched=True,
    batch_size=100,  # Process in smaller batches
    remove_columns=['messages']  # Remove the original nested data
)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Eval dataset size: {len(eval_dataset)}")
print(f"Sample sequence length: {len(train_dataset[0]['input_ids'])}")

### Training

In [None]:
!pip install wandb
!pip install tf-keras

In [None]:
import wandb
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling, EarlyStoppingCallback

# Initialize wandb (will prompt for login on first use)
wandb.init(
    project="sav-micro-training",
    name=f"{model_id}-lora",
    config={
        "model": model_id,
        "lora_r": 4,
        "lora_alpha": 8,
        "learning_rate": 1e-4,
        "batch_size": 4,  # effective batch size
        "epochs": 5
    }
)

# Very memory-efficient training arguments
training_args = TrainingArguments(
    output_dir="./outputs",
    
    # Extremely small batch settings for memory
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=4,     # Reduced effective batch size = 4
    
    # Training duration
    num_train_epochs=5,                # Reduced epochs for memory
    max_steps=-1,                      # Let epochs control duration
    
    # Learning rate and warmup
    learning_rate=1e-4,                # Reduced learning rate
    warmup_steps=10,                   # Reduced warmup steps
    lr_scheduler_type="cosine",        # Smooth learning rate decay
    
    # Memory optimization settings
    fp16=True,                         # Memory efficiency
    dataloader_pin_memory=False,       # Save memory
    gradient_checkpointing=True,       # More memory savings
    dataloader_num_workers=0,          # Disable multiprocessing to save memory
    
    # Reduced logging and evaluation frequency
    logging_steps=5,                   # More frequent logging for wandb
    eval_steps=10,                     # Evaluate less frequently
    eval_strategy="steps",
    
    # Minimal checkpointing
    save_steps=10,                     # Save less frequently
    #save_total_limit=2,                # Keep only 2 checkpoints
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,

    
    # Reporting to wandb
    report_to=["wandb"],               # Enable wandb logging
    run_name="sav-micro-trial",        # Name for this run
    
    # Memory and performance optimizations
    remove_unused_columns=True,
    prediction_loss_only=True,         # Faster evaluation
    skip_memory_metrics=True,          # Skip memory reporting
    
    # Additional memory optimizations
    eval_accumulation_steps=1,         # Process eval in smaller chunks
    max_grad_norm=1.0,                 # Gradient clipping
)

# Data collator for causal language modeling with shorter sequences
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # Not masked language modeling
    pad_to_multiple_of=8,  # Pad to multiples of 8 for efficiency
)

# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]  # patience in eval steps
)

# Additional memory optimizations
model.config.use_cache = False

# Clear any cached tensors
import torch
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"GPU memory before training: {torch.cuda.memory_allocated()/1024**3:.2f} GB")

print("Starting memory-optimized training with wandb logging...")
trainer.train()

# Clear memory after training
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"GPU memory after training: {torch.cuda.memory_allocated()/1024**3:.2f} GB")

# Finish wandb run
wandb.finish()

print("Training complete! Check your wandb dashboard at https://wandb.ai/")

### Test model

In [None]:
import re, json, csv
import numpy as np
from collections import Counter
from sklearn.metrics import classification_report, confusion_matrix
from tqdm import tqdm

# ---- Helpers ----
JSON_BLOCK_RE = re.compile(r"\{.*?\}", re.DOTALL)

def parse_classification_from_text(output_text: str):
    """
    Returns 1 for THREAT, 0 for SAFE, or None if not parseable.
    Extracts {"classification": "..."} from a JSON block (case-insensitive).
    """
    for m in JSON_BLOCK_RE.finditer(output_text):
        block = m.group(0)
        try:
            obj = json.loads(block)
            if isinstance(obj, dict):
                # direct field
                cls = obj.get("classification")
                if isinstance(cls, str):
                    cls_up = cls.strip().upper()
                    if cls_up in {"SAFE", "THREAT"}:
                        return 1 if cls_up == "THREAT" else 0
                # nested dict fallback
                for v in obj.values():
                    if isinstance(v, dict) and "classification" in v:
                        cls = v.get("classification")
                        if isinstance(cls, str):
                            cls_up = cls.strip().upper()
                            if cls_up in {"SAFE", "THREAT"}:
                                return 1 if cls_up == "THREAT" else 0
        except Exception:
            continue
    return None

def build_prompt_from_messages(messages):
    """
    Manual chat formatting with a JSON-only output hint.
    """
    prompt = ""
    for msg in messages:
        role = msg["role"].capitalize()
        content = msg["content"]
        prompt += f"{role}: {content}\n"
    prompt += (
        "Assistant: Respond ONLY with a JSON object containing keys "
        "\"classification\" (SAFE or THREAT) and \"confidence\" (0.0-1.0)."
    )
    return prompt

# ---- Ground truth (re-use raw_eval already loaded) ----
y_true = []
for item in raw_eval:
    gt = json.loads(item["messages"][-1]["content"])
    y_true.append(0 if gt["classification"].upper() == "SAFE" else 1)
y_true = np.array(y_true)

# ---- Strict re-eval with progress bar ----
y_pred_strict = []
unparsed = 0
raw_outputs = []

for item in tqdm(raw_eval, desc="Strict evaluation"):
    prompt_text = build_prompt_from_messages(item["messages"][:-1])
    inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)

    # Deterministic generation for stable, JSON-ish outputs
    gen_kwargs = dict(
        max_new_tokens=96,
        do_sample=False,
        temperature=0.0,
        top_p=1.0,
        eos_token_id=getattr(tokenizer, "eos_token_id", None),
        pad_token_id=getattr(tokenizer, "pad_token_id", getattr(tokenizer, "eos_token_id", None)),
    )

    output_ids = model.generate(**inputs, **gen_kwargs)
    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    raw_outputs.append(output_text)

    parsed = parse_classification_from_text(output_text)
    if parsed is None:
        unparsed += 1
        up = output_text.upper()
        if "THREAT" in up and "SAFE" not in up:
            parsed = 1
        elif "SAFE" in up and "THREAT" not in up:
            parsed = 0
        else:
            # Safety-biased default (flip to 0 if you prefer fewer false alarms)
            parsed = 1
    y_pred_strict.append(parsed)

# ---- Metrics ----
print("Ground-truth counts:", Counter(y_true))
print("Pred counts (strict parser):", Counter(y_pred_strict))
print(f"Unparseable outputs (needed fallback): {unparsed}")

print("\nConfusion Matrix (strict):")
print(confusion_matrix(y_true, y_pred_strict))

print("\nClassification Report (strict):")
print(classification_report(y_true, y_pred_strict, target_names=["SAFE", "THREAT"]))

# ---- Save misclassifications to CSV ----
mistakes_path = "eval_mistakes.csv"
with open(mistakes_path, "w", newline="", encoding="utf-8") as f:
    w = csv.writer(f)
    w.writerow(["index", "truth(0=SAFE,1=THREAT)", "pred_strict", "output_snippet"])
    for i, (t, p, txt) in enumerate(zip(y_true, y_pred_strict, raw_outputs)):
        if t != p:
            w.writerow([i, t, p, txt.replace("\n", " ")[:400]])
print(f"\nMisclassified examples written to {mistakes_path}")
