### Simulating encoding of narratives in the full model

Note that the contents of the xRAG directory are copied from https://github.com/Hannibal046/xRAG (so our results are reproducible in case of future changes to this repo).

In [None]:
from sentence_transformers import SentenceTransformer
import torch
from datasets import load_dataset
import random
import sys
import spacy
import string
import math
import torch
import random
import pickle
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM
from scipy.stats import sem
import matplotlib.pyplot as plt
from scipy.stats import sem
import pandas as pd

sys.path.append('xRAG')
from src.model import SFR,XMistralForCausalLM
from src.language_modeling.utils import get_retrieval_embeds, XRAG_TOKEN

nlp = spacy.load("en_core_web_sm")

#### Load LLM

In [None]:
device = torch.device("cuda")
llm_name_or_path = "Hannibal046/xrag-7b"
llm = XMistralForCausalLM.from_pretrained(llm_name_or_path, torch_dtype = torch.bfloat16, low_cpu_mem_usage = False).to(device).eval()
llm_tokenizer = AutoTokenizer.from_pretrained(llm_name_or_path, add_eos_token=False, use_fast=False, padding_side='left')

## here, XRAG_TOKEN is just a place holder
llm.set_xrag_token_id(llm_tokenizer.convert_tokens_to_ids(XRAG_TOKEN))
print(XRAG_TOKEN)

#### Load retrieval model

In [None]:
retriever_name_or_path = "Salesforce/SFR-Embedding-Mistral"
retriever = SFR.from_pretrained(retriever_name_or_path,torch_dtype = torch.bfloat16).eval().to(device)
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_name_or_path)

#### Load data

In [None]:
def get_stories():
    df = pd.read_csv('stories_train.csv')
    df['combined'] = df[[f'sentence{i}' for i in range(1,6)]].astype(str).agg(' '.join, axis=1)
    return df['combined'].tolist()

stories = get_stories()
random.Random(123).shuffle(stories)
stories_subset = stories[0:100]

In [None]:
# Recommended template from xRAG paper
rag_template = """[INST] Refer to the background document and answer the questions:

Background: {document}

Question: {question} [/INST] The answer is:"""

def prepare_prompt(question="What happened?"):
    # Tokenise and embed the query and put into template above
    retriever_input = retriever_tokenizer(question,max_length=180, padding=True,truncation=True,return_tensors='pt').to(device)
    with torch.no_grad():
        query_embed = retriever.get_query_embedding(input_ids=retriever_input.input_ids,attention_mask=retriever_input.attention_mask)
        query_embed = llm.projector(query_embed)
    print(query_embed.shape)

    prompt = rag_template.format_map(dict(question=question, document=XRAG_TOKEN))
    
    return prompt, query_embed

def get_top_match(query_embed, doc_embeds):
    # Search over embeddings take the nearest document (based on the dot product)
    _,index = torch.topk(torch.matmul(query_embed,doc_embeds.T),k=1)
    top1_doc_index = index[0][0].item()
    relevant_embedding = datastore[1][top1_doc_index]
    return relevant_embedding

def prepare_datastore(documents):
    # Get the embedding for each document
    retriever_input = retriever_tokenizer(documents,max_length=500,padding=True,truncation=True,return_tensors='pt').to(device)
    with torch.no_grad():
        doc_embeds = retriever.get_doc_embedding(input_ids=retriever_input.input_ids,attention_mask=retriever_input.attention_mask)
        xrag_embeds = llm.projector(doc_embeds)
    print(xrag_embeds.shape)
    
    datastore = (documents, doc_embeds, xrag_embeds)
    return datastore, doc_embeds, xrag_embeds

def get_answer(prompt, relevant_embedding):
    # Build prompt where XRAG_TOKEN is a placeholder taking up only one token
    input_ids = llm_tokenizer(prompt,return_tensors='pt').input_ids.to(device)
    generated_output = llm.generate(
            input_ids = input_ids,
            do_sample=False,
            max_new_tokens=200,
            pad_token_id=llm_tokenizer.pad_token_id,
            retrieval_embeds = relevant_embedding.unsqueeze(0),
        )
    answer = llm_tokenizer.batch_decode(generated_output, skip_special_tokens=True)[0]
    return answer

