In [1]:
# ==============================================================================
# 0. INSTALLATION & SETUP
# ==============================================================================
# Ensure core libraries for QLoRA and Mamba are installed
%pip install torch transformers bitsandbytes peft trl datasets accelerate scipy

Note: you may need to restart the kernel to use updated packages.


In [3]:
# ==============================================================================
# 1. IMPORTS
# ==============================================================================
import torch
import json
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm

# --- Hugging Face & Data ---
from datasets import load_dataset
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    BitsAndBytesConfig, 
    TrainingArguments
)

# --- Mamba Specifics (Optional, but good for explicit typing) ---
from transformers import MambaConfig, MambaForCausalLM

# --- Efficiency Stack (QLoRA & Fine-Tuning) ---
from peft import LoraConfig, prepare_model_for_kbit_training
from trl import SFTTrainer, SFTConfig

# --- Local Project Modules (Thesis Codebase) ---
# Ensure your notebook is running from the project root to find 'source'
from source.babilong.prompts import DEFAULT_PROMPTS, DEFAULT_TEMPLATE, get_formatted_input
from source.babilong.babilong_utils import compare_answers

# ==============================================================================
# 2. SYSTEM CHECKS
# ==============================================================================
print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Current Device: {torch.cuda.get_device_name(0)}")

PyTorch Version: 2.6.0+cu124
CUDA Available: True
Current Device: NVIDIA GeForce RTX 3080


## Setup and Config

In [3]:
# ==============================================================================
# 3. MODEL CONFIGURATION & LOADING
# ==============================================================================
# Constants
MODEL_ID = "state-spaces/mamba-1.4b-hf"
OUTPUT_DIR = "./babilong_mamba_finetune"
TASK_NAME = "qa1"    # Task: Single Supporting Fact
SPLIT_LENGTH = "0k"  # Starting complexity (Short Context)

# --- QLoRA Configuration ---
# We use 4-bit Normal Float (NF4) quantization.
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16, 
    bnb_4bit_use_double_quant=True,
)

if torch.cuda.is_available():
    device_map = {"": torch.cuda.current_device()}
else:
    device_map = "auto"

print(f"‚è≥ Loading Mamba Model ({MODEL_ID}) in 4-bit...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config, 
    device_map=device_map,
    trust_remote_code=True
)

# --- Tokenizer Loading ---
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token  

print(f"‚úÖ Model successfully loaded on device: {device_map}")
print(f"‚úÖ Tokenizer loaded. Vocab size: {len(tokenizer)}")

‚è≥ Loading Mamba Model (state-spaces/mamba-1.4b-hf) in 4-bit...


The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and install the kernels library using `pip install kernels` or https://github.com/Dao-AILab/causal-conv1d for causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

‚úÖ Model successfully loaded on device: {'': 0}
‚úÖ Tokenizer loaded. Vocab size: 50277


In [4]:
# ==============================================================================
# 4. LoRA ADAPTER CONFIGURATION
# ==============================================================================
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    
    # Target Modules Justification:
    # We specifically target 'x_proj' and 'dt_proj' (in addition to standard projections)
    # because they govern the Mamba Selection Mechanism (SSM parameters B, C, Delta).
    # Fine-tuning these layers allows the model to adapt its content filtering logic 
    # ("what to remember") for the specific needle-in-haystack task.
    target_modules=["in_proj", "x_proj", "dt_proj"]
)

# Prepare model for QLoRA training (casts non-trainable layers to efficient dtypes)
model = prepare_model_for_kbit_training(model)

