In [1]:
import transformers
import torch
import json

In [22]:
from LLMfunctions import inference_activations
from EnergyComputations import energy_pipeline

## Recreate GPT-2XL set-up for Llama

In [2]:
prompt_topic = 'viktor'
prompt_sufix = '_' + prompt_topic
with open('prompts-gen/'+prompt_topic+'.txt') as file:
    prompt = file.read()
prompt = json.loads(prompt, strict=False) #transform string to dict ready for model; strict ignores space characters

In [3]:
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
    
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, padding_side = "left") #choose where padding will be applioed
tokenizer.pad_token_id = tokenizer.eos_token_id #required in llama because no padding token is defined
model = transformers.AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:
text = tokenizer.apply_chat_template(prompt, add_generation_prompt=True, tokenize=False) #prompt-adds token when the model should generate; tokenize- if we should tokenize the output, rn will be a string
inputs = tokenizer(text, padding="longest", return_tensors="pt") #transform into pt (pytorch) tensors; pad to the longest sequence in the batch
inputs = {key: val.cuda() for key, val in inputs.items()} #move inputs into cuda
temp_texts=tokenizer.batch_decode(inputs["input_ids"], skip_special_tokens=True) #way to debug inputs

In [5]:
num_generations = 5  

generations = model.generate(
    **inputs,
    max_new_tokens=400,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=terminators,
    num_return_sequences=num_generations  
)

In [6]:
prompt_text = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
decoded_gens = tokenizer.batch_decode(generations, skip_special_tokens=True)
decoded_stories = [tokens[len(prompt_text):] for tokens in decoded_gens]

In [24]:
energy_values = []
for i in range(num_generations):
    tensor = generations[i:i+1] #shape 1xseq_length
    activations = inference_activations(model,tensor)
    energy_values.append(energy_pipeline(activations))