-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Example] Add an example of running GPT-2 model (#203)
* add example gpt-2 * format * update readme * . * .
- Loading branch information
1 parent
3142734
commit 735eca2
Showing
16 changed files
with
484 additions
and
18 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# GPT-2 Demo | ||
|
||
This example demonstrates how to use Hidet to compile and run a GPT-2 model. | ||
|
||
## Requirements | ||
|
||
This example requires a nightly build version of hidet before we release the next version (v0.2.4) to PyPI: | ||
Run the following commands under the `examples/gpt-2` directory to install the required packages: | ||
```console | ||
$ pip install --pre --extra-index-url https://download.hidet.org/whl hidet | ||
$ pip install -r requirements.txt | ||
``` | ||
|
||
## Usage | ||
|
||
```bash | ||
$ python main.py | ||
>>> Alan Turing theorized that computers would one day become | ||
generating: 100%|██████████████████████████████| 30/30 [00:00<00:00, 128.30it/s] | ||
Alan Turing theorized that computers would one day become the most powerful machines on the planet. | ||
|
||
The computer is a machine that can perform complex calculations, and it can perform these calculations | ||
``` | ||
|
||
## Configs | ||
```bash | ||
Usage: main.py [OPTIONS] | ||
|
||
Options: | ||
--max-num-tokens INTEGER Max number of total tokens to process and | ||
generate [default: 40] | ||
--use-fp16 Use fp16 | ||
--model-size [124M|355M|774M|1558M] | ||
[default: 124M] | ||
--tune Tune the operators for better performance. | ||
May take several minutes. | ||
--help Show this message and exit. | ||
``` | ||
|
||
## Acknowledgements | ||
We referred to the [picoGPT](https://github.com/jaymody/picoGPT)'s clean and simple implementation of GPT-2. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
"""Byte pair encoding utilities. | ||
Copied from: https://github.com/openai/gpt-2/blob/master/src/encoder.py. | ||
""" | ||
import json | ||
import os | ||
from functools import lru_cache | ||
|
||
import regex as re | ||
|
||
|
||
@lru_cache() | ||
def bytes_to_unicode(): | ||
""" | ||
Returns list of utf-8 byte and a corresponding list of unicode strings. | ||
The reversible bpe codes work on unicode strings. | ||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. | ||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. | ||
This is a significant percentage of your normal, say, 32K bpe vocab. | ||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings. | ||
And avoids mapping to whitespace/control characters the bpe code barfs on. | ||
""" | ||
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) | ||
cs = bs[:] | ||
n = 0 | ||
for b in range(2**8): | ||
if b not in bs: | ||
bs.append(b) | ||
cs.append(2**8 + n) | ||
n += 1 | ||
cs = [chr(n) for n in cs] | ||
return dict(zip(bs, cs)) | ||
|
||
|
||
def get_pairs(word): | ||
"""Return set of symbol pairs in a word. | ||
Word is represented as tuple of symbols (symbols being variable-length strings). | ||
""" | ||
pairs = set() | ||
prev_char = word[0] | ||
for char in word[1:]: | ||
pairs.add((prev_char, char)) | ||
prev_char = char | ||
return pairs | ||
|
||
|
||
class Encoder: | ||
def __init__(self, encoder, bpe_merges, errors="replace"): | ||
self.encoder = encoder | ||
self.decoder = {v: k for k, v in self.encoder.items()} | ||
self.errors = errors # how to handle errors in decoding | ||
self.byte_encoder = bytes_to_unicode() | ||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} | ||
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) | ||
self.cache = {} | ||
|
||
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions | ||
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") | ||
|
||
def bpe(self, token): | ||
if token in self.cache: | ||
return self.cache[token] | ||
word = tuple(token) | ||
pairs = get_pairs(word) | ||
|
||
if not pairs: | ||
return token | ||
|
||
while True: | ||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) | ||
if bigram not in self.bpe_ranks: | ||
break | ||
first, second = bigram | ||
new_word = [] | ||
i = 0 | ||
while i < len(word): | ||
try: | ||
j = word.index(first, i) | ||
new_word.extend(word[i:j]) | ||
i = j | ||
except: | ||
new_word.extend(word[i:]) | ||
break | ||
|
||
if word[i] == first and i < len(word) - 1 and word[i + 1] == second: | ||
new_word.append(first + second) | ||
i += 2 | ||
else: | ||
new_word.append(word[i]) | ||
i += 1 | ||
new_word = tuple(new_word) | ||
word = new_word | ||
if len(word) == 1: | ||
break | ||
else: | ||
pairs = get_pairs(word) | ||
word = " ".join(word) | ||
self.cache[token] = word | ||
return word | ||
|
||
def encode(self, text): | ||
bpe_tokens = [] | ||
for token in re.findall(self.pat, text): | ||
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) | ||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) | ||
return bpe_tokens | ||
|
||
def decode(self, tokens): | ||
text = "".join([self.decoder[token] for token in tokens]) | ||
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) | ||
return text | ||
|
||
|
||
def get_encoder(model_name="124M"): | ||
import hidet | ||
models_dir = hidet.utils.hidet_cache_dir("./examples/gpt-2") | ||
with open(os.path.join(models_dir, model_name, "encoder.json"), "r") as f: | ||
encoder = json.load(f) | ||
with open(os.path.join(models_dir, model_name, "vocab.bpe"), "r", encoding="utf-8") as f: | ||
bpe_data = f.read() | ||
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]] | ||
return Encoder(encoder=encoder, bpe_merges=bpe_merges) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from typing import Optional | ||
import os | ||
import numpy.testing | ||
import numpy as np | ||
import hidet | ||
from hidet import ops | ||
from hidet import FlowGraph | ||
from utils import load_params | ||
|
||
|
||
def linear(x, w, b): # [m, in], [in, out], [out] -> [m, out] | ||
w = hidet.asarray(w).cuda() | ||
b = hidet.asarray(b).cuda() | ||
return x @ w + b | ||
|
||
|
||
def layer_norm(x, g, b, eps: float = 1e-5): | ||
g = hidet.asarray(g).cuda() | ||
b = hidet.asarray(b).cuda() | ||
x = ops.layer_norm(x, epsilon=eps) | ||
return g * x + b | ||
|
||
|
||
def ffn(x, c_fc, c_proj): # [n_seq, n_embd] -> [n_seq, n_embd] | ||
# project up | ||
a = ops.gelu(linear(x, **c_fc)) # [n_seq, n_embd] -> [n_seq, 4*n_embd] | ||
|
||
# project back down | ||
x = linear(a, **c_proj) # [n_seq, 4*n_embd] -> [n_seq, n_embd] | ||
|
||
return x | ||
|
||
|
||
def attention(q, k, v, mask): # [n_q, d_k], [n_k, d_k], [n_k, d_v], [n_q, n_k] -> [n_q, d_v] | ||
return ops.softmax(q @ ops.transpose(k, [-1, -2]) / float(np.sqrt(q.shape[-1])) + mask, axis=-1) @ v | ||
|
||
|
||
def mha(x, c_attn, c_proj, n_head): # [n_seq, n_embd] -> [n_seq, n_embd] | ||
n_seq = x.shape[0] | ||
causal_mask = hidet.asarray((1 - np.tri(x.shape[0])) * -1e10, dtype=x.dtype).cuda() # [n_seq, n_seq] | ||
x = linear(x, **c_attn) # [n_seq, n_embd] -> [n_seq, n_embd * 3] | ||
|
||
# [n_seq, n_embed * 3] -> [n_seq, 3, n_head, n_embed // n_head] | ||
x = ops.reshape(x, [x.shape[0], 3, n_head, x.shape[1] // (3 * n_head)]) | ||
# [n_seq, 3, n_head, n_embed // n_head] -> [3, n_head, n_seq, n_embed // n_head] | ||
x = ops.transpose(x, [1, 2, 0, 3]) | ||
# [3, n_head, n_seq, n_embed // n_head] -> [3, n_head, n_seq, n_embed // n_head] | ||
q, k, v = [t for t in ops.split(x, 3, axis=0)] | ||
|
||
# impl 1: | ||
# o = ops.attention(q, k, v, causal_mask) # [1, n_head, n_seq, n_embed // n_head] | ||
|
||
# impl 2: | ||
o = attention(q, k, v, causal_mask) | ||
|
||
o = ops.rearrange(o, [[2], [0, 1, 3]]) | ||
o = linear(o, **c_proj) # [n_seq, n_embd] -> [n_seq, n_embd] | ||
|
||
return o | ||
|
||
|
||
def transformer_block(x, mlp, attn, ln_1, ln_2, n_head): | ||
# multi-head causal self attention | ||
x = x + mha(layer_norm(x, **ln_1), **attn, n_head=n_head) # [n_seq, n_embd] -> [n_seq, n_embd] | ||
|
||
# position-wise feed forward network | ||
x = x + ffn(layer_norm(x, **ln_2), **mlp) # [n_seq, n_embd] -> [n_seq, n_embd] | ||
|
||
return x | ||
|
||
|
||
def gpt2_forward(ids, wte, wpe, blocks, ln_f, n_head): # [n_seq] -> [n_seq, n_vocab] | ||
wte = hidet.asarray(wte).cuda() | ||
wpe = hidet.asarray(wpe).cuda() | ||
|
||
# [n_seq] -> [n_seq, n_embd] | ||
x = hidet.ops.take(wte, ids) + hidet.ops.take(wpe, hidet.ops.arange(ids.shape[0]).cuda()) | ||
|
||
for block in blocks: | ||
x = transformer_block(x, **block, n_head=n_head) | ||
|
||
x = layer_norm(x, **ln_f) | ||
x = ops.matmul(x, ops.transpose(wte)) | ||
return x | ||
|
||
|
||
def gpt2(model_size: str = "124M", seq_length: Optional[int] = 1000, use_fp16=False) -> FlowGraph: | ||
cache_dir = hidet.utils.hidet_cache_dir('./examples/gpt-2/') | ||
model_name = 'model_{}_seq{}_{}.hf'.format(model_size, seq_length, 'fp16' if use_fp16 else 'fp32') | ||
hf_path = os.path.join(cache_dir, model_name) | ||
if os.path.exists(hf_path): | ||
return hidet.load_graph(hf_path) | ||
else: | ||
print("Building hidet graph for GPT-2 ({}) with sequence length {}".format(model_size, seq_length)) | ||
hparams, params = load_params(model_size, models_dir=cache_dir) | ||
if seq_length > hparams["n_ctx"]: | ||
raise ValueError(f"seq_length should be less than or equal to {hparams['n_ctx']}") | ||
|
||
ids = hidet.symbol([seq_length], dtype='int32', device='cuda') | ||
out = gpt2_forward(ids, **params, n_head=hparams["n_head"]) | ||
graph = hidet.trace_from(out, inputs=[ids]) | ||
with hidet.graph.PassContext() as ctx: | ||
if use_fp16: | ||
ctx.set_precision('float16') | ||
ctx.set_mma('mma') | ||
graph_opt = hidet.graph.optimize(graph) | ||
|
||
hidet.save_graph(graph_opt, hf_path) | ||
return graph_opt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from typing import List | ||
|
||
import click | ||
|
||
import hidet | ||
from tqdm import tqdm | ||
import torch | ||
from hidet import FlowGraph | ||
from gpt_model import gpt2 | ||
from encoder import get_encoder | ||
|
||
hidet.option.search_space(0) | ||
|
||
bucket_size = 50 | ||
|
||
|
||
class GPT2Generator: | ||
def __init__(self, max_num_tokens, model_size, use_fp16): | ||
import hidet.cuda.graph | ||
|
||
graph: FlowGraph = gpt2(seq_length=max_num_tokens, model_size=model_size, use_fp16=use_fp16) | ||
self.cuda_graph: hidet.cuda.graph.CudaGraph = graph.cuda_graph() | ||
self.encoder = get_encoder() | ||
self.max_num_tokens = max_num_tokens | ||
|
||
# get the torch view for the two hidet tensors | ||
self.input_ids = self.cuda_graph.inputs[0].torch() | ||
self.logits = self.cuda_graph.outputs[0].torch() | ||
|
||
def __call__(self, text: str) -> str: | ||
ids: List[int] = self.encoder.encode(text) | ||
num_init_tokens = len(ids) | ||
if num_init_tokens > self.max_num_tokens: | ||
return text | ||
|
||
self.input_ids[:num_init_tokens] = torch.asarray(ids) | ||
|
||
for i in tqdm(range(num_init_tokens, self.max_num_tokens), "generating", ncols=80): | ||
self.cuda_graph.run() | ||
next_token: int = torch.argmax(self.logits[i - 1], dim=-1).item() | ||
self.input_ids[i] = next_token | ||
|
||
output_ids = self.input_ids[num_init_tokens:].cpu().tolist() | ||
output_text = self.encoder.decode(output_ids) | ||
return output_text | ||
|
||
|
||
@click.command() | ||
@click.option("--max-num-tokens", default=40, type=int, help='Max number of total tokens to process and generate', | ||
show_default=True) | ||
@click.option("--use-fp16", is_flag=True, default=False, help='Use fp16', show_default=True) | ||
@click.option("--model-size", default="124M", type=click.Choice(['124M', '355M', '774M', '1558M']), show_default=True) | ||
@click.option("--tune", is_flag=True, default=False, | ||
help='Tune the operators for better performance. May take several minutes.', show_default=True) | ||
def main(max_num_tokens: int, use_fp16: bool, model_size: str, tune: bool): | ||
if tune: | ||
hidet.option.search_space(2) | ||
generator = GPT2Generator(max_num_tokens, model_size, use_fp16) | ||
while True: | ||
x = click.prompt(">>> ", type=str, prompt_suffix="") | ||
response = generator(x) | ||
click.echo(x + response) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
regex | ||
requests | ||
tensorflow==2.11.0 | ||
hidet |
Oops, something went wrong.