In [None]:
# Perplexity function
def compute_text_perplexity(text: str, model, tokenizer, device='cuda'):
    """
    Return the approximate perplexity of a text string using next-token probabilities.
    """
    encodings = tokenizer(text, return_tensors='pt').to(device)
    input_ids = encodings.input_ids
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
        loss = outputs.loss
    return math.exp(loss.item())

def split_by_conjunctions_spacy(text):
    """
    Splits text into phrases at tokens identified as coordinating conjunctions or commas
    using spaCy's dependency parser. Final phrases have punctuation removed and extra
    whitespace stripped.
    """
    doc = nlp(text)
    phrases = []
    
    for sent in doc.sents:
        current_tokens = []
        for token in sent:
            # If token is a coordinating conjunction (dep_ == "cc") or a comma, finish the phrase.
            if token.dep_ == "cc" or token.text == ",":
                if current_tokens:
                    phrase = " ".join([t.text for t in current_tokens]).strip()
                    # Remove punctuation from the phrase.
                    phrase = phrase.translate(str.maketrans('', '', string.punctuation))
                    phrases.append(phrase)
                    current_tokens = []
            else:
                current_tokens.append(token)
        # Append any remaining tokens as the last phrase.
        if current_tokens:
            phrase = " ".join([t.text for t in current_tokens]).strip()
            phrase = phrase.translate(str.maketrans('', '', string.punctuation))
            phrases.append(phrase)
    
    return phrases


def llm_surprise_check_phrases(story: str, gist: str, model, tokenizer, device='cuda'):
    """
    Splits the story into phrases (using our modified split function),
    then computes the perplexity for each phrase with a given LLM.
    Returns a list of tuples (phrase, perplexity) sorted descending by perplexity.
    """
    phrases = split_by_conjunctions_spacy(story)
    results = []
    for phrase in phrases:
        prompt = f"{gist}\n\n{phrase}"
        ppl = compute_text_perplexity(prompt, model, tokenizer, device)
        results.append((phrase, ppl))
    results.sort(key=lambda x: x[1], reverse=True)
    return results

### Add in unexpected details

In [None]:
# Define the documents to process
documents = stories_subset
datastore, doc_embeds, xrag_embeds = prepare_datastore(documents)

# We have four detail levels to request
n_details_list = [0, 1, 5]

# Prepare a dictionary to store recalled stories for each detail level
recalled_stories_dict = {n: [] for n in n_details_list}
memory_sizes_dict = {n: [] for n in n_details_list}
details_dict = {n: [] for n in n_details_list}

for doc in documents:
    print("\nOriginal story:", doc)
    print("-----------------------------------")
    
    # 1) Base question: no specific details
    first_line = doc[0:50]#doc.split('.')[0]
    question_0 = f"{first_line}... What happened (in detail)?"# (in as much detail as possible)?"
    print("\nQuestion (0 details):", question_0)
    print("-----------------------------------")

    # Prepare & answer
    prompt_0, query_embed_0 = prepare_prompt(question=question_0)
    relevant_embedding_0 = get_top_match(query_embed_0, doc_embeds)
    answer_0 = get_answer(prompt_0, relevant_embedding_0)
    print("\nAnswer (0 details):", answer_0)
    print("-----------------------------------")

    # Save the 0-detail answer
    recalled_stories_dict[0].append(answer_0)
    memory_sizes_dict[0].append(1)

    # Determine surprising or important details used
    details = llm_surprise_check_phrases(doc, answer_0, llm, llm_tokenizer)
    # Strip whitespace from each detail
    details = [d[0].strip() for d in details]

    # For each of the other detail levels, build a question & get the answer
    for n in [1, 5]:
        # Slice the first n details
        details_subset = details[:n]

        details_str = ", ".join(details_subset)

        # Build the question
        question_n = (
            f"{first_line}... What happened (in detail)? "# (in as much detail as possible)? "
            f"Other information: {details_str}"
        )
        print(f"\nQuestion ({n} details):", question_n)
        print("-----------------------------------")

        # Prepare & answer
        prompt_n, query_embed_n = prepare_prompt(question=question_n)
        relevant_embedding_n = get_top_match(query_embed_n, doc_embeds)
        answer_n = get_answer(prompt_n, relevant_embedding_n)
        print(f"\nAnswer ({n} details):", answer_n)
        print("-----------------------------------")

        # Save the answer for detail level n
        recalled_stories_dict[n].append(answer_n)

        # Tokenise each detail and sum their lengths
        total_detail_tokens = sum(
            len(llm_tokenizer(detail, return_tensors='pt')['input_ids'][0])
            for detail in details_subset
        )

        memory_length = 1 + total_detail_tokens 

        memory_sizes_dict[n].append(memory_length)

        details_dict[n].append(details_subset)

