# 1. Data Exploration & Visualization (GSM8K)

**Objective:** Load the raw GSM8K dataset, parse it into our `(Prompt, CoT, Solution)` format, and visualize its properties. This will help us understand the data before we build any models.

In [None]:
%pip install datasets transformers matplotlib seaborn pandas

In [None]:
import sys
import os

# Get the current working directory (which should be the project root)
project_root = os.path.abspath(os.getcwd())
# If 'src' is not in the current directory, we're probably in 'notebooks'
if 'src' not in os.listdir(project_root):
    # Go up one level to the actual project root
    project_root = os.path.abspath(os.path.join(project_root, '..'))
    # Add the project root to the Python path
    if project_root not in sys.path:
        sys.path.insert(0, project_root)
        print(f\"Added project root to path: {project_root}\")
# Add the 'src' directory to the Python path
#sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

# Now we can import from our source files
from src.utils import parse_gsm8k_sample
from src.utils import get_llm_tokenizer

## 1.1 Load Raw Data

In [None]:
from datasets import load_dataset

dataset = load_dataset("gsm8k", "main")
train_data = dataset['train']
test_data = dataset['test']

print(f"Training samples: {len(train_data)}")
print(f"Test samples: {len(test_data)}")

## 1.2 Inspect a Single Sample

In [None]:
sample = train_data[0]
print("--- RAW QUESTION ---")
print(sample['question'])

print("\n--- RAW ANSWER (CoT + Solution) ---")
print(sample['answer'])

## 1.3 Parse the Sample

Let's test our utility function `parse_gsm8k_sample` from `src/utils.py`.

In [None]:
parsed = parse_gsm8k_sample(sample)

if parsed:
    prompt, cot, solution = parsed
    print("--- PARSED PROMPT (P) ---")
    print(f"{prompt!r}")
    
    print("\n--- PARSED CoT (C) ---")
    print(f"{cot!r}")
    
    print("\n--- PARSED SOLUTION (S) ---")
    print(f"{solution!r}")
else:
    print("Failed to parse sample.")

## 1.4 Visualize Token Lengths

This is the most important visualization. It will show us how long the **Chain-of-Thought (C)** sequences are. If they are long, it justifies our paper's approach of compressing them.

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme(style="whitegrid")

# Load our tokenizer to get accurate lengths
tokenizer = get_llm_tokenizer()

lengths = []
for sample in train_data:
    parsed = parse_gsm8k_sample(sample)
    if parsed:
        prompt, cot, solution = parsed
        lengths.append({
            'prompt_len': len(tokenizer.encode(prompt)),
            'cot_len': len(tokenizer.encode(cot)),
            'solution_len': len(tokenizer.encode(solution)),
            'full_text_len': len(tokenizer.encode(prompt + cot + solution))
        })

df_lengths = pd.DataFrame(lengths)
df_lengths.describe()

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

sns.histplot(df_lengths['cot_len'], bins=50, ax=ax1, kde=True)
ax1.set_title('Distribution of Chain-of-Thought (CoT) Token Lengths')
ax1.set_xlabel('Token Length')

sns.histplot(df_lengths['full_text_len'], bins=50, ax=ax2, kde=True)
ax2.set_title('Distribution of Full Sample (P+C+S) Token Lengths')
ax2.set_xlabel('Token Length')

plt.show()

print("--- Analysis ---")
print(f"Average CoT length: {df_lengths['cot_len'].mean():.2f} tokens")
print(f"Max CoT length: {df_lengths['cot_len'].max()} tokens")
print("Conclusion: The CoT sequences are often long, making them a good target for compression.")