Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] Add an example of running GPT-2 model #203

Merged
merged 5 commits into from
Apr 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Empty file added examples/README.md
Empty file.
41 changes: 41 additions & 0 deletions examples/gpt-2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# GPT-2 Demo

This example demonstrates how to use Hidet to compile and run a GPT-2 model.

## Requirements

This example requires a nightly build version of hidet before we release the next version (v0.2.4) to PyPI:
Run the following commands under the `examples/gpt-2` directory to install the required packages:
```console
$ pip install --pre --extra-index-url https://download.hidet.org/whl hidet
$ pip install -r requirements.txt
```

## Usage

```bash
$ python main.py
>>> Alan Turing theorized that computers would one day become
generating: 100%|██████████████████████████████| 30/30 [00:00<00:00, 128.30it/s]
Alan Turing theorized that computers would one day become the most powerful machines on the planet.

The computer is a machine that can perform complex calculations, and it can perform these calculations
```

## Configs
```bash
Usage: main.py [OPTIONS]

Options:
--max-num-tokens INTEGER Max number of total tokens to process and
generate [default: 40]
--use-fp16 Use fp16
--model-size [124M|355M|774M|1558M]
[default: 124M]
--tune Tune the operators for better performance.
May take several minutes.
--help Show this message and exit.
```

