In [1]:
!which python
!python --version
!nvidia-smi

/Users/nadavt/opt/anaconda3/envs/jaymody-sps-env/bin/python
Python 3.9.10
zsh:1: command not found: nvidia-smi


In [2]:
# !python -m torch.utils.collect_env

In [3]:
pip install -q torch transformers

Note: you may need to restart the kernel to use updated packages.


In [13]:
# Load model & Sanity check
import torch
from transformers import AutoTokenizer, GPT2LMHeadModel

draft_tokenizer = AutoTokenizer.from_pretrained("gpt2")
target_tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")

draft_model = GPT2LMHeadModel.from_pretrained("gpt2")
target_model = GPT2LMHeadModel.from_pretrained("gpt2-medium")

draft_inputs = draft_tokenizer("Hello, my dog is cute", return_tensors="pt")
target_inputs = target_tokenizer("Hello, my dog is cute", return_tensors="pt")
assert torch.all(draft_inputs.input_ids == target_inputs.input_ids)
assert torch.all(draft_inputs.attention_mask == target_inputs.attention_mask)

draft_outputs = draft_model(**draft_inputs, labels=draft_inputs["input_ids"])
target_outputs = target_model(**target_inputs, labels=target_inputs["input_ids"])


loss = draft_outputs.loss
logits = draft_outputs.logits
logits.shape

torch.Size([1, 6, 50257])

In [21]:
# Generate
draft_outputs = draft_model.generate(**draft_inputs, max_new_tokens=20, return_dict_in_generate=True, output_scores=True, do_sample=False)
target_outputs = target_model.generate(**target_inputs, max_new_tokens=20, return_dict_in_generate=True, output_scores=True, do_sample=False)
print("draft_outputs.sequences == target_outputs.sequences:")
print(draft_outputs.sequences == target_outputs.sequences)

print("draft_outputs.sequences:")
print(draft_tokenizer.batch_decode(draft_outputs.sequences))

print("target_outputs.sequences:")
print(target_tokenizer.batch_decode(target_outputs.sequences))

# print("draft_outputs:")
# print(draft_outputs)
# print("====")
# transition_scores = draft_model.compute_transition_scores(
#     draft_outputs.sequences, draft_outputs.scores, normalize_logits=True
# )
# print("transition_scores:")
# print(transition_scores)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


draft_outputs.sequences == target_outputs.sequences:
tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False]])
draft_outputs.sequences:
["Hello, my dog is cute. I'm not sure if she's a puppy or not. I'm not sure if she's"]
target_outputs.sequences:
["Hello, my dog is cute. I'm going to take him to the park today. I'm going to take him to the"]


tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False]])

In [None]:
from typing import Any
from torch import Tensor
from collections.abc import Iterable


def print_shapes(obj: Any, level: int = 0):
    padding = "--" * level
    if isinstance(obj, Tensor):
        print(f"{padding}{obj.shape}")
    elif isinstance(obj, dict):
        for k, v in obj.items():
            print(f"{padding}{k}")
            print_shapes(v, level + 1)
    elif isinstance(obj, str):
        print(f"{padding}type {type(obj)} is not supported")
    elif isinstance(obj, Iterable):
        for x in obj:
            print_shapes(x, level + 1)
    else:
        print(f"{padding}type {type(obj)} is not supported")


# # Test cases
# print_shapes(["test", Tensor([1, 2, 3]), {"key": Tensor([4, 5, 6])}])
# print("====")
# print_shapes("This is a string")
# print("====")
# print_shapes((1, 2, Tensor([7, 8, 9]), {"nested": [Tensor([10, 11, 12]), "string"]}))


In [None]:
import transformers

assert isinstance(draft_inputs, transformers.tokenization_utils_base.BatchEncoding)

In [None]:
print_shapes(draft_outputs)

In [None]:
type(draft_outputs.scores), len(draft_outputs.scores)

In [None]:
print(draft_outputs.scores[-1].shape)
print(draft_outputs.scores[-1].min())
print(draft_outputs.scores[-1].max())

In [None]:
draft_tokenizer.batch_decode(draft_outputs.sequences)

In [None]:
import numpy as np


# input_length is the length of the input prompt for decoder-only models, like the GPT family, and 1 for
# encoder-decoder models, like BART or T5.
input_length = 1 if draft_model.config.is_encoder_decoder else draft_inputs.input_ids.shape[1]
generated_tokens = draft_outputs.sequences[:, input_length:]
for tok, score in zip(generated_tokens[0], transition_scores[0]):
    # | token | token string | logits | probability
    print(f"| {tok:5d} | {draft_tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}")