In [2]:
# Mount drive 
from google.colab import drive
drive.mount('/content/drive')
# Update directory with your path to unets.py file 
!cp drive/MyDrive/Martin/unets.py .

Mounted at /content/drive


In [3]:
import unets
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [25]:
general_n_samples = 2

In [9]:
!pip install gdown
import gdown

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


# Diffusion Model

In [21]:
url_diffusion = "https://drive.google.com/file/d/1underEMIGRU08LFQaRW0wIYyBGe1KL8c/view?usp=sharing"
output = "diffusion_model.pt"
gdown.download(url=url_diffusion, output=output, quiet=False, fuzzy=True)

Downloading...
From: https://drive.google.com/uc?id=1underEMIGRU08LFQaRW0wIYyBGe1KL8c
To: /content/diffusion_model.pt
100%|██████████| 57.9M/57.9M [00:02<00:00, 26.4MB/s]


'diffusion_model.pt'

In [22]:
def gather(consts: torch.Tensor, t: torch.Tensor):
    """Gather consts for t and reshape to feature map shape"""
    c = consts.gather(-1, t)
    return c.reshape(-1, 1, 1, 1)

In [23]:
from typing import Tuple, Optional
import torch
import torch.nn.functional as F
import torch.utils.data
from torch import nn

class DenoiseDiffusion:
# eps_model = UNet model => ϵθ(xt,t)  |  n_steps = t
  def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
        super().__init__()
        self.eps_model = eps_model

        # Create β1,…,βT linearly increasing variance schedule (diffusion rate)
        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)

        self.alpha = 1. - self.beta
        # ∏ alpha 
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        # T
        self.n_steps = n_steps

        self.sigma2 = self.beta
#### Get q(x_t|x_0) distribution
  def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

        # sqrt(gather alpha_bar by t) * x_0
        mean = gather(self.alpha_bar, t) ** 0.5 * x0
       
        var = 1 - gather(self.alpha_bar, t)
        #
        return mean, var
#### Sample from q(x_t|x_0)
  def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):

        # ϵ∼N(0,I)
        if eps is None:
            eps = torch.randn_like(x0)

        # get q(x_t|x_0)
        mean, var = self.q_xt_x0(x0, t)
        # Sample from q(x_t|x_0)
        return mean + (var ** 0.5) * eps

#### Sample from pθ(xt−1∣xt)
  def p_sample(self, xt: torch.Tensor, t: torch.Tensor, labels: torch.Tensor):
        
        # ϵθ(xt,t)
        eps_theta = self.eps_model(xt, t, y=labels)
        # gather alpha_bar by t
        alpha_bar = gather(self.alpha_bar, t)
        # $\alpha_t$
        alpha = gather(self.alpha, t)
          
        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5 # αt=1−βt  
        
        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
        
        # gather σ^2 by t
        var = gather(self.sigma2, t)

        # ϵ∼N(0,I)
        eps = torch.randn(xt.shape, device=xt.device)
        # Sample
        return mean + (var ** .5) * eps
#### Simplified Loss
  def loss(self, x0: torch.Tensor, labels: torch.Tensor, noise: Optional[torch.Tensor] = None):
        # Get batch size
        batch_size = x0.shape[0]
        # Get random t for each sample in the batch
        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)

        # ϵ∼N(0,I)
        if noise is None:
            noise = torch.randn_like(x0)

        # Sample x_t for q(x_t|x_0)
        xt = self.q_sample(x0, t, eps=noise)

        eps_theta = self.eps_model(xt, t, y=labels)

        # MSE loss
        return F.mse_loss(noise, eps_theta)

In [26]:
from typing import List
import numpy as np
import torch
import torch.utils.data
import torchvision
from PIL import Image


