In [1]:
import os
from pathlib import Path

import torch 
import numpy as np

from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, GPTNeoForCausalLM



In [2]:
available_gpus = [torch.cuda.device(i) for i in range(torch.cuda.device_count())]
available_gpus

[<torch.cuda.device at 0x7fc20c38f2b0>,
 <torch.cuda.device at 0x7fc20c331de0>,
 <torch.cuda.device at 0x7fc20c331cf0>,
 <torch.cuda.device at 0x7fc20c331ba0>,
 <torch.cuda.device at 0x7fc20c330130>,
 <torch.cuda.device at 0x7fc20c330820>,
 <torch.cuda.device at 0x7fc20c3303a0>,
 <torch.cuda.device at 0x7fc0bbcba830>]

In [3]:
device = 'cuda:7' if torch.cuda.is_available() else 'cpu'

In [4]:
token = "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"

In [5]:
# model_name_or_path = "EleutherAI/gpt-neo-1.3B"

# model_name_or_path = "princeton-nlp/Sheared-LLaMA-1.3B"

# model_name_or_path = "meta-llama/Llama-2-7b-hf"
# model_name_or_path = "meta-llama/Llama-2-7b-chat-hf"

# model_name_or_path = "mistralai/Mistral-7B-v0.1"
# model_name_or_path = "mistralai/Mistral-7B-Instruct-v0.1"
model_name_or_path = "mistralai/Mistral-7B-Instruct-v0.2"


In [6]:
# change cach dir for models
CACHE_DIR = "/data/pre-trained-models-cache"

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, token=token)



In [8]:
attn_implementation = 'eager'

In [9]:
# lm = AutoModelForCausalLM.from_pretrained(model_name_or_path, token=token, cache_dir=CACHE_DIR, attn_implementation=attn_implementation)

In [10]:
# load LLM2Vec transformed model
# attn_implementation='flash_attention_2'
# model_name_or_path = 'vaibhavad/llama-enc'
model_name_or_path = 'vaibhavad/mistral-enc'
lm = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16, cache_dir=CACHE_DIR, attn_implementation=attn_implementation)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [11]:
if attn_implementation == 'flash_attention_2':
    tokenizer.padding_side  = 'left'

In [12]:
lm.config

