AR Generation

In [None]:
import torch
import torch.nn.functional as F
from typing import Optional
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.models.gemma import modeling_gemma

In [None]:
import torch

In [None]:
# Load Gemma model and tokenizer
model_name = "google/gemma-2b"  # or "google/gemma-2-2b-it" for instruction-tuned
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
    device_map=device,
)

model.eval()
print(f"Model loaded on {device}")

In [None]:
def make_att_2d_masks(pad_masks, att_masks):
    """
    Tokens can attend to valid inputs tokens which have a cumulative mask_ar
    smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
    setup several types of attention, for example:

      [[1 1 1 1 1 1]]: pure causal attention.

      [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend
        between themselves and the last 3 tokens have a causal attention.

    Args:
        pad_masks: bool[B, N] true if its part of the input, false if padding.
        att_masks: int[B, N] mask that's 1 where previous tokens
            cannot depend on it and 0 where it shares the same
            attention mask as the previous token.

    Returns:
        att_2d_masks: bool[B, N, N] 2D attention mask
    """
    if att_masks.ndim != 2:
        raise ValueError(f"att_masks must be 2D, got {att_masks.ndim}D")
    if pad_masks.ndim != 2:
        raise ValueError(f"pad_masks must be 2D, got {pad_masks.ndim}D")

    # cumsum shape: (B, N)
    cumsum = torch.cumsum(att_masks, dim=1)
    # att_2d_masks shape: (B, N, N)
    att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
    # pad_2d_masks shape: (B, N, N)
    pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
    # att_2d_masks & pad_2d_masks shape: (B, N, N)
    return att_2d_masks & pad_2d_masks


def prepare_attention_masks_4d(att_2d_masks):
    """
    Helper method to prepare 4D attention masks for transformer.
    """
    OPENPI_ATTENTION_MASK_VALUE = -1e9
    att_2d_masks_4d = att_2d_masks[:, None, :, :]
    return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)

In [None]:
@torch.no_grad()
def generate_autoregressive(
    model,
    tokenizer,
    prompt: str,
    max_decoding_steps: int = 100,
    eos_token_id: Optional[int] = None,
    temperature: float = 0.7,
    device: str = "cuda",
) -> str:
    """
    Generate tokens autoregressively from a Gemma model.
    
    Following the pattern from modeling_pi05ki.py sample_subtask method.
    
    Args:
        model: Gemma model instance
        tokenizer: Tokenizer instance
        prompt: Input text prompt
        max_decoding_steps: Maximum number of tokens to generate
        eos_token_id: End-of-sequence token ID (defaults to tokenizer.eos_token_id)
        temperature: Sampling temperature (0.0 for greedy, >0 for sampling)
        device: Device to run on
    
    Returns:
        Generated text string
    """
    if eos_token_id is None:
        eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 1
    
    # Tokenize prompt
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    batch_size, seq_len = input_ids.shape
    
    # Create attention masks (all True for valid tokens)
    pad_masks = torch.ones_like(input_ids, dtype=torch.bool)
    # Causal attention mask (all zeros for causal attention)
    att_masks = torch.zeros_like(input_ids, dtype=torch.bool)
    
    # Create 2D attention masks
    att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
    
    # Compute position IDs
    position_ids = torch.cumsum(pad_masks, dim=1) - 1
    
    # Convert to 4D format expected by the model
    att_2d_masks_4d = prepare_attention_masks_4d(att_2d_masks)
    
    # Set attention implementation to eager (for compatibility)
    model.config._attn_implementation = "eager"
    
    # Get embeddings
    inputs_embeds = model.model.embed_tokens(input_ids)
    
    # Initial forward pass to get past_key_values
    outputs = model.model(
        inputs_embeds=inputs_embeds,
        attention_mask=att_2d_masks_4d,
        position_ids=position_ids,
        past_key_values=None,
        use_cache=True,
    )
    
    past_key_values = outputs.past_key_values
    last_hidden_state = outputs.last_hidden_state
    
    # Extract last token embedding: (B, embd_dim)
    last_token_embed = last_hidden_state[:, -1, :]
    
    # Convert to logits: (B, vocab_size)
    last_logits = model.lm_head(last_token_embed)
    
    # Track valid length for position IDs
    prefix_valid_length = torch.sum(pad_masks, dim=1)  # (B,)
    
    # Initialize output tokens
    output_tokens = torch.zeros((batch_size, max_decoding_steps),
                                dtype=torch.long, device=device)
    all_eos = torch.zeros(batch_size, dtype=torch.bool, device=device)
    
    # Running attention mask (will grow as we generate)
    running_attention_mask = pad_masks.clone()
    
    # Autoregressive loop
    for step in range(max_decoding_steps):
        # Sample next token
        if temperature > 0.0:
            probs = F.softmax(last_logits / temperature, dim=-1)
            token = torch.multinomial(probs, num_samples=1)  # (B, 1)
        else:
            token = torch.argmax(last_logits, dim=-1, keepdim=True)  # (B, 1)
        
        output_tokens[:, step] = token.squeeze(-1)
        
        # Check for EOS
        all_eos |= (token.squeeze(-1) == eos_token_id)
        if all_eos.all():
            break
        
        # Embed the new token
        next_token_embeds = model.model.embed_tokens(token)  # (B, 1, embd_dim)
        
        # Create position IDs for the new token
        position_ids = prefix_valid_length[:, None] + step + 1
        
        # Create attention mask for the new token
        new_mask = torch.ones(
            (batch_size, 1),
            dtype=running_attention_mask.dtype,
            device=device
        )
        running_attention_mask = torch.cat([running_attention_mask, new_mask], dim=1)
        
        # Create 2D attention mask for the extended sequence
        extended_att_masks = torch.zeros_like(running_attention_mask, dtype=torch.bool)
        extended_att_2d_masks = make_att_2d_masks(running_attention_mask, extended_att_masks)
        extended_att_2d_masks_4d = prepare_attention_masks_4d(extended_att_2d_masks)
        
        # Forward pass with past_key_values
        outputs = model.model(
            inputs_embeds=next_token_embeds,
            attention_mask=extended_att_2d_masks_4d,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=True,
        )
        
        past_key_values = outputs.past_key_values
        last_hidden_state = outputs.last_hidden_state
        
        # Extract last token embedding
        last_token_embed = last_hidden_state[:, -1, :]
        
        # Get next logits
        last_logits = model.lm_head(last_token_embed)
    
    # Decode generated tokens
    generated_ids = output_tokens
    # Filter out padding (0) and EOS tokens for cleaner output
    generated_text = tokenizer.decode(
        generated_ids[0].tolist(),
        skip_special_tokens=True
    )
    
    return generated_text

In [None]:
# Example usage
prompt = "The capital of France is"
print(f"Prompt: {prompt}")

generated_text = generate_autoregressive(
    model=model,
    tokenizer=tokenizer,
    prompt=prompt,
    max_decoding_steps=50,
    temperature=0.7,
    device=device,
)

print(f"\nGenerated: {generated_text}")