class Configs():

    device: torch.device = device

    # U-Net model for ϵθ(xt,t)
    eps_model: unets
    # Diffusion algorithm
    diffusion: DenoiseDiffusion

    # Number of channels in the image
    image_channels: int = 1
    # Image size
    image_size: int = 28

    # Number of time steps T
    n_steps: int = 1000
    # Batch size
    batch_size: int = 64
    # Number of samples to generate
    n_samples: int = general_n_samples
    # Learning rate
    learning_rate: float = 2e-5

    # Number of training epochs
    epochs: int = 5

    # Dataset
    dataset: torch.utils.data.Dataset
    # Dataloader
    data_loader: torch.utils.data.DataLoader

    # Adam optimizer
    optimizer: torch.optim.Adam

    def init(self):
        # Create U-Net model for ϵθ(xt,t)
        self.eps_model = UNet_model

        # Create DDPM
        self.diffusion = DenoiseDiffusion(
            eps_model=self.eps_model,
            n_steps=self.n_steps,
            device=self.device,
        )

        # Create optimizer
        self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)

        

# Sample images
    def sample(self, items=[]):

        with torch.no_grad():
            # xT∼p(xT)=N(xT;0,I)
            x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
                            device=self.device)
            if not items or len(items) != self.n_samples:
              labels = torch.randint(10, (len(x),), dtype=torch.int64).to(self.device)
            else:
              labels = torch.tensor(items).to(self.device)
            
            # Remove noise for T steps
            for t_ in range(self.n_steps):
                # t
                t = self.n_steps - t_ - 1
                # Sample from pθ(xt−1∣xt)
                x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long), labels)

            # Log samples
            
            import torchvision
            # create grid of images
            img_grid = torchvision.utils.make_grid(x.cpu().view(self.n_samples, 1, 28, 28), normalize = True)
            
            # write to tensorboard
            from torchvision.utils import save_image
            values_label = []
            values_label_str = ""
            for i in range(self.n_samples):
              values_label.append(labels[i].item())
              if(i!=0):
                values_label_str = values_label_str + "_"
              values_label_str = values_label_str + str(labels[i].item())
            return x.cpu().view(self.n_samples, 1, 28, 28).numpy()

# Init Diffusion Model

In [27]:
UNet_model = unets.UNet(
        image_size=28,
        in_channels=1,
        out_channels=1,
        num_classes=10
    ).to(device)

# Fill architecture with the trained weights. Change with your path to the model
UNet_model.load_state_dict(torch.load("diffusion_model.pt"))

<All keys matched successfully>

In [28]:
configs = Configs()
configs.init()

# GLIDE Classifier-free guidance

In [45]:
!pip install git+https://github.com/openai/glide-text2im

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/openai/glide-text2im
  Cloning https://github.com/openai/glide-text2im to /tmp/pip-req-build-4j2f9gij
  Running command git clone -q https://github.com/openai/glide-text2im /tmp/pip-req-build-4j2f9gij
