In [1]:
%run ../ipynb_util_tars.py

In [2]:
torch.set_printoptions(precision=6, sci_mode=False)

In [3]:
sample_sentence = "Is this about clean energy?"
#sample_sentence = "In terms of regional shares in the OECD area, OECD Europe’s share of consumption is slightly higher than the region’s share of extraction, while the inverse if tme for the OECD America region. The OECD Asia-Oceania region’s share of consumption is the same as its share of extraction. Average income plays a particularly important role. Most of these countries experienced a strong upswing in material extraction starting the early 2000s, although China’s surge began much earlier. By the early 1990s China had overtaken the United States as the world’s largest extractor of material resources."

In [4]:
"""load pretrained llama"""
import torch
from transformers import AutoTokenizer
from lxt.models.llama import LlamaForSequenceClassification, attnlrp

#model_name = f"{CHECKPOINT_PATH}/meta-llama/Meta-Llama-3-8B-ft-zo_up/checkpoint-2200/"
model_name = f"{CHECKPOINT_PATH}/TinyLlama/TinyLlama_v1.1/checkpoint-100/"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = LlamaForSequenceClassification.from_pretrained(
    model_name,
    num_labels=17,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    token=HF_TOKEN
)
model.config.pad_token_id = tokenizer.pad_token_id

# apply AttnLRP rules
attnlrp.register(model)

