In [1]:
import torch
import transformers
import datasets
import huggingface_hub
import safetensors

In [2]:
# cache_dir = '/Users/christopher/Documents/unirepsCache'
cache_dir = '/net/scratch2/chriswolfram/hf_cache'

In [3]:
huggingface_hub.login(new_session=False)

In [4]:
model_name = 'meta-llama/Llama-3.2-1B'

In [5]:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, device_map='auto', cache_dir=cache_dir)
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype='auto', device_map='auto', cache_dir=cache_dir)

In [6]:
# Add padding token if needed
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [7]:
dataset = datasets.load_dataset('stanfordnlp/imdb', cache_dir=cache_dir)
dataset = dataset['test'].take(4096)

In [9]:
model.get_output_embeddings().weight.detach()

tensor([[ 0.0045,  0.0166,  0.0210,  ..., -0.0054, -0.0422, -0.0315],
        [ 0.0215, -0.0238,  0.0211,  ..., -0.0107, -0.0011, -0.0374],
        [ 0.0136,  0.0104,  0.0128,  ...,  0.0081, -0.0122,  0.0051],
        ...,
        [ 0.0009,  0.0164, -0.0193,  ..., -0.0003, -0.0030,  0.0066],
        [ 0.0009,  0.0164, -0.0193,  ..., -0.0003, -0.0030,  0.0066],
        [ 0.0009,  0.0164, -0.0193,  ..., -0.0003, -0.0030,  0.0066]],
       device='cuda:0', dtype=torch.bfloat16)

In [None]:
def compute_embeddings(examples):
    tokens = tokenizer(examples['text'], padding='longest', return_tensors='pt')
    input_ids = tokens['input_ids'].to(model.device)
    attention_mask = tokens['attention_mask'].to(model.device)
    
    with torch.no_grad():
        model_out = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
    
    last_token_indices = attention_mask.sum(dim=-1) - 1
    embeddings = model_out.hidden_states[-1][torch.arange(input_ids.size(0)), last_token_indices]
    
    return {'embeddings': embeddings}

train_embeddings = dataset.map(compute_embeddings, batched=True, batch_size=64)
train_embeddings.set_format('torch')

Map:   0%|          | 0/4096 [00:00<?, ? examples/s]

In [12]:
train_embeddings['embeddings'][-1]

tensor([ 2.5312,  3.8438,  0.1943,  ..., -4.7500, -4.9375, -1.0859])

In [32]:
with torch.no_grad():
    model_output = model(input_ids=tokenizer(train_embeddings['text'][-1], return_tensors='pt').input_ids.to(model.device), output_hidden_states=True)

In [33]:
model_output.logits.shape

torch.Size([1, 161, 128256])

In [45]:
model_output.hidden_states[-1][0][-1]

tensor([ 2.5312,  3.9062,  0.2373,  ..., -4.7188, -4.8438, -1.0781],
       device='cuda:0', dtype=torch.bfloat16)

In [34]:
torch.stack(model_output.hidden_states).shape

torch.Size([17, 1, 161, 2048])