Building wheels for collected packages: glide-text2im
  Building wheel for glide-text2im (setup.py) ... [?25l[?25hdone
  Created wheel for glide-text2im: filename=glide_text2im-0.0.0-py3-none-any.whl size=1953663 sha256=04ecca4921f1c7dce17b56724650842a0e78dd2f5e5c9678b72ff9a3572f1a03
  Stored in directory: /tmp/pip-ephem-wheel-cache-u3sa5xce/wheels/b4/36/07/46711fd6462da277046c6720504e61546b6e32adc0293abc96
Successfully built glide-text2im
Installing collected packages: glide-text2im
Successfully installed glide-text2im-0.0.0


In [46]:
url_glide = "https://drive.google.com/file/d/1KoLcgRm_kgAfE9BzKJGrPwVZQ-920Cdw/view?usp=sharing"
output = "glide_fine-tuning.pt"
gdown.download(url=url_glide, output=output, quiet=False, fuzzy=True)

Downloading...
From: https://drive.google.com/uc?id=1KoLcgRm_kgAfE9BzKJGrPwVZQ-920Cdw
To: /content/glide_fine-tuning.pt
100%|██████████| 1.54G/1.54G [00:36<00:00, 42.2MB/s]


'glide_fine-tuning.pt'

In [47]:
from PIL import Image
from IPython.display import display
import torch as th

from glide_text2im.download import load_checkpoint
from glide_text2im.model_creation import (
    create_gaussian_diffusion,
    create_model_and_diffusion,
    model_and_diffusion_defaults,
    model_and_diffusion_defaults_upsampler
)

In [48]:
class ConfigsGLIDE():
  batch_size = 1
  guidance_scale = 3.0
  model = None
  # Tune this parameter to control the sharpness of 256x256 images. 
  # A value of 1.0 is sharper, but sometimes results in grainy artifacts.
  upsample_temp = 0.997

  def init(self, path="glide_fine-tuning.pt"):
    # Create base model.
    self.options = model_and_diffusion_defaults()
    has_cuda = th.cuda.is_available()
    fp16 = has_cuda
    self.options['use_fp16'] = has_cuda and fp16
    self.options['timestep_respacing'] = '27' # use 100 diffusion steps for fast sampling
    self.model, self.diffusion = create_model_and_diffusion(**self.options)
    self.model.eval()
    if has_cuda:
        self.model.convert_to_fp16()
    self.model.to(device)
        
    self.model.load_state_dict(torch.load(path), strict=False)
    print('total base parameters', sum(x.numel() for x in self.model.parameters()))
    # Create upsampler model.
    self.options_up = model_and_diffusion_defaults_upsampler()
    self.options_up['use_fp16'] = has_cuda and fp16
    self.options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling
    self.model_up, self.diffusion_up = create_model_and_diffusion(**self.options_up)
    self.model_up.eval()
    if has_cuda:
        self.model_up.convert_to_fp16()
    self.model_up.to(device)
    self.model_up.load_state_dict(load_checkpoint('upsample', device))
    print('total upsampler parameters', sum(x.numel() for x in self.model_up.parameters()))

  # Create a classifier-free guidance sampling function
  def model_fn(self, x_t, ts, **kwargs):
      half = x_t[: len(x_t) // 2]
      combined = th.cat([half, half], dim=0)
      model_out = self.model(combined, ts, **kwargs)
      eps, rest = model_out[:, :3], model_out[:, 3:]
      cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
      half_eps = uncond_eps + self.guidance_scale * (cond_eps - uncond_eps)
      eps = th.cat([half_eps, half_eps], dim=0)
      return th.cat([eps, rest], dim=1)

  def sample(self, caption):
    ##############################
    # Sample from the base model #
    ##############################

    # Create the text tokens to feed to the model.
    tokens = self.model.tokenizer.encode(caption)
    tokens, mask = self.model.tokenizer.padded_tokens_and_mask(
        tokens, self.options['text_ctx']
    )

    # Create the classifier-free guidance tokens (empty)
    full_batch_size = self.batch_size * 2
    uncond_tokens, uncond_mask = self.model.tokenizer.padded_tokens_and_mask(
        [], self.options['text_ctx']
    )

    # Pack the tokens together into model kwargs.
    model_kwargs = dict(
        tokens=th.tensor(
            [tokens] * self.batch_size + [uncond_tokens] * self.batch_size, device=device
        ),
        mask=th.tensor(
            [mask] * self.batch_size + [uncond_mask] * self.batch_size,
            dtype=th.bool,
            device=device,
        ),
    )



    # Sample from the base model.
    self.model.del_cache()
    samples = self.diffusion.p_sample_loop(
        self.model_fn,
        (full_batch_size, 3, self.options["image_size"], self.options["image_size"]),
        device=device,
        clip_denoised=True,
        progress=True,
        model_kwargs=model_kwargs,
        cond_fn=None,
    )[:self.batch_size]
    self.model.del_cache()
    # Show the output
    
    return samples

  def upsample(self, image, caption):
      ##############################
      # Upsample the 64x64 samples #
      ##############################

      tokens = self.model_up.tokenizer.encode(caption)
      tokens, mask = self.model_up.tokenizer.padded_tokens_and_mask(
          tokens, self.options_up['text_ctx']
      )

      # Create the model conditioning dict.
      model_kwargs = dict(
          # Low-res image to upsample.
          low_res=((image+1)*127.5).round()/127.5 - 1,

          # Text tokens
          tokens=th.tensor(
              [tokens] * self.batch_size, device=device
          ),
          mask=th.tensor(
              [mask] * self.batch_size,
              dtype=th.bool,
              device=device,
          ),
      )

      # Sample from the base model.
      self.model_up.del_cache()
      up_shape = (self.batch_size, 3, self.options_up["image_size"], self.options_up["image_size"])
      up_samples = self.diffusion_up.ddim_sample_loop(
          self.model_up,
          up_shape,
          noise=th.randn(up_shape, device=device) * self.upsample_temp,
          device=device,
          clip_denoised=True,
          progress=True,
          model_kwargs=model_kwargs,
          cond_fn=None,
      )[:self.batch_size]
      self.model_up.del_cache()
      
      # Show the output
      return up_samples


# Init GLIDE Fine-Tuned

In [49]:
configs_glide = ConfigsGLIDE()
# Replace with your path of the fine-tuned version of the model
configs_glide.init()

total base parameters 385030726


  0%|          | 0.00/1.59G [00:00<?, ?iB/s]

total upsampler parameters 398361286


# DALL·E Mini & Mega

In [None]:
# Install required libraries
!pip install -q dalle-mini
!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git

[K     |████████████████████████████████| 197 kB 8.6 MB/s 
[K     |████████████████████████████████| 1.8 MB 91.4 MB/s 
[K     |████████████████████████████████| 175 kB 85.5 MB/s 
[K     |████████████████████████████████| 4.4 MB 58.8 MB/s 
[K     |████████████████████████████████| 235 kB 93.6 MB/s 
[K     |████████████████████████████████| 596 kB 77.3 MB/s 
[K     |████████████████████████████████| 145 kB 88.5 MB/s 
[K     |████████████████████████████████| 217 kB 81.2 MB/s 
[K     |████████████████████████████████| 51 kB 8.6 MB/s 
[K     |████████████████████████████████| 72 kB 705 kB/s 
[K     |████████████████████████████████| 6.6 MB 35.9 MB/s 
[K     |████████████████████████████████| 101 kB 13.3 MB/s 
[K     |████████████████████████████████| 181 kB 71.7 MB/s 
[K     |████████████████████████████████| 147 kB 55.6 MB/s 
[K     |████████████████████████████████| 63 kB 1.8 MB/s 
[?25h  Building wheel for emoji (setup.py) ... [?25l[?25hdone
  Building wheel for pathto

In [None]:
import numpy as np
from PIL import Image
from tqdm.notebook import trange

In [None]:
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
from transformers import CLIPProcessor, FlaxCLIPModel
from dalle_mini import DalleBartProcessor
import jax
from functools import partial
import jax.numpy as jnp
from functools import partial
from flax.training.common_utils import shard_prng_key
from flax.jax_utils import replicate
import random

In [None]:
jax.local_device_count()

1

In [None]:
from functools import partial

# model inference
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6, 7))
def p_generate(
    tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale, model
):
    return model.generate(
        **tokenized_prompt,
        prng_key=key,
        params=params,
        top_k=top_k,
        top_p=top_p,
        temperature=temperature,
        condition_scale=condition_scale,
    )


# decode image
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(2))
def p_decode(indices, params, vqgan):
    return vqgan.decode_code(indices, params=params)

In [None]:
class ConfigsDALLE():
  
  processor = None
  cond_scale = 10.0

  model = None
  params = None

  vqgan = None
  vqgan_params = None

  def init(self, model_type = "dalle-mini/dalle-mini/mini-1:v0"):
    # Model references

    # dalle-mega
    # DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest"  # can be wandb artifact or 🤗 Hub or local folder or google bucket
    DALLE_COMMIT_ID = None
    # VQGAN model
    VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
    VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"

    DALLE_MODEL = model_type
    self.processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
      
    # Load dalle-mini
    self.model, params = DalleBart.from_pretrained(
    DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False
    )   
    # Load VQGAN
    self.vqgan, vqgan_params = VQModel.from_pretrained(
    VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False
    )
    self.params = replicate(params)
    self.vqgan_params = replicate(vqgan_params)

  def sample(self, n_predictions, captions):
    tokenized_prompts = self.processor(captions)
    tokenized_prompt = replicate(tokenized_prompts)

    # We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)
    gen_top_k = None
    gen_top_p = None
    temperature = None
    cond_scale = 10.0    
    
    # generate images
    images = []
    for i in trange(max(n_predictions // jax.device_count(), 1)):
        # get a new key
        # create a random key
        seed = random.randint(0, 2**32 - 1)
        key = jax.random.PRNGKey(seed)
        key, subkey = jax.random.split(key)
        # generate images
        encoded_images = p_generate(
            tokenized_prompt,
            shard_prng_key(subkey),
            self.params,
            gen_top_k,
            gen_top_p,
            temperature,
            cond_scale, self.model
        )
        # remove BOS
        encoded_images = encoded_images.sequences[..., 1:]
        # decode images
        decoded_images = p_decode(encoded_images, self.vqgan_params, self.vqgan)
        decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
        for decoded_img in decoded_images:
            img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
            images.append(img)
            
            
    
    return images

# Init Dalle Mini and Mega

In [None]:
configs_dalle_mega = ConfigsDALLE()
# Introduce key where asked: 0d5bebe3d058dc7b017bd222be6b288a8d9c2873
# dalle-mega
configs_dalle_mega.init('dalle-mini/dalle-mini/mega-1-fp16:latest')

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Downloading large artifact mega-1-fp16:latest, 4938.53MB. 7 files... Done. 0:0:41.3


Downloading:   0%|          | 0.00/34.2M [00:00<?, ?B/s]

[34m[1mwandb[0m: Downloading large artifact mega-1-fp16:latest, 4938.53MB. 7 files... Done. 0:0:20.5
Some of the weights of DalleBart were initialized in float16 precision from the model checkpoint at /tmp/tmpi7ulvrjb:
[('lm_head', 'kernel'), ('model', 'decoder', 'embed_positions', 'embedding'), ('model', 'decoder', 'embed_tokens', 'embedding'), ('model', 'decoder', 'final_ln', 'bias'), ('model', 'decoder', 'layernorm_embedding', 'bias'), ('model', 'decoder', 'layernorm_embedding', 'scale'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_0', 'k_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_0', 'out_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_0', 'q_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_0', 'v_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_1', 'k_proj', 'kernel')

Downloading:   0%|          | 0.00/434 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/290M [00:00<?, ?B/s]

In [None]:
configs_dalle_mini = ConfigsDALLE()
# 0d5bebe3d058dc7b017bd222be6b288a8d9c2873
configs_dalle_mini.init()

[34m[1mwandb[0m: Downloading large artifact mini-1:v0, 1673.43MB. 7 files... Done. 0:0:18.7
[34m[1mwandb[0m: Downloading large artifact mini-1:v0, 1673.43MB. 7 files... Done. 0:0:10.3


# CLIP Guided Diffusion

In [40]:
# Install dependencies

!git clone https://github.com/openai/CLIP
!git clone https://github.com/crowsonkb/guided-diffusion
!pip install -e ./CLIP
!pip install -e ./guided-diffusion
!pip install lpips
!curl -OL https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt

fatal: destination path 'CLIP' already exists and is not an empty directory.
fatal: destination path 'guided-diffusion' already exists and is not an empty directory.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Obtaining file:///content/CLIP
Installing collected packages: clip
  Attempting uninstall: clip
    Found existing installation: clip 1.0
    Can't uninstall 'clip'. No files were found to uninstall.
  Running setup.py develop for clip
Successfully installed clip-1.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Obtaining file:///content/guided-diffusion
Installing collected packages: guided-diffusion
  Attempting uninstall: guided-diffusion
    Found existing installation: guided-diffusion 0.0.0
    Can't uninstall 'guided-diffusion'. No files were found to uninstall.
  Running setup.py develop for guided-diffusion
Successfully installed guided-diffusion-0.0.0
Looking in i

In [54]:
# Imports

import gc
import io
import math
import sys

from IPython import display
import lpips
from PIL import Image
import requests
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm

sys.path.append('./CLIP')
sys.path.append('./guided-diffusion')

import clip
from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults

In [55]:
class MakeCutouts(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
        return torch.cat(cutouts)

In [56]:
class ConfigsCLIPGuidedDiffusion():
  
  # Model settings

  model_config = model_and_diffusion_defaults()
  model_config.update({
        'attention_resolutions': '32, 16, 8',
        'class_cond': False,
        'diffusion_steps': 100,
        'rescale_timesteps': True,
        'timestep_respacing': '100',  # Modify this value to decrease the number of
                                      # timesteps.
        'image_size': 256,
        'learn_sigma': True,
        'noise_schedule': 'linear',
        'num_channels': 256,
        'num_head_channels': 64,
        'num_res_blocks': 2,
        'resblock_updown': True,
        'use_checkpoint': False,
        'use_fp16': True,
        'use_scale_shift_norm': True,
  })
  prompts = []
  image_prompts = []
  image_prompts = []
  batch_size = 1
  clip_guidance_scale = 1000  # Controls how much the image should look like the prompt.
  tv_scale = 150              # Controls the smoothness of the final output.
  range_scale = 50            # Controls how far out of range RGB values are allowed to be.
  cutn = 16
  n_batches = 1
  init_image = None   # This can be an URL or Colab local path and must be in quotes.
  skip_timesteps = 0  # This needs to be between approx. 200 and 500 when using an init image.
                      # Higher values make the output look more like the init.
  init_scale = 0      # This enhances the effect of the init image, a good value is 1000.
  seed = 0

  def init(self, path="256x256_diffusion_uncond.pt"):
    # Load models

    self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print('Using device:', device)

    self.model, self.diffusion = create_model_and_diffusion(**self.model_config)
    self.model.load_state_dict(torch.load(path, map_location='cpu'))
    self.model.requires_grad_(False).eval().to(device)
    if self.model_config['use_fp16']:
        self.model.convert_to_fp16()

    self.clip_model = clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device)
    self.clip_size = self.clip_model.visual.input_resolution
    self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                    std=[0.26862954, 0.26130258, 0.27577711])
    self.lpips_model = lpips.LPIPS(net='vgg').to(device)

  def fetch(self, url_or_path):
    if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
        r = requests.get(url_or_path)
        r.raise_for_status()
        fd = io.BytesIO()
        fd.write(r.content)
        fd.seek(0)
        return fd
    return open(url_or_path, 'rb')

  def parse_prompt(self, prompt):
    if prompt.startswith('http://') or prompt.startswith('https://'):
        vals = prompt.rsplit(':', 2)
        vals = [vals[0] + ':' + vals[1], *vals[2:]]
    else:
        vals = prompt.rsplit(':', 1)
    vals = vals + ['', '1'][len(vals):]
    return vals[0], float(vals[1])

  def tv_loss(self, input):
    """L2 total variation loss, as in Mahendran et al."""
    input = F.pad(input, (0, 1, 0, 1), 'replicate')
    x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
    y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
    return (x_diff**2 + y_diff**2).mean([1, 2, 3])

  def range_loss(self, input):
    return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])

  def spherical_dist_loss(self, x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)

  def cond_fn(self, x, t, out, y=None):
        n = x.shape[0]
        fac = self.diffusion.sqrt_one_minus_alphas_cumprod[self.cur_t]
        x_in = out['pred_xstart'] * fac + x * (1 - fac)
        clip_in = self.normalize(self.make_cutouts(x_in.add(1).div(2)))
        image_embeds = self.clip_model.encode_image(clip_in).float()
        dists = self.spherical_dist_loss(image_embeds.unsqueeze(1), self.target_embeds.unsqueeze(0))
        dists = dists.view([self.cutn, n, -1])
        losses = dists.mul(self.weights).sum(2).mean(0)
        tv_losses = self.tv_loss(x_in)
        range_losses = self.range_loss(out['pred_xstart'])
        loss = losses.sum() * self.clip_guidance_scale + tv_losses.sum() * self.tv_scale + range_losses.sum() * self.range_scale
        return -torch.autograd.grad(loss, x)[0]

  def sample(self, caption='a house'):
    self.prompts = [caption]
    if self.seed is not None:
        torch.manual_seed(self.seed)
    self.make_cutouts = MakeCutouts(self.clip_size, self.cutn)
    side_x = side_y = self.model_config['image_size']

    self.target_embeds, self.weights = [], []

    for prompt in self.prompts:
        txt, weight = self.parse_prompt(prompt)
        self.target_embeds.append(self.clip_model.encode_text(clip.tokenize(txt).to(device)).float())
        self.weights.append(weight)
    for prompt in self.image_prompts:
        path, weight = self.parse_prompt(prompt)
        img = Image.open(self.fetch(path)).convert('RGB')
        img = TF.resize(img, min(side_x, side_y, *img.size), transforms.InterpolationMode.LANCZOS)
        batch = self.make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
        embed = self.clip_model.encode_image(self.normalize(batch)).float()
        self.target_embeds.append(embed)
        self.weights.extend([weight / self.cutn] * self.cutn)

    self.target_embeds = torch.cat(self.target_embeds)
    self.weights = torch.tensor(self.weights, device=device)
    if self.weights.sum().abs() < 1e-3:
        raise RuntimeError('The weights must not sum to 0.')
    self.weights /= self.weights.sum().abs()

    init = None
    if self.init_image is not None:
        init = Image.open(self.fetch(self.init_image)).convert('RGB')
        init = init.resize((side_x, side_y), Image.LANCZOS)
        init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)

    
    self.cur_t = None

    

    if self.model_config['timestep_respacing'].startswith('ddim'):
        sample_fn = self.diffusion.ddim_sample_loop_progressive
    else:
        sample_fn = self.diffusion.p_sample_loop_progressive


    result = None
    for i in range(self.n_batches):
        self.cur_t = self.diffusion.num_timesteps - self.skip_timesteps - 1

        samples = sample_fn(
            self.model,
            (self.batch_size, 3, side_y, side_x),
            clip_denoised=False,
            model_kwargs={},
            cond_fn=self.cond_fn,
            progress=True,
            skip_timesteps=self.skip_timesteps,
            init_image=init,
            randomize_class=True,
            cond_fn_with_grad=True,
        )
        
        
        for j, sample in enumerate(samples):
            self.cur_t -= 1
            if j % 100 == 0 or self.cur_t == -1:
                print()
                for k, image in enumerate(sample['pred_xstart']):
                    filename = f'progress_{i * self.batch_size + k:05}.png'
                    result = image.add(1).div(2).clamp(0, 1)
                    
    return result

# Init CLIP-Guided Diffusion

In [57]:
configs_clip_guided = ConfigsCLIPGuidedDiffusion()
# Replace with your path to the model
configs_clip_guided.init()

Using device: cuda
Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]


  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "


Loading model from: /usr/local/lib/python3.7/dist-packages/lpips/weights/v0.1/vgg.pth


# Gradio

In [50]:
!pip install gradio

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting gradio
  Downloading gradio-3.1.1-py3-none-any.whl (5.6 MB)
[K     |████████████████████████████████| 5.6 MB 17.0 MB/s 
Collecting ffmpy
  Downloading ffmpy-0.3.0.tar.gz (4.8 kB)
Collecting httpx
  Downloading httpx-0.23.0-py3-none-any.whl (84 kB)
[K     |████████████████████████████████| 84 kB 2.3 MB/s 
[?25hCollecting python-multipart
  Downloading python-multipart-0.0.5.tar.gz (32 kB)
Collecting orjson
  Downloading orjson-3.7.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (272 kB)
[K     |████████████████████████████████| 272 kB 88.4 MB/s 
Collecting analytics-python
  Downloading analytics_python-1.4.0-py2.py3-none-any.whl (15 kB)
Collecting paramiko
  Downloading paramiko-2.11.0-py2.py3-none-any.whl (212 kB)
[K     |████████████████████████████████| 212 kB 68.3 MB/s 
[?25hCollecting pycryptodome
  Downloading pycryptodome-3.15.0-cp35-abi3-manylinux2010_

In [58]:
import gradio as gr
import numpy as np
import torch as th
title = "Image generator from caption"
description = "Use one of the models for image generation from a text caption."


def fn(model_choice, text, text2):
  def default():
    print("Error")

  def diffusionmodel(captions):
    labels = []
    if captions[0].isdigit() and captions[1].isdigit():
      n1 = int(captions[0])
      n2 = int(captions[1])
      if n1 >= 0 and n1 <10 and n2 >= 0 and n2 <10:
        labels.append(n1)
        labels.append(n2)
      else:
        pass
    else: 
      pass
    a = configs.sample(labels)
    images = []
    for i in a:
      b = (i - np.min(i)) / (np.max(i) - np.min(i))
      images.append(b[0])
    return images
  def glide(captions):
    images = []
    for caption in captions:
      
      image = configs_glide.sample(caption)
      image_upsampled = configs_glide.upsample(image, caption)
      scaled = ((image_upsampled + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()
      reshaped = scaled.permute(2, 0, 3, 1).reshape([image_upsampled.shape[2], -1, 3])
      images.append(Image.fromarray(reshaped.numpy()))
    return images

  def dalle_mini(captions):
    images = configs_dalle_mini.sample(1, captions)
    return images

  def dalle_mega(captions):
    images = configs_dalle_mega.sample(1, captions)
    return images  

  def clip_diff(captions):
    images = []
    for caption in captions:
      image = configs_clip_guided.sample(caption)
      image = TF.to_pil_image(image.cpu())
      images.append(image)
    return images

  dict = {
    'Simple Denoising Diffusion Model': diffusionmodel,
    'GLIDE Fine-tuned': glide,
    'DALL·E Mini': dalle_mini,
    'DALL·E Mega': dalle_mega,
    'CLIP Diffusion Model': clip_diff
    }
  end_images = dict[model_choice]([text, text2])
  return end_images
#clip guided diffusion aladir 2 versiones dalle 
gr.Interface(fn, [gr.inputs.Dropdown(["Simple Denoising Diffusion Model", "GLIDE Fine-tuned", "DALL·E Mini", "DALL·E Mega", "CLIP Diffusion Model"]), "text", "text"], outputs=[gr.Gallery(label="Generated Images")], title=title, description=description).launch(debug = True)

  "Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components",


Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Running on public URL: https://57098.gradio.app

This share link expires in 72 hours. For free permanent hosting, check out Spaces: https://huggingface.co/spaces


  0%|          | 0/100 [00:00<?, ?it/s]





  0%|          | 0/100 [00:00<?, ?it/s]



Keyboard interruption in main thread... closing server.


(<gradio.routes.App at 0x7fa06cfe7b50>,
 'http://127.0.0.1:7860/',
 'https://57098.gradio.app')