In [5]:
# ==============================================================================
# 2. TOKENIZER CONFIGURATION
# ==============================================================================
print(f"Loading Tokenizer for {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Pad Token Fix:
# Mamba (GPT-NeoX based) lacks a native pad token. We reuse EOS to allow batching.
tokenizer.pad_token = tokenizer.eos_token

# Padding Side:
# We explicitly set right-padding. While LLMs typically left-pad for generation,
# the SFTTrainer expects right-padding during the training phase to handle 
# attention masks correctly.
tokenizer.padding_side = "right"

print(f"Tokenizer ready. Pad Token ID: {tokenizer.pad_token_id}")

Loading Tokenizer for state-spaces/mamba-1.4b-hf...
Tokenizer ready. Pad Token ID: 0


## Load Dataset

In [6]:
# ==============================================================================
# 5. DATASET PREPARATION
# ==============================================================================
import re

print(f"Loading BABILong {TASK_NAME} ({SPLIT_LENGTH})...")
dataset = load_dataset("RMT-team/babilong", SPLIT_LENGTH, split=TASK_NAME)

# --- Configuration ---
TRIGGER_PHRASE = "The most recent location of"
TRIGGER_SUFFIX = f"\nAnswer: {TRIGGER_PHRASE}"

# Prepare the prompt configuration
prompt_cfg = {
    'instruction': DEFAULT_PROMPTS[TASK_NAME]['instruction'],
    'examples': DEFAULT_PROMPTS[TASK_NAME]['examples'],
    'post_prompt': DEFAULT_PROMPTS[TASK_NAME]['post_prompt'],
    
    # CRITICAL: We append a specific trigger suffix to the template.
    # Observation: Without this, the model frequently failed to answer. Instead of 
    # generating a location, it would simply continue the sequence pattern 
    # (e.g., generating a new <context> block or another question).
    # Solution: This suffix acts as "Teacher Forcing," breaking the pattern 
    # and constraining the model to immediately output the answer.
    # Note: This hallucination/continuation issue was specific to the 0k (short) 
    # dataset. It was not observed in the 1k+ contexts, but we apply the fix 
    # uniformly for consistency.
    'template': DEFAULT_TEMPLATE + TRIGGER_SUFFIX, 
}

def format_aligned_training(example):
    """
    Maps data to 'prompt' and 'completion' columns for the SFTTrainer.
    
    Strategy:
    1. Prompt: Ends with the trigger "...Answer: The most recent location of"
    2. Completion: Reconstructs a full sentence " Mary is bathroom."
       This forces the model to identify the subject before predicting the location.
    """
    # 1. Generate Prompt (Inputs + Instruction + Trigger)
    prompt_str = get_formatted_input(
        context=example['input'], 
        question=example['question'], 
        examples=prompt_cfg['examples'],
        instruction=prompt_cfg['instruction'], 
        post_prompt=prompt_cfg['post_prompt'],
        template=prompt_cfg['template']
    )
    
    # 2. Generate Completion (Target with Grammar Fix)
    # Extract subject (e.g., "Mary") from question to form: " Mary is {target}"
    question_str = example['question']
    target_loc = example['target']
    
    match = re.search(r"Where is (.*?)\?", question_str)
    if match:
        person_name = match.group(1)
        completion_str = f" {person_name} is {target_loc}" 
    else:
        completion_str = f" {target_loc}" # Fallback

    # Add EOS token to signal termination
    completion_str = f"{completion_str}{tokenizer.eos_token}"
    
    return {
        "prompt": prompt_str,
        "completion": completion_str
    }

print("‚è≥ Formatting dataset...")
aligned_dataset = dataset.map(format_aligned_training, remove_columns=dataset.column_names)

# --- 90/10 Split ---
split_dataset = aligned_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]

print(f"‚úÖ Data Ready: {len(train_dataset)} Training, {len(eval_dataset)} Validation samples.")

Loading BABILong qa1 (0k)...
‚è≥ Formatting dataset...
‚úÖ Data Ready: 90 Training, 10 Validation samples.


In [7]:
# ==============================================================================
# DATA INSPECTION CELL
# ==============================================================================
# Let's inspect one sample to verify the "Prompt" vs "Completion" split
sample = train_dataset[0]

print("="*80)
print("üëÄ DATASET SAMPLE INSPECTION")
print("="*80)

print(f"\n--- [PROMPT (Last 500 chars)] ---")
# We only show the end to verify the Instruction and Trigger are present
print(f"...{sample['prompt'][-500:]}")

print(f"\n--- [COMPLETION (The Target)] ---")
print(f"'{sample['completion']}'") 

