In [2]:
from abc import ABC, abstractmethod
from enum import Enum
from typing import Callable, List, Dict, Any, Optional, Tuple, Union

import torch

from transformers import AutoModelForCausalLM, AutoTokenizer
from oat_evaluation.llms.llm import LLM, ExposedActivationsRequest, LLMResponses, TokenSelectionMethod
from contextlib import contextmanager

model_path = "/workspace/gemma_2_9b_instruct"
model = AutoModelForCausalLM.from_pretrained(
    # float16 was the default for obfuscated-activations too
    model_path, device_map="cuda", torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(model_path)



  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 4/4 [01:33<00:00, 23.31s/it]


In [3]:
model_embedding_layer = model.get_input_embeddings()


In [4]:
# Now.... Let's try different ways of:
# Turning embed into chat-template embed
# Feeding embeds into the model

# First, let's establish what "chat-template" embed we actually want.
# i.e. let's embed pre and post chat and compare them...
# To identify each tensor, we'll want to hash it ig

prompt = "How to bake?"

print("BASE PROMPT")
token_ids1 = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to("cuda")
print(token_ids1)
print(token_ids1["input_ids"].shape)
embeddings1 = model_embedding_layer(token_ids1["input_ids"])
print(embeddings1)
print(embeddings1.shape)

print("\nCHAT PROMPT")

messages = [
    [{"role": "user", "content": prompt}]
]
token_ids2 = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,
    padding=True,
    return_tensors="pt",
    return_dict=True
).to("cuda")
print(token_ids2)
embeddings2 = model_embedding_layer(token_ids2["input_ids"])
print(embeddings2)
print(embeddings2.shape)

#print(embeddings2[0, 0].equal(embeddings1[0, 0]))

insertion_index = -1
for i in range(embeddings2.shape[1]):
    if embeddings2[0, i].equal(embeddings1[0, 0]):
        # Insertion here!
        print(f"Found index {i}!")
        insertion_index = i
        break

embeddings2_intro = embeddings2[0, :insertion_index].unsqueeze(0)
embeddings2_outro = embeddings2[0, insertion_index+embeddings1.shape[1]:].unsqueeze(0)

print(f"SECTION SIZES:")
print(embeddings2_intro.shape)
print(embeddings2_outro.shape)

embedding_chat_function = lambda raw: torch.cat((embeddings2_intro, raw, embeddings2_outro), dim=1)

combo_embedding = embedding_chat_function(embeddings1)
print("COMBINED SIZE:")
print(combo_embedding.shape)

for i in range(embeddings2.shape[1]):
    print(embeddings2[0, i].equal(combo_embedding[0, i]))

# for each item in new embedding...
# if it equals index in base: mark ""


BASE PROMPT
{'input_ids': tensor([[  2299,    577,  44528, 235336]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1]], device='cuda:0')}
torch.Size([1, 4])
tensor([[[-0.0134, -0.0056, -0.0208,  ...,  0.0049, -0.0243, -0.0239],
         [-0.0240,  0.0229, -0.0211,  ..., -0.0066, -0.0115, -0.0057],
         [-0.0272,  0.0236, -0.0591,  ..., -0.0204, -0.0371, -0.0564],
         [ 0.0055, -0.0525, -0.0269,  ..., -0.0056, -0.0532, -0.0179]]],
       device='cuda:0', dtype=torch.float16, grad_fn=<EmbeddingBackward0>)
torch.Size([1, 4, 3584])

