In [None]:
import os
!pip install -q dalle-mini
!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git
# Model references
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
import random

# dalle-mega
DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest"
DALLE_COMMIT_ID = None

# VQGAN model
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"

jax.local_device_count()
# Load dalle-mini
model, params = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False)

# Load VQGAN
vqgan, vqgan_params = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False)

if(jax.device_count('gpu')) > 1:
    raise Exception("More than one")

params = replicate(params)
vqgan_params = replicate(vqgan_params)

from functools import partial

# model inference
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale):
    return model.generate(**tokenized_prompt, prng_key=key, params=params, top_k=top_k, top_p=top_p, temperature=temperature, condition_scale=condition_scale,)


# decode image
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
    return vqgan.decode_code(indices, params=params)

# create a random key
seed = random.randint(0, 2 ** 32 - 1)
key = jax.random.PRNGKey(seed)

processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)

In [7]:
prompts_pre = [
    "Microsoft",
    "A golden retriever",
    "New York City",
    "A moogle"
]
prompts_post = [
    "in the style of league of legends splash art, trending on ArtStation HQ, Digital art",
    "in a snowglobe",
    "zoomed out",
    "detailed miniature replica",
    "(Source: Getty Images)",
    "infinitely detailed",
    "blender render",
    "god rays sun rays extremely hd studio footage",
    "award winning photograph",
    "dramatic lighting, Sigma 85mm f_1.4, 8k",
    "after effects",
    "professional photograph",
    "3d render",
    "cartoon",
    "pencil sketch",
    "photoshopped",
    "anime",
    "in the style of anime",
    "Trending on artstation",
    "Illustration",
    "watercolor",
    "painted in the style of Zdzisław Beksiński",
    "in the style of starwars",
    "impressionist style",
    "drawing",
    "black and white painting, creepy",
    "security camera footage",
    "pixel art",
    "sculpture",
    "baroque artwork",
    "on an alien planet",
    "screenshot",
    "monster"
]

# number of predictions per prompt
n_predictions = 4

# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)#
min_k = 1
min_p = .01
min_temp = .01
min_cond = 1

base_k = 10
base_p = 5
base_temp = .5
base_cond = 10

max_k = 4999
max_p = 5000
max_temp = .98
max_cond = 100
gen_top_k = []
gen_top_p = []
temperature = []
cond_scale = []

for i in range(0, n_predictions):
    gen_top_k.append(None)
    #gen_top_k.append(random.randint(1, max_k))
    gen_top_p.append(None)
    #gen_top_p.append(random.randint(1, max_p))
    temperature.append(None)
    #temperature.append(random.randint(1,99)*.01)
    cond_scale.append(base_cond)
    #cond_scale.append(random.randint(0,max_cond))
    #gen_top_k.append(floor(max_k - (min_k + ((i / n_predictions) * max_k ))))
    #gen_top_p.append(min_p + ((i / n_predictions) * max_p))
    #temperature.append(min_temp + ((i / n_predictions) * max_temp))
    #cond_scale.append(min_cond + ((i / n_predictions) * max_cond))
    #print(f"{gen_top_p[i]} {gen_top_k[i]} {temperature[i]} {cond_scale[i]}")

In [8]:
from flax.training.common_utils import shard_prng_key
import numpy as np
import os
from PIL import Image
from tqdm.notebook import trange
import wandb

project = 'dalle-mini-tables-prompt-tweak'
run = wandb.init(project=project)
columns = ["captions"] + [f"img"]  + ["k"] + ["p"] + ["t"] + ["c"] + ["seed"]
gen_table = wandb.Table(columns=columns)

allprompts = []
for prompt in prompts_pre:
    allprompts.append(prompt)
    for post in prompts_post:
        allprompts.append(f"{prompt} {post}")

