In [3]:
from transformers import pipeline
from transformers import AutoModelForCausalLM, AutoTokenizer


generator = pipeline("text-generation", tokenizer=AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf", add_bos_token=True),
                     model="apple/OpenELM-270M-Instruct", trust_remote_code=True)

#generator("I can't believe you did such a ", do_sample=False)

# -> Somehow this pipeline doesn't work properly for apple models...

[{'generated_text': "I can't believe you did such a \n\n\n\n\n\n\n\n\n\n"}]

In [4]:
# From transformers-from-scratch tutorial
import torch
import numpy as np
from torch import Tensor
from typing import List, Tuple
from jaxtyping import Float, Int


class TransformerSampler:
    """Class to sample from 
    
    Adapted from the excellent transformer-from-scratch course:
    https://arena3-chapter1-transformer-interp.streamlit.app/%5B1.1%5D_Transformer_from_Scratch
    https://github.com/callummcdougall/ARENA_3.0"""

    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    @torch.inference_mode()
    def sample(self, prompt: str, max_tokens_generated=100, verbose=False, **kwargs):
        '''
        Returns a string of autoregressively generated text, starting from the prompt.

        Sampling terminates at max_tokens_generated, or when the model generates an
        end-of-sequence token.

        kwargs are passed to sample_next_token, to give detailed instructions on how 
        new tokens are chosen.
        '''
        self.model.eval()
        tokens = torch.tensor([self.tokenizer.encode(prompt)])

        for i in range(max_tokens_generated):
            logits = self.model(tokens).logits[0,-1]
            next_token = self.sample_next_token(tokens[0], logits, **kwargs)
            
            # Append the generated token
            tokens = torch.cat((tokens, torch.tensor([[next_token]])), 1)

            # Break if EOS token generated
            if next_token == self.tokenizer.eos_token_id:
                break

        return self.tokenizer.decode(tokens[0])

    @staticmethod
    def sample_next_token(
        input_ids: Int[Tensor, "seq_len"], 
        logits: Float[Tensor, "seq_len d_vocab"], 
        temperature=1.0, 
        top_k=0, 
        top_p=0.0, 
        frequency_penalty=0.0,
        seed=None
    ):
        assert input_ids.ndim == 1, "input_ids should be a 1D sequence of token ids"
        assert temperature >= 0, "Temperature should be non-negative"
        assert 0 <= top_p <= 1.0, "Top-p must be a probability"
        assert 0 <= top_k, "Top-k must be non-negative"
        assert not (top_p != 0 and top_k != 0), "At most one of top-p and top-k supported"

        # Set random seeds for reproducibility
        if seed is not None:
            torch.manual_seed(seed)
            np.random.seed(seed)

        # Apply all the specialized sampling methods
        if temperature == 0:
            return TransformerSampler.greedy_search(logits)
        elif temperature != 1.0:
            logits = TransformerSampler.apply_temperature(logits, temperature)
        if frequency_penalty != 0.0:
            logits = TransformerSampler.apply_frequency_penalty(input_ids, logits, frequency_penalty)
        if top_k > 0:
            return TransformerSampler.sample_top_k(logits, top_k)
        if top_p > 0.0:
            return TransformerSampler.sample_top_p(logits, top_p)
        return TransformerSampler.sample_basic(logits)


    @staticmethod
    def greedy_search(
        logits: Float[Tensor, "d_vocab"]
    ) -> int:
        '''
        Returns the most likely token (as an int).
        '''
        out = logits.argmax().item()
        return out


    @staticmethod
    def apply_temperature(
        logits: Float[Tensor, "d_vocab"],
        temperature: float
    ) -> Float[Tensor, "d_vocab"]:
        '''
        Applies temperature scaling to the logits.
        '''
        return logits / temperature

    @staticmethod
    def apply_frequency_penalty(
        input_ids: Int[Tensor, "seq_len"],
        logits: Float[Tensor, "d_vocab"],
        freq_penalty: float
    ) -> Float[Tensor, "d_vocab"]:
        '''
        Applies a frequency penalty to the logits.
        '''
        freqs = torch.bincount(input_ids)
        if len(freqs)<len(logits):
            freqs = t.cat([freqs, torch.zeros(len(logits)-len(freqs))], dim=0)
        return logits - freq_penalty*freqs

    @staticmethod
    def sample_basic(
        logits: Float[Tensor, "d_vocab"]
    ) -> int:
        '''
        Samples from the distribution defined by the logits.
        '''
        m = torch.distributions.categorical.Categorical(logits=logits)
        return m.sample()

    @staticmethod
    def sample_top_k(
        logits: Float[Tensor, "d_vocab"],
        k: int
    ) -> int:
        '''
        Samples from the top k most likely tokens.
        '''
        top_k = torch.topk(logits, k)
        top_logits, top_indices = top_k.values, top_k.indices
        m = torch.distributions.categorical.Categorical(logits=top_logits)
        index = m.sample()
        return top_indices[index]

    @staticmethod
    def sample_top_p(
        logits: Float[Tensor, "d_vocab"],
        top_p: float,
        min_tokens_to_keep: int = 1
    ) -> int:
        '''
        Samples from the most likely tokens which make up at least p cumulative probability.
        '''
        probs = logits.softmax(dim=-1)
        
        # Sort the probabilities from largest to smallest
        sorted_probs = torch.sort(probs, descending=True).values  # Also has .indices

        # Find the cutoff point where the cumulative probability first equals or exceeds top_p.
        # We do the cutoff inclusively, keeping the first probability above the threshold.
        cutoff_index = torch.where(torch.cumsum(sorted_probs, dim=0) >= top_p)[0][0]
        cutoff_prob = sorted_probs[max(cutoff_index, min_tokens_to_keep)]
        # FIXME: If cutoff_prob is exactly matched for several indices, we might get too many values

        # If the number of kept probabilities is less than min_tokens_to_keep, keep that many tokens instead.

        # Set all other probabilities to zero
        masked_probs = probs.masked_fill(probs<cutoff_prob, 0)
        masked_probs /= masked_probs.sum()
        
        # Normalize and sample
        m = torch.distributions.categorical.Categorical(probs=masked_probs)
        return m.sample()


In [25]:
from model import Transformer

model = Transformer('apple/OpenELM-270M', trust_remote_code=True)

In [38]:
sampler = TransformerSampler(model=model.model, tokenizer=model.tokenizer)

In [48]:
prompt = "Jingle bells, jingle bells, jingle all the way"
print(f"Greedy decoding with prompt: {prompt!r}\n")

output = sampler.sample(prompt, max_tokens_generated=64, temperature=0.0)
print(f"Your model said: {output!r}\n")

Greedy decoding with prompt: 'Jingle bells, jingle bells, jingle all the way'

Your model said: "<s> Jingle bells, jingle bells, jingle all the way!\nThe holiday season is upon us and it's time to get your holiday shopping done.\nIf you're like me, you're probably looking for a new pair of shoes to wear to work or to the gym.\nI'm not a big fan of shoes"

