In [1]:
import torch
import transformers
import datasets
import huggingface_hub

In [2]:
# cache_dir = '/Users/christopher/Documents/unirepsCache'
cache_dir = '/net/scratch2/chriswolfram/hf_cache'

In [3]:
huggingface_hub.login(new_session=False)

In [4]:
# model_name = 'meta-llama/Llama-3.2-1B'
model_name = 'google/gemma-2-27b'

In [5]:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, device_map='auto', cache_dir=cache_dir)
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype='auto', device_map='auto', cache_dir=cache_dir)

Loading checkpoint shards:   0%|          | 0/24 [00:00<?, ?it/s]

In [8]:
# Add padding token if needed
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [9]:
# TODO: Use split slicing and maybe shuffling+streaming to extract samples from large datasets
dataset = datasets.load_dataset('stanfordnlp/imdb', cache_dir=cache_dir)
dataset = dataset['test'].take(1024)

In [10]:
model.get_output_embeddings().weight.detach()

tensor([[ 0.0055, -0.0071,  0.0009,  ...,  0.0044,  0.0036,  0.0037],
        [-0.0318,  0.0242,  0.0067,  ...,  0.0007,  0.0008, -0.0130],
        [ 0.0048, -0.0054, -0.0048,  ...,  0.0069, -0.0050,  0.0040],
        ...,
        [-0.0077,  0.0068, -0.0240,  ...,  0.0234,  0.0228, -0.0018],
        [ 0.0109,  0.0115, -0.0016,  ..., -0.0034, -0.0011,  0.0138],
        [ 0.0058, -0.0071,  0.0007,  ...,  0.0043,  0.0035,  0.0038]],
       device='cuda:0')

In [None]:
def compute_embeddings(examples):
    tokens = tokenizer(examples['text'], padding='longest', return_tensors='pt')
    input_ids = tokens['input_ids'].to(model.device)
    
    with torch.no_grad():
        model_out = model(input_ids=input_ids, output_hidden_states=True, use_cache=False)
    
    layer_token_embeddings = torch.stack(model_out.hidden_states)[:,0].float()

    layer_last_embeddings = layer_token_embeddings[:,-1].cpu()
    layer_mean_embeddings = layer_token_embeddings.mean(1).cpu()

    return {'layer_last_embeddings': layer_last_embeddings, 'layer_mean_embeddings': layer_mean_embeddings}

# TODO: This currently sets new_fingerprint because otherwise `map` appears to hash compute_embeddings which includes the entire model!
embeddings = dataset.take(128).map(compute_embeddings, new_fingerprint='test_fingerprint', load_from_cache_file=False)
embeddings.set_format('torch')
# torch.cuda.empty_cache()



Map:   0%|          | 0/128 [00:00<?, ? examples/s]

In [37]:
embeddings[-1]['layer_last_embeddings']

tensor([[ 4.3045e-01,  5.2381e-01,  8.2463e-01,  ..., -1.1306e-01,
          3.3893e-02, -2.3120e-02],
        [-1.0142e-01,  2.8770e-01,  3.6419e-01,  ...,  6.0856e-02,
         -2.9304e-02, -8.4377e-02],
        [-3.7873e-02, -3.0839e-02,  3.2535e-01,  ...,  2.4453e-01,
         -9.8545e-02,  5.0027e-02],
        ...,
        [-5.3461e+01, -3.4155e+00, -2.3939e+02,  ...,  2.5938e+02,
          1.3480e+02, -1.9434e+02],
        [-3.0611e+02, -3.3730e+01, -4.7096e+02,  ...,  1.6534e+02,
          1.7137e+02, -4.1980e+02],
        [-9.2746e+00,  6.5281e-02, -1.0728e+01,  ...,  5.8639e+00,
          2.6707e+00, -1.5843e+01]])

In [38]:
embeddings[-1]['layer_mean_embeddings']

