<a href="https://colab.research.google.com/github/hypereikon/modelos_ml/blob/main/VQGANxCLIP_ESP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#VQGANxCLIP

sintesis de imagen, texto-a-imagen

####codigo por [@aicrumb](https://github.com/aicrumb)
traducido por [@hypereikon](https://www.instagram.com/hypereikon/)


###configuración

In [None]:
#@title Instalar las dependencias necesarias, solo es necesario una vez por sesión
!git clone https://github.com/openai/CLIP
!git clone https://github.com/CompVis/taming-transformers
!pip install ftfy regex tqdm omegaconf pytorch-lightning
!curl -L 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' > vqgan_imagenet_f16_1024.yaml
!curl -L 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fckpts%2Flast.ckpt&dl=1' > vqgan_imagenet_f16_1024.ckpt
!curl -L 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' > vqgan_imagenet_f16_16384.yaml
!curl -L 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fckpts%2Flast.ckpt&dl=1' > vqgan_imagenet_f16_16384.ckpt

###generación

Argumentos:

* `model`: el "diccionario" que usará. entregan distintos resultados, recomiendo probar cada uno con la misma frases.
* `input_url`: link a la imagen inicial, opcional.
* `input_img`: ruta de colab de imagen inicial, opcional.
* `seed`: cambiara el resultado, recomiendo jugar cambiandola con la misma frase.
* `macro`: frase que buscara el modelo. puedes mezclar frases con coma; "frase1,frase2".
* `micro`: "detalles" adicionales a la frase.
* `width`: ancho en pixeles (maximo 700px, a menos pixeles mayor *coherencia*).
* `height`: alto en pixeles (maximo 700px, a menos pixeles mayor *coherencia*).
* `penalize`: frases a restar, si aparecen conceptos que no quieres puedes eliminarlos con esto. puedes activarlo o desactivarlo en `penalize_text`.
* `step_size`: que tanto avanzara entre *frames*, a un mayor valor mas *rapido avanzará*.
* `steps`:  cantidad total de *frames* que llegará, a un mayor valor mayor cantidad de detalles.

Si no usaras input_img ni uploaded_img debes dejarlas en blanco.

####Este modelo es multilingüe, puedes buscar en ingles, español, etc. o mezclados. En ingles entrega mejores resultados ٩(˘◡˘)۶


In [None]:
import argparse
import math
from pathlib import Path
import sys
model = "16384" #@param ["16384", "1024"] {type:"string"}
sys.path.append('./taming-transformers')

from IPython import display
from omegaconf import OmegaConf
from PIL import Image
from taming.models import cond_transformer, vqgan
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm

from CLIP import clip

def sinc(x):
    return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))


def lanczos(x, a):
    cond = torch.logical_and(-a < x, x < a)
    out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
    return out / out.sum()


def ramp(ratio, width):
    n = math.ceil(width / ratio + 1)
    out = torch.empty([n])
    cur = 0
    for i in range(out.shape[0]):
        out[i] = cur
        cur += ratio
    return torch.cat([-out[1:].flip([0]), out])[1:-1]


def resample(input, size, align_corners=True):
    n, c, h, w = input.shape
    dh, dw = size

    input = input.view([n * c, 1, h, w])

    if dh < h:
        kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
        pad_h = (kernel_h.shape[0] - 1) // 2
        input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
        input = F.conv2d(input, kernel_h[None, None, :, None])

    if dw < w:
        kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
        pad_w = (kernel_w.shape[0] - 1) // 2
        input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
        input = F.conv2d(input, kernel_w[None, None, None, :])

    input = input.view([n, c, h, w])
    return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)
    

def replace_grad(fake, real):
    return fake.detach() - real.detach() + real


class Prompt(nn.Module):
    def __init__(self, embed, weight=1., stop=float('-inf')):
        super().__init__()
        self.register_buffer('embed', embed)
        self.register_buffer('weight', torch.as_tensor(weight))
        self.register_buffer('stop', torch.as_tensor(stop))

    def forward(self, input):
        input_normed = F.normalize(input.unsqueeze(1), dim=2)
        embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
        dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
        dists = dists * self.weight.sign()
        return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()


def parse_prompt(prompt):
    vals = prompt.rsplit(':', 2)
    vals = vals + ['', '1', '-inf'][len(vals):]
    return vals[0], float(vals[1]), float(vals[2])


