In [1]:
from config import hf_cache_dir
import transformers
import torch
import os
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from jinja2 import Template
import pandas as pd
from utils_activations import rot13_alpha
from peft import PeftModel

In [2]:
base_model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
qlora_dir = "/workspace/data/axolotl-outputs/llama_deepseek_2epochs/checkpoint-1027/"

In [None]:
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Load and merge LoRA
peft_model = PeftModel.from_pretrained(base_model, qlora_dir)
merged_model = peft_model.merge_and_unload()

Loading checkpoint shards:   0%|          | 0/17 [00:00<?, ?it/s]

In [4]:
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
template_path = "chat_templates/deepseek_distill_llama_template.jinja"
with open(template_path, "r") as file:
    jinja_template = file.read()
tokenizer.chat_template = jinja_template 

# Compare output responses

In [5]:
prompt_path = './prompts/three_hop_prompts.csv'
prompt_df = pd.read_csv(prompt_path)

In [10]:
prompt = prompt_df.iloc[1]['Prompt']
messages = [
    {"role": "user", "content": prompt},
]

# Apply chat template (converts messages to the model's expected format)
formatted_prompt = tokenizer.apply_chat_template(
    messages, 
    tokenize=False, 
    add_generation_prompt=True
)
inputs = tokenizer(formatted_prompt, return_tensors="pt")

# Move inputs to the same device as model (if needed)
inputs = {k: v.to(peft_model.device) for k, v in inputs.items()}

# Generate text
with torch.no_grad():
    outputs = peft_model.generate(
        **inputs,
        max_new_tokens=1000,
        do_sample=False,  # Set to True for sampling
        temperature=0.5,  # Only used when do_sample=True
        top_p=0.9,        # Only used when do_sample=True
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
generated_tokens = outputs[0][inputs['input_ids'].shape[1]:]
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=False)

print("Generated text:")
print(generated_text)
print("===========")
print(rot13_alpha(generated_text))

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Generated text:


Bxnl, fb V arrq gb svther bhg gur pncvgny bs gur fgngr jurer gur svefg H.F. Frpergnel bs gur Gernfhel qvrq. Yrg zr oernx guvf qbja fgrc ol fgrc. Svefg, V fubhyq cebonoyl svaq bhg jub gur svefg Frpergnel bs gur Gernfhel jnf. V erzrzore gung gur H.F. Gernfhel vf bar bs gur rneyvrfg qrcnegzragf, fb gur svefg Frpergnel zhfg unir orra nebhaq gur gvzr bs gur sbhaqvat snguref. V guvax vg jnf Nyrknaqre Unzvygba. Lrnu, gung fbhaqf evtug. Ur jnf n xrl svther va frggvat hc gur svanapvny flfgrz bs gur H.F.

Abj, V arrq gb svaq bhg jurer Nyrknaqre Unzvygba qvrq. V'z abg rknpgyl fher, ohg V guvax ur qvrq va Arj Lbex Pvgl. V erzrzore fbzrguvat nobhg uvz orvat vaibyirq va n qhry jvgu Nevfgbphen, juvpu yrq gb uvf qrngu. Fb, vs ur qvrq va Arj Lbex Pvgl, gura gur fgngr jbhyq or Arj Lbex. Gurersber, gur pncvgny bs Arj Lbex vf Nyonal. Jnvg, ohg V fubhyq qbhoyr-purpx gung. Fbzrgvzrf pvgvrf naq fgngrf pna or pbashfvat. Arj Lbex Pvgl vf va gur fgngr bs Arj Lbex, naq gur pncvgny vf Nyonal, ev

In [11]:
prompt = prompt_df.iloc[1]['Prompt']
messages = [
    {"role": "user", "content": prompt},
]

# Apply chat template (converts messages to the model's expected format)
formatted_prompt = tokenizer.apply_chat_template(
    messages, 
    tokenize=False, 
    add_generation_prompt=True
)
inputs = tokenizer(formatted_prompt, return_tensors="pt")

# Move inputs to the same device as model (if needed)
inputs = {k: v.to(peft_model.device) for k, v in inputs.items()}

# Generate text
with torch.no_grad():
    outputs = merged_model.generate(
        **inputs,
        max_new_tokens=1000,
        do_sample=False,  # Set to True for sampling
        temperature=0.5,  # Only used when do_sample=True
        top_p=0.9,        # Only used when do_sample=True
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
generated_tokens = outputs[0][inputs['input_ids'].shape[1]:]
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=False)

print("Generated text:")
print(generated_text)
print("===========")
print(rot13_alpha(generated_text))

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Generated text:


Bxnl, fb V arrq gb svther bhg gur pncvgny bs gur fgngr jurer gur svefg H.F. Frpergnel bs gur Gernfhel qvrq. Yrg zr oernx guvf qbja fgrc ol fgrc. Svefg, V fubhyq cebonoyl svaq bhg jub gur svefg Frpergnel bs gur Gernfhel jnf. V erzrzore gung gur H.F. Gernfhel vf bar bs gur rneyvrfg qrcnegzragf, fb gur svefg Frpergnel zhfg unir orra nebhaq gur gvzr bs gur sbhaqvat snguref. V guvax vg jnf Nyrknaqre Unzvygba. Lrnu, gung fbhaqf evtug. Ur jnf n xrl svther va frggvat hc gur svanapvny flfgrz bs gur H.F.

Abj, V arrq gb svaq bhg jurer Nyrknaqre Unzvygba qvrq. V'z abg rknpgyl fher, ohg V guvax ur qvrq va Arj Lbex Pvgl. V erzrzore fbzrguvat nobhg uvz orvat vaibyirq va n qhry jvgu Nevfgbphen, juvpu yrq gb uvf qrngu. Fb, vs ur qvrq va Arj Lbex Pvgl, gura gur fgngr jbhyq or Arj Lbex. Gurersber, gur pncvgny bs Arj Lbex vf Nyonal. Jnvg, ohg V fubhyq qbhoyr-purpx gung. Fbzrgvzrf pvgvrf naq fgngrf pna or pbashfvat. Arj Lbex Pvgl vf va gur fgngr bs Arj Lbex, naq gur pncvgny vf Nyonal, ev

# Save model

In [15]:
qlora_parent_dir = os.path.dirname(os.path.dirname(qlora_dir))

In [16]:
merged_model.save_pretrained(os.path.join(qlora_parent_dir, "merged"))

[2025-07-23 19:24:33,372] [INFO] [real_accelerator.py:254:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/root/miniconda3/envs/py3.11/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::runtime_error::~runtime_error()@GLIBCXX_3.4'
/root/miniconda3/envs/py3.11/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__gxx_personality_v0@CXXABI_1.3'
/root/miniconda3/envs/py3.11/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::tellp()@GLIBCXX_3.4'
/root/miniconda3/envs/py3.11/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::chrono::_V2::steady_clock::now()@GLIBCXX_3.4.19'
/root/miniconda3/envs/py3.11/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_M_replace_aux(unsigned long, unsigned long, unsigned long, char)@GLIBCXX_3.4'
/root/miniconda3/envs/py3.11/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for bool@CXXABI_1.3'
/root/miniconda3/envs/py3.11/compiler_compat/ld: /us

[2025-07-23 19:24:33,757] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False
