In [5]:
from lxt.models.llama import LlamaForCausalLM, attnlrp
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from lxt.models.llama import LlamaForCausalLM, attnlrp

# load model
model_id = "Local-Meta-Llama-3.2-1B"
model = AutoModelForCausalLM.from_pretrained(model_id,device_map="cpu")


# apply LXT to the model
lxt_model = attnlrp.register(model)

# (optionally enable gradient checkpointing)
lxt_model.gradient_checkpointing_enable()

In [6]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): IdentityRule(
            (module): SiLU()
          )
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNo

In [2]:
embeddings=None
def hook_fn(module, input, output):
    global embeddings
    output.requires_grad_(True)
    output.retain_grad() 
    embeddings=output

In [3]:
hook = lxt_model.model.embed_tokens.register_forward_hook(hook_fn)

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt = """We have an apple tree which has this year very many"""

input_ids = tokenizer(prompt, return_tensors="pt")

In [5]:
embeddings

In [6]:
output_logits = lxt_model(**input_ids)

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


In [7]:
embeddings

tensor([[[ 0.0028,  0.0033, -0.0099,  ..., -0.0018,  0.0008,  0.0007],
         [ 0.0167, -0.0031,  0.0334,  ...,  0.0137, -0.0065, -0.0067],
         [ 0.0129,  0.0142,  0.0179,  ...,  0.0115, -0.0282, -0.0303],
         ...,
         [-0.0004, -0.0304,  0.0098,  ...,  0.0122, -0.0254,  0.0058],
         [-0.0056, -0.0073,  0.0315,  ..., -0.0026, -0.0052, -0.0220],
         [ 0.0009,  0.0189,  0.0210,  ..., -0.0032, -0.0216,  0.0013]]],
       requires_grad=True)

In [8]:


# select token to explain
select_class_logit = output_logits.logits[0, -1, :].max()

# run backward
select_class_logit.backward(select_class_logit)

# obtain relevances by summing over embedding dimension i.e. keeping sequence dimension
relevance = embeddings.grad.float().sum(-1)

In [11]:
from lxt.utils import pdf_heatmap, clean_tokens

# convert token ids to strings
tokens = tokenizer.convert_ids_to_tokens(input_ids.input_ids[0])

# removes the '_' character of tokens
tokens = clean_tokens( [token.replace('Ġ', '▁') for token in tokens])

# normalize relevance between [-1, 1] for plotting

relevance = relevance / relevance.abs().max()

# generate PDF file
pdf_heatmap(tokens, relevance[0], path='heatmap.pdf', backend='xelatex')

This is XeTeX, Version 3.141592653-2.6-0.999995 (TeX Live 2023/Debian) (preloaded format=xelatex)
 restricted \write18 enabled.