tensor([[-2.0482e-01,  4.1662e-01, -1.7881e-02,  ...,  1.4470e-01,
          1.9332e-01, -5.1653e-02],
        [-1.2791e-01,  8.6794e-02,  1.1645e-02,  ...,  2.4013e-02,
          6.9477e-02, -1.4329e-01],
        [ 5.5292e-02, -1.2492e-01, -1.9515e-02,  ...,  7.1906e-03,
          4.3979e-02, -4.8419e-02],
        ...,
        [ 8.2082e+01,  6.4029e+01,  1.7222e+00,  ...,  2.3686e+01,
         -8.0630e+00, -7.2619e+01],
        [ 1.2726e+02,  8.3658e+01, -4.8099e+01,  ..., -3.4301e+00,
         -1.3376e+01, -1.1925e+02],
        [ 2.4463e+00,  2.1555e+00, -8.4142e-01,  ..., -1.2124e+00,
         -1.1269e+00, -6.6522e+00]])

In [18]:
with torch.no_grad():
    model_output = model(input_ids=tokenizer(embeddings['text'][-1], return_tensors='pt').input_ids.to(model.device), output_hidden_states=True, use_cache=False)

In [42]:
torch.stack(model_output.hidden_states)[:,0,-1]

tensor([[ 4.3045e-01,  5.2381e-01,  8.2463e-01,  ..., -1.1306e-01,
          3.3893e-02, -2.3120e-02],
        [-1.0142e-01,  2.8770e-01,  3.6419e-01,  ...,  6.0856e-02,
         -2.9304e-02, -8.4377e-02],
        [-3.7873e-02, -3.0839e-02,  3.2535e-01,  ...,  2.4453e-01,
         -9.8545e-02,  5.0027e-02],
        ...,
        [-5.3461e+01, -3.4155e+00, -2.3939e+02,  ...,  2.5938e+02,
          1.3480e+02, -1.9434e+02],
        [-3.0611e+02, -3.3730e+01, -4.7096e+02,  ...,  1.6534e+02,
          1.7137e+02, -4.1980e+02],
        [-9.2746e+00,  6.5281e-02, -1.0728e+01,  ...,  5.8639e+00,
          2.6707e+00, -1.5843e+01]], device='cuda:0')

In [34]:
torch.stack(model_output.hidden_states)[:,0].float().mean(1)

tensor([[-2.0482e-01,  4.1662e-01, -1.7881e-02,  ...,  1.4470e-01,
          1.9332e-01, -5.1653e-02],
        [-1.2791e-01,  8.6794e-02,  1.1645e-02,  ...,  2.4013e-02,
          6.9477e-02, -1.4329e-01],
        [ 5.5292e-02, -1.2492e-01, -1.9515e-02,  ...,  7.1906e-03,
          4.3979e-02, -4.8419e-02],
        ...,
        [ 8.2082e+01,  6.4029e+01,  1.7222e+00,  ...,  2.3686e+01,
         -8.0630e+00, -7.2619e+01],
        [ 1.2726e+02,  8.3658e+01, -4.8099e+01,  ..., -3.4301e+00,
         -1.3376e+01, -1.1925e+02],
        [ 2.4463e+00,  2.1555e+00, -8.4142e-01,  ..., -1.2124e+00,
         -1.1269e+00, -6.6522e+00]], device='cuda:0')

In [12]:
torch.stack(model_output.hidden_states).mean(2)[-1,0]

tensor([-1.5535,  1.9119,  0.9343,  ...,  1.7148, -1.6708,  1.1468],
       device='cuda:0')

In [98]:
(torch.stack(model_output.hidden_states).mean(2)[-1,0].cpu() - train_embeddings3['layer_mean_embeddings'][-1,-1]).abs().mean()

tensor(0.)

In [97]:
(torch.stack(model_output.hidden_states).mean(2)[-1,0].cpu() - train_embeddings2['layer_mean_embeddings'][-1,-1]).abs().mean()

tensor(0.)

In [75]:
(torch.stack(model_output.hidden_states).float().mean(2)[-1,0].cpu() - train_embeddings2['layer_mean_embeddings'][-1,-1]).abs().mean()

tensor(0.0019)

