In [1]:
import os
from pathlib import Path

import torch 
import numpy as np

from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, GPTNeoForCausalLM



In [2]:
available_gpus = [torch.cuda.device(i) for i in range(torch.cuda.device_count())]
available_gpus

[<torch.cuda.device at 0x7fd17be2c0d0>,
 <torch.cuda.device at 0x7fd179c7ee30>,
 <torch.cuda.device at 0x7fd1783fd7e0>,
 <torch.cuda.device at 0x7fd1783fdd50>,
 <torch.cuda.device at 0x7fd1783fd390>,
 <torch.cuda.device at 0x7fd1783fddb0>,
 <torch.cuda.device at 0x7fd1783fd570>,
 <torch.cuda.device at 0x7fd022ce1ff0>]

In [3]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [4]:
token = "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"

In [5]:
# model_name_or_path = "princeton-nlp/Sheared-LLaMA-1.3B"

# model_name_or_path = "meta-llama/Llama-2-7b-hf"
# model_name_or_path = "meta-llama/Llama-2-7b-chat-hf"

model_name_or_path = "mistralai/Mistral-7B-v0.1"
# model_name_or_path = "mistralai/Mistral-7B-Instruct-v0.1"
# model_name_or_path = "mistralai/Mistral-7B-Instruct-v0.2"


In [6]:
# change cach dir for models
CACHE_DIR = "/data/pre-trained-models-cache"

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, token=token)
attn_implementation = 'eager'
lm = AutoModelForCausalLM.from_pretrained(model_name_or_path, token=token, cache_dir=CACHE_DIR, attn_implementation=attn_implementation)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [8]:
# load LLM2Vec transformed model
# attn_implementation='flash_attention_2'
# lm = AutoModelForCausalLM.from_pretrained('vaibhavad/mistral-enc', torch_dtype=torch.bfloat16, cache_dir=CACHE_DIR, attn_implementation=attn_implementation)

In [9]:
if attn_implementation == 'flash_attention_2':
    tokenizer.padding_side  = 'left'

In [10]:
lm.config

