In [1]:
import torch
import transformers
import datasets
import huggingface_hub
import safetensors

In [2]:
cache_dir = '/Users/christopher/Documents/unirepsCache'

In [3]:
model_name = 'meta-llama/Llama-3.2-1B'

In [4]:
huggingface_hub.login(new_session=False)

In [5]:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, device_map='auto', cache_dir=cache_dir)
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype='auto', device_map='auto', cache_dir=cache_dir)

In [6]:
# Add padding token if needed
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [13]:
dataset = datasets.load_dataset('stanfordnlp/imdb')
dataset = dataset['test'].take(4096)

In [8]:
model.get_output_embeddings().weight.detach()

tensor([[ 0.0045,  0.0166,  0.0210,  ..., -0.0054, -0.0422, -0.0315],
        [ 0.0215, -0.0238,  0.0211,  ..., -0.0107, -0.0011, -0.0374],
        [ 0.0136,  0.0104,  0.0128,  ...,  0.0081, -0.0122,  0.0051],
        ...,
        [ 0.0009,  0.0164, -0.0193,  ..., -0.0003, -0.0030,  0.0066],
        [ 0.0009,  0.0164, -0.0193,  ..., -0.0003, -0.0030,  0.0066],
        [ 0.0009,  0.0164, -0.0193,  ..., -0.0003, -0.0030,  0.0066]],
       device='mps:0', dtype=torch.bfloat16)

In [9]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm):

In [26]:
def tokenize_function(examples):
    return tokenizer(examples['text'], padding='longest', return_tensors='pt')

tokenized_dataset = dataset.map(tokenize_function, batched=True)

# tokenized_dataset = tokenized_dataset.remove_columns(['text'])
tokenized_dataset.set_format('torch')

Map:   0%|          | 0/4096 [00:00<?, ? examples/s]

In [45]:
def compute_embeddings(batch):
    input_ids = batch['input_ids'].to(model.device)
    attention_mask = batch['attention_mask'].to(model.device)
    
    with torch.no_grad():
        model_out = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
    
    last_token_indices = attention_mask.sum(dim=-1) - 1
    embeddings = model_out.hidden_states[-1][torch.arange(input_ids.size(0)), last_token_indices]
    
    return {'embeddings': embeddings}

train_embeddings = tokenized_dataset.map(compute_embeddings, batched=True, batch_size=6)

Map:   0%|          | 0/4096 [00:00<?, ? examples/s]

KeyboardInterrupt: 

In [41]:
with torch.no_grad():
    model_output = model(input_ids=tokenizer('Hi there! This is a test', return_tensors='pt').input_ids.to(model.device), output_hidden_states=True)

In [37]:
model_output.logits.shape

torch.Size([1, 8, 128256])

In [44]:
torch.stack(model_output.hidden_states).shape

torch.Size([17, 1, 8, 2048])