In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [52]:
tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoderbase-1b", use_fast=False)

In [3]:
model = AutoModelForCausalLM.from_pretrained("bigcode/starcoderbase-1b").cuda()

Downloading (…)lve/main/config.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.05k/1.05k [00:00<00:00, 10.8MB/s]
Downloading model.safetensors: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.55G/4.55G [06:43<00:00, 11.3MB/s]
Downloading (…)neration_config.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 111/111 [00:00<00:00, 491kB/s]


In [422]:
PROMPT = """
-- Task
-- We are given two strings s and c, you have to deleted all the characters in s 
-- that are equal to any character in c
-- then check if the result string is palindrome.
-- A string is called palindrome if it reads the same backward as forward.
-- You should return a tuple containing the result string and True/False for the check.
--
-- Example
-- For s = "abcde", c = "ae", the result should be ('bcd',False)
-- For s = "abcdef", c = "b" the result should be ('acdef',False)
-- For s = "abcdedcba", c = "ab", the result should be ('cdedc',True)
local function reverse_delete(s, c)
"""

In [423]:
def find_end_tok_i(tokenizer, enc, stop_seqs=["\nend", "\n--"]):
    def stop_in_enc(enc):
        dec = tokenizer.decode(enc)
        for stop in stop_seqs:
            if stop in dec:
                return True

        return False
        
    i = 0
    
    while i < len(enc) - 1 and not stop_in_enc(enc[:i]):
        i += 1

    return i
    
    
toks = tokenizer.encode(PROMPT, return_tensors="pt").cuda()

In [424]:
out = model.generate(toks, do_sample=True, max_new_tokens=150, temperature=0.2, top_p=0.95)
end_tok = find_end_tok_i(tokenizer, out[0][len(toks[0])-1:]) + len(toks[0])
out = out[0][:end_tok-1]

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


In [425]:
print(tokenizer.decode(out))


-- Task
-- We are given two strings s and c, you have to deleted all the characters in s 
-- that are equal to any character in c
-- then check if the result string is palindrome.
-- A string is called palindrome if it reads the same backward as forward.
-- You should return a tuple containing the result string and True/False for the check.
--
-- Example
-- For s = "abcde", c = "ae", the result should be ('bcd',False)
-- For s = "abcdef", c = "b" the result should be ('acdef',False)
-- For s = "abcdedcba", c = "ab", the result should be ('cdedc',True)
local function reverse_delete(s, c)
	local result = ""
	for i = 1, #s do
		if s:sub(i, i) ~= c:sub(1, 1) then
			result = result.. s:sub(i, i)
		end
	end
	return result, true
end


In [426]:
enc = model(out, output_attentions=True)
attns = enc["attentions"]
# quite a deep tensor...
layer_i = 0
batch_i = 0 # we only have one prompt
attn_head_i = 0
print(end_tok)
# attns[layer_i][batch_i][attn_head_i][end_tok]

# get last layer attns
last_layer_attns = attns[-1][batch_i]
print(last_layer_attns.size())
last_layer_attns_head_mean = last_layer_attns.mean(dim=0)

238
torch.Size([16, 237, 237])


In [427]:
len(enc[0])

237

In [428]:
ar = torch.flip(torch.arange(1, len(enc[0]) + 1), [0]).cuda()
summed = last_layer_attns_head_mean.sum(0)
mean_pooled = summed / ar
mean_pooled[-1]

tensor(0.0254, device='cuda:0', grad_fn=<SelectBackward0>)

In [429]:
len(out)

237

In [430]:
import pandas as pd
pd.DataFrame(last_layer_attns_head_mean[-1].detach().cpu().numpy()).describe()[0]

count    237.000000
mean       0.004219
std        0.038850
min        0.000117
25%        0.000318
50%        0.000708
75%        0.001851
max        0.598300
Name: 0, dtype: float64

In [431]:
import pandas as pd
from termcolor import colored

FORE="black"
def color_tok(attn, tok, distr) -> str:
    if attn < distr["25%"]:
        return colored(tok, color=FORE, on_color='on_light_grey')
    elif attn < distr["50%"]:
        return colored(tok, color=FORE, on_color='on_green')
    elif attn < distr["75%"]:
        return colored(tok, color=FORE, on_color='on_yellow')
    else:
        return colored(tok, color=FORE, on_color='on_red')
    

