diff --git a/README.md b/README.md index 79e54a9..cf7b817 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,12 @@ images = parti.generate(texts = [ - 🤗 Huggingface for the transformers library and the ease for encoding text with T5 language model +## Todo + +- [ ] get working vit vqgan-vae trainer code, as discriminator needs to be trained +- [ ] preencoding of text with designated t5 +- [ ] training code for parti +- [ ] inference caching ## Citations diff --git a/parti_pytorch/t5.py b/parti_pytorch/t5.py new file mode 100644 index 0000000..7ad0689 --- /dev/null +++ b/parti_pytorch/t5.py @@ -0,0 +1,80 @@ +import torch +import transformers +from transformers import T5Tokenizer, T5EncoderModel, T5Config + +transformers.logging.set_verbosity_error() + +def exists(val): + return val is not None + +# config + +MAX_LENGTH = 256 + +DEFAULT_T5_NAME = 'google/t5-v1_1-base' + +T5_CONFIGS = {} + +# singleton globals + +def get_tokenizer(name): + tokenizer = T5Tokenizer.from_pretrained(name) + return tokenizer + +def get_model(name): + model = T5EncoderModel.from_pretrained(name) + return model + +def get_model_and_tokenizer(name): + global T5_CONFIGS + + if name not in T5_CONFIGS: + T5_CONFIGS[name] = dict() + if "model" not in T5_CONFIGS[name]: + T5_CONFIGS[name]["model"] = get_model(name) + if "tokenizer" not in T5_CONFIGS[name]: + T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name) + + return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer'] + +def get_encoded_dim(name): + if name not in T5_CONFIGS: + # avoids loading the model if we only want to get the dim + config = T5Config.from_pretrained(name) + T5_CONFIGS[name] = dict(config=config) + elif "config" in T5_CONFIGS[name]: + config = T5_CONFIGS[name]["config"] + elif "model" in T5_CONFIGS[name]: + config = T5_CONFIGS[name]["model"].config + else: + assert False + return config.d_model + +# encoding text + +def t5_encode_text(texts, name = DEFAULT_T5_NAME): + t5, tokenizer = get_model_and_tokenizer(name) + + if torch.cuda.is_available(): + t5 = t5.cuda() + + device = next(t5.parameters()).device + + encoded = tokenizer.batch_encode_plus( + texts, + return_tensors = "pt", + padding = 'longest', + max_length = MAX_LENGTH, + truncation = True + ) + + input_ids = encoded.input_ids.to(device) + attn_mask = encoded.attention_mask.to(device) + + t5.eval() + + with torch.no_grad(): + output = t5(input_ids = input_ids, attention_mask = attn_mask) + encoded_text = output.last_hidden_state.detach() + + return encoded_text, attn_mask.bool() diff --git a/setup.py b/setup.py index 9ae34d6..2c3500d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'parti-pytorch', packages = find_packages(exclude=[]), - version = '0.0.1', + version = '0.0.2', license='MIT', description = 'Parti - Pathways Autoregressive Text-to-Image Model - Pytorch', author = 'Phil Wang',