[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/DeepFloyd-IF-colab/blob/main/DeepFloyd-IF-I-M-v1.0-test.ipynb)

In [None]:
# https://www.kaggle.com/code/shonenkov/deepfloyd-if-4-3b-generator-of-pictures modified

!pip install -q torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 torchtext==0.14.1 torchdata==0.5.1 --extra-index-url https://download.pytorch.org/whl/cu116 -U
# !pip install -q xformers==0.0.16 triton==2.0.0 -U
!pip install -q deepfloyd-if==1.0.1 
!pip install -q git+https://github.com/openai/CLIP.git --no-deps
# !git clone https://huggingface.co/bakedpotat/prompts
!pip install -q -U diffusers~=0.16 transformers~=4.28 safetensors~=0.3 sentencepiece~=0.1 accelerate~=0.18 bitsandbytes~=0.38 huggingface_hub

get_ipython().kernel.do_shutdown(True)

In [None]:
from transformers import T5EncoderModel

hf_token = "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"

text_encoder = T5EncoderModel.from_pretrained(
    "DeepFloyd/IF-I-L-v1.0",
    load_in_8bit=True,
    subfolder="text_encoder",
    device_map="auto",
    variant="8bit",
    use_auth_token=hf_token
)

from diffusers import DiffusionPipeline

pipe = DiffusionPipeline.from_pretrained(
    "DeepFloyd/IF-I-L-v1.0", 
    text_encoder=text_encoder,
    unet=None, 
    device_map="auto",
    safety_checker=None,
    use_auth_token=hf_token
)

prompt = 'a photograph of an astronaut riding a horse holding a sign that says "Pixel\'s in space"'
prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)

import numpy as np
prompt_embeds = prompt_embeds.cpu()
negative_embeds = negative_embeds.cpu()
np.save('prompt.npy', prompt_embeds)
np.save('negative.npy', negative_embeds)

get_ipython().kernel.do_shutdown(True)

In [None]:
import os
os.environ['FORCE_MEM_EFFICIENT_ATTN'] = "1"
import sys
import random

import torch
import numpy as np

from deepfloyd_if.modules import IFStageI, IFStageII, StableStageIII

hf_token = "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"
device = 'cuda:0'
if_I = IFStageI('IF-I-L-v1.0', device=device, hf_token=hf_token)
if_II = IFStageII('IF-II-L-v1.0', device=device, hf_token=hf_token)
if_III = StableStageIII('stable-diffusion-x4-upscaler', device=device)

In [None]:
prompts, t5_embs = [], []
prompt = 'a photograph of an astronaut riding a horse holding a sign that says "Pixel\'s in space"'
t5_numpy = np.load(f'/content/prompt.npy')
t5_embs.append(torch.from_numpy(t5_numpy).unsqueeze(0))
prompts.append(prompt)

t5_embs = torch.cat(t5_embs).to(device)
t5_embs.shape

# Stage-I: 64px

seed = 42

stageI_generations, _meta = if_I.embeddings_to_image(
    t5_embs, seed=seed, batch_repeat=1,
    dynamic_thresholding_p=0.95,
    dynamic_thresholding_c=1.5,
    guidance_scale=7.0,
    sample_loop='ddpm',
    sample_timestep_respacing='smart50',
    image_size=64,
    aspect_ratio="1:1",
    progress=True,
    disable_watermark=True,
)
pil_images_I = if_I.to_images(stageI_generations, disable_watermark=True)
if_I.show(pil_images_I)

In [None]:
# Stage-II: 64px --> 256 px

stageII_generations, _meta = if_II.embeddings_to_image(
    stageI_generations,
    t5_embs, seed=seed, batch_repeat=1,
    dynamic_thresholding_p=0.95,
    dynamic_thresholding_c=1.0,
    aug_level=0.25,
    guidance_scale=4.0,
    image_scale=4.0,
    sample_loop='ddpm',
    sample_timestep_respacing='50',
    progress=True,
)
pil_images_II = if_II.to_images(stageII_generations, disable_watermark=True)
if_II.show(pil_images_II)

In [None]:
# Stage-III: 256px --> 1024px

stageIII_generations = []
for idx in range(len(stageII_generations)):
    if_III_kwargs = {}
    if_III_kwargs['prompt'] = prompts[idx:idx+1]
    if_III_kwargs['low_res'] = stageII_generations[idx:idx+1]
    if_III_kwargs['seed'] = seed
    if_III_kwargs['t5_embs'] = t5_embs[idx:idx+1]
    _stageIII_generations, _meta = if_III.embeddings_to_image(**if_III_kwargs)
    stageIII_generations.append(_stageIII_generations)

stageIII_generations = torch.cat(stageIII_generations, 0)
pil_images_III = if_III.to_images(stageIII_generations, disable_watermark=True)

for idx in range(len(prompts)):
    pil_img, prompt = pil_images_III[idx], prompts[idx]
    pil_img.save(f'{idx}.png')
    if_I.show([pil_img],size=14)
    print(prompt, '\n'*3)