Skip to content

Commit

Permalink
release initial pass of parti
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 23, 2022
1 parent b74dc4f commit 31a24f3
Show file tree
Hide file tree
Showing 4 changed files with 333 additions and 4 deletions.
81 changes: 81 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,87 @@

Implementation of <a href="https://parti.research.google/">Parti</a>, Google's pure attention-based text-to-image neural network, in Pytorch

## Install

```bash
$ pip install parti-pytorch
```

## Usage

```python
import torch
from parti_pytorch import Parti
from parti_pytorch.vit_vqgan import VitVQGanVAE

# first instantiate your ViT VQGan VAE
# a VQGan VAE made of transformers

vit_vae = VitVQGanVAE(
dim = 512, # dimensions
image_size = 256, # target image size
num_layers = 4 # number of layers
).cuda()

images = torch.randn(4, 3, 256, 256).cuda()

loss = vit_vae(images, return_loss = True)
loss.backward()

# do the above with as many images as possible
# then you plugin the ViT VqGan VAE into your Parti

parti = Parti(
vae = vit_vae, # vit vqgan vae
dim = 512, # model dimension
depth = 8, # depth
dim_head = 64, # attention head dimension
heads = 8, # attention heads
dropout = 0., # dropout
ff_mult = 4, # feedforward expansion factor
t5_name = 't5-large', # name of your T5
)

# ready your training text and images

texts = [
'a child screaming at finding a worm within a half-eaten apple',
'lizard running across the desert on two feet',
'waking up to a psychedelic landscape',
'seashells sparkling in the shallow waters'
]

images = torch.randn(4, 3, 256, 256).cuda()

# feed it into your parti instance, with return_loss set to True

loss = parti(
texts = texts,
images = images,
return_loss = True
)

loss.backward()

# do this for a long time on much data
# then...

images = parti.generate(texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles'
], cond_scale = 3.)

# (3, 3, 256, 256) <-- save your images
```

## Appreciation

- <a href="https://stability.ai/">StabilityAI</a> for the sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.

- <a href="https://huggingface.co/">🤗 Huggingface</a> for the transformers library and the ease for encoding text with T5 language model


## Citations

