forked from borisdayma/dalle-mini
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
179 lines (156 loc) · 6.02 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_force_compilation_parallelism=1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import tempfile
import random
from typing import Optional, List
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard, shard_prng_key
from functools import partial
import numpy as np
from PIL import Image
from transformers import CLIPProcessor, FlaxCLIPModel
from tqdm.notebook import trange
from cog import BasePredictor, Path, Input, BaseModel
from src.dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
class ModelOutput(BaseModel):
clip_score: Optional[float]
image: Path
class Predictor(BasePredictor):
def setup(self):
# Load dalle-mini
DALLE_MODEL = "checkpoints/dalle_mini_mega-1-fp16"
DALLE_COMMIT_ID = None
self.model, self.params = DalleBart.from_pretrained(
DALLE_MODEL,
revision=DALLE_COMMIT_ID,
dtype=jnp.float16,
_do_init=False,
ignore_mismatched_sizes=True,
)
self.processor = DalleBartProcessor.from_pretrained(
DALLE_MODEL, revision=DALLE_COMMIT_ID
)
# Load VQGAN
VQGAN_REPO = "checkpoints/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
self.vqgan, self.vqgan_params = VQModel.from_pretrained(
VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False
)
# Load CLIP
CLIP_REPO = "checkpoints/openai_clip-vit-base-patch32"
CLIP_COMMIT_ID = None
self.clip, self.clip_params = FlaxCLIPModel.from_pretrained(
CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False
)
self.clip_processor = CLIPProcessor.from_pretrained(
CLIP_REPO, revision=CLIP_COMMIT_ID
)
def predict(
self,
prompt: str = Input(
default="sunset over a lake in the mountains",
description="Prompt for generating image.",
),
show_clip_score: bool = Input(
default=False,
description="CLIP score will be displayed for each generated image if set to True.",
),
n_predictions: int = Input(
default=8, description="Number of images to generate.", ge=1, le=8
),
) -> List[ModelOutput]:
model, params = self.model, self.params
vqgan, vqgan_params = self.vqgan, self.vqgan_params
clip, clip_params = self.clip, self.clip_params
params = replicate(params, device)
vqgan_params = replicate(vqgan_params, device)
clip_params = replicate(clip_params, device)
# exit()
print("Tokenizing prompt...")
tokenized_prompt = self.processor([prompt])
tokenized_prompt = replicate(tokenized_prompt)
# model inference
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(
tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale
):
return model.generate(
**tokenized_prompt,
prng_key=key,
params=params,
top_k=top_k,
top_p=top_p,
temperature=temperature,
condition_scale=condition_scale,
)
# decode image
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
return vqgan.decode_code(indices, params=params)
# score images
@partial(jax.pmap, axis_name="batch")
def p_clip(inputs, params):
logits = clip(params=params, **inputs).logits_per_image
return logits
# create a random key
seed = random.randint(0, 2 ** 32 - 1)
key = jax.random.PRNGKey(seed)
# We can customize generation parameters
gen_top_k = None
gen_top_p = None
temperature = None
cond_scale = 3.0
images = []
final_output = []
# generate images
print("Generating images...")
for _ in range(n_predictions):
# get a new key
key, subkey = jax.random.split(key)
# generate images
encoded_images = p_generate(
tokenized_prompt,
shard_prng_key(subkey),
params,
gen_top_k,
gen_top_p,
temperature,
cond_scale,
)
# remove BOS
encoded_images = encoded_images.sequences[..., 1:]
# decode images
decoded_images = p_decode(encoded_images, vqgan_params)
decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
for decoded_img in decoded_images:
img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
images.append(img)
if not show_clip_score:
for i, img in enumerate(images):
out_path = Path(tempfile.mkdtemp()) / f"output_{i}.png"
img.save(str(out_path))
final_output.append(ModelOutput(image=out_path))
else:
print("Ranking images by CLIP score...")
# get clip scores
clip_inputs = self.clip_processor(
text=[prompt],
images=images,
return_tensors="np",
padding="max_length",
max_length=77,
truncation=True,
).data
logits = p_clip(shard(clip_inputs), clip_params)
logits = logits.squeeze().flatten()
rank_list = list(logits.argsort()[::-1])
for i, idx in enumerate(rank_list):
out_path = Path(tempfile.mkdtemp()) / f"output{idx}.png"
images[idx].save(str(out_path))
clip_score = logits[idx]
final_output.append(ModelOutput(clip_score=clip_score, image=out_path))
return final_output