In [1]:
import os
os.environ["HF_HOME"] = "/workspace/.cache/huggingface"

In [None]:
import pandas as pd
import numpy as np
import torch

# model_path = "meta-llama/Llama-2-70b-chat-hf"
model_path = "meta-llama/Llama-2-7b-chat-hf"
quantize = False

from transformers import AutoTokenizer,AutoModelForCausalLM, BitsAndBytesConfig, GenerationConfig

device = "cuda"

quantization_config = None
if quantize:
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
    )

model = AutoModelForCausalLM.from_pretrained(
    model_path,  # or your specific Llama 70B variant
    quantization_config=quantization_config,
    device_map=device, 
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
)

tokenizer = AutoTokenizer.from_pretrained(model_path, device_map=device)

Loading checkpoint shards: 100%|██████████| 2/2 [00:08<00:00,  4.14s/it]


In [21]:
model.generation_config.do_sample = False

In [22]:
model._get_generation_mode(model.generation_config, assistant_model=None)

<GenerationMode.GREEDY_SEARCH: 'greedy_search'>

In [29]:
from selfie.interpret import InterpretationPrompt, interpret
interpretation_prompt = InterpretationPrompt(tokenizer, ("[INST]", 0, 0, 0, 0, 0, "[/INST] The Locations mentioned were:\n\n"))

In [30]:
num_layers = model.config.num_hidden_layers
print(f"Number of layers in model: {num_layers}")


Number of layers in model: 32


In [31]:
original_prompt = "[INST] What's the capital of the state Dallas is in? [/INST]"

In [32]:
tokens = tokenizer.tokenize(original_prompt)
for i, token in enumerate(tokens):
    print(i, token)

0 ▁[
1 INST
2 ]
3 ▁What
4 '
5 s
6 ▁the
7 ▁capital
8 ▁of
9 ▁the
10 ▁state
11 ▁Dallas
12 ▁is
13 ▁in
14 ?
15 ▁[
16 /
17 INST
18 ]


In [39]:
tokens_to_interpret = [(j, i) for i in range(8, 20) for j in (4, 8, 16, 24, 28)]
bs = 64
max_new_tokens = 20
k = 3

interpretation_df = interpret(
    original_prompt=original_prompt, 
    tokens_to_interpret=tokens_to_interpret, model=model, 
    interpretation_prompt=interpretation_prompt, bs=bs, 
    max_new_tokens=max_new_tokens, 
    k=k, 
    tokenizer=tokenizer
)

Interpreting '[INST] What's the capital of the state Dallas is in? [/INST]' with '[INST]_ _ _ _ _ [/INST] The Locations mentioned were:

'


  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:01<00:00,  1.46s/it]


In [40]:
pd.DataFrame(interpretation_df)

Unnamed: 0,prompt,interpretation,layer,token,token_decoded,relevancy_score
0,[INST] What's the capital of the state Dallas ...,1. Canberra - This is the capital city of Aust...,4,8,capital,"[0.0498, 0.00293, 0.3752, 0.2617, 0.0, 0.1707,..."
1,[INST] What's the capital of the state Dallas ...,1. Tokyo - Japan\n2. Beijing - China\n3. Mosco...,8,8,capital,"[0.03662, 0.0004883, 0.2246, 0.6113, 0.327, 0...."
2,[INST] What's the capital of the state Dallas ...,1. Tokyo\n2. Beijing\n3. Moscow\n4. London\n5.,16,8,capital,"[0.4197, 0.001953, 0.2192, 0.1929, 0.01367, 0...."
3,[INST] What's the capital of the state Dallas ...,"1. Tokyo, Japan - Tokyo is the capital and lar...",24,8,capital,"[0.1748, 0.003906, 0.685, 0.0801, 0.0, 0.3462,..."
4,[INST] What's the capital of the state Dallas ...,"1. Tokyo, Japan - Tokyo is the capital and lar...",28,8,capital,"[0.3154, 0.000977, 0.775, 0.2319, 0.0004883, 0..."
5,[INST] What's the capital of the state Dallas ...,1. Brazil - Brazil's capital is Brasília.\n2. ...,4,9,of,"[0.03125, 0.002441, 0.05847, 0.5156, 0.4663, 0..."
6,[INST] What's the capital of the state Dallas ...,"1. Brazil - Brazil is a country, not a city, s...",8,9,of,"[0.1406, 0.0, 0.4045, 0.6855, 0.586, 0.08154, ..."
7,[INST] What's the capital of the state Dallas ...,1. Brazil\n2. Russia\n3. China\n4. Japan\n\nTh...,16,9,of,"[0.1875, 0.0, 0.267, 0.0537, 0.007812, 0.0, 0...."
8,[INST] What's the capital of the state Dallas ...,1. Brazil\n2. Canada\n3. China\n4. France\n5. ...,24,9,of,"[0.1938, 0.0004883, 0.197, 0.05566, 0.009766, ..."
9,[INST] What's the capital of the state Dallas ...,1. The United States\n2. The United Kingdom\n3...,28,9,of,"[0.1543, 0.0004883, 0.06076, 0.1417, 0.04688, ..."