In [99]:
(torch.stack(model_output.hidden_states).float().mean(2)[-1,0].cpu() - train_embeddings3['layer_mean_embeddings'][-1,-1]).abs().mean()

tensor(0.0019)

In [141]:
input_layer_token_embeddings[0,-1,0]

tensor([ 1.5547, -1.5234,  2.7812,  ..., -0.6445,  0.6094, -1.0469],
       device='cuda:0', dtype=torch.bfloat16)

In [142]:
torch.stack(model_output.hidden_states).shape

torch.Size([17, 1, 98, 2048])

In [140]:
input_layer_token_embeddings = torch.stack(model_output.hidden_states).permute(1,0,2,3)

In [121]:
train_embeddings['layer_mean_embeddings'][-1] - torch.stack(model_output.hidden_states)[:,0].mean(1).cpu()

tensor([[ 0.0000e+00,  0.0000e+00, -6.1035e-05,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  1.5259e-05,
          0.0000e+00,  0.0000e+00],
        [ 3.0518e-05,  1.5259e-05,  2.4414e-04,  ...,  0.0000e+00,
          0.0000e+00,  3.0518e-05],
        ...,
        [-9.7656e-04,  0.0000e+00,  4.8828e-04,  ..., -9.7656e-04,
         -4.8828e-04,  0.0000e+00],
        [ 0.0000e+00,  9.7656e-04, -1.2207e-04,  ...,  1.9531e-03,
         -1.2207e-03,  8.5449e-04],
        [-3.9062e-03,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  3.9062e-03]])

In [98]:
torch.stack(model_output.hidden_states).shape

torch.Size([17, 1, 161, 2048])

In [99]:
torch.stack(model_output.hidden_states)[:,0,-1]

tensor([[ 1.1414e-02, -2.2278e-03, -1.1292e-03,  ..., -1.7944e-02,
         -2.6855e-03, -2.1729e-02],
        [ 2.0142e-03,  4.8828e-03, -1.0400e-01,  ..., -2.1729e-02,
         -2.8809e-02, -3.7354e-02],
        [ 7.7515e-03, -5.6152e-02, -1.5430e-01,  ...,  5.4932e-02,
         -3.9551e-02, -6.4941e-02],
        ...,
        [-9.5703e-02,  3.2227e-01, -5.9766e-01,  ..., -7.6562e-01,
         -7.1289e-02, -5.2246e-02],
        [ 1.9141e-01,  4.3359e-01, -5.6641e-01,  ..., -6.3281e-01,
         -9.3262e-02, -6.0547e-02],
        [ 2.5312e+00,  3.9062e+00,  2.3730e-01,  ..., -4.7188e+00,
         -4.8438e+00, -1.0781e+00]], device='cuda:0', dtype=torch.bfloat16)

In [104]:
torch.stack(model_output.hidden_states)[:,0].mean(1)

tensor([[-9.9182e-04,  3.4027e-03,  1.2512e-02,  ..., -8.8501e-04,
         -6.8970e-03, -5.7068e-03],
        [-1.2085e-02,  8.4229e-03, -1.9653e-02,  ...,  9.2773e-03,
         -2.4048e-02,  4.2419e-03],
        [ 1.1292e-03,  3.9673e-03, -2.4658e-02,  ...,  1.7700e-02,
         -3.5400e-02, -3.6926e-03],
        ...,
        [ 3.4668e-02,  1.6309e-01, -1.3184e-01,  ..., -9.6680e-02,
         -1.6724e-02,  7.3730e-02],
        [ 9.0332e-02,  2.1484e-01, -7.2754e-02,  ..., -7.2754e-02,
         -1.3281e-01, -2.6611e-02],
        [ 1.3125e+00,  3.5312e+00,  6.1719e-01,  ..., -2.7969e+00,
         -3.2812e+00, -1.4648e-01]], device='cuda:0', dtype=torch.bfloat16)

