In [5]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained(
    "AstroMLab/AstroSage-8b",
    device_map="auto",
    output_hidden_states=True  
)
tokenizer = AutoTokenizer.from_pretrained("AstroMLab/AstroSage-8b")
tokenizer.pad_token = tokenizer.eos_token  # Use EOS token for padding
tokenizer.padding_side = "left"  # Optimal for causal models


def get_embeddings(text, pool_method="mean"):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(model.device)
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Corrected access pattern for causal models
    last_layer_hidden_states = outputs.hidden_states[-1]  # Shape: [batch_size, seq_len, hidden_size]
    
    if pool_method == "mean":
        return last_layer_hidden_states.mean(dim=1).squeeze()
    elif pool_method == "last":
        return last_layer_hidden_states[:, -1, :].squeeze()
    else:
        raise ValueError("Supported methods: 'mean' or 'last'")




Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [6]:

text = "The Milky Way is a barred spiral galaxy."
embedding = get_embeddings([text], pool_method="mean")
print("Embedding shape:", embedding.shape)

Embedding shape: torch.Size([4096])


In [7]:
import numpy as np

np.linalg.norm(embedding[0].cpu().numpy())

0.08212645