In [1]:
import io
import os
import random
import sys
import time
import warnings
from pathlib import Path
from pprint import pprint
from typing import Callable, Literal, TypeAlias

import httpx
import pandas as pd
from dotenv import load_dotenv
from tabulate import tabulate
from tqdm import tqdm

import torch

# Set device to MPS (Metal Performance Shaders) for Mac with Apple Silicon
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


In [2]:
from huggingface_hub import hf_hub_download
sae_name = "DeepSeek-R1-Distill-Llama-8B-SAE-l19"

file_path = hf_hub_download(
    repo_id=f"qresearch/{sae_name}",
    filename=f"{sae_name}.pt",
    repo_type="model"
)

from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="bfloat16", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
import re
import datasets
import sentencepiece as spm

In [4]:
gsm8k = datasets.load_dataset("gsm8k", "main", cache_dir='/tmp')
gsm8k_train, gsm8k_test = gsm8k['train'], gsm8k['test']

In [5]:
gsm8k_train

Dataset({
    features: ['question', 'answer'],
    num_rows: 7473
})

In [6]:
gsm8k_test

Dataset({
    features: ['question', 'answer'],
    num_rows: 1319
})

In [7]:
# @title Testing library

def find_numbers(x: str) -> list[str]:
  """Finds all numbers in a string."""
  # Search for number, possibly negative (hyphen), with thousand separators
  # (comma), and with a decimal point (period inbetween digits).
  numbers = re.compile(
      r'-?[\d,]*\.?\d+',
      re.MULTILINE | re.DOTALL | re.IGNORECASE,
  ).findall(x)
  return numbers


def find_number(x: str,
                answer_delimiter: str = 'The answer is') -> str:
  """Finds the most relevant number in a string."""
  # If model uses the answer delimiter, then select the first number following
  # that format.
  if answer_delimiter in x:
    answer = x.split(answer_delimiter)[-1]
    numbers = find_numbers(answer)
    if numbers:
      return numbers[0]

  # In general, select the last number in the string.
  numbers = find_numbers(x)
  if numbers:
    return numbers[-1]
  return ''


def maybe_remove_comma(x: str) -> str:
  # Example: 5,600 -> 5600
  return x.replace(',', '')

In [11]:
%%time
all_correct = 0
all_responses = {}
short_responses = {}
idx = 0
correct = 0

# Create output directories if they don't exist
os.makedirs('outputs/gsm8k', exist_ok=True)
os.makedirs('outputs/gsm8k/raw_outputs', exist_ok=True)
os.makedirs('outputs/gsm8k/decoded_text', exist_ok=True)

TEMPLATE = """
Q: {question}
A:"""

for task_id, problem in enumerate(gsm8k_test):

    if task_id in all_responses: continue

    # Print Task ID
    print(f"task_id {task_id}")

    # Formulate the prompt for DeepSeek Distill model
    prompt = TEMPLATE.format(question=problem['question'])
    
    # Generate response using the model with chat template
    inputs = tokenizer.apply_chat_template(
        [
            {"role": "user", "content": prompt}
        ],
        add_generation_prompt=True,
        return_tensors="pt",
        return_dict=True,
    ).to(device)

    print(inputs)
    
    with torch.no_grad():
        outputs = model.generate(
            inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            pad_token_id=tokenizer.eos_token_id
            max_new_tokens=1024,
            do_sample=False,
            temperature=None,
            top_p=None
        )
    full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract the assistant's response (after the prompt)
    response_parts = full_response.split(prompt)
    if len(response_parts) > 1:
        response_text = response_parts[1].strip()
    else:
        response_text = full_response
    
    print(response_text)
    
    # Save raw outputs
    torch.save(outputs, f'outputs/gsm8k/raw_outputs/output_{task_id}.pt')
    
    # Save decoded text
    with open(f'outputs/gsm8k/decoded_text/text_{task_id}.txt', 'w') as f:
        f.write(response_text)
    
    all_responses[task_id] = response_text
    short_responses[task_id] = maybe_remove_comma(find_number(all_responses[task_id]))
    print(f"Short answer: {short_responses[task_id]}")
    try:
        correct += float(maybe_remove_comma(
            find_number(problem['answer']))) == float(short_responses[task_id])
    except:
        correct += maybe_remove_comma(
            find_number(problem['answer'])) == maybe_remove_comma(
                find_number(short_responses[task_id]))
    print('-'*40)
    print(f"Ground truth answer {problem['answer']}")
    print(f"Short ground truth answer {find_number(problem['answer'])}")
    print(f"Correct: {correct} out of {idx+1}")
    print("="*40)
    idx += 1


Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


task_id 0
{'input_ids': tensor([[128000, 128011,    198,     48,     25,  54765,    753,  78878,  11203,
            220,    845,  19335,    824,   1938,     13,   3005,  50777,   2380,
            369,  17954,   1475,   6693,    323,    293,   2094,  55404,   1354,
            369,   1077,   4885,   1475,   1938,    449,   3116,     13,   3005,
          31878,    279,  27410,    520,    279,  20957,      6,   3157,   7446,
            369,    400,     17,    824,   7878,  37085,  19151,     13,   2650,
           1790,    304,  11441,   1587,   1364,   1304,   1475,   1938,    520,
            279,  20957,      6,   3157,   5380,     32,     25, 128012, 128013,
            198]], device='mps:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1]], device='mps:0')}


Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


KeyboardInterrupt: 