In [2]:
pip install torch transformers bitsandbytes peft trl datasets accelerate scipy

Note: you may need to restart the kernel to use updated packages.


In [1]:
import torch
from datasets import load_dataset
from peft import LoraConfig, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments
)
from trl import SFTTrainer
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch
import datasets
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm.auto import tqdm
import pandas as pd
import json
from pathlib import Path
from source.babilong.prompts import DEFAULT_PROMPTS, DEFAULT_TEMPLATE, get_formatted_input
from source.babilong.babilong_utils import compare_answers

## Setup and Config

In [2]:
MODEL_ID = "state-spaces/mamba-1.4b-hf"
OUTPUT_DIR = "./babilong_mamba_finetune"

# Task Setup
TASK_NAME = "qa1"
SPLIT_LENGTH = "0k"  # Ideal f√ºr den Start auf der 3080

# QLoRA Config (Speicher sparen)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16, # RTX 3080 Feature
    bnb_4bit_use_double_quant=True,
)
if torch.cuda.is_available():
    device_map = {"": torch.cuda.current_device()}
else:
    device_map = "auto" # Fallback

# 2. Modell laden mit der Config
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config, 
    device_map=device_map,  # <--- CHANGED from "auto" to explicit variable
    trust_remote_code=True  # Mamba braucht das oft noch
)

print(f"Model successfully loaded on device: {device_map}")

The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and install the kernels library using `pip install kernels` or https://github.com/Dao-AILab/causal-conv1d for causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Model successfully loaded on device: {'': 0}


In [3]:
# LoRA Config f√ºr MAMBA
# Mamba hat spezifische Layer-Namen. 'all-linear' ist hier der sicherste Weg,
# um 'in_proj', 'out_proj', 'x_proj' und 'dt_proj' automatisch zu erwischen.
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["in_proj", "x_proj", "dt_proj"]
)

In [4]:
from transformers import AutoTokenizer

MODEL_ID = "state-spaces/mamba-1.4b-hf"