CHAT PROMPT
{'input_ids': tensor([[     2,    106,   1645,    108,   2299,    577,  44528, 235336,    107,
            108,    106,   2516,    108]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}
tensor([[[-0.0112,  0.0026,  0.0081,  ..., -0.0006,  0.0043,  0.0033],
         [-0.0444, -0.0019, -0.0378,  ..., -0.0056, -0.0254, -0.0242],
         [-0.0223, -0.0214, -0.0588,  ..., -0.0151, 

In [5]:
model_embedding_layer = model.get_input_embeddings()

def token_ids_to_embeddings(token_ids):
    """Expects input shape (batch_size, seq_len). Outputs shape (batch_size, seq_len, embedding_size)."""
    return model_embedding_layer(token_ids)


token_id_tensor = torch.tensor(tokenizer.pad_token_id, device='cuda').unsqueeze(0).unsqueeze(0)
pad_embedding = token_ids_to_embeddings(token_id_tensor)
print(pad_embedding.shape)

generated_outputs = model.generate(
        input_embeds=combo_embedding,
        max_new_tokens=50,
        do_sample=True,
        top_p=0.9,
        temperature=0.8
    )


torch.Size([1, 1, 3584])


ValueError: The following `model_kwargs` are not used by the model: ['input_embeds'] (note: typos in the generate arguments will also show up in this list)

In [10]:
# %%
from contextlib import contextmanager

@contextmanager
def inject_prompt_embeddings_via_forward_hook(model, prompt_embeds: torch.Tensor):
    """
    A context manager to override the *output* of the embedding layer
    for the multi-token "prompt" pass. For single-token steps,
    we do nothing (use normal embeddings).
    """
    def embedding_output_hook(module, module_input, module_output):
        """
        module_output is the float embeddings of shape [batch_size, seq_len, hidden_dim].
        We'll check if seq_len > 1 => prompt pass. Then we override.
        """
        bsz, seq_len, hidden_dim = module_output.shape
        if seq_len > 1:
            # Overwrite with our custom prompt embeddings
            #new_embeds = prompt_embeds[:, :seq_len, :]
            print(f"About to replace emebds shape {module_output.shape} with {prompt_embeds.shape}")
            print(f"Specifically: {prompt_embeds}")
            print(f"At module {module}")
            print(f"With module input {module_input}")
            return prompt_embeds
        else:
            # Single-token step -> no changes
            return module_output

    # Register forward hook
    handle = model.get_input_embeddings().register_forward_hook(embedding_output_hook)
    try:
        yield
    finally:
        handle.remove()

# Now let's test it

batch_size, seq_len, hidden_dim = combo_embedding.shape
dummy_input_ids = torch.full(
    (batch_size, seq_len),
    fill_value=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0,
    dtype=torch.long,
    device=combo_embedding.device
)



with inject_prompt_embeddings_via_forward_hook(model, combo_embedding):
    generated_outputs = model.generate(
        input_ids=dummy_input_ids,
        max_new_tokens=50,
        do_sample=True,
        top_p=0.9,
        temperature=0.8
    )

decoded = tokenizer.batch_decode(generated_outputs, skip_special_tokens=True)
print("=== Model Output ===")
for i, text in enumerate(decoded):
    print(f"[Sample {i}] {text}")


About to replace emebds shape torch.Size([1, 13, 3584]) with torch.Size([1, 13, 3584])
Specifically: tensor([[[-0.0112,  0.0026,  0.0081,  ..., -0.0006,  0.0043,  0.0033],
         [-0.0444, -0.0019, -0.0378,  ..., -0.0056, -0.0254, -0.0242],
         [-0.0223, -0.0214, -0.0588,  ..., -0.0151, -0.0009, -0.0240],
         ...,
         [-0.0444, -0.0019, -0.0378,  ..., -0.0056, -0.0254, -0.0242],
         [-0.0058,  0.0115,  0.0273,  ..., -0.0038, -0.0253, -0.0239],
         [-0.0139,  0.0752,  0.0123,  ...,  0.0013, -0.0026, -0.0114]]],
       device='cuda:0', dtype=torch.float16, grad_fn=<CatBackward0>)
At module Embedding(256000, 3584, padding_idx=0)
With module input (tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0'),)
=== Model Output ===
[Sample 0] LetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLetLet


In [None]:
model



In [18]:
outputs1 = model.forward(inputs_embeds=combo_embedding)
print(outputs1.logits.shape)
print(tokenizer.batch_decode(torch.argmax(outputs1.logits, dim=-1), skip_special_tokens=False))

torch.Size([1, 13, 256000])
['<h1>\n\n\n\n\n\n do make a\n\n\n\n\n\n\n\n\n\nBaking']


In [25]:
from oat_evaluation.llms.autollm import AutoLLM

llm = AutoLLM("/workspace/gemma_2_9b_instruct")

Loading model from /workspace/gemma_2_9b_instruct...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading checkpoint shards: 100%|██████████| 4/4 [01:31<00:00, 22.87s/it]


Found chat template intro length 4, outro length 5
Loaded model with left-padding token: <pad>


In [27]:
messages = [
    [{"role": "user", "content": prompt}]
    for prompt in ["How to bake?", "LOL"]
]
tokenized_chat = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,
    padding=True,
    return_tensors="pt",
    return_dict=True
).to(model.device)
print(tokenized_chat)

{'input_ids': tensor([[     2,    106,   1645,    108,   2299,    577,  44528, 235336,    107,
            108,    106,   2516,    108],
        [     0,      0,      0,      2,    106,   1645,    108,  26620,    107,
            108,    106,   2516,    108]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}


In [28]:
prompt = "How to bake?"
response = "First, mix the ingredients."
resp1 = llm.generate_responses_forced([prompt], [response])
print(resp1.responses_strings)

prompt_embeds = llm.string_to_embedding(prompt)
resp_embeds = llm.string_to_embedding(response)
resp2 = llm.generate_responses_forced([prompt], [response])
print(resp2.responses_strings)

Outputs: CausalLMOutputWithPast(loss=None, logits=tensor([[[-26.0938,  17.2500,  21.2344,  ..., -10.3750,  -3.5645, -13.8125],
         [-24.4531,   4.1953, -16.5312,  ..., -13.5703, -10.8516, -25.2812],
         [-25.4375,   2.9688, -17.9219,  ..., -14.4766, -11.4844, -26.1250],
         ...,
         [-28.8594,  -0.9131, -22.4062,  ..., -18.6250, -15.1875, -28.7188],
         [-19.0000,  -0.9028, -11.5625,  ...,  -9.2812,  -8.0312, -18.0469],
         [-15.5312,   1.6055,  -2.6191,  ...,  -3.9766,  -3.6816, -13.1875]]],
       device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>), past_key_values=<transformers.cache_utils.HybridCache object at 0x785e8892fd00>, hidden_states=None, attentions=None)
Error trying to extract start length: 'CausalLMOutputWithPast' object has no attribute 'sequences'


AttributeError: 'CausalLMOutputWithPast' object has no attribute 'sequences'

In [None]:
import torch
import torch.nn.functional as F

def custom_autoregressive_generate(
    model,
    tokenizer,
    initial_input_ids: torch.Tensor,
    max_new_tokens=50,
    temperature=1.0,
    top_p=0.9,
    custom_embedding_fn=None,
    device="cuda",
    eos_token_id=None,
):
    """
    Custom autoregressive loop. Feeds the entire sequence to model.forward each step,
    picks the next token from the logits, appends to sequence, and repeats.
    
    Args:
      model: A HuggingFace CausalLM model or similar
      tokenizer: A HuggingFace tokenizer
      initial_input_ids: shape (batch_size, seq_len); your starting tokens
      max_new_tokens: how many tokens to generate
      temperature, top_p: sampling parameters
      custom_embedding_fn: optional function that overrides token+pos embeddings
      device: "cuda" or "cpu"
      eos_token_id: If not None, generation will stop upon generating this token

    Returns:
      A list of token IDs (per batch) including the newly generated tokens.
    """

    # Move the current tokens to device
    input_ids = initial_input_ids.to(device)
    batch_size = input_ids.size(0)

    # We'll store outputs in a Python list so we can keep appending
    # (though you can also keep it in a single tensor if you prefer).
    generated = [input_ids for _ in range(batch_size)]
    # Actually, to keep it simpler for multi-batch, let's keep it as one tensor:
    generated = input_ids.clone()

    for step in range(max_new_tokens):
        seq_len = generated.size(1)

        if custom_embedding_fn is None:
            # Normal forward pass: pass input_ids
            outputs = model.forward(input_ids=generated)
            logits = outputs.logits  # shape [batch_size, seq_len, vocab_size]
        else:
            # If you have a custom function that transforms your input_ids into embeddings
            # and you do not want the model to do the token-embedding lookup,
            # you can pass `inputs_embeds` directly:
            with torch.no_grad():
                # custom_embedding_fn should produce a float tensor of shape [batch_size, seq_len, hidden_dim]
                # possibly by hooking or by direct construction
                inputs_embeds = custom_embedding_fn(model, generated)
                outputs = model.forward(inputs_embeds=inputs_embeds)
                logits = outputs.logits

        # Take the last-token's logits (shape [batch_size, vocab_size])
        next_logits = logits[:, -1, :]

        # Apply temperature
        if temperature != 1.0:
            next_logits = next_logits / temperature

        # Apply top-p (nucleus) sampling
        if top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
            cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)

            # Remove tokens with cumulative probability > top_p
            cutoff = (cumulative_probs > top_p).float().argmax(dim=-1)
            # We'll build a mask to zero out probabilities beyond the cutoff
            # For each batch item, find the cutoff index
            for b in range(batch_size):
                # index up to cutoff[b] inclusive
                sorted_logits[b, cutoff[b]+1 :] = float("-inf")

            # Re-sort back to original positions
            _, original_indices = torch.sort(sorted_indices, descending=False)
            next_logits = sorted_logits.gather(1, original_indices)

        # Convert logits -> probabilities
        probs = F.softmax(next_logits, dim=-1)

        # Sample from the distribution
        next_tokens = torch.multinomial(probs, num_samples=1)  # [batch_size, 1]

        # Append new token to `generated`
        generated = torch.cat([generated, next_tokens], dim=1)  # shape [batch_size, seq_len+1]

        # If we have an EOS token, we can stop if all sequences ended
        if eos_token_id is not None:
            # Check if any of the new tokens are EOS
            is_eos = (next_tokens == eos_token_id).all()
            if is_eos:
                break

    return generated


### Example usage

# Suppose you have a model, tokenizer, and an initial text prompt:
# model = ...
# tokenizer = ...
prompt_text = "The quick brown fox"
initial_inputs = tokenizer(prompt_text, return_tensors="pt")

# No custom embedding override (pure normal next-token generation):
generated_tokens = custom_autoregressive_generate(
    model=model,
    tokenizer=tokenizer,
    initial_input_ids=initial_inputs["input_ids"],
    max_new_tokens=30,
    temperature=0.7,
    top_p=0.9,
)

decoded = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
print("Decoded text:", decoded[0])  # single batch, index 0