## Acknowledgements
We referred to the [picoGPT](https://github.com/jaymody/picoGPT)'s clean and simple implementation of GPT-2.
122 changes: 122 additions & 0 deletions examples/gpt-2/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""Byte pair encoding utilities.

Copied from: https://github.com/openai/gpt-2/blob/master/src/encoder.py.
"""
import json
import os
from functools import lru_cache

import regex as re


@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))


def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs


class Encoder:
def __init__(self, encoder, bpe_merges, errors="replace"):
self.encoder = encoder
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}

# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)

if not pairs:
return token

while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break

if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word

def encode(self, text):
bpe_tokens = []
for token in re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
return bpe_tokens

def decode(self, tokens):
text = "".join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
return text


def get_encoder(model_name="124M"):
import hidet
models_dir = hidet.utils.hidet_cache_dir("./examples/gpt-2")
with open(os.path.join(models_dir, model_name, "encoder.json"), "r") as f:
encoder = json.load(f)
with open(os.path.join(models_dir, model_name, "vocab.bpe"), "r", encoding="utf-8") as f:
bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
return Encoder(encoder=encoder, bpe_merges=bpe_merges)
109 changes: 109 additions & 0 deletions examples/gpt-2/gpt_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import Optional
import os
import numpy.testing
import numpy as np
import hidet
from hidet import ops
from hidet import FlowGraph
from utils import load_params


def linear(x, w, b): # [m, in], [in, out], [out] -> [m, out]
w = hidet.asarray(w).cuda()
b = hidet.asarray(b).cuda()
return x @ w + b


def layer_norm(x, g, b, eps: float = 1e-5):
g = hidet.asarray(g).cuda()
b = hidet.asarray(b).cuda()
x = ops.layer_norm(x, epsilon=eps)
return g * x + b


def ffn(x, c_fc, c_proj): # [n_seq, n_embd] -> [n_seq, n_embd]
# project up
a = ops.gelu(linear(x, **c_fc)) # [n_seq, n_embd] -> [n_seq, 4*n_embd]

# project back down
x = linear(a, **c_proj) # [n_seq, 4*n_embd] -> [n_seq, n_embd]

return x


def attention(q, k, v, mask): # [n_q, d_k], [n_k, d_k], [n_k, d_v], [n_q, n_k] -> [n_q, d_v]
return ops.softmax(q @ ops.transpose(k, [-1, -2]) / float(np.sqrt(q.shape[-1])) + mask, axis=-1) @ v


def mha(x, c_attn, c_proj, n_head): # [n_seq, n_embd] -> [n_seq, n_embd]
n_seq = x.shape[0]
causal_mask = hidet.asarray((1 - np.tri(x.shape[0])) * -1e10, dtype=x.dtype).cuda() # [n_seq, n_seq]
x = linear(x, **c_attn) # [n_seq, n_embd] -> [n_seq, n_embd * 3]

# [n_seq, n_embed * 3] -> [n_seq, 3, n_head, n_embed // n_head]
x = ops.reshape(x, [x.shape[0], 3, n_head, x.shape[1] // (3 * n_head)])
# [n_seq, 3, n_head, n_embed // n_head] -> [3, n_head, n_seq, n_embed // n_head]
x = ops.transpose(x, [1, 2, 0, 3])
# [3, n_head, n_seq, n_embed // n_head] -> [3, n_head, n_seq, n_embed // n_head]
q, k, v = [t for t in ops.split(x, 3, axis=0)]

# impl 1:
# o = ops.attention(q, k, v, causal_mask) # [1, n_head, n_seq, n_embed // n_head]

# impl 2:
o = attention(q, k, v, causal_mask)

o = ops.rearrange(o, [[2], [0, 1, 3]])
o = linear(o, **c_proj) # [n_seq, n_embd] -> [n_seq, n_embd]

return o


def transformer_block(x, mlp, attn, ln_1, ln_2, n_head):
# multi-head causal self attention
x = x + mha(layer_norm(x, **ln_1), **attn, n_head=n_head) # [n_seq, n_embd] -> [n_seq, n_embd]

# position-wise feed forward network
x = x + ffn(layer_norm(x, **ln_2), **mlp) # [n_seq, n_embd] -> [n_seq, n_embd]

return x


def gpt2_forward(ids, wte, wpe, blocks, ln_f, n_head): # [n_seq] -> [n_seq, n_vocab]
wte = hidet.asarray(wte).cuda()
wpe = hidet.asarray(wpe).cuda()

# [n_seq] -> [n_seq, n_embd]
x = hidet.ops.take(wte, ids) + hidet.ops.take(wpe, hidet.ops.arange(ids.shape[0]).cuda())

for block in blocks:
x = transformer_block(x, **block, n_head=n_head)

x = layer_norm(x, **ln_f)
x = ops.matmul(x, ops.transpose(wte))
return x


def gpt2(model_size: str = "124M", seq_length: Optional[int] = 1000, use_fp16=False) -> FlowGraph:
cache_dir = hidet.utils.hidet_cache_dir('./examples/gpt-2/')
model_name = 'model_{}_seq{}_{}.hf'.format(model_size, seq_length, 'fp16' if use_fp16 else 'fp32')
hf_path = os.path.join(cache_dir, model_name)
if os.path.exists(hf_path):
return hidet.load_graph(hf_path)
else:
print("Building hidet graph for GPT-2 ({}) with sequence length {}".format(model_size, seq_length))
hparams, params = load_params(model_size, models_dir=cache_dir)
if seq_length > hparams["n_ctx"]:
raise ValueError(f"seq_length should be less than or equal to {hparams['n_ctx']}")

ids = hidet.symbol([seq_length], dtype='int32', device='cuda')
out = gpt2_forward(ids, **params, n_head=hparams["n_head"])
graph = hidet.trace_from(out, inputs=[ids])
with hidet.graph.PassContext() as ctx:
if use_fp16:
ctx.set_precision('float16')
ctx.set_mma('mma')
graph_opt = hidet.graph.optimize(graph)

hidet.save_graph(graph_opt, hf_path)
return graph_opt
66 changes: 66 additions & 0 deletions examples/gpt-2/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import List

import click

import hidet
from tqdm import tqdm
import torch
from hidet import FlowGraph
from gpt_model import gpt2
from encoder import get_encoder

hidet.option.search_space(0)

bucket_size = 50


class GPT2Generator:
def __init__(self, max_num_tokens, model_size, use_fp16):
import hidet.cuda.graph

graph: FlowGraph = gpt2(seq_length=max_num_tokens, model_size=model_size, use_fp16=use_fp16)
self.cuda_graph: hidet.cuda.graph.CudaGraph = graph.cuda_graph()
self.encoder = get_encoder()
self.max_num_tokens = max_num_tokens

# get the torch view for the two hidet tensors
self.input_ids = self.cuda_graph.inputs[0].torch()
self.logits = self.cuda_graph.outputs[0].torch()

def __call__(self, text: str) -> str:
ids: List[int] = self.encoder.encode(text)
num_init_tokens = len(ids)
if num_init_tokens > self.max_num_tokens:
return text

self.input_ids[:num_init_tokens] = torch.asarray(ids)

for i in tqdm(range(num_init_tokens, self.max_num_tokens), "generating", ncols=80):
self.cuda_graph.run()
next_token: int = torch.argmax(self.logits[i - 1], dim=-1).item()
self.input_ids[i] = next_token

output_ids = self.input_ids[num_init_tokens:].cpu().tolist()
output_text = self.encoder.decode(output_ids)
return output_text


@click.command()
@click.option("--max-num-tokens", default=40, type=int, help='Max number of total tokens to process and generate',
show_default=True)
@click.option("--use-fp16", is_flag=True, default=False, help='Use fp16', show_default=True)
@click.option("--model-size", default="124M", type=click.Choice(['124M', '355M', '774M', '1558M']), show_default=True)
@click.option("--tune", is_flag=True, default=False,
help='Tune the operators for better performance. May take several minutes.', show_default=True)
def main(max_num_tokens: int, use_fp16: bool, model_size: str, tune: bool):
if tune:
hidet.option.search_space(2)
generator = GPT2Generator(max_num_tokens, model_size, use_fp16)
while True:
x = click.prompt(">>> ", type=str, prompt_suffix="")
response = generator(x)
click.echo(x + response)


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions examples/gpt-2/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
regex
requests
tensorflow==2.11.0
hidet