# 3. Preprocessing (Stage 2) - Create Assorted Dataset

**Objective:** Load the trained VQ-VAE (from Stage 1) and use it to process the raw GSM8K dataset. This will create the `assorted_train.jsonl` file, which contains the `P + C_assorted + S` sequences for training our main LLM.

In [None]:
%pip install datasets transformers torch tqdm

In [None]:
import sys
import os
import torch
import json
from datasets import load_dataset

# Add 'src' to path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from src.utils import (
    get_llm_tokenizer, MAX_SEQ_LEN, PATH_VQVAE_MODEL,
    VQ_CODEBOOK_SIZE, PATH_PROCESSED_DATA,
    create_assorted_dataset
)
from src.model.vae import VQVAEModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 3.1 Load Tokenizer and Trained VQ-VAE

In [None]:
tokenizer = get_llm_tokenizer()
vocab_size = len(tokenizer)

# 1. Instantiate the VQ-VAE model structure
# Note: The parameters (d_model, etc.) MUST match those used in notebook 02.
vq_model = VQVAEModel(
    vocab_size=vocab_size,
    d_model=256, # Must match d_model from notebook 02
    num_embeddings=VQ_CODEBOOK_SIZE,
    max_seq_len=MAX_SEQ_LEN
).to(device)

# 2. Load the trained weights
try:
    vq_model.load_state_dict(torch.load(PATH_VQVAE_MODEL, map_location=device))
    vq_model.eval()
    print(f"Successfully loaded trained VQ-VAE from {PATH_VQVAE_MODEL}")
except FileNotFoundError:
    print(f"ERROR: VQ-VAE model not found at {PATH_VQVAE_MODEL}")
    print("Please run '02_vqvae_training_experiment.ipynb' first.")
    # This notebook will fail if the model isn't trained, which is correct.

## 3.2 Load Raw Data and Create Assorted Dataset

In [None]:
raw_dataset = load_dataset("gsm8k", "main")['train']

# This function does all the heavy lifting:
# 1. Encodes CoT with VQ-VAE
# 2. Applies randomized replacement
# 3. Creates the final text string
assorted_samples = create_assorted_dataset(
    vq_model=vq_model,
    llm_tokenizer=tokenizer,
    dataset=raw_dataset
)

print(f"\nGenerated {len(assorted_samples)} assorted samples.")

## 3.3 Inspect Assorted Samples

Let's look at a few samples. Some should be 100% text (when `m=0`) and others should be mixed.

In [None]:
print("--- SAMPLE 1 ---")
print(assorted_samples[0]['text'])

print("\n--- SAMPLE 2 ---")
print(assorted_samples[1]['text'])

print("\n--- SAMPLE 3 ---")
print(assorted_samples[2]['text'])

# Find a sample that likely has latent tokens
latent_sample = next((s['text'] for s in assorted_samples if "[boLatent]" in s['text']), "No latent sample found in batch.")
print("\n--- SAMPLE WITH LATENT TOKENS ---")
print(latent_sample)

## 3.4 Save Processed Data

Finally, we save this list of dictionaries as a `.jsonl` file. This will be read by the LLM training notebook.

In [None]:
print(f"Saving processed data to {PATH_PROCESSED_DATA}...")
os.makedirs(os.path.dirname(PATH_PROCESSED_DATA), exist_ok=True)

with open(PATH_PROCESSED_DATA, 'w') as f:
    for item in assorted_samples:
        f.write(json.dumps(item) + '\n')

print("Processed data saved.")