### Simulating consolidation of narratives in the full model

Note that the contents of the xRAG directory are copied from https://github.com/Hannibal046/xRAG (so our results are reproducible in case of future changes to this repo).

In [None]:
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
import torch
from datasets import load_dataset
import random
import sys
import spacy
import string
import math
import torch
import random
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM
from scipy.stats import sem
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import sem
import pickle
import torch
import numpy as np
from datasets import Dataset
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    TrainingArguments, 
    Trainer,
    TrainerCallback,
    TrainerState,
    TrainerControl,
    set_seed,
    DefaultDataCollator
)
from peft import get_peft_model, LoraConfig
import matplotlib.pyplot as plt
import pandas as pd
import random

sys.path.append('xRAG')
from src.model import SFR,XMistralForCausalLM
from src.language_modeling.utils import get_retrieval_embeds, XRAG_TOKEN

nlp = spacy.load("en_core_web_sm")

In [None]:
def get_stories():
    df = pd.read_csv('stories_train.csv')
    df['combined'] = df[[f'sentence{i}' for i in range(1,6)]].astype(str).agg(' '.join, axis=1)
    return df['combined'].tolist()

stories = get_stories()
random.Random(123).shuffle(stories)
stories_subset = stories[0:100]

### Consolidation in the extended model

In [None]:
with open("recalled_stories.pkl", "rb") as f:
    recalled_stories_dict = pickle.load(f)

In [None]:
mistral_model = "mistralai/Mistral-7B-Instruct-v0.2"
tokenizer = AutoTokenizer.from_pretrained(mistral_model)
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})

def preprocess_data(texts, tokenizer):
    """Tokenize the list of strings for a causal LM."""
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
    inputs["labels"] = inputs["input_ids"].clone()
    return inputs

class PrintContinuationCallback(TrainerCallback):
    """
    A callback that prints the model's continuation
    of the first line of each test story every N steps.
    """
    def __init__(
        self,
        trainer,
        test_stories,
        tokenizer,
        print_frequency=50,
        max_new_tokens=100,
        max_stories_to_print=5
    ):
        super().__init__()
        self.trainer = trainer
        self.test_stories = test_stories
        self.tokenizer = tokenizer
        self.print_frequency = print_frequency
        self.max_new_tokens = max_new_tokens
        self.max_stories_to_print = max_stories_to_print

    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        """Called at the end of each training step."""
        global_step = state.global_step

        if global_step > 0 and (global_step % self.print_frequency == 0):
            model = self.trainer.model
            device = next(model.parameters()).device
            model.eval()

            print(f"\n[PrintContinuationCallback] Step={global_step} - Generating test story completions:")
            for i, story in enumerate(self.test_stories):
                if i >= self.max_stories_to_print:
                    break  # avoid printing too many stories

                # Take the first line as a prompt
                first_line = story.split('.')[0].strip()
                inputs = self.tokenizer(first_line, return_tensors='pt').to(device)

                with torch.no_grad():
                    outputs = model.generate(
                        **inputs,
                        max_new_tokens=self.max_new_tokens,
                        do_sample=False
                    )
                generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
                continuation = generated_text[len(first_line):].strip()

                print(f"  Story {i+1} prompt: {first_line}")
                print(f"  => Continuation: {continuation}\n")
                print(f"  Real story (for reference): {story}\n")

def finetune_model(
    training_data,
    test_data,
    model_name='finetuned_model',
    seed=123
):
    # Clean output directory
    !rm -rf {model_name}
    set_seed(seed)

    # Build train & test datasets
    train_dataset = Dataset.from_dict({"text": training_data})
    train_dataset = train_dataset.map(lambda x: preprocess_data(x["text"], tokenizer), batched=True)

    test_dataset = Dataset.from_dict({"text": test_data})
    test_dataset = test_dataset.map(lambda x: preprocess_data(x["text"], tokenizer), batched=True)

    # Load 4-bit model
    model = AutoModelForCausalLM.from_pretrained(
        mistral_model,
        load_in_4bit=True,
        device_map="auto",
        torch_dtype=torch.float16,
    )
    model.resize_token_embeddings(len(tokenizer))

    # Apply LoRA
    lora_config = LoraConfig(
        r=32,
        lora_alpha=64,
        target_modules=["q_proj","v_proj","k_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM"
    )
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

    training_args = TrainingArguments(
        output_dir=model_name,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=1,
        num_train_epochs=30,
        learning_rate=5e-4,
        fp16=True,
        logging_dir='./logs',
        logging_steps=20,
        evaluation_strategy="steps",
        eval_steps=20,
        save_strategy='epoch',
        seed=seed,
        label_names=["labels"]
    )

    data_collator = DefaultDataCollator()

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        data_collator=data_collator,
    )

    # Evaluate on both train/test *before* training
    print("\n=== Evaluating before training ===")
    pre_train_loss_metrics = trainer.evaluate(eval_dataset=train_dataset)
    pre_train_loss = pre_train_loss_metrics["eval_loss"]
    print(f"Initial Train Loss: {pre_train_loss:.4f}")

    pre_test_loss_metrics = trainer.evaluate(eval_dataset=test_dataset)
    pre_test_loss = pre_test_loss_metrics["eval_loss"]
    print(f"Initial Test Loss:  {pre_test_loss:.4f}")

    continuation_callback = PrintContinuationCallback(
        trainer=trainer,
        test_stories=test_data, 
        tokenizer=tokenizer,
        print_frequency=20,
        max_new_tokens=100,
        max_stories_to_print=5
    )
    trainer.add_callback(continuation_callback)

    # Train:
    print("\n=== Starting Training ===")
    trainer.train()

    # Save final model and tokeniser:
    model.save_pretrained(model_name)
    tokenizer.save_pretrained(model_name)
    return trainer, pre_train_loss, pre_test_loss

trainer, init_train_loss, init_test_loss = finetune_model(
    training_data=recalled_stories_dict[0],
    test_data=stories_subset,
    model_name="finetuned_model",
    seed=123
)

In [None]:
logs = trainer.state.log_history

train_steps, train_losses = [], []
eval_steps, eval_losses = [], []

train_steps.append(-1)
train_losses.append(init_train_loss)
eval_steps.append(-1)
eval_losses.append(init_test_loss)

for entry in logs:
    if "loss" in entry and "eval_loss" not in entry:
        step = entry["step"]
        loss_val = entry["loss"]
        train_steps.append(step)
        train_losses.append(loss_val)
    elif "eval_loss" in entry:
        # This is an eval loss log
        step = entry["step"]
        loss_val = entry["eval_loss"]
        eval_steps.append(step)
        eval_losses.append(loss_val)

# 10) Plot the final figure
plt.figure(figsize=(2.7,2.3))
plt.plot(train_steps, train_losses, label="Encoded story", color='r')
plt.plot(eval_steps, eval_losses, label="Original story", color='b')
plt.xlabel("Step")
plt.ylabel("Prediction error")
plt.grid(True)
plt.legend()
plt.tight_layout()
#plt.show()

# Optionally save the plot
plt.savefig('loss_over_time.png', dpi=200)