class MakeCutouts(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(resample(cutout, (self.cut_size, self.cut_size)).clamp(0, 1))
        return torch.cat(cutouts, dim=0)


def load_vqgan_model(config_path, checkpoint_path):
    config = OmegaConf.load(config_path)
    if config.model.target == 'taming.models.vqgan.VQModel':
        model = vqgan.VQModel(**config.model.params)
        model.eval().requires_grad_(False)
        model.init_from_ckpt(checkpoint_path)
    elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer':
        parent_model = cond_transformer.Net2NetTransformer(**config.model.params)
        parent_model.eval().requires_grad_(False)
        parent_model.init_from_ckpt(checkpoint_path)
        model = parent_model.first_stage_model
    else:
        raise ValueError(f'unknown model type: {config.model.target}')
    del model.loss
    return model

input_url = ""#@param {"type":"string"}
if(input_url!=""):
    !wget -i "$input_url" -O "/content/input_img.jpg" -q
    input_image = "/content/input_img.jpg"
else:
    input_image = None
uploaded_img = ""#@param {"type":"string"}
if(uploaded_img!=""):
    input_image = uploaded_img
else:
    input_image = None
seed =  0#@param
macro = "a banana doing standup" #@param {"type":"string"}
micro = "comedy" #@param {"type":"string"}
width = 256 #@param
height =  256#@param
step_size = 0.06 #@param
args = argparse.Namespace(
    prompts=[macro, macro, micro],
    size=[width, height],
    init_image=input_image,
    init_weight=0.,
    clip_model='ViT-B/32',
    vqgan_config='vqgan_imagenet_f16_{}.yaml'.format(model),
    vqgan_checkpoint='vqgan_imagenet_f16_{}.ckpt'.format(model),
    step_size=step_size,
    cutn=64,
    cut_pow=1.,
    display_freq=10,
    seed=seed
)


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)
perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device)

cut_size = perceptor.visual.input_resolution
e_dim = model.quantize.e_dim
make_cutouts = MakeCutouts(cut_size, args.cutn, cut_pow=args.cut_pow)
n_toks = model.quantize.n_e
toksX, toksY = args.size[0] // 16, args.size[1] // 16

torch.manual_seed(args.seed)

if args.init_image:
    pil_image = Image.open(args.init_image).convert('RGB')
    pil_image = pil_image.resize((toksX * 16, toksY * 16), Image.LANCZOS)
    z, *_ = model.encode(TF.to_tensor(pil_image).to(device).unsqueeze(0) * 2 - 1)
else:
    one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=device), n_toks).float()
    z = one_hot.matmul(model.quantize.embedding.weight)
    z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
z_orig = z.clone()
z.requires_grad_(True)
opt = optim.Adam([z], lr=args.step_size)

normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                 std=[0.26862954, 0.26130258, 0.27577711])

pMs = []

for prompt in args.prompts:
    txt, weight, stop = parse_prompt(prompt)
    embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
    pMs.append(Prompt(embed, weight, stop).to(device))

penalize_text = True #@param {"type":"boolean"}
penalize = "graffiti, text,sad,confused" #@param {"type":"string"}

if(penalize_text):
    txt, weight, stop = parse_prompt(penalize)
    embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
    pMs.append(Prompt(embed, -weight, stop).to(device))


def synth(z):
    z_q, *_ = model.quantize(z)
    return model.decode(z_q).add(1).div(2).clamp(0, 1)
!mkdir frames
@torch.no_grad()
def checkin(i, losses):
    losses_str = ', '.join(f'{loss.item():g}' for loss in losses)
    tqdm.write(f'Step: {i} | Avg: {sum(losses).item():g} | Losses: {losses_str}')
    out = synth(z)
    a = TF.to_pil_image(out[0].cpu())
    a.save('progress.png')
    a.save('frames/{}.png'.format(str(int(i/10)).rjust(4, "0")))
    display.display(display.Image('progress.png'))

def ascend_txt():
    out = synth(z)
    iii = perceptor.encode_image(normalize(make_cutouts(out+(torch.randn(out.shape)*0.1).cuda()))).float()

    result = []

    if args.init_weight:
        result.append(F.mse_loss(z, z_orig) * args.init_weight / 2)

    for prompt in pMs:
        result.append(prompt(iii))

    return result

def train(i):
    opt.zero_grad()
    lossAll = ascend_txt()
    if i % args.display_freq == 0:
        checkin(i, lossAll)
    loss = sum(lossAll)
    loss.backward()
    opt.step()
steps = 1000#@param
i = 0
try:
    with tqdm() as pbar:
        while i<steps:
            train(i)
            i += 1
            pbar.update()
except KeyboardInterrupt:
    pass

###Crear video y eliminar imagenes, para generar denuevo.

In [None]:
#@title crear video
!ffmpeg -i "/content/frames/%04d.png" /content/video.mp4
from IPython.display import HTML
from base64 import b64encode
mp4 = open('/content/video.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""
<video width=400 controls>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)

In [None]:
#@title borrar frames para generar denuevo
%cd /content/frames/
!rm *.png
%cd /content

No usar este notebook para fines comerciales (NFTs). (◔◡◔)

Si te parece util esta herramient puedes apoyar a Chloe comprando sus [NFTs](https://www.hicetnunc.xyz/tz/tz1Ss4GU4bmhZKPxWYCLWJAzaLeySuBveY1N)