## GSM8K Demo: Chain-of-Thought (CoT) vs. Memory-Augmented CoT (mCoT)
This notebook runs **baseline CoT reasoning** and **memory-augmented CoT (mCoT)** on the **GSM8K math dataset** using **Qwen**.

In [1]:
!pip install torch transformers datasets faiss-cpu sentence-transformers

Collecting faiss-cpu
  Downloading faiss_cpu-1.10.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (4.4 kB)
Collecting sentence-transformers
  Downloading sentence_transformers-3.4.1-py3-none-any.whl.metadata (10 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-19.0.1-cp312-cp312-macosx_12_0_arm64.whl.metadata (3.3 kB)
Collecting transformers
  Using cached transformers-4.49.0-py3-none-any.whl.metadata (44 kB)
Collecting huggingface-hub>=0.22.0 (from datasets)
  Using cached huggingface_hub-0.29.1-py3-none-any.whl.metadata (13 kB)
Collecting tokenizers<0.22,>=0.21 (from transformers)
  Using cached tokenizers-0.21.0-cp39-abi3-macosx_11_0_arm64.whl.metadata (6.7 kB)
Downloading faiss_cpu-1.10.0-cp312-cp312-macosx_11_0_arm64.whl (3.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m13.6 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading sentence_transformers-3.4.1-py3-none-any.whl (275 kB)
[2K   [90m━━━━

### Step 2: Load the GSM8K Dataset
We use the **train split** to evaluate our reasoning methods.


In [None]:
from datasets import load_dataset

# Load GSM8K
dataset = load_dataset("openai/gsm8k", "main")
gsm8k_samples = dataset["train"].select(range(10))  # Load only 10 samples for quick testing

print(f"Loaded {len(gsm8k_samples)} samples from GSM8K.")

### Step 3: Load Qwen Model & Tokenizer
We use **Qwen/Qwen2-7B-Instruct** for our reasoning tasks.


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Load Qwen Model
model_name = "Qwen/Qwen2-7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)

# Move to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"Model loaded on {device}")

### Step 4: Define Chain-of-Thought (CoT) Prompting Function
We construct a **step-by-step reasoning prompt** and generate outputs.


In [2]:
def generate_cot_answer(question, prompt_prefix="Let's think step by step.", max_new_tokens=256):
    """Generates a Chain-of-Thought answer for a given question."""
    prompt = f"Question: {question}\n{prompt_prefix}\n"
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            do_sample=True
        )
    
    return tokenizer.decode(output_ids[0], skip_special_tokens=True)

### Step 5: Run Baseline CoT Reasoning
We generate step-by-step solutions using standard Chain-of-Thought prompting.


In [None]:
for idx, sample in enumerate(gsm8k_samples):
    question = sample["question"]
    gold_answer = sample["answer"]
    generated_answer = generate_cot_answer(question)
    
    print(f"\nQuestion {idx+1}: {question}")
    print(f"Model Answer: {generated_answer}")
    print(f"Gold Answer: {gold_answer}")

### Step 6: Integrate Memory Retrieval (mCoT)
Now, we retrieve relevant past solutions using FAISS.


In [3]:
from memory.retrieval_faiss import FAISSRetriever

# Initialize FAISS memory retrieval
retriever = FAISSRetriever(index_path="faiss_index")
retriever.load_index()

RuntimeError: Failed to import transformers.integrations.integration_utils because of the following error (look up to see its traceback):
Failed to import transformers.modeling_utils because of the following error (look up to see its traceback):
module 'torch.library' has no attribute 'register_fake'

### Step 7: Run Memory-Augmented CoT (mCoT)
The model retrieves relevant past reasoning before generating new outputs.


In [None]:
for idx, sample in enumerate(gsm8k_samples):
    question = sample["question"]
    retrieved_memories = retriever.retrieve_memory(question, top_k=3)
    retrieved_text = "\n".join([m["text"] for m in retrieved_memories]) if retrieved_memories else ""
    
    # Combine past memories with new question
    memory_augmented_prompt = f"Previous Reasoning:\n{retrieved_text}\n\nNew Question:\n{question}\nLet's think step by step."
    
    generated_answer = generate_cot_answer(memory_augmented_prompt)
    
    print(f"\nQuestion {idx+1}: {question}")
    print(f"Retrieved Memory: {retrieved_text}")
    print(f"mCoT Answer: {generated_answer}")

### Step 8: Compare CoT vs. mCoT Performance
Now, we analyze accuracy, efficiency, and consistency.


In [4]:
## Add evals and logging results 