while len(allprompts) > 0:
    prompts = []
    print(f"Remaining: {len(allprompts)}")
    while len(prompts) < 8 and len(allprompts) > 0:
        p = allprompts.pop()
        prompts.append(p)

    if len(prompts) > 0:
        tokenized_prompts = processor(prompts)
        tokenized_prompt = replicate(tokenized_prompts)

        print(f"Prompts: {prompts}\n")
        # generate images
        images = []
        for i in trange(max(n_predictions // jax.device_count(), 0)):
            # get a new key
            key, subkey = jax.random.split(key)
            # generate images

            #ams = namedtuple('ams', 'k p temp cond')
            #gen_params = ams(gen_top_p[i], gen_top_k[i], temperature[i], cond_scale[i])
            encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), params, gen_top_k[i], gen_top_p[i], temperature[i], cond_scale[i])
            # remove BOS
            encoded_images = encoded_images.sequences[..., 1:]
            # decode images
            decoded_images = p_decode(encoded_images, vqgan_params)
            decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
            for j, decoded_img in enumerate(decoded_images):
                img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
                images.append(img)
                path=f'{os.getcwd()}/../pictures/{prompts[j].replace(",","").replace(" ","")}{i}.jpg'
                img.save(path)
                gen_table.add_data(prompts[j], wandb.Image(img), gen_top_k[i], gen_top_p[i], temperature[i], cond_scale[i], seed)

            # Log the Table to W&B dashboard.

wandb.log({"Generated Images": gen_table})
# Close the W&B run.
run.finish()


/home/cp/code/dalle-mini/tools/../pictures/Microsoftonanalienplanet2.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftbaroqueartwork2.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftsculpture2.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftpixelart2.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftsecuritycamerafootage2.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftblackandwhitepaintingcreepy2.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftdrawing2.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftimpressioniststyle2.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftonanalienplanet3.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftbaroqueartwork3.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftsculpture3.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftpixelart3.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftsecuritycamerafootage3.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftblackand

  0%|          | 0/4 [00:00<?, ?it/s]

/home/cp/code/dalle-mini/tools/../pictures/Microsoftinthestyleofstarwars0.jpg
/home/cp/code/dalle-mini/tools/../pictures/MicrosoftpaintedinthestyleofZdzisławBeksiński0.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftwatercolor0.jpg
/home/cp/code/dalle-mini/tools/../pictures/MicrosoftIllustration0.jpg
/home/cp/code/dalle-mini/tools/../pictures/MicrosoftTrendingonartstation0.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftinthestyleofanime0.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftanime0.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftphotoshopped0.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftinthestyleofstarwars1.jpg
/home/cp/code/dalle-mini/tools/../pictures/MicrosoftpaintedinthestyleofZdzisławBeksiński1.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftwatercolor1.jpg
/home/cp/code/dalle-mini/tools/../pictures/MicrosoftIllustration1.jpg
/home/cp/code/dalle-mini/tools/../pictures/MicrosoftTrendingonartstation1.jpg
/home/cp/code/da

  0%|          | 0/4 [00:00<?, ?it/s]

/home/cp/code/dalle-mini/tools/../pictures/Microsoftpencilsketch0.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftcartoon0.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoft3drender0.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftprofessionalphotograph0.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftaftereffects0.jpg
/home/cp/code/dalle-mini/tools/../pictures/MicrosoftdramaticlightingSigma85mmf_1.48k0.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftawardwinningphotograph0.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftgodrayssunraysextremelyhdstudiofootage0.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftpencilsketch1.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftcartoon1.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoft3drender1.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftprofessionalphotograph1.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftaftereffects1.jpg
/home/cp/code/dalle-mini/tools/.

  0%|          | 0/4 [00:00<?, ?it/s]

/home/cp/code/dalle-mini/tools/../pictures/Microsoftblenderrender0.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftinfinitelydetailed0.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoft(Source:GettyImages)0.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftdetailedminiaturereplica0.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftzoomedout0.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftinasnowglobe0.jpg
/home/cp/code/dalle-mini/tools/../pictures/MicrosoftinthestyleofleagueoflegendssplasharttrendingonArtStationHQDigitalart0.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoft0.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftblenderrender1.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftinfinitelydetailed1.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoft(Source:GettyImages)1.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftdetailedminiaturereplica1.jpg
/home/cp/code/dalle-mini/tools/../pictures/Microsoftzoomedout1.jp

VBox(children=(Label(value='43.681 MB of 43.683 MB uploaded (29.099 MB deduped)\r'), FloatProgress(value=0.999…