In [8]:
def get_model_out(examples):
    tokens = tokenizer(examples['text'], padding='longest', return_tensors='pt')
    input_ids = tokens['input_ids'].to(model.device)
    attention_mask = tokens['attention_mask'].to(model.device)
    
    with torch.no_grad():
        model_out = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
    
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'model_out': model_out}

In [9]:
out = get_model_out(dataset.take(10))
torch.cuda.empty_cache()

In [11]:
input_ids = out['input_ids'].cpu()
attention_mask = out['attention_mask'].cpu()
model_out = out['model_out']

In [12]:
last_token_indices = attention_mask.sum(-1) - 1

In [23]:
torch.stack(model_out.hidden_states).cpu().permute(1,0,2,3).shape

torch.Size([10, 17, 452, 2048])

In [59]:
input_layer_token_embeddings = torch.stack(model_out.hidden_states).permute(1,0,2,3).cpu()

In [60]:
layer_last_embedding = input_layer_token_embeddings[torch.arange(input_layer_token_embeddings.shape[0]), :, last_token_indices]

In [42]:
input_layer_token_embeddings.shape

torch.Size([10, 17, 452, 2048])

In [50]:
attention_mask.shape

torch.Size([10, 452])

In [54]:
attention_mask.unsqueeze(0).shape

torch.Size([1, 10, 452])

In [73]:
attention_mask.unsqueeze(-1).unsqueeze(1).shape

torch.Size([10, 1, 452, 1])

In [80]:
attention_mask.sum(1)

tensor([315, 279, 146, 452, 160, 218, 361, 208, 177, 216])

In [83]:
((input_layer_token_embeddings * attention_mask.unsqueeze(-1).unsqueeze(1)).sum(2) / attention_mask.sum(1).unsqueeze(-1).unsqueeze(-1)).shape

torch.Size([10, 17, 2048])

In [84]:
layer_mean_embeddings = (input_layer_token_embeddings * attention_mask.unsqueeze(-1).unsqueeze(1)).sum(2) / attention_mask.sum(1).unsqueeze(-1).unsqueeze(-1)

In [85]:
layer_mean_embeddings

tensor([[[-1.9073e-03,  3.4180e-03,  1.1902e-02,  ..., -5.9814e-03,
          -5.7068e-03, -4.3640e-03],
         [-1.2329e-02,  1.1475e-02, -1.5625e-02,  ...,  4.3640e-03,
          -1.2085e-02,  1.6098e-03],
         [ 1.9531e-03,  6.3782e-03, -1.8555e-02,  ...,  1.0010e-02,
          -1.8433e-02, -5.5847e-03],
         ...,
         [-1.3086e-01, -3.5156e-02, -4.4861e-03,  ..., -1.4258e-01,
           3.2959e-02,  9.8633e-02],
         [-1.3965e-01,  3.3691e-02,  1.0840e-01,  ..., -1.3867e-01,
           2.4261e-03,  1.0107e-01],
         [-4.1406e-01,  2.4688e+00,  1.1172e+00,  ..., -2.7031e+00,
          -2.7656e+00,  7.1094e-01]],

        [[-2.7924e-03,  2.7618e-03,  1.0254e-02,  ..., -3.5400e-03,
          -4.1809e-03, -3.6621e-03],
         [-1.1230e-02,  1.0254e-02, -2.5635e-02,  ...,  3.9062e-03,
          -2.1729e-02,  2.5177e-03],
         [ 4.8065e-04,  8.8501e-03, -3.1006e-02,  ...,  1.0986e-02,
          -4.1504e-02, -8.0566e-03],
         ...,
         [ 4.1992e-02,  9

In [None]:
(input_layer_token_embeddings.permute(1,0,2,3) * attention_mask.unsqueeze(-1)).mean().shape

torch.Size([17, 10, 452, 2048])

In [127]:
layer_mean_embeddings = ((torch.stack(model_out.hidden_states)[:, torch.arange(input_ids.size(0))] * attention_mask.unsqueeze(-1)).sum(2) / attention_mask.sum(-1).unsqueeze(-1)).permute(1,0,2)

torch.Size([10, 17, 2048])