```bibtex
Expand Down
250 changes: 246 additions & 4 deletions parti_pytorch/parti_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from typing import List
from functools import partial

import torch
import torch.nn.functional as F
from torch import nn, einsum
import torchvision.transforms as T

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

from parti_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME

# helper functions

def exists(val):
Expand All @@ -13,6 +19,45 @@ def exists(val):
def default(val, d):
return val if exists(val) else d

def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner

# sampling helpers

def log(t, eps = 1e-20):
return torch.log(t + eps)

def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))

def gumbel_sample(t, temperature = 1., dim = -1):
return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)

def top_k(logits, thres = 0.5):
num_logits = logits.shape[-1]
k = max(int((1 - thres) * num_logits), 1)
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs

# classifier free guidance functions

def prob_mask_like(shape, prob, device):
if prob == 1:
return torch.ones(shape, device = device, dtype = torch.bool)
elif prob == 0:
return torch.zeros(shape, device = device, dtype = torch.bool)
else:
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob

# normalization

class LayerNorm(nn.Module):
Expand Down Expand Up @@ -41,8 +86,8 @@ def FeedForward(dim, mult = 4, dropout = 0.):
class Attention(nn.Module):
def __init__(
self,
*,
dim,
*,
context_dim = None,
dim_head = 64,
heads = 8,
Expand Down Expand Up @@ -128,9 +173,206 @@ class Parti(nn.Module):
def __init__(
self,
*,
dim
dim,
depth,
dim_head = 64,
heads = 8,
dropout = 0.,
ff_mult = 4,
vae = None,
vae_image_size = None,
vae_codebook_size = None,
t5_name = DEFAULT_T5_NAME,
text_embed_dim = None,
cond_drop_prob = 0.25
):
super().__init__()

def forward(self, x):
return x
# text conditioning

text_embed_dim = default(text_embed_dim, get_encoded_dim(t5_name))
self.encode_texts = partial(t5_encode_text, name = t5_name)

assert cond_drop_prob > 0.
self.cond_drop_prob = cond_drop_prob # classifier free guidance for transformers - @crowsonkb

# vae and image handling

assert exists(vae) ^ exists(vae_codebook_size)
self.vae = vae

codebook_size = default(vae_codebook_size, vae.codebook_size)
image_size = default(vae_image_size, vae.image_size)
self.image_token_embed = nn.Embedding(codebook_size + 1, dim) # + 1 for start token (or padding)

self.image_encoded_dim = vae.get_encoded_fmap_size(image_size)

# projecting to logits

self.init_norm = LayerNorm(dim)

self.layers = nn.ModuleList([])

for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, causal = True, dim_head = dim_head, heads = heads, dropout = dropout),
Attention(dim, context_dim = text_embed_dim, dim_head = dim_head, heads = heads, dropout = dropout),
FeedForward(dim, mult = ff_mult, dropout = dropout)
]))

self.final_norm = LayerNorm(dim)

self.to_logits = nn.Linear(dim, codebook_size, bias = False)
self.to_logits.weight = self.image_token_embed.weight

# default device

if exists(vae):
self.to(next(vae.parameters()).device)

@torch.no_grad()
@eval_decorator
def generate(
self,
texts,
*,
cond_scale = 3.,
filter_thres = 0.9,
temperature = 1.,
return_pil_images = False
):
device = next(self.parameters()).device

text_token_embeds, text_mask = self.encode_texts(texts)
text_token_embeds.to(device)
text_mask.to(device)

batch = text_token_embeds.shape[0]

image_seq_len = self.image_encoded_dim ** 2

image_tokens = torch.empty((batch, 0), device = device, dtype = torch.long)

for _ in range(image_seq_len + 1):
logits = self.forward_with_cond_scale(
text_token_embeds = text_token_embeds,
text_mask = text_mask,
image_token_ids = image_tokens
)[:, -1]

filtered_logits = top_k(logits, thres = filter_thres)
sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)

sampled = rearrange(sampled, 'b -> b 1')

sampled = sampled - 1
sampled = sampled.masked_fill(sampled == -1, 0)
image_tokens = torch.cat((image_tokens, sampled), dim = -1)

image_tokens = image_tokens[:, 1:] # remove start token

image_tokens = rearrange(image_tokens, 'b (h w) -> b h w', h = self.image_encoded_dim)

if not exists(self.vae):
return image_tokens

with torch.no_grad():
fmap = self.vae.get_fmap_from_codebook(image_tokens)
images = self.vae.decode(fmap)

if not return_pil_images:
return images

pil_images = list(map(T.ToPILImage(), images.unbind(dim = 0)))
return pil_images

def forward_with_cond_scale(self, *args, cond_scale = 3, **kwargs):
logits = self.forward(*args, cond_drop_prob = 0., **kwargs)

if cond_scale == 1:
return logits

null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale

def forward(
self,
texts: List[str] = None,
text_token_embeds = None,
text_mask = None,
images = None,
image_token_ids = None,
cond_drop_prob = None,
return_loss = False
):
assert exists(texts) ^ exists(text_token_embeds)
assert exists(images) ^ exists(image_token_ids)
cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)

# encoding images

if not exists(image_token_ids):
assert exists(self.vae), 'vae must be given if you want to encode the image live'

with torch.no_grad():
_, image_token_ids, _ = self.vae.encode(images, return_indices_and_loss = True)

image_token_ids = rearrange(image_token_ids, 'b ... -> b (...)')

if exists(image_token_ids):
image_token_ids = image_token_ids + 1

image_token_ids = F.pad(image_token_ids, (1, 0), value = 0) # add start token [0]

if return_loss:
assert image_token_ids.shape[-1] > 1, 'not enough image tokens given to return a loss'
image_token_ids, labels = image_token_ids[:, :-1], image_token_ids[:, 1:]

image_token_emb = self.image_token_embed(image_token_ids)

batch, device = image_token_emb.shape[0], image_token_emb.device

# text

if not exists(text_token_embeds):
with torch.no_grad():
text_token_embeds, text_mask = self.encode_texts(texts)

if not exists(text_mask):
text_mask = torch.ones(text_token_embeds.shape[:2], dtype = torch.bool)

text_token_embeds.to(device)
text_mask.to(device)

# classifier free guidance conditional dropout

if cond_drop_prob > 0:
keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
text_mask &= rearrange(keep_mask, 'b -> b 1')

# attend

x = image_token_emb
x = self.init_norm(x)

for self_attn, cross_attn, ff in self.layers:
x = self_attn(x) + x
x = cross_attn(x, context = text_token_embeds, context_mask = text_mask) + x
x = ff(x) + x

x = self.final_norm(x)

# to logits

logits = self.to_logits(x)

if not return_loss:
return logits

loss = F.cross_entropy(
rearrange(logits, 'b n c -> b c n'),
labels,
ignore_index = 0
)

return loss
5 changes: 5 additions & 0 deletions parti_pytorch/vit_vqgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,11 @@ def load_state_dict(self, *args, **kwargs):
def codebook(self):
return self.vq.codebook

def get_fmap_from_codebook(self, indices):
codes = self.codebook[indices]
fmap = self.vq.project_out(codes)
return rearrange(fmap, 'b h w c -> b c h w')

def encode(self, fmap, return_indices_and_loss = True):
fmap = self.enc_dec.encode(fmap)

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
'ema-pytorch',
'torch>=1.6',
'torchvision',
'transformers',
'vector-quantize-pytorch'
],
classifiers=[
Expand Down

0 comments on commit 31a24f3

Please sign in to comment.