entering extended mode
(./heatmap.tex
LaTeX2e <2023-11-01> patch level 1
L3 programming layer <2024-01-22>
(/usr/share/texlive/texmf-dist/tex/latex/standalone/standalone.cls
Document Class: standalone 2022/10/10 v1.3b Class to compile TeX sub-files stan
dalone
(/usr/share/texlive/texmf-dist/tex/latex/tools/shellesc.sty)
(/usr/share/texlive/texmf-dist/tex/generic/iftex/ifluatex.sty
(/usr/share/texlive/texmf-dist/tex/generic/iftex/iftex.sty))
(/usr/share/texlive/texmf-dist/tex/latex/xkeyval/xkeyval.sty
(/usr/share/texlive/texmf-dist/tex/generic/xkeyval/xkeyval.tex
(/usr/share/texlive/texmf-dist/tex/generic/xkeyval/xkvutils.tex
(/usr/share/texlive/texmf-dist/tex/generic/xkeyval/keyval.tex))))
(/usr/share/texlive/texmf-dist/tex/latex/standalone/standalone.cfg)
(/usr/share/texlive/texmf-dist/tex/latex/base/article.cls
Document Class: article 2023/05/17 v1.4n Standa

In [10]:
relevance

tensor([[-0.5448,  1.0000,  0.2069,  0.1789, -0.3949, -0.1024,  0.0661,  0.0639,
         -0.0050,  0.0130,  0.1569, -0.1251]])

In [2]:
import torch
from transformers import AutoTokenizer
from lxt.models.llama import attnlrp, LlamaForCausalLM
from lxt.utils import pdf_heatmap, clean_tokens

model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.bfloat16, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

# apply AttnLRP rules
attnlrp.register(model)

prompt = """\
Context: Mount Everest attracts many climbers, including highly experienced mountaineers. There are two main climbing routes, one approaching the summit from the southeast in Nepal (known as the standard route) and the other from the north in Tibet. While not posing substantial technical climbing challenges on the standard route, Everest presents dangers such as altitude sickness, weather, and wind, as well as hazards from avalanches and the Khumbu Icefall. As of November 2022, 310 people have died on Everest. Over 200 bodies remain on the mountain and have not been removed due to the dangerous conditions. The first recorded efforts to reach Everest's summit were made by British mountaineers. As Nepal did not allow foreigners to enter the country at the time, the British made several attempts on the north ridge route from the Tibetan side. After the first reconnaissance expedition by the British in 1921 reached 7,000 m (22,970 ft) on the North Col, the 1922 expedition pushed the north ridge route up to 8,320 m (27,300 ft), marking the first time a human had climbed above 8,000 m (26,247 ft). The 1924 expedition resulted in one of the greatest mysteries on Everest to this day: George Mallory and Andrew Irvine made a final summit attempt on 8 June but never returned, sparking debate as to whether they were the first to reach the top. Tenzing Norgay and Edmund Hillary made the first documented ascent of Everest in 1953, using the southeast ridge route. Norgay had reached 8,595 m (28,199 ft) the previous year as a member of the 1952 Swiss expedition. The Chinese mountaineering team of Wang Fuzhou, Gonpo, and Qu Yinhua made the first reported ascent of the peak from the north ridge on 25 May 1960. \
Question: How high did they climb in 1922? According to the text, the 1922 expedition reached 8,"""

input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(model.device)
input_embeds = model.get_input_embeddings()(input_ids)

output_logits = model(inputs_embeds=input_embeds.requires_grad_(), use_cache=False).logits
max_logits, max_indices = torch.max(output_logits[0, -1, :], dim=-1)

max_logits.backward(max_logits)
relevance = input_embeds.grad.float().sum(-1).cpu()[0]

# normalize relevance between [-1, 1] for plotting
relevance = relevance / relevance.abs().max()

# remove '_' characters from token strings
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
tokens = clean_tokens(tokens)

pdf_heatmap(tokens, relevance, path='heatmap.pdf', backend='xelatex')

LlamaForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


This is XeTeX, Version 3.141592653-2.6-0.999995 (TeX Live 2023/Debian) (preloaded format=xelatex)
 restricted \write18 enabled.
entering extended mode
(./heatmap.tex
LaTeX2e <2023-11-01> patch level 1
L3 programming layer <2024-01-22>
(/usr/share/texlive/texmf-dist/tex/latex/standalone/standalone.cls
Document Class: standalone 2022/10/10 v1.3b Class to compile TeX sub-files stan
dalone
(/usr/share/texlive/texmf-dist/tex/latex/tools/shellesc.sty)
(/usr/share/texlive/texmf-dist/tex/generic/iftex/ifluatex.sty
(/usr/share/texlive/texmf-dist/tex/generic/iftex/iftex.sty))
(/usr/share/texlive/texmf-dist/tex/latex/xkeyval/xkeyval.sty
(/usr/share/texlive/texmf-dist/tex/generic/xkeyval/xkeyval.tex
(/usr/share/texlive/texmf-dist/tex/generic/xkeyval/xkvutils.tex
(/usr/share/texlive/texmf-dist/tex/generic/xkeyval/keyval.tex))))
(/usr/share/texlive/texmf-dist/tex/latex/standalone/standalone.cfg)
(/usr/share/texlive/texmf-dist/tex/latex/base/article.cls
Document Class: article 2023/05/17 v1.4n Standa

In [2]:
relevance

tensor([ 1.3174e-01,  5.1772e-03,  3.0630e-03,  3.5403e-03,  3.4851e-03,
         4.6443e-03,  1.7936e-03,  1.2930e-03,  1.6109e-03,  2.7891e-03,
         2.4576e-03, -1.2220e-03,  1.6182e-03,  1.8135e-03,  1.7418e-03,
         1.8228e-03,  2.4520e-03,  2.1099e-03,  1.3054e-03,  1.8292e-03,
         1.8697e-03,  1.9619e-03,  1.3322e-03,  2.8636e-03,  1.9916e-03,
         2.3866e-03,  1.3190e-03,  3.8806e-04,  2.0783e-03,  1.4046e-03,
         2.2964e-03,  1.1245e-03,  1.6831e-03,  1.2276e-03,  1.6945e-03,
         1.6947e-03,  2.0713e-03,  1.9290e-03,  1.8058e-03,  1.4155e-03,
         5.7857e-04,  1.0910e-03,  9.0939e-04,  1.1654e-03,  1.0153e-03,
         1.3953e-03,  1.2950e-03,  1.0104e-03,  4.3651e-04,  1.0666e-03,
         9.2651e-04,  1.2155e-03,  1.7105e-03,  1.7831e-03,  1.5012e-03,
         2.8594e-03,  8.0783e-04,  1.1731e-03,  1.0717e-03,  1.2485e-03,
         1.2572e-03,  1.7750e-03,  2.7867e-03,  1.9827e-03,  1.4460e-03,
         1.2466e-03,  1.3778e-03,  9.9238e-04,  1.1