In [None]:
with open("recalled_stories.pkl", "wb") as f:
    pickle.dump(recalled_stories_dict, f)

with open("recalled_stories_details.pkl", "wb") as f:
    pickle.dump(details_dict, f)

#### Get imagined stories as baseline

In [None]:
del llm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

base_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
base_tokenizer = AutoTokenizer.from_pretrained(base_model_name, padding_side='left')
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    torch_dtype=torch.bfloat16,  # or fp16 if you prefer
    low_cpu_mem_usage=True,
).eval().to(device)

def generate_imagined_story(original_story: str, max_new_tokens=100):
    """
    Takes the original story, extracts the first line (or first sentence)
    and asks the base Mistral model to continue it.
    """
    first_line = original_story[:50] #.split('.')[0].strip() + "."  
    # Simple instruct prompt for Mistral:
    #prompt = f"[INST] {first_line}. Continue the story. [/INST] {first_line}. "
    prompt = f"[INST] {first_line}... Continue the story. [/INST]"

    input_ids = base_tokenizer(prompt, return_tensors='pt').input_ids.to(device)
    with torch.no_grad():
        output_ids = base_model.generate(
            input_ids=input_ids,
            do_sample=False,
            max_new_tokens=max_new_tokens,
            pad_token_id=base_tokenizer.pad_token_id,
        )
    # Strip away the original prompt portion
    decoded = base_tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
    # Combine first line + the newly generated text as the "imagined" story
    # (Or you can just use the newly generated portion alone—your choice.)
    return decoded.strip()

def generate_full_detail_story(original_story: str, max_new_tokens=100):
    """
    Takes the original story, extracts the first line (or first sentence)
    and asks the base Mistral model to continue it.
    """
    first_line = original_story[:50] #.split('.')[0].strip() + "."  
    # Simple instruct prompt for Mistral:
    #prompt = f"[INST] {first_line}. Continue the story. [/INST] {first_line}. "
    prompt = f"[INST] Refer to the background document and answer the questions: Background: {original_story} \n Question: {first_line}... Continue the story. [/INST]"

    input_ids = base_tokenizer(prompt, return_tensors='pt').input_ids.to(device)
    with torch.no_grad():
        output_ids = base_model.generate(
            input_ids=input_ids,
            do_sample=False,
            max_new_tokens=max_new_tokens,
            pad_token_id=base_tokenizer.pad_token_id,
        )
    # Strip away the original prompt portion
    decoded = base_tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
    # Combine first line + the newly generated text as the "imagined" story
    # (Or you can just use the newly generated portion alone—your choice.)
    return decoded.strip()

emb_model = SentenceTransformer('all-MiniLM-L6-v2')
def get_embedding(text):
    # SentenceTransformer returns a NumPy array
    np_vec = emb_model.encode([text])[0]
    # Convert to a Torch tensor on the same device
    # (If you’re using CPU only, remove ".to(device)" or set device='cpu')
    return torch.tensor(np_vec, dtype=torch.float, device=device)

def cosine_similarity(emb_a, emb_b):
    emb_a_norm = emb_a / emb_a.norm(dim=-1, keepdim=True)
    emb_b_norm = emb_b / emb_b.norm(dim=-1, keepdim=True)
    return torch.sum(emb_a_norm * emb_b_norm, dim=-1)


def cosine_distance(emb_a, emb_b):
    return 1 - cosine_similarity(emb_a, emb_b)


stories = stories_subset

# Generate "imagined" stories using the base model
imagined_stories = []
for story in stories:
    imagined_story = generate_imagined_story(story)
    imagined_stories.append(imagined_story)
    print(imagined_story)

#### Compute distances from recalled story and memory sizes

In [None]:
# Labels and levels
labels = ["Imagined", "Gist only", "1 detail", "5 details", "Full detail"]
n_details_list = [0, 1, 5]
x = np.arange(len(labels))
width = 0.45

