Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement modular encoder/decoder class #364

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions data/encdec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import pickle
import tiktoken

class EncoderDecoder:
"""
Base class that provides unified encode/decode methods.
Handles instantiating the right subclass under self.impl.
"""

def __init__(self, meta_path):
"""
Initialize the correct encoding implementation in self.impl
based on meta_path contents.
"""
if self._is_char_encoding(meta_path):
self.impl = CharEncoderDecoder(meta_path)
else:
self.impl = BPEEncoderDecoder()

def encode(self, text):
return self.impl.encode(text)

def decode(self, tokens):
return self.impl.decode(tokens)

def _is_char_encoding(self, meta_path):
"""Check if meta_path contains a character-level encoding."""
if meta_path is None:
return False
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
return 'itos' in meta


class CharEncoderDecoder:
"""
Encoding/decoding of text at the character level.
Uses mappings defined in a metadata file.
"""

def __init__(self, meta_path):
"""Load stoi/itos mappings from meta_path pickle file."""
with open(meta_path, 'rb') as f:
self.meta = pickle.load(f)

self.itos = self.meta['itos']
self.stoi = self.meta['stoi']

def encode(self, text):
"""Encode text to a list of integers."""
return [self.stoi[c] for c in text]

def decode(self, tokens):
"""Decode a list of integers to text."""
return ''.join([self.itos[i] for i in tokens])


class BPEEncoderDecoder:

def __init__(self):
"""Create BPE encoder directly using tiktoken."""
self.encoder = tiktoken.get_encoding("gpt2")

def encode(self, text):
return self.encoder.encode(text)

def decode(self, tokens):
return self.encoder.decode(tokens)
29 changes: 6 additions & 23 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
Sample from a trained model
"""
import os
import pickle
from contextlib import nullcontext
import torch
import tiktoken

from data.encdec import EncoderDecoder
from model import GPTConfig, GPT

# -----------------------------------------------------------------------------
Expand All @@ -20,6 +20,7 @@
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
compile = False # use PyTorch 2.0 to compile the model to be faster
checkpoint= None
exec(open('configurator.py').read()) # overrides from command line or config file
# -----------------------------------------------------------------------------

Expand Down Expand Up @@ -53,37 +54,19 @@
if compile:
model = torch.compile(model) # requires PyTorch 2.0 (optional)

# look for the meta pickle in case it is available in the dataset folder
load_meta = False
if init_from == 'resume' and 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these...
meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')
load_meta = os.path.exists(meta_path)
if load_meta:
print(f"Loading meta from {meta_path}...")
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
# TODO want to make this more general to arbitrary encoder/decoder schemes
stoi, itos = meta['stoi'], meta['itos']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
else:
# ok let's assume gpt-2 encodings by default
print("No meta.pkl found, assuming GPT-2 encodings...")
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
decode = lambda l: enc.decode(l)
encdec = EncoderDecoder(os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl') if checkpoint else None)

# encode the beginning of the prompt
if start.startswith('FILE:'):
with open(start[5:], 'r', encoding='utf-8') as f:
start = f.read()
start_ids = encode(start)
start_ids = encdec.encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])

# run generation
with torch.no_grad():
with ctx:
for k in range(num_samples):
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
print(decode(y[0].tolist()))
print(encdec.decode(y[0].tolist()))
print('---------------')