In [1]:
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
import torch
import torch.nn as nn


2025-06-23 14:55:42.728100: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
# Load model and processor
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, device_map="auto").eval()
processor = AutoProcessor.from_pretrained(model_id)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [4]:

class Gemma3Wrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        # Access output (tied) embeddings
        self.output_embeddings = model.get_output_embeddings()

    def get_hidden(self, input_ids, attention_mask=None):
        outs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True
        )
        return outs.hidden_states[-1]  # [batch, seq_len, hidden_dim]

    def forward_from_hidden(self, hidden_states):
        return self.output_embeddings(hidden_states)


In [13]:
# Wrap the model
wrapper = Gemma3Wrapper(model)

# Input
text = "I need a ball because I have to leave for a football"
# inputs = processor(text, return_tensors="pt").to(model.device)
inputs = processor(text=text, return_tensors="pt").to(model.device)

input_ids = inputs["input_ids"]
attention_mask = inputs.get("attention_mask", None)

# Step 1: Get hidden states
with torch.no_grad():
    hidden = wrapper.get_hidden(input_ids, attention_mask)

# Step 2: Modify hidden state (perturb last token)
modified_hidden = hidden.clone()
modified_hidden[:, -1, :] /= 75000
# modified_hidden=hidden.clone()

# Step 3: Get logits from modified hidden
with torch.no_grad():
    logits = wrapper.forward_from_hidden(modified_hidden)

# Step 4: Predict next token
next_token_id = torch.argmax(logits[:, -1, :], dim=-1)
predicted_token = processor.tokenizer.decode(next_token_id)

print("Input Text:", text)
print("Predicted Next Token (after modification):", predicted_token)

Input Text: I need a ball because I have to leave for a football
Predicted Next Token (after modification):  game