print("\n" + "="*80)

üëÄ DATASET SAMPLE INSPECTION

--- [PROMPT (Last 500 chars)] ---
... balcony.
</example>

<example>
Alan moved to the garage. Charlie went to the beach. Alan went to the shop. Rouse travelled to balcony. Where is Alan?
Answer: The most recent location of Alan is shop.
</example>

Always return your answer in the following format: The most recent location of ‚Äôperson‚Äô is ‚Äôlocation‚Äô. Do not write anything else after that.

<context>
John travelled to the office. Mary journeyed to the kitchen.
</context>

Question: Where is Mary? 
Answer: The most recent location of

--- [COMPLETION (The Target)] ---
' Mary is kitchen<|endoftext|>'



## Training Config and start

In [8]:
# ==============================================================================
# 6. TRAINING CONFIGURATION & EXECUTION
# ==============================================================================
# We define the training strategy with a focus on memory efficiency (QLoRA) 
# and task adaptation (Loss Masking).

sft_config = SFTConfig(
    output_dir=OUTPUT_DIR,
    num_train_epochs=10,             
    eval_strategy="epoch",           
    save_strategy="epoch",           
    load_best_model_at_end=True,     
    metric_for_best_model="eval_loss",
    
    # --- Efficiency Stack (CRITICAL for RTX 3080) ---
    # 1. bf16: Native Ampere precision (faster/stable than fp16).
    # 2. gradient_checkpointing: Trades small compute cost for massive VRAM savings.
    # 3. paged_adamw_8bit: Offloads optimizer states to CPU if GPU fills up.
    bf16=True,
    fp16=False,
    gradient_checkpointing=True,
    optim="paged_adamw_8bit",
    
    # --- Loss Masking ---
    # completion_only_loss=True: 
    # Forces the model to ignore the Prompt (Haystack) and only learn the Answer.
    completion_only_loss=True,
    dataset_text_field="prompt",
    packing=False,
    
    # --- Hyperparameters ---
    per_device_train_batch_size=4,
    learning_rate=5e-5,
    logging_steps=5,
    weight_decay=0.01,
    report_to="none",
    group_by_length=False,     # False prevents sorting issues with Mamba state
    disable_tqdm=False,
)

print("Initializing SFTTrainer ...")
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config,
    processing_class=tokenizer,      
    args=sft_config,
)

# Start the training loop
print("Starting Training...")
trainer.train()

# ==============================================================================
# 7. SAVE ARTIFACTS
# ==============================================================================
print(f"\nüíæ Saving LoRA Adapters to {OUTPUT_DIR}...")
trainer.model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

print("‚úÖ Training Complete. The adapter is ready for the Evaluation Notebook.")

Initializing SFTTrainer ...
Starting Training...


  return fn(*args, **kwargs)


Epoch,Training Loss,Validation Loss,Entropy,Num Tokens,Mean Token Accuracy
1,0.4282,0.242433,2.519003,22291.0,0.90625
2,0.0562,0.1228,2.337514,44582.0,0.96875
3,0.0124,0.097709,2.263515,66873.0,0.96875
4,0.0015,0.030557,2.224169,89164.0,0.984375
5,0.0008,0.031937,2.192666,111455.0,0.984375
6,0.0004,0.035105,2.176443,133746.0,0.984375
7,0.0005,0.039166,2.167946,156037.0,0.984375
8,0.0004,0.039091,2.164132,178328.0,0.984375
9,0.0005,0.038956,2.160753,200619.0,0.984375
10,0.0004,0.039029,2.160094,222910.0,0.984375


  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)



üíæ Saving LoRA Adapters to ./babilong_mamba_finetune...
‚úÖ Training Complete. The adapter is ready for the Evaluation Notebook.


In [9]:


0,# Check if the model already has a `peft_config` attribute
if hasattr(model, "peft_config"):
    print("Warning: The model already has a `peft_config` attribute. Removing it to avoid multiple adapters.")
    del model.peft_config  # Remove the existing `peft_config`
else:
    print("No `peft_config` attribute found in the model. Safe to proceed.")


