In [1]:
from abc import ABC, abstractmethod
from typing import List, Union, Tuple, Optional
from transformers import PreTrainedModel, PreTrainedTokenizer, PretrainedConfig
import abc
import torch
from typing import Callable

# Element is the base class for all elements in the prompt cache.
# It defines the basic interface for all elements.
class Element(ABC):
    name: Union[None, str]
    offset: int

    def __init__(self, offset: int, name: Optional[str] = None):
        self.name = name
        self.offset = offset

    @abstractmethod
    def __len__(self) -> int:
        raise NotImplementedError

    def __repr__(self):
        return f"[{self.offset}:{self.offset + len(self)}]"

    @abstractmethod
    def token_ids(self) -> List[int]:
        raise NotImplementedError

    @abstractmethod
    def position_ids(self) -> List[int]:
        raise NotImplementedError


class LanguageModel(ABC):
    name: str
    hf_tokenizer: PreTrainedTokenizer
    hf_model: PreTrainedModel
    stop_token_ids: List[int]
    stop_str: List[str]
    use_full_position_ids: bool = False

    def __init__(self, name: str, model: PreTrainedModel, tokenizer: PreTrainedTokenizer,
                 stop_token_ids: Optional[List[int]] = None, stop_str: Optional[List[str]] = None):
        self.name = name
        self.hf_tokenizer = tokenizer
        self.hf_model = model
        self.stop_token_ids = stop_token_ids if stop_token_ids is not None else [self.eos_token_id]
        self.stop_str = stop_str if stop_str is not None else []

    @abc.abstractmethod
    def get_formatter(self) -> Callable[[str], str]:
        pass

    def get_cache_shape(self) -> Tuple[int, int, int]:
        num_head = self.config.num_attention_heads
        head_dim = self.config.hidden_size // self.config.num_attention_heads

        return self.config.num_hidden_layers, num_head, head_dim

    def store_k_hook(self, k_cache: torch.Tensor) -> torch.Tensor:
        return k_cache

    def store_v_hook(self, v_cache: torch.Tensor) -> torch.Tensor:
        return v_cache

    def read_k_hook(self, k_cache: torch.Tensor) -> torch.Tensor:
        return k_cache

    def read_v_hook(self, v_cache: torch.Tensor) -> torch.Tensor:
        return v_cache

    def __call__(self, **kwargs):
        return self.hf_model(**kwargs)

    def encode(self, text: str) -> List[int]:
        # Warning: this is a hack to remove bos_token
        token_ids = self.hf_tokenizer.encode(text, add_special_tokens=False)
        return token_ids

    def decode(self, token_ids: List[int]) -> str:
        return self.hf_tokenizer.decode(token_ids, skip_special_tokens=False, spaces_between_special_tokens=False)

    @property
    def unk_token(self) -> str:
        return self.hf_tokenizer.unk_token

    @property
    def unk_token_id(self) -> int:
        return self.hf_tokenizer.unk_token_id

    @property
    def eos_token(self) -> str:
        return self.hf_tokenizer.eos_token

    @property
    def eos_token_id(self) -> int:
        return self.hf_tokenizer.eos_token_id

    @property
    def device(self) -> torch.device:
        return self.hf_model.device

    @property
    def config(self) -> PretrainedConfig:
        return self.hf_model.config

class TokenSequence(Element):
    text: str
    token_ids: List[int]
    position_ids: List[int]
    
    def __init__(self, offset: int, text: str, lm: LanguageModel, max_tokens: Optional[int] = None):
        super().__init__(offset)
        self.text = text
        self.token_ids = lm.encode(text)
        
        if max_tokens is not None:
            self._token_ids = self._token_ids[:max_tokens // 2] + self._token_ids[-max_tokens // 2:]
        
        self._position_ids = list(range(self.offset, self.offset + len(self._token_ids)))
    
    def __len__(self) -> int:
        return len(self.token_ids)
    
    def __repr__(self):
        return f"TokenSequence({self.text})"
    
    def token_ids(self) -> List[int]:
        return self._token_ids
    
    def position_ids(self) -> List[int]:
        return self._position_ids

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
TokenSequence(0, "Hello, world!", LanguageModel("gpt2", None, None))