def visualize_attn(tokenizer, out, meaned_attns):
    distr = pd.DataFrame(meaned_attns.detach().cpu().numpy()).describe()[0]
    for i, tok in enumerate(out):
        colored = color_tok(meaned_attns[i], tokenizer.decode(tok), distr)
        print(colored, end="")

visualize_attn(tokenizer, out, mean_pooled)

[41m[30m
[0m[41m[30m--[0m[43m[30m Task[0m[41m[30m
[0m[43m[30m--[0m[41m[30m We[0m[43m[30m are[0m[43m[30m given[0m[43m[30m two[0m[41m[30m strings[0m[43m[30m s[0m[42m[30m and[0m[42m[30m c[0m[41m[30m,[0m[41m[30m you[0m[42m[30m have[0m[43m[30m to[0m[43m[30m deleted[0m[43m[30m all[0m[43m[30m the[0m[43m[30m characters[0m[42m[30m in[0m[43m[30m s[0m[42m[30m [0m[43m[30m
[0m[47m[30m--[0m[43m[30m that[0m[43m[30m are[0m[42m[30m equal[0m[47m[30m to[0m[47m[30m any[0m[47m[30m character[0m[42m[30m in[0m[42m[30m c[0m[41m[30m
[0m[41m[30m--[0m[42m[30m then[0m[42m[30m check[0m[47m[30m if[0m[42m[30m the[0m[42m[30m result[0m[42m[30m string[0m[42m[30m is[0m[41m[30m pal[0m[42m[30mindrome[0m[41m[30m.[0m[41m[30m
[0m[43m[30m--[0m[43m[30m A[0m[47m[30m string[0m[43m[30m is[0m[42m[30m called[0m[43m[30m pal[0m[42m[30mindrome[0m[42m[30m if[0m[47m[30m it

In [329]:
tokenizer.decode(out[0])

'\n'

In [332]:
print("\e[48;5;4m%03d")

\e[48;5;4m%03d


In [432]:
# full mean pool fn
def mean_pool_attn_from_toks(toks):
    assert len(toks.size()) == 1, "mean pooling batched toks is currently not supported"
    enc = model(toks, output_attentions=True)
    attns = enc["attentions"]
    # quite a deep tensor...
    layer_i = 0
    batch_i = 0 # we only have one prompt
    attn_head_i = 0
    # attns[layer_i][batch_i][attn_head_i][tok]

    # get last layer attns
    last_layer_attns = attns[-1][batch_i]
    last_layer_attns_head_mean = last_layer_attns.mean(dim=0)

    ar = torch.flip(torch.arange(1, len(enc[0]) + 1), [0]).cuda()
    summed = last_layer_attns_head_mean.sum(0)
    mean_pooled = summed / ar
    return mean_pooled

def generate_with_stop(model, tokenizer, prompt):
    toks = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
    out = model.generate(toks, do_sample=True, max_new_tokens=150, temperature=0.2, top_p=0.95)
    end_tok = find_end_tok_i(tokenizer, out[0][len(toks[0])-1:]) + len(toks[0])
    out = out[0][:end_tok-1]
    return out
    
generate_with_stop(model, tokenizer, PROMPT)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


tensor([  203,   287,  4187,   203,   287,  2688,   884,  2702,  3134,  5852,
          309,   461,   281,    30,   844,  1159,   372,  8128,  1169,   322,
         7125,   328,   309,   225,   203,   287,   688,   884,  5040,   372,
         1346,  5341,   328,   281,   203,   287,  1615,  1505,   415,   322,
         1056,   802,   438, 18375, 39379,    32,   203,   287,   399,   802,
          438,  3823, 18375, 39379,   415,   561, 14822,   322,  2432, 21851,
          619,  7386,    32,   203,   287,  2448,  1395,   442,   312,  8825,
         6621,   322,  1056,   802,   461,  2933,    33,  2700,   436,   322,
         1505,    32,   203,   287,   203,   287,  5938,   203,   287,  2616,
          309,   280,   313,  8183,   286,   392,   281,   280,   313,  3633,
          392,   322,  1056,  1395,   526,  2726, 23550,   370,  2700,    27,
          203,   287,  2616,   309,   280,   313, 25870,   392,   281,   280,
          313,    84,    20,   322,  1056,  1395,   526,  2726, 