<a href="https://colab.research.google.com/github/calder-rh/literate-text-encoding/blob/main/code/encoding_tests.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers
!pip install bitarray

import sys, os
if "google.colab" in sys.modules:
    repo_url = "https://github.com/calder-rh/literate-text-encoding"
    repo_path = "/content/literate-text-encoding"

    if os.path.exists(repo_path):
        os.chdir(repo_path)
        !git pull origin main
    else:
        !git clone {repo_url}
      
    os.chdir(repo_path)
    code_path = '/content/literate-text-encoding/code'
    if code_path not in sys.path:
        sys.path.append(code_path)



In [10]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
import torch.nn.functional as F
from translator import PredictionModel, BitsToText, TextToBits, TokenProbability

# Load model & tokenizer
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [13]:
class GPT2Translator(PredictionModel):
    def _predict_next(self, text: str):
        inputs = tokenizer(text, return_tensors="pt")
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
        last_token_logits = logits[0, -1, :]
        probs = F.softmax(last_token_logits, dim=-1)

        vocab_size = probs.shape[0]
        token_probs = [
            TokenProbability(tokenizer.decode([i]), float(probs[i]))
            for i in range(vocab_size)
        ]
        token_probs.sort(key=lambda x: (x.prob, x.text))

        return token_probs

    def _predict_start(self) -> list[TokenProbability]:
        return self._predict_next('<|endoftext|>')

In [14]:
ttb = TextToBits(GPT2Translator(), 'Hello world, this is a sample!')
print(ttb.translate())

AssertionError: 