MistralConfig {
  "_name_or_path": "mistralai/Mistral-7B-v0.1",
  "architectures": [
    "MistralForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 32768,
  "model_type": "mistral",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "rms_norm_eps": 1e-05,
  "rope_theta": 10000.0,
  "sliding_window": 4096,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.39.0.dev0",
  "use_cache": true,
  "vocab_size": 32000
}

In [11]:
lm.eval()

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNorm()
  

In [12]:
# output_prefix, text = 'Montreal', "Montreal is the second most populous city in Canada, the tenth most populous city in North America, and the most populous city in the province of Quebec. Founded in 1642 as Ville-Marie, or 'City of Mary', it is named after Mount Royal, the triple-peaked hill around which the early city of Ville-Marie was built. The city is centred on the Island of Montreal, which obtained its name from the same origin as the city, and a few much smaller peripheral islands, the largest of which is Île Bizard. The city is 196 km (122 mi) east of the national capital, Ottawa, and 258 km (160 mi) southwest of the provincial capital, Quebec City."
# output_prefix, text = 'Philadelphia', "Philadelphia, commonly referred to as Philly, is the most populous city in the U.S. state of Pennsylvania and the second-most populous city in the Northeast megalopolis and Mid-Atlantic regions after New York City. Philadelphia is known for its extensive contributions to United States history, especially the American Revolution, and served as the nation's capital until 1800. It maintains contemporary influence in business and industry, culture, sports, and music. Philadelphia is the nation's sixth-most populous city, with a population of 1,603,797 in the 2020 census and is the urban core of the larger Delaware Valley (or Philadelphia metropolitan area), the nation's seventh-largest and one of the world's largest metropolitan regions consisting of 6.245 million residents in the metropolitan statistical area and 7.366 million residents in its combined statistical area."
output_prefix, text = "Baltimore", "Baltimore is the most populous city in the U.S. state of Maryland. With a population of 585,708 at the 2020 census, it is the 30th-most populous city in the United States. Baltimore was designated an independent city by the Constitution of Maryland in 1851, and is currently the most populous independent city in the nation. As of the 2020 census, the population of the Baltimore metropolitan area was estimated to be 2,838,327, making it the 20th-largest metropolitan area in the country. When combined with the larger Washington metropolitan area, the Washington–Baltimore combined statistical area (CSA) has a 2020 U.S. census population of 9,973,383, the third-largest in the country."
# output_prefix, text = "", ""

In [13]:
ids = tokenizer.encode(text, padding="do_not_pad")
tokens = tokenizer.convert_ids_to_tokens(ids)
seq_len = len(tokens)
input_ids = torch.tensor(ids).reshape(1, -1)
print(tokenizer.padding_side)
print(seq_len)
print(input_ids.shape)
print(input_ids)
print(tokens)

left
194
torch.Size([1, 194])
tensor([[    1, 23349,   349,   272,  1080,  1852,  9504,  2990,   297,   272,
           500, 28723, 28735, 28723,  1665,   302, 20261, 28723,  2326,   264,
          4889,   302, 28705, 28782, 28783, 28782, 28725, 28787, 28734, 28783,
           438,   272, 28705, 28750, 28734, 28750, 28734, 21254, 28725,   378,
           349,   272, 28705, 28770, 28734,   362, 28733,  2284,  1852,  9504,
          2990,   297,   272,  2969,  3543, 28723, 23349,   403, 20444,   396,
          7126,  2990,   486,   272, 18620,   302, 20261,   297, 28705, 28740,
         28783, 28782, 28740, 28725,   304,   349,  5489,   272,  1080,  1852,
          9504,  7126,  2990,   297,   272,  5878, 28723,  1136,   302,   272,
         28705, 28750, 28734, 28750, 28734, 21254, 28725,   272,  4889,   302,
           272, 23349,  1424, 22159,  2698,   403, 11909,   298,   347, 28705,
         28750, 28725, 28783, 28770, 28783, 28725, 28770, 28750, 28787, 28725,
          2492,   378,

In [14]:
offset = 0
# offset = 4096 # TODO(mm): using any offset here results in CUDA errors. Try to figure out why.
position_ids = torch.arange(start=offset, end=seq_len + offset).view(1, seq_len)
position_ids.shape
position_ids

tensor([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
          14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
          28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
          42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
          56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
          70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
          84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
          98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
         112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
         126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
         140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
         154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
         168, 169, 170, 171, 172, 173, 174, 175, 176

In [15]:
attention_type = "causal"
# attention_type = "bidirectional"

In [16]:
# enable bidirectional attention
attention_mask = None
if attention_type == "bidirectional":
    # construct attention mask (batch_size, 1, seq_len, seq_len)
    attention_mask = torch.ones(size=(1, 1, seq_len, seq_len)).to(device)

    if model_name_or_path in ["princeton-nlp/Sheared-LLaMA-1.3B", "meta-llama/Llama-2-7b-hf", "meta-llama/Llama-2-7b-chat-hf"]:
        lm.model._update_causal_mask = lambda attention_mask, _: attention_mask

    if model_name_or_path == "EleutherAI/gpt-neo-1.3B":
        gpt_neo_max_length = 2048
        bi_mask = torch.ones((1, 1, gpt_neo_max_length, gpt_neo_max_length), dtype=bool)

        # overwrite causal mask at every layer
        for lidx in range(len(lm.transformer.h)):
            lm.transformer.h[lidx].attn.attention.bias = bi_mask

In [17]:
# put inputs and model on GPU
lm.to(device)
input_ids = input_ids.to(device)
position_ids = position_ids.to(device)

In [18]:
print(input_ids.shape)
print(position_ids.shape)
# print(attention_mask.shape)

torch.Size([1, 194])
torch.Size([1, 194])


In [19]:
labels = input_ids
output = lm.forward(input_ids=input_ids, position_ids=position_ids, labels=labels, attention_mask=attention_mask, output_attentions=True, output_hidden_states=True)

----

In [20]:
# look at attention matrices
# A = output.attentions[-1].squeeze()[-1]
A = output.attentions[-1].squeeze()[-1].detach().cpu().float().numpy() 
print(np.triu(A, k=1)) # the future

[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]


In [21]:
output.loss

tensor(0.7626, device='cuda:0', grad_fn=<NllLossBackward0>)

In [22]:
logits = output.logits.squeeze()
print(logits.shape)

torch.Size([194, 32000])


In [23]:
preds = torch.argmax(logits, dim=1)
print(preds.shape)

torch.Size([194])


In [24]:
preds_tokens = tokenizer.convert_ids_to_tokens(preds)

In [25]:
tokens

['<s>',
 '▁Baltimore',
 '▁is',
 '▁the',
 '▁most',
 '▁pop',
 'ulous',
 '▁city',
 '▁in',
 '▁the',
 '▁U',
 '.',
 'S',
 '.',
 '▁state',
 '▁of',
 '▁Maryland',
 '.',
 '▁With',
 '▁a',
 '▁population',
 '▁of',
 '▁',
 '5',
 '8',
 '5',
 ',',
 '7',
 '0',
 '8',
 '▁at',
 '▁the',
 '▁',
 '2',
 '0',
 '2',
 '0',
 '▁census',
 ',',
 '▁it',
 '▁is',
 '▁the',
 '▁',
 '3',
 '0',
 'th',
 '-',
 'most',
 '▁pop',
 'ulous',
 '▁city',
 '▁in',
 '▁the',
 '▁United',
 '▁States',
 '.',
 '▁Baltimore',
 '▁was',
 '▁designated',
 '▁an',
 '▁independent',
 '▁city',
 '▁by',
 '▁the',
 '▁Constitution',
 '▁of',
 '▁Maryland',
 '▁in',
 '▁',
 '1',
 '8',
 '5',
 '1',
 ',',
 '▁and',
 '▁is',
 '▁currently',
 '▁the',
 '▁most',
 '▁pop',
 'ulous',
 '▁independent',
 '▁city',
 '▁in',
 '▁the',
 '▁nation',
 '.',
 '▁As',
 '▁of',
 '▁the',
 '▁',
 '2',
 '0',
 '2',
 '0',
 '▁census',
 ',',
 '▁the',
 '▁population',
 '▁of',
 '▁the',
 '▁Baltimore',
 '▁met',
 'ropolitan',
 '▁area',
 '▁was',
 '▁estimated',
 '▁to',
 '▁be',
 '▁',
 '2',
 ',',
 '8',
 '3',
 '8'

In [26]:
preds_tokens

['▁#',
 ',',
 '▁a',
 '▁largest',
 '▁pop',
 'ulous',
 '▁city',
 '▁in',
 '▁the',
 '▁U',
 '.',
 'S',
 '.',
 '▁state',
 '▁of',
 '▁Maryland',
 ',',
 '▁It',
 '▁a',
 '▁population',
 '▁of',
 '▁',
 '6',
 '9',
 '5',
 ',',
 '7',
 '0',
 '8',
 '▁in',
 '▁the',
 '▁',
 '2',
 '0',
 '1',
 '0',
 '▁census',
 ',',
 '▁it',
 '▁is',
 '▁the',
 '▁largest',
 '3',
 '0',
 'th',
 '▁most',
 'most',
 '▁pop',
 'ulous',
 '▁city',
 '▁in',
 '▁the',
 '▁United',
 '▁States',
 '▁and',
 '▁Baltimore',
 '▁is',
 '▁established',
 '▁an',
 '▁independent',
 '▁city',
 '▁by',
 '▁the',
 '▁Constitution',
 '▁of',
 '▁Maryland',
 '▁in',
 '▁',
 '1',
 '8',
 '5',
 '1',
 ',',
 '▁and',
 '▁today',
 '▁the',
 '▁the',
 '▁largest',
 '▁pop',
 'ulous',
 '▁independent',
 '▁city',
 '▁in',
 '▁the',
 '▁United',
 '.',
 '▁As',
 '▁of',
 '▁',
 '▁',
 '2',
 '0',
 '2',
 '0',
 '▁census',
 ',',
 '▁the',
 '▁population',
 '▁of',
 '▁the',
 '▁Baltimore',
 '▁met',
 'ropolitan',
 '▁area',
 '▁was',
 '▁',
 '▁to',
 '▁be',
 '▁',
 '2',
 ',',
 '8',
 '3',
 '8',
 ',',
 '3',
 '2

In [27]:
len(output.hidden_states)

33

In [28]:
output.hidden_states[-1].shape

torch.Size([1, 194, 4096])

In [29]:
output.hidden_states[-1].squeeze()[10].shape

torch.Size([4096])

In [30]:
# save hidden states to disk

# data_path = f"/data/hidden_states_data/{model_name_or_path.split('/')[-1]}/offset{offset}/{attention_type}"
data_path = f"/data/hidden_states_data/{output_prefix}/{model_name_or_path.split('/')[-1]}/{attention_type}"

# create dir
Path(data_path).mkdir(parents=True, exist_ok=True)    
    
for layer in range(len(output.hidden_states)):
    A = output.hidden_states[layer].detach().cpu().numpy()
    file_name = f"H_layer{layer}.npy"
    with open(os.path.join(data_path, file_name), 'wb') as f:
        np.save(f, A)