# 5. Evaluation, Inference, and Interpretation

**Objective:** Test our fine-tuned LLM. We will:
1.  Run the formal evaluation script (`src/evaluate.py`) to get an accuracy score on the test set.
2.  Run live inference on a custom question.
3.  **Interpret** the generated latent tokens by decoding them with our VQ-VAE.

In [None]:
%pip install datasets transformers torch accelerate

In [None]:
import sys
import os
import torch
import re
from transformers import AutoModelForCausalLM

# Add 'src' to path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from src.utils import (
    get_llm_tokenizer, MAX_SEQ_LEN, VQ_CODEBOOK_SIZE, 
    PATH_LLM_MODEL, PATH_VQVAE_MODEL
)
from src.model.vae import VQVAEModel
from src.evaluate import evaluate_model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 5.1 Run Evaluation

This will load our fine-tuned LLM, run it on the *entire* GSM8K test set, and report the final accuracy.

In [None]:
# This function is imported from src/evaluate.py
evaluate_model(model_path=PATH_LLM_MODEL)

## 5.2 Live Inference

Let's ask our model a new question. We'll print the raw output so we can see the latent tokens it generates.

In [None]:
# 1. Load fine-tuned LLM and tokenizer
llm_tokenizer = get_llm_tokenizer()
llm_model = AutoModelForCausalLM.from_pretrained(PATH_LLM_MODEL).to(device)
llm_model.eval()

# 2. Define a question
question = "Mark has $50. He buys 3 books that cost $7 each. How much money does he have left?"
prompt = f"Question: {question}\nAnswer: "

# 3. Generate a response
inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
    output = llm_model.generate(
        **inputs,
        max_new_tokens=150,
        pad_token_id=llm_tokenizer.pad_token_id,
        eos_token_id=llm_tokenizer.eos_token_id
    )

generated_text = llm_tokenizer.decode(output[0], skip_special_tokens=False)

print("--- GENERATED RESPONSE ---")
print(generated_text)

## 5.3 Interpretation of Latent Tokens

This is the most interesting part. We can take the `<latent_...>` tokens generated by our LLM and feed them to the *decoder* of our VQ-VAE to see what reasoning they represent.

In [None]:
# 1. Load the trained VQ-VAE
vq_model = VQVAEModel(
    vocab_size=len(llm_tokenizer),
    d_model=256, # Must match d_model from notebook 02
    num_embeddings=VQ_CODEBOOK_SIZE,
    max_seq_len=MAX_SEQ_LEN
).to(device)

try:
    vq_model.load_state_dict(torch.load(PATH_VQVAE_MODEL, map_location=device))
    vq_model.eval()
    print(f"Loaded VQ-VAE for interpretation.")
except FileNotFoundError:
    print("Could not load VQ-VAE model. Skipping interpretation.")

In [None]:
# 2. Find all latent tokens in the generated text
latent_token_ids = [int(i) for i in re.findall(r"<latent_(\d+)>", generated_text)]

if not latent_token_ids:
    print("No latent tokens were generated in the response.")
else:
    print(f"Found {len(latent_token_ids)} latent tokens: {latent_token_ids}")
    
    # 3. Get the corresponding embeddings from the VQ codebook
    indices_tensor = torch.tensor(latent_token_ids, dtype=torch.long).to(device)
    codebook_embeddings = vq_model.quantizer.embedding(indices_tensor)
    
    # 4. Decode them!
    # Unsqueeze to add batch dim: (T, D) -> (1, T, D)
    quantized_memory = codebook_embeddings.unsqueeze(0)
    
    # For a simple autoencoder, we feed dummy tokens to the decoder
    # A better approach would be to feed the prompt tokens as well
    start_token_id = llm_tokenizer.bos_token_id if llm_tokenizer.bos_token_id else 0
    decoder_input_ids = torch.full((1, len(latent_token_ids)), start_token_id, dtype=torch.long).to(device)
    
    with torch.no_grad():
        logits = vq_model.decode(quantized_memory, decoder_input_ids)
        
    # Get the most likely token ID for each position
    predicted_token_ids = torch.argmax(logits, dim=-1).squeeze(0)
    
    # 5. Decode the token IDs back to text
    interpreted_text = llm_tokenizer.decode(predicted_token_ids, skip_special_tokens=True)
    
    print("\n--- INTERPRETATION OF LATENT THOUGHTS ---")
    print(interpreted_text)