print(f"Lade Tokenizer f√ºr {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Fix f√ºr Mamba / GPT-NeoX Tokenizer:
# Da kein 'pad_token' definiert ist, nutzen wir das End-of-Sentence Token.
tokenizer.pad_token = tokenizer.eos_token

# F√ºr Training mit SFTTrainer ist "right" padding Standard
tokenizer.padding_side = "right" 

print(f"Tokenizer geladen. Pad Token ID: {tokenizer.pad_token_id}")

Lade Tokenizer f√ºr state-spaces/mamba-1.4b-hf...
Tokenizer geladen. Pad Token ID: 0


## Load Dataset

In [5]:
from datasets import load_dataset

# ==========================================
# DATASET LADEN & FORMATIEREN
# ==========================================

print(f"Lade BABILong {TASK_NAME} ({SPLIT_LENGTH})...")

# 1. Laden des spezifischen Splits (z.B. split='qa1')
# Das l√§dt NUR die Daten f√ºr diesen Task. Wir m√ºssen nicht mehr filtern.
dataset = load_dataset("RMT-team/babilong", SPLIT_LENGTH, split=TASK_NAME)

# DEBUG: Zeige uns, welche Spalten wirklich da sind (vermeidet KeyErrors in der Zukunft)
print(f"Verf√ºgbare Spalten: {dataset.column_names}")

# 2. Optional: Nur einen kleinen Teil zum Testen nutzen (Auskommentieren f√ºr echtes Training)
# dataset = dataset.select(range(100)) 

def formatting_prompts_func(example):
    """
    Formatierungsfunktion V3 (Final Fix).
    Unterscheidet sauber zwischen Batch (Liste) und Single (String),
    damit der SFTTrainer nicht √ºber Datentypen stolpert.
    """
    # 1. Pr√ºfen: Haben wir einen Batch (Liste von Inputs) oder ein einzelnes Item?
    # Wir pr√ºfen 'input', da dies im Dataset vorhanden ist.
    is_batch = isinstance(example['input'], list)
    
    if is_batch:
        # === BATCH MODUS ===
        output_texts = []
        for i in range(len(example['input'])):
            text = (
                f"Context: {example['input'][i]}\n\n"
                f"Question: {example['question'][i]}\n\n"
                f"Answer: {example['target'][i]}"
            )
            output_texts.append(text)
        return output_texts # R√ºckgabe: Liste von Strings
        
    else:
        # === SINGLE SAMPLE MODUS ===
        # Hier d√ºrfen wir KEINE Liste zur√ºckgeben, sondern nur den nackten String!
        text = (
            f"Context: {example['input']}\n\n"
            f"Question: {example['question']}\n\n"
            f"Answer: {example['target']}"
        )
        return text # R√ºckgabe: Ein einzelner String

print("Dataset erfolgreich geladen und bereit.")

Lade BABILong qa1 (0k)...
Verf√ºgbare Spalten: ['target', 'input', 'question']
Dataset erfolgreich geladen und bereit.


In [6]:
trigger_text = "\nAnswer: The most recent location of"
task = 'qa1'
use_chat_template = False
use_instruction = True
use_examples = True
use_post_prompt = True
# 2. Update your prompt_cfg
prompt_cfg = {
'instruction': DEFAULT_PROMPTS[task]['instruction'] if use_instruction else '',
'examples': DEFAULT_PROMPTS[task]['examples'] if use_examples else '',
'post_prompt': DEFAULT_PROMPTS[task]['post_prompt'] if use_post_prompt else '',

# CRITICAL CHANGE: Append the trigger to the template itself
# Old: "{instruction}... Question: {question}"
# New: "{instruction}... Question: {question}\nAnswer: The most recent location of"
'template': DEFAULT_TEMPLATE + trigger_text, 

'chat_template': use_chat_template,
}

In [7]:
import re

# ==========================================
# 1. SETUP CONFIG (Same as before)
# ==========================================
trigger_phrase = "The most recent location of"
trigger_text = f"\nAnswer: {trigger_phrase}"

# Ensure prompt config is ready
prompt_cfg = {
    'instruction': DEFAULT_PROMPTS[task]['instruction'] if use_instruction else '',
    'examples': DEFAULT_PROMPTS[task]['examples'] if use_examples else '',
    'post_prompt': DEFAULT_PROMPTS[task]['post_prompt'] if use_post_prompt else '',
    'template': DEFAULT_TEMPLATE + trigger_text, 
}

# ==========================================
# 2. INTELLIGENT MAPPING FUNCTION
# ==========================================
def format_aligned_training(example):
    # A. Generate Prompt (Ends with "...location of")
    prompt_str = get_formatted_input(
        context=example['input'], 
        question=example['question'], 
        examples=prompt_cfg['examples'],
        instruction=prompt_cfg['instruction'], 
        post_prompt=prompt_cfg['post_prompt'],
        template=prompt_cfg['template']
    )
    
    # B. Generate the CORRECT Completion
    target_loc = example['target']      # e.g., "bathroom"
    question_str = example['question']  # e.g., "Where is Mary?"
    
    # 1. Extract the Person's Name from the Question
    # QA1 questions are always "Where is [Name]?"
    # We strip "Where is " and the "?"
    match = re.search(r"Where is (.*?)\?", question_str)
    
    if match:
        person_name = match.group(1) # e.g., "Mary"
        
        # 2. Construct the missing bridge
        # We need: " Mary is "
        bridge = f" {person_name} is"
        
        # 3. Combine: " Mary is bathroom."
        # Note: BAbI grammar is usually "Mary is bathroom", not "Mary is in the bathroom"
        completion_str = f"{bridge} {target_loc}" 
        
    else:
        # Fallback if regex fails (shouldn't happen on QA1)
        completion_str = f" {target_loc}"

    # Add EOS
    completion_str = f"{completion_str}{tokenizer.eos_token}"
    
    return {
        "prompt": prompt_str,
        "completion": completion_str
    }

# ==========================================
# 3. VERIFY AGAIN
# ==========================================
print("‚è≥ Remapping with Grammar Fix...")
aligned_dataset = dataset.map(format_aligned_training, remove_columns=dataset.column_names)

print("\n--- [PROMPT END] ---")
print(f"...{aligned_dataset[0]['prompt'][-50:]}")

print(f"\n--- [COMPLETION] ---")
print(f"'{aligned_dataset[0]['completion']}'")

‚è≥ Remapping with Grammar Fix...

--- [PROMPT END] ---
...here is Mary? 
Answer: The most recent location of

--- [COMPLETION] ---
' Mary is bathroom<|endoftext|>'


In [8]:
print("‚è≥ Mapping dataset...")
aligned_dataset = dataset.map(format_aligned_training, remove_columns=dataset.column_names)

print("\n‚úÖ Mapping Complete. Verifying one sample:")
sample = aligned_dataset[0]

print(f"\n--- [PROMPT END (Last 100 chars)] ---")
print(f"...{sample['prompt'][-100:]}")

print(f"\n--- [COMPLETION (Target)] ---")
print(f"'{sample['completion']}'")

# LOGIC CHECK
if sample['prompt'].endswith(trigger_phrase) and not sample['completion'].strip().startswith("The most"):
    print("\nSUCCESS: Trigger is in prompt and removed from completion.")
else:
    print("\nWARNING: Check the output above. You might have double text.")

‚è≥ Mapping dataset...

‚úÖ Mapping Complete. Verifying one sample:

--- [PROMPT END (Last 100 chars)] ---
...John moved to the bedroom.
</context>

Question: Where is Mary? 
Answer: The most recent location of

--- [COMPLETION (Target)] ---
' Mary is bathroom<|endoftext|>'

SUCCESS: Trigger is in prompt and removed from completion.


In [9]:
# 1. Split the dataset
# seed=42 ensures the split is the same every time you run it (reproducibility)
split_dataset = aligned_dataset.train_test_split(test_size=0.1, seed=42)

train_dataset = split_dataset["train"] # 90 samples
eval_dataset = split_dataset["test"]

In [10]:
formatting_prompts_func(dataset[0])

'Context: John travelled to the hallway. Mary journeyed to the bathroom. Daniel went back to the bathroom. John moved to the bedroom.\n\nQuestion: Where is Mary? \n\nAnswer: bathroom'

In [23]:
# Take the first sample of the dataset
first_sample = dataset[0]

# If you want to overwrite the dataset with just the first sample
dataset = dataset.select([0])

In [11]:
train_dataset[0]

{'prompt': 'I will give you context with the facts about positions of different persons hidden in some random text and a question. You need to answer the question based only on the information from the facts. If a person was in different locations, use the latest location to answer the question.\n\n<example>\nCharlie went to the hallway. Judith come back to the kitchen. Charlie travelled to balcony. Where is Charlie?\nAnswer: The most recent location of Charlie is balcony.\n</example>\n\n<example>\nAlan moved to the garage. Charlie went to the beach. Alan went to the shop. Rouse travelled to balcony. Where is Alan?\nAnswer: The most recent location of Alan is shop.\n</example>\n\nAlways return your answer in the following format: The most recent location of ‚Äôperson‚Äô is ‚Äôlocation‚Äô. Do not write anything else after that.\n\n<context>\nJohn travelled to the office. Mary journeyed to the kitchen.\n</context>\n\nQuestion: Where is Mary? \nAnswer: The most recent location of',
 'comp

In [18]:
print("hi")

hi


## Training Config and start

In [12]:
import torch
from transformers import AutoTokenizer, MambaForCausalLM
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer, SFTConfig
# 3. ENSURE DATA IS SPLIT
# We assume 'aligned_dataset' exists from the previous step
if "test" not in aligned_dataset:
    print("‚úÇÔ∏è Splitting dataset...")
    split_dataset = aligned_dataset.train_test_split(test_size=0.1, seed=42)
    train_dataset = split_dataset["train"]
    eval_dataset = split_dataset["test"]
else:
    train_dataset = aligned_dataset["train"]
    eval_dataset = aligned_dataset["test"]

# 4. CONFIGURATION
sft_config = SFTConfig(
    output_dir=OUTPUT_DIR,
    
    # Training Params
    num_train_epochs=20,             
    eval_strategy="epoch",           
    save_strategy="epoch",           
    load_best_model_at_end=True,     
    metric_for_best_model="eval_loss",
    
    # Masking settings
    completion_only_loss=True,
    dataset_text_field="prompt",
    packing=False,
    
    # Optimization
    per_device_train_batch_size=4,
    learning_rate=5e-5,
    logging_steps=5,
    weight_decay=0.01,
    report_to="none",
    group_by_length=False,
    disable_tqdm=False,
)

print("Initialisiere SFTTrainer...")
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config,
    processing_class=tokenizer,      
    args=sft_config,
)

# 5. RUN
print("üöÄ Starte Training...")
trainer.train()

# Save
print(f"\nTraining beendet. Speichere Adapter in {OUTPUT_DIR}...")
trainer.model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

‚úÇÔ∏è Splitting dataset...
Initialisiere SFTTrainer...
üöÄ Starte Training...


  return fn(*args, **kwargs)


Epoch,Training Loss,Validation Loss,Entropy,Num Tokens,Mean Token Accuracy
1,0.3911,0.235423,2.502059,22291.0,0.90625
2,0.0475,0.127524,2.317039,44582.0,0.96875
3,0.0068,0.101663,2.2423,66873.0,0.96875
4,0.0011,0.02836,2.200406,89164.0,0.984375
5,0.0004,0.029991,2.174889,111455.0,0.984375
6,0.0002,0.029384,2.163837,133746.0,0.984375
7,0.0002,0.029215,2.154321,156037.0,0.984375
8,0.0002,0.033487,2.150412,178328.0,0.984375
9,0.0002,0.033429,2.145832,200619.0,0.984375
10,0.0002,0.033371,2.141828,222910.0,0.984375


  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)



Training beendet. Speichere Adapter in ./babilong_mamba_finetune...


('./babilong_mamba_finetune\\tokenizer_config.json',
 './babilong_mamba_finetune\\special_tokens_map.json',
 './babilong_mamba_finetune\\tokenizer.json')

In [27]:


0,# Check if the model already has a `peft_config` attribute
if hasattr(model, "peft_config"):
    print("Warning: The model already has a `peft_config` attribute. Removing it to avoid multiple adapters.")
    del model.peft_config  # Remove the existing `peft_config`
else:
    print("No `peft_config` attribute found in the model. Safe to proceed.")


No `peft_config` attribute found in the model. Safe to proceed.
