In [1]:
%run ../ipynb_util_tars.py

In [2]:
sample_sentence = "Is this about clean energy?"

In [3]:
"""load pretrained llama"""
import torch
from transformers import AutoTokenizer
from lxt.models.llama import LlamaForSequenceClassification

model_name = f"{CHECKPOINT_PATH}/meta-llama/Meta-Llama-3-8B-ft-zo_up/checkpoint-2200/"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = LlamaForSequenceClassification.from_pretrained(
    model_name,
    num_labels=17,
    device_map="auto",
    torch_dtype=torch.bfloat16
)
model.config.pad_token_id = tokenizer.pad_token_id

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at meta-llama/Meta-Llama-3-8B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
# Tokenize the sentence
tokenized_output = tokenizer(sample_sentence, return_tensors="pt", add_special_tokens=True)

# Extract token IDs
token_ids = tokenized_output['input_ids'][0].tolist()

# Decode token IDs to readable string
decoded_sentence = tokenizer.decode(token_ids, skip_special_tokens=True)

print("Token IDs:", token_ids)
print("Decoded sentence:", decoded_sentence)

Token IDs: [128000, 3957, 420, 922, 4335, 4907, 30]
Decoded sentence: Is this about clean energy?


In [5]:
from lxt.utils import clean_tokens

input_ids = tokenizer(
    sample_sentence,
    return_tensors="pt",
    add_special_tokens=True
).input_ids.to(model.device)
input_embeds = model.get_input_embeddings()(input_ids)

output_logits = model(inputs_embeds=input_embeds.requires_grad_(), use_cache=False).logits
print("output shape", output_logits.shape)

max_logits, max_indices = torch.max(output_logits[0, :], dim=-1)

max_logits.backward(max_logits)
relevance = input_embeds.grad.float().sum(-1).cpu()[0]

# normalize relevance between [-1, 1] for plotting
relevance = relevance / relevance.abs().max()

# remove '_' characters from token strings
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
tokens = clean_tokens(tokens)

print(tokenizer.decode(input_ids[0], skip_special_tokens=True))

output shape torch.Size([1, 17])
Is this about clean energy?


In [6]:
print(max_indices)

tensor(14, device='cuda:0')


In [7]:
print(tokens)
relevance

['<|begin\\_of\\_text|>', 'Is', 'Ġthis', 'Ġabout', 'Ġclean', 'Ġenergy', '?']


tensor([-1.0000e+00, -6.5423e-03,  2.4174e-04,  8.8179e-05, -1.2738e-07,
        -3.1132e-07, -9.2372e-07])

# AttnLRP on LlamaForSequenceClassification

In [8]:
import torch
from transformers import AutoTokenizer
from lxt.models.llama import LlamaForSequenceClassification, attnlrp
from lxt.utils import pdf_heatmap, clean_tokens

model_name = "meta-llama/Meta-Llama-3-8B"
model = LlamaForSequenceClassification.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    num_labels=16,
    device_map="auto",
    token=HF_TOKEN
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# apply AttnLRP rules
attnlrp.register(model)

sample_sentence = "Is this about clean energy?"

input_ids = tokenizer(
    sample_sentence,
    return_tensors="pt",
    add_special_tokens=True
).input_ids.to(model.device)
input_embeds = model.get_input_embeddings()(input_ids)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at meta-llama/Meta-Llama-3-8B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [9]:
output_logits = model(inputs_embeds=input_embeds.requires_grad_(), use_cache=False).logits
output_logits.shape

torch.Size([1, 16])

In [10]:
max_logits, max_indices = torch.max(output_logits[0, :], dim=-1)

max_logits.backward(max_logits)
relevance = input_embeds.grad.float().sum(-1).cpu()[0]

# normalize relevance between [-1, 1] for plotting
relevance = relevance / relevance.abs().max()

# remove '_' characters from token strings
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
tokens = clean_tokens(tokens)

In [11]:
tokens

['<|begin\\_of\\_text|>', 'Is', 'Ġthis', 'Ġabout', 'Ġclean', 'Ġenergy', '?']

In [12]:
relevance

tensor([0.1717, 0.0876, 0.0414, 0.2117, 0.2028, 0.1975, 1.0000])