MistralConfig {
  "_name_or_path": "vaibhavad/mistral-enc",
  "architectures": [
    "MistralForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 32768,
  "model_type": "mistral",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "rms_norm_eps": 1e-05,
  "rope_theta": 1000000.0,
  "sliding_window": null,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.39.0.dev0",
  "use_cache": true,
  "vocab_size": 32000
}

In [13]:
lm.eval()

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNorm()
  

In [14]:
text = 'Montreal is the second most populous city in Canada, the tenth most populous city in North America, and the most populous city in the province of Quebec. Founded in 1642 as Ville-Marie, or "City of Mary",[15] it is named after Mount Royal,[16] the triple-peaked hill around which the early city of Ville-Marie was built.[17] The city is centred on the Island of Montreal, which obtained its name from the same origin as the city,[18][19] and a few much smaller peripheral islands, the largest of which is Île Bizard. The city is 196 km (122 mi) east of the national capital, Ottawa, and 258 km (160 mi) southwest of the provincial capital, Quebec City.'
print(tokenizer.padding_side)

ids = tokenizer.encode(text, padding="do_not_pad")
tokens = tokenizer.convert_ids_to_tokens(ids)
seq_len = len(tokens)
input_ids = torch.tensor(ids).reshape(1, -1)

print(seq_len)
print(input_ids.shape)
print(input_ids)
print(tokens)

left
188
torch.Size([1, 188])
tensor([[    1, 25645,   349,   272,  1676,  1080,  1852,  9504,  2990,   297,
          6082, 28725,   272,   261,  8016,  1080,  1852,  9504,  2990,   297,
          3964,  4352, 28725,   304,   272,  1080,  1852,  9504,  2990,   297,
           272, 14707,   302, 27798, 28723,  5196,   286,   297, 28705, 28740,
         28784, 28781, 28750,   390,   550,  2457, 28733,  7308,   412, 28725,
           442,   345, 22013,   302,  5480,   548, 28792, 28740, 28782, 28793,
           378,   349,  5160,  1024,  7612,  8413, 28725, 28792, 28740, 28784,
         28793,   272, 22212, 28733,   386,  6343, 12254,  1401,   690,   272,
          2935,  2990,   302,   550,  2457, 28733,  7308,   412,   403,  4429,
         20011, 28740, 28787, 28793,   415,  2990,   349,  1595,   893,   356,
           272,  7633,   302, 25645, 28725,   690,  7365,   871,  1141,   477,
           272,  1348,  5016,   390,   272,  2990, 28725, 28792, 28740, 28783,
          3328, 28740,

In [15]:
offset = 0
# offset = 4096 # TODO(mm): using any offset here results in CUDA errors. Try to figure out why.
position_ids = torch.arange(start=offset, end=seq_len + offset).view(1, seq_len)
position_ids.shape
position_ids

tensor([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
          14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
          28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
          42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
          56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
          70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
          84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
          98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
         112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
         126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
         140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
         154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
         168, 169, 170, 171, 172, 173, 174, 175, 176

In [16]:
attention_type = "causal"
attention_type = "bidirectional"

In [17]:
# enable bidirectional attention
attention_mask = None
if attention_type == "bidirectional":
    # construct attention mask (batch_size, 1, seq_len, seq_len)
    attention_mask = torch.ones(size=(1, 1, seq_len, seq_len)).to(device)

    if model_name_or_path in ["princeton-nlp/Sheared-LLaMA-1.3B", "meta-llama/Llama-2-7b-hf", "meta-llama/Llama-2-7b-chat-hf", 'vaibhavad/llama-enc']:
        lm.model._update_causal_mask = lambda attention_mask, _: attention_mask

    if model_name_or_path == "EleutherAI/gpt-neo-1.3B":
        gpt_neo_max_length = 2048
        bi_mask = torch.ones((1, 1, gpt_neo_max_length, gpt_neo_max_length), dtype=bool)

        # overwrite causal mask at every layer
        for lidx in range(len(lm.transformer.h)):
            lm.transformer.h[lidx].attn.attention.bias = bi_mask

In [18]:
# put inputs and model on GPU
lm.to(device)
input_ids = input_ids.to(device)
position_ids = position_ids.to(device)

In [19]:
print(input_ids.shape)
print(position_ids.shape)
# print(attention_mask.shape)

torch.Size([1, 188])
torch.Size([1, 188])


In [20]:
labels = input_ids
output = lm.forward(input_ids=input_ids, position_ids=position_ids, labels=labels, attention_mask=attention_mask, output_attentions=True, output_hidden_states=True)

In [21]:
# output

----

In [22]:
# look at attention matrices
# A = output.attentions[-1].squeeze()[-1]
A = output.attentions[-1].squeeze()[-1].detach().cpu().float().numpy() 
print(np.triu(A, k=1)) # the future

[[0.0000000e+00 6.2500000e-02 5.0354004e-03 ... 2.1057129e-03
  8.7356567e-04 8.9111328e-03]
 [0.0000000e+00 0.0000000e+00 3.3007812e-01 ... 7.7819824e-04
  7.5683594e-03 3.9367676e-03]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 1.1520386e-03
  1.2874603e-04 1.3198853e-03]
 ...
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00
  4.5312500e-01 4.6081543e-03]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00
  0.0000000e+00 3.6621094e-02]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00
  0.0000000e+00 0.0000000e+00]]


In [23]:
# model_name_or_path = "vaibhavad/mistral-enc"

In [24]:
# save attention matrices to disk
data_path = f"/data/attention_data/{model_name_or_path.split('/')[-1]}/{attention_type}"

# create dir
Path(data_path).mkdir(parents=True, exist_ok=True)    
    
for layer in range(len(output.attentions)):
    A = output.attentions[layer].squeeze().detach().cpu().float().numpy()
    file_name = f"A_layer{layer}.npy"
    with open(os.path.join(data_path, file_name), 'wb') as f:
        np.save(f, A)