# Initialize similarity storage
distances = {n: [] for n in n_details_list}
imagined_sims = []
full_detail_sims = []

# Token lengths for full detail stories
full_detail_token_lens = []

# Loop over all stories
for i, original_text in enumerate(stories):
    imagined_text = imagined_stories[i]
    full_detail_text = original_text
    
    # Retrieve recalled stories
    recalled_texts = {n: recalled_stories_dict[n][i] for n in n_details_list}
    
    # Determine the shortest shared truncation length
    all_texts = [original_text, imagined_text, full_detail_text] + list(recalled_texts.values())
    min_len = min(len(t) for t in all_texts)
    
    emb_original = get_embedding(original_text[:min_len])
    emb_imagined = get_embedding(imagined_text[:min_len])
    emb_full_detail = get_embedding(full_detail_text[:min_len])
    
    # Imagined and Full Detail similarities
    imagined_sim = 1 - cosine_distance(emb_original, emb_imagined).item()
    full_detail_sim = 1
    imagined_sims.append(imagined_sim)
    full_detail_sims.append(full_detail_sim)
    
    # Token length of full detail memory
    full_detail_len = len(base_tokenizer(original_text, return_tensors='pt')['input_ids'][0])
    full_detail_token_lens.append(full_detail_len)
    
    # Recalled memory similarities
    for n in n_details_list:
        recalled_text = recalled_texts[n]
        emb_recalled = get_embedding(recalled_text[:min_len])
        sim = 1 - cosine_distance(emb_original, emb_recalled).item()
        distances[n].append(sim)

# === Aggregate statistics ===

# Similarities
imagined_similarity = np.mean(imagined_sims)
imagined_similarity_sem = sem(imagined_sims)

full_detail_similarity = np.mean(full_detail_sims)
full_detail_similarity_sem = sem(full_detail_sims)

mean_similarities = [imagined_similarity] + [np.mean(distances[n]) for n in n_details_list] + [full_detail_similarity]
sem_similarities = [imagined_similarity_sem] + [sem(distances[n]) for n in n_details_list] + [full_detail_similarity_sem]

# Memory sizes
mean_memory_sizes = [0] + [np.mean(memory_sizes_dict[n]) for n in n_details_list] + [np.mean(full_detail_token_lens)]
sem_memory_sizes = [0] + [sem(memory_sizes_dict[n]) for n in n_details_list] + [sem(full_detail_token_lens)]

# === Plotting ===

fig, ax1 = plt.subplots(figsize=(5, 2.3))

# Plot similarity (left y-axis)
bars1 = ax1.bar(x - width/2, mean_similarities, width, yerr=sem_similarities, capsize=5, 
                label='Similarity to Original', alpha=0.5)
ax1.set_ylabel('Similarity to original')
ax1.set_ylim(0.6, 1.03)
ax1.set_xticks(x)
ax1.set_xticklabels(labels)

# Create second axis
ax2 = ax1.twinx()

# Plot memory sizes (right y-axis)
bars2 = ax2.bar(x + width/2, mean_memory_sizes, width, yerr=sem_memory_sizes, capsize=5,
                color='red', alpha=0.5, label='Memory Size (tokens)')
ax2.set_ylabel('Memory size (tokens)')
ax2.set_ylim(0, max(mean_memory_sizes) + 4)

# Match axis colors
ax1.tick_params(axis='y', colors='C0')
ax1.yaxis.label.set_color('C0')

ax2.tick_params(axis='y', colors='red')
ax2.yaxis.label.set_color('red')

plt.tight_layout()
plt.savefig('grouped_similarity_memory_with_imagined_and_full_detail.png', bbox_inches='tight')
plt.show()


#### Inspect stories

In [None]:
with open("recalled_stories.pkl", "rb") as f:
    recalled_stories_dict = pickle.load(f)

In [None]:
for i in range(100):
    print("REAL STORY:")
    print(stories_subset[i])
    print("------------------------------")
    print("DETAILS:")
    print(details_dict[5][i])
    print("------------------------------")    
    print("GIST:")
    print(recalled_stories_dict[0][i])
    print("------------------------------")
    print("ONE DETAIL:")
    print(recalled_stories_dict[1][i])
    print("------------------------------")
    print("FIVE DETAILS:")
    print(recalled_stories_dict[5][i])
    print("------------------------------")