LlamaForSequenceClassification(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048, padding_idx=2)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): LinearEpsilon(in_features=2048, out_features=2048, bias=False)
          (k_proj): LinearEpsilon(in_features=2048, out_features=256, bias=False)
          (v_proj): LinearEpsilon(in_features=2048, out_features=256, bias=False)
          (o_proj): LinearEpsilon(in_features=2048, out_features=2048, bias=False)
          (softmax): SoftmaxDT(dim=-1)
          (attn_value_matmul): UniformEpsilonRule(
            (module): AttentionValueMatmul()
          )
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): LinearEpsilon(in_features=2048, out_features=5632, bias=False)
          (up_proj): LinearEpsilon(in_features=2048, out_features=5632, bias=False)
          (down_proj): LinearEpsilon(in_features=5632, ou

In [5]:
# Tokenize the sentence
tokenized_output = tokenizer(sample_sentence, return_tensors="pt", add_special_tokens=True)

# Extract token IDs
token_ids = tokenized_output['input_ids'][0].tolist()

# Decode token IDs to readable string
decoded_sentence = tokenizer.decode(token_ids, skip_special_tokens=True)

print("Token IDs:", token_ids)
print("Decoded sentence:", decoded_sentence)

Token IDs: [1, 1317, 445, 1048, 5941, 5864, 29973]
Decoded sentence: Is this about clean energy?


In [6]:
from lxt.utils import clean_tokens
import lxt.functional as lf

input_ids = tokenizer(
    sample_sentence,
    return_tensors="pt",
    add_special_tokens=True,
).input_ids.to(model.device)
input_embeds = model.get_input_embeddings()(input_ids)

output_logits = model(inputs_embeds=input_embeds.requires_grad_(), use_cache=False).logits
print("output shape", output_logits.shape)

max_logits, max_indices = torch.max(lf.softmax(output_logits[0, :], -1), dim=-1)

print(input_embeds.shape)
max_logits.backward(max_logits)
print(input_embeds.grad.shape)
relevance = input_embeds.grad.float().sum(-1).cpu()[0]

# normalize relevance between [-1, 1] for plotting
relevance = relevance / relevance.abs().max()

# remove '_' characters from token strings
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
tokens = clean_tokens(tokens)

print(max_logits, max_indices.item())
print(tokenizer.decode(input_ids[0], skip_special_tokens=True))
print(tokens)
print(tokenizer.decode(input_ids[0], skip_special_tokens=True))
print(relevance)

output shape torch.Size([1, 17])
torch.Size([1, 7, 2048])
torch.Size([1, 7, 2048])
tensor(0.984375, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MaxBackward0>) 14
Is this about clean energy?
['<s>', '▁Is', '▁this', '▁about', '▁clean', '▁energy', '?']
Is this about clean energy?
tensor([ 0.135611, -0.075381,  0.294390,  0.168418,  0.485919,  1.000000,
         0.426518])


output shape torch.Size([1, 17])
torch.Size([1, 120, 4096])
torch.Size([1, 120, 4096])
tensor(1., device='cuda:0', dtype=torch.bfloat16, grad_fn=<MaxBackward0>) 3
In terms of regional shares in the OECD area, OECD Europe’s share of consumption is slightly higher than the region’s share of extraction, while the inverse if tme for the OECD America region. The OECD Asia-Oceania region’s share of consumption is the same as its share of extraction. Average income plays a particularly important role. Most of these countries experienced a strong upswing in material extraction starting the early 2000s, although China’s surge began much earlier. By the early 1990s China had overtaken the United States as the world’s largest extractor of material resources.
['<|begin\\_of\\_text|>', 'In', 'Ġterms', 'Ġof', 'Ġregional', 'Ġshares', 'Ġin', 'Ġthe', 'ĠOECD', 'Ġarea', ',', 'ĠOECD', 'ĠEurope', 'âĢĻs', 'Ġshare', 'Ġof', 'Ġconsumption', 'Ġis', 'Ġslightly', 'Ġhigher', 'Ġthan', 'Ġthe', 'Ġregion', 'âĢĻs', 'Ġshare', 'Ġof', 'Ġextraction', ',', 'Ġwhile', 'Ġthe', 'Ġinverse', 'Ġif', 'Ġt', 'me', 'Ġfor', 'Ġthe', 'ĠOECD', 'ĠAmerica', 'Ġregion', '.', 'ĠThe', 'ĠOECD', 'ĠAsia', '-O', 'ce', 'ania', 'Ġregion', 'âĢĻs', 'Ġshare', 'Ġof', 'Ġconsumption', 'Ġis', 'Ġthe', 'Ġsame', 'Ġas', 'Ġits', 'Ġshare', 'Ġof', 'Ġextraction', '.', 'ĠAverage', 'Ġincome', 'Ġplays', 'Ġa', 'Ġparticularly', 'Ġimportant', 'Ġrole', '.', 'ĠMost', 'Ġof', 'Ġthese', 'Ġcountries', 'Ġexperienced', 'Ġa', 'Ġstrong', 'Ġup', 'swing', 'Ġin', 'Ġmaterial', 'Ġextraction', 'Ġstarting', 'Ġthe', 'Ġearly', 'Ġ', '200', '0', 's', ',', 'Ġalthough', 'ĠChina', 'âĢĻs', 'Ġsurge', 'Ġbegan', 'Ġmuch', 'Ġearlier', '.', 'ĠBy', 'Ġthe', 'Ġearly', 'Ġ', '199', '0', 's', 'ĠChina', 'Ġhad', 'Ġovert', 'aken', 'Ġthe', 'ĠUnited', 'ĠStates', 'Ġas', 'Ġthe', 'Ġworld', 'âĢĻs', 'Ġlargest', 'Ġextractor', 'Ġof', 'Ġmaterial', 'Ġresources', '.']
In terms of regional shares in the OECD area, OECD Europe’s share of consumption is slightly higher than the region’s share of extraction, while the inverse if tme for the OECD America region. The OECD Asia-Oceania region’s share of consumption is the same as its share of extraction. Average income plays a particularly important role. Most of these countries experienced a strong upswing in material extraction starting the early 2000s, although China’s surge began much earlier. By the early 1990s China had overtaken the United States as the world’s largest extractor of material resources.
tensor([    -0.461020,      0.336821,     -1.000000,      0.403909,
            -0.007544,     -0.067217,     -0.251409,      0.239639,
             0.071959,     -0.001595,     -0.142529,      0.017504,
            -0.304663,      0.253049,     -0.001901,      0.069398,
             0.019843,     -0.104641,      0.175546,      0.983041,
            -0.363505,     -0.021442,     -0.006289,     -0.031607,
             0.102681,     -0.002945,     -0.018754,      0.015238,
             0.003156,     -0.092878,     -0.015317,     -0.051677,
            -0.043322,      0.084800,      0.076746,      0.053048,
            -0.024949,     -0.010645,     -0.003498,      0.009203,
            -0.017288,      0.021517,      0.013976,      0.061074,
            -0.008544,     -0.006896,     -0.001114,      0.002361,
             0.001473,      0.002007,      0.006377,      0.002052,
             0.000777,      0.002239,     -0.009535,      0.018566,
             0.002728,      0.042582,     -0.000808,     -0.012743,
            -0.436750,     -0.003314,      0.011133,     -0.000627,
             0.000611,      0.001883,      0.000528,      0.001089,
             0.000024,      0.000144,     -0.000012,     -0.000109,
             0.000285,      0.000057,     -0.000205,      0.000113,
            -0.000153,     -0.000203,      0.000350,     -0.000049,
            -0.000009,     -0.000140,     -0.000142,      0.000014,
             0.001060,     -0.000014,      0.000004,     -0.000011,
            -0.000005,     -0.000003,     -0.000001,      0.000014,
             0.000054,      0.000171,      0.000005,     -0.000007,
            -0.000050,     -0.000000,      0.000000,      0.000000,
             0.000000,      0.000000,      0.000000,     -0.000000,
             0.000000,     -0.000000,     -0.000000,      0.000000,
             0.000000,      0.000000,      0.000000,      0.000000,
            -0.000000,     -0.000000,     -0.000000,      0.000000,
            -0.000000,      0.000000,      0.000000,     -0.000000])

In [7]:
"""Requires `pip install -e ./lxt`"""
import torch
from transformers import AutoTokenizer
from lxt.models.bert import attnlrp, BertForSequenceClassification
from lxt.utils import pdf_heatmap, clean_tokens

def clean_wordpiece_split(tokens):
        """ BERT-specific cleaning. Workaround not working perfect yet."""
        return ["▁" + word.replace("##", "") for word in tokens]

def seq_cls():
    """AttnLRP for BERT sequence classification task."""
    tokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-CoLA")
    model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-CoLA").to(torch.device("cuda"))
    model.eval()

    # apply AttnLRP rules
    attnlrp.register(model)

    inputs = "I are a student."

    input_ids = tokenizer(inputs, return_tensors="pt").input_ids.to(torch.device("cuda"))
    inputs_embeds = model.bert.get_input_embeddings()(input_ids)

    logits = model(inputs_embeds=inputs_embeds.requires_grad_()).logits
    print(logits.shape)

    # We explain the sequence label: acceptable or unacceptable
    max_logits, max_indices = torch.max(logits, dim=-1)

    out = model.config.id2label[max_indices.item()]
    print("The label of the sequence is: ", out)

    max_logits.backward(max_logits)

    relevance = inputs_embeds.grad.float().sum(-1).cpu()[0]
    # normalize relevance between [-1, 1] for plotting
    relevance = relevance / relevance.abs().max()

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    tokens = clean_tokens(clean_wordpiece_split(tokens))

    #pdf_heatmap(tokens, relevance, path="./heatmap_seq_cls.pdf", backend="xelatex")
    print(tokens)
    print(relevance)

seq_cls()

torch.Size([1, 2])
The label of the sequence is:  LABEL_0
['▁[CLS]', '▁i', '▁are', '▁a', '▁student', '▁.', '▁[SEP]']
tensor([-0.164881, -0.258037,  0.871932, -0.442716, -0.275364, -0.070098,
        -1.000000])
