Skip to content

Commit

Permalink
todo
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 23, 2022
1 parent 31a24f3 commit e4e1bc2
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 1 deletion.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ images = parti.generate(texts = [

- <a href="https://huggingface.co/">🤗 Huggingface</a> for the transformers library and the ease for encoding text with T5 language model

## Todo

- [ ] get working vit vqgan-vae trainer code, as discriminator needs to be trained
- [ ] preencoding of text with designated t5
- [ ] training code for parti
- [ ] inference caching

## Citations

Expand Down
80 changes: 80 additions & 0 deletions parti_pytorch/t5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch
import transformers
from transformers import T5Tokenizer, T5EncoderModel, T5Config

transformers.logging.set_verbosity_error()

def exists(val):
return val is not None

# config

MAX_LENGTH = 256

DEFAULT_T5_NAME = 'google/t5-v1_1-base'

T5_CONFIGS = {}

# singleton globals

def get_tokenizer(name):
tokenizer = T5Tokenizer.from_pretrained(name)
return tokenizer

def get_model(name):
model = T5EncoderModel.from_pretrained(name)
return model

def get_model_and_tokenizer(name):
global T5_CONFIGS

if name not in T5_CONFIGS:
T5_CONFIGS[name] = dict()
if "model" not in T5_CONFIGS[name]:
T5_CONFIGS[name]["model"] = get_model(name)
if "tokenizer" not in T5_CONFIGS[name]:
T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)

return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']

def get_encoded_dim(name):
if name not in T5_CONFIGS:
# avoids loading the model if we only want to get the dim
config = T5Config.from_pretrained(name)
T5_CONFIGS[name] = dict(config=config)
elif "config" in T5_CONFIGS[name]:
config = T5_CONFIGS[name]["config"]
elif "model" in T5_CONFIGS[name]:
config = T5_CONFIGS[name]["model"].config
else:
assert False
return config.d_model

# encoding text

def t5_encode_text(texts, name = DEFAULT_T5_NAME):
t5, tokenizer = get_model_and_tokenizer(name)

if torch.cuda.is_available():
t5 = t5.cuda()

device = next(t5.parameters()).device

encoded = tokenizer.batch_encode_plus(
texts,
return_tensors = "pt",
padding = 'longest',
max_length = MAX_LENGTH,
truncation = True
)

input_ids = encoded.input_ids.to(device)
attn_mask = encoded.attention_mask.to(device)

t5.eval()

with torch.no_grad():
output = t5(input_ids = input_ids, attention_mask = attn_mask)
encoded_text = output.last_hidden_state.detach()

return encoded_text, attn_mask.bool()
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'parti-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.1',
version = '0.0.2',
license='MIT',
description = 'Parti - Pathways Autoregressive Text-to-Image Model - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit e4e1bc2

Please sign in to comment.