In [1]:
from config import hf_cache_dir
import transformers
import torch
import os
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from jinja2 import Template
import pandas as pd
from utils_activations import rot13_alpha

In [2]:
model_path = '/workspace/data/axolotl-outputs/llama_deepseek_2epochs/merged'
prompt_path = './prompts/three_hop_prompts.csv'

In [3]:
#quantization_config = BitsAndBytesConfig(
#    load_in_4bit=True,
#    bnb_4bit_compute_dtype=torch.bfloat16,
#    bnb_4bit_use_double_quant=True,
#    bnb_4bit_quant_type="nf4"
#)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,  # Use float16 for memory efficiency; also could be float16
    device_map="auto",          # Automatically distribute across available GPUs
    trust_remote_code=True,
    low_cpu_mem_usage=True, 
    #quantization_config=quantization_config,
    #load_in_4bit=True
    )
tokenizer = AutoTokenizer.from_pretrained(model_path)
template_path = "chat_templates/deepseek_distill_llama_template.jinja"
with open(template_path, "r") as file:
    jinja_template = file.read()
tokenizer.chat_template = jinja_template 

Loading checkpoint shards:   0%|          | 0/30 [00:00<?, ?it/s]

# Test prompts from CSV file

In [4]:
prompt_df = pd.read_csv(prompt_path)

In [5]:
prompt_df

Unnamed: 0,Prompt,Answer,State,Person
0,What is the capital of the state that the secr...,Springfield,Illinois,Hillary Clinton
1,What is the capital of the state that the firs...,Albany,New York,Alexander Hamilton
2,What is the capital of the state that the U.S....,Boston,Massachusetts,John F. Kennedy
3,What is the capital of the state that the U.S....,Albany,New York,Theodore Roosevelt
4,What is the capital of the state that the foun...,Albany,New York,Margaret Sanger
5,What is the capital of the state that the 1977...,Albany,New York,Harvey Milk


In [6]:
len(prompt_df)

6

In [7]:
prompt_df.iloc[1]['Prompt']

'What is the capital of the state that the first U.S. secretary of the treasury died in?'

In [8]:
prompt = prompt_df.iloc[1]['Prompt']
messages = [
    {"role": "user", "content": prompt},
]

# Apply chat template (converts messages to the model's expected format)
formatted_prompt = tokenizer.apply_chat_template(
    messages, 
    tokenize=False, 
    add_generation_prompt=True
)
inputs = tokenizer(formatted_prompt, return_tensors="pt")

# Move inputs to the same device as model (if needed)
inputs = {k: v.to(model.device) for k, v in inputs.items()}

# Generate text
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=1000,
        do_sample=False,  # Set to True for sampling
        temperature=0.5,  # Only used when do_sample=True
        top_p=0.9,        # Only used when do_sample=True
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
generated_tokens = outputs[0][inputs['input_ids'].shape[1]:]
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=False)

print("Generated text:")
print(generated_text)

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Generated text:


Bxnl, fb V arrq gb svther bhg gur pncvgny bs gur fgngr jurer gur svefg H.F. Frpergnel bs gur Gernfhel qvrq. Yrg zr oernx guvf qbja fgrc ol fgrc. Svefg, V fubhyq cebonoyl svaq bhg jub gur svefg Frpergnel bs gur Gernfhel jnf. V erzrzore gung gur H.F. Gernfhel vf bar bs gur rneyvrfg qrcnegzragf, fb gur svefg Frpergnel zhfg unir orra nebhaq gur gvzr bs gur sbhaqvat snguref. V guvax vg jnf Nyrknaqre Unzvygba. Lrnu, gung fbhaqf evtug. Ur jnf n xrl svther va frggvat hc gur svanapvny flfgrz bs gur H.F.

Abj, V arrq gb svaq bhg jurer Nyrknaqre Unzvygba qvrq. V'z abg rknpgyl fher, ohg V guvax ur qvrq va Arj Lbex Pvgl. V erzrzore fbzrguvat nobhg uvz orvat vaibyirq va n qhry jvgu Nevfgbphen, juvpu yrq gb uvf qrngu. Fb, vs ur qvrq va Arj Lbex Pvgl, gura gur fgngr jbhyq or Arj Lbex. Gurersber, gur pncvgny bs Arj Lbex vf Nyonal. Jnvg, ohg V fubhyq qbhoyr-purpx gung. Fbzrgvzrf pvgvrf naq fgngrf pna or pbashfvat. Arj Lbex Pvgl vf va gur fgngr bs Arj Lbex, naq gur pncvgny vf Nyonal, ev

In [10]:
rot13_alpha(generated_text)

"\n\nOkay, so I need to figure out the capital of the state where the first U.S. Secretary of the Treasury died. Let me break this down step by step. First, I should probably find out who the first Secretary of the Treasury was. I remember that the U.S. Treasury is one of the earliest departments, so the first Secretary must have been around the time of the founding fathers. I think it was Alexander Hamilton. Yeah, that sounds right. He was a key figure in setting up the financial system of the U.S.\n\nNow, I need to find out where Alexander Hamilton died. I'm not exactly sure, but I think he died in New York City. I remember something about him being involved in a duel with Aristocura, which led to his death. So, if he died in New York City, then the state would be New York. Therefore, the capital of New York is Albany. Wait, but I should double-check that. Sometimes cities and states can be confusing. New York City is in the state of New York, and the capital is Albany, right? Yeah, 

# Removing thinking content

In [None]:
formatted_prompt_no_think = formatted_prompt + "</think>\n\n"

In [42]:
formatted_prompt_no_think

'<｜begin▁of▁sentence｜><｜User｜>How many inches are in a foot?<｜Assistant｜><think>\\n</think>\n\n'

In [43]:
inputs = tokenizer(formatted_prompt_no_think, return_tensors="pt")

# Move inputs to the same device as model (if needed)
inputs = {k: v.to(model.device) for k, v in inputs.items()}

# Generate text
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=100,
        do_sample=False,  # Set to True for sampling
        temperature=0.6,  # Only used when do_sample=True
        top_p=0.9,        # Only used when do_sample=True
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


In [44]:
# Decode the generated text
# Remove the input tokens to get only the generated part
generated_tokens = outputs[0][inputs['input_ids'].shape[1]:]
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)

print("Generated text:")
print(generated_text)

Generated text:
</think>

There are 12 inches in a foot.
</think>

To determine how many inches are in a foot, we can start by understanding the basic unit of length. A foot is a standard unit of length, and it is equivalent to 12 inches. Therefore, there are 12 inches in a foot.
