<a href="https://colab.research.google.com/github/jaskooner/voganclip-eleutherAI/blob/main/Semantic_Style_Transfer_with_CLIP%2BVQGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Zero shot semantic style transfer
Equal contribution:


By 

Katherine Crowson (https://twitter.com/RiversHaveWings)

Louis Castricato (https://twitter.com/lcastricato)

Nev (https://twitter.com/apeoffire)

Jbustter (https://twitter.com/jbusted1) 

Theodore (https://twitter.com/TheodoreGalanos)

... and all of our friends at EleutherAI!

Business end below is the area of interest. Masking is performed via a logit lens technique, optimization is performed via a spherical geodesic + reweighing technique. Currently limitations are mostly due to the restrictions imposed by our small CLIP model as well as various tweaks needed for masking. Masking is a bit finnicky. We will update this notebook in due time. 

When performing interactive editing, you'll need to keep reuploading the output back to imgur. We wanted the interactive editing to be non-stateful, so the z values are not preserved. This in turn allows for a more interactive experience than say StyleCLIP.

We also have support for custom GANs. 

VQGAN Wikiart 16k: http://eaidata.bmk.sh/data/Wikiart_16384/ which was trained specifically for this project


# Version Number: 1.1

Added: Support for custom GANs. Dynamically scaling masking.

Coming soon: Improved masking. Better default GAN.
*italicized text*
Coming "soon": Whitepaper

Coming less "soon": fatter CLIP text encoder. Currently WIP

In [None]:
# @title Licensed under the MIT License

# Copyright (c) 2021 Katherine Crowson

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

In [None]:
!nvidia-smi

In [None]:
!git clone https://github.com/openai/CLIP
!git clone https://github.com/CompVis/taming-transformers
!pip install ftfy regex tqdm omegaconf pytorch-lightning kornia madgrad einops

In [None]:
!curl -L 'https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' > vqgan_gumbel_f8_8192.yaml
!curl -L 'https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fckpts%2Flast.ckpt&dl=1' > vqgan_gumbel_f8_8192.ckpt

!curl -L 'http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.yaml' > config.yaml
!curl -L 'http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.ckpt' > wikiart.ckpt

In [None]:
import math
import sys

from IPython import display
from omegaconf import OmegaConf
from PIL import Image
import requests
import torch
from torch import nn, optim
import madgrad
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm

sys.path.append('./taming-transformers')

from CLIP import clip
from taming.models import cond_transformer, vqgan
import kornia.augmentation as K

# Masking

In [None]:

class BoxCropper(object): 
    def __init__(self, w=0.3, h=0.3):
      self.w, self.h = w, h

    def sample(self, source):
        w, h = int(source.width*self.w), int(source.height*self.h)
        w, h = torch.randint(w//2, w+1, []).item(), torch.randint(h//2, h+1, []).item()
        h = w
        x1 = torch.randint(0, source.width - w + 1, []).item()
        y1 = torch.randint(0, source.height - h + 1, []).item()
        x2, y2 = x1 + w, y1 + h
        box = x1, y1, x2, y2
        crop = source.crop(box)
        mask = torch.zeros([source.size[1], source.size[0]])
        mask[y1:y2, x1:x2] = 1.
        return crop, mask

def sample(source, sampler, model, preprocess, n=64000, batch_size=128):
    n_batches = 0- -n // batch_size  # round up
    t_crop = 0

    model.eval()
    with torch.no_grad():
        for step in tqdm(range(n_batches)):
            t_crop = float(step)/float(n_batches)
            crop_cur = (0.4) * (1- t_crop) + (0.1) * t_crop
            sampler.w = crop_cur
            sampler.h = crop_cur

            batch = []
            for _ in range(batch_size):
                crop, mask = sampler.sample(source)
                batch.append((preprocess(crop).unsqueeze(0).to(next(model.parameters()).device), mask))
            crops = torch.cat([img for img, *_ in batch], axis=0)
            embeddings = model.encode_image(crops).cpu().detach()
            # yield *zip(embeddings, [mask for _, mask, *_ in batch])
            for emb, msk in zip(embeddings, [mask for _, mask, *_ in batch]):
                yield emb, msk
    # return samples

In [None]:
def aggregate(samples, labels, model):
    texts = clip.tokenize(labels).to(device)
    with torch.no_grad():
        text_embeddings = model.encode_text(texts).cpu()
    masks = []
    for label, text_emb in zip(labels, text_embeddings):
        text_features = text_emb / text_emb.norm(dim=-1, keepdim=True)
        pixel_sum = torch.ones_like(next(samples)[1])
        samples_per_pixel = torch.ones_like(next(samples)[1])
        # dists = [spherical_dist(text_emb.float(), embedding.float()).item()
        #          for embedding, *_ in samples]
        # min_dist, max_dist = min(dists), max(dists)
        for embedding, mask in samples: # dist, (embedding, mask) in zip(dists, samples):
            image_features = embedding / embedding.norm(dim=-1, keepdim=True)
            logit_scale = model.logit_scale.exp().to(image_features.device)
            logits_per_image = logit_scale * image_features @ text_features.t()
            dist = logits_per_image.float().exp().item()
            # dist = spherical_dist(text_emb.float(), embedding.float()).item()
            pixel_sum += mask * dist
            samples_per_pixel += mask
        img = (#samples_per_pixel-
               pixel_sum
               ) / samples_per_pixel
        # img *= 4
        # print(img.max())
        # print(img.min(), img.max())
        img = ((img - img.min()
        ) / img.max()) ** 2 # 0.75
        # img /= img.max()
        #img[img <= 0.001] = 0.
        masks.append((img, label))
    return masks

In [None]:
def visualise(source, masks):
    source = TF.to_tensor(source)
    for img, label in masks:
        TF.to_pil_image(source * img[None]).save('mask_temp.png')
        display.display(display.Image('mask_temp.png'))

def save(masks):
    source = torch.ones_like(masks[0])
    for img, label in masks:
        return source * img[None]


# Generator

In [None]:
def sinc(x):
    return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))


def lanczos(x, a):
    cond = torch.logical_and(-a < x, x < a)
    out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
    return out / out.sum()


def ramp(ratio, width):
    n = math.ceil(width / ratio + 1)
    out = torch.empty([n])
    cur = 0
    for i in range(out.shape[0]):
        out[i] = cur
        cur += ratio
    return torch.cat([-out[1:].flip([0]), out])[1:-1]


def resample(input, size, align_corners=True):
    n, c, h, w = input.shape
    dh, dw = size

    input = input.view([n * c, 1, h, w])

    if dh < h:
        kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
        pad_h = (kernel_h.shape[0] - 1) // 2
        input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
        input = F.conv2d(input, kernel_h[None, None, :, None])

    if dw < w:
        kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
        pad_w = (kernel_w.shape[0] - 1) // 2
        input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
        input = F.conv2d(input, kernel_w[None, None, None, :])

    input = input.view([n, c, h, w])
    return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)


class ReplaceGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x_forward, x_backward):
        ctx.shape = x_backward.shape
        return x_forward

    @staticmethod
    def backward(ctx, grad_in):
        return None, grad_in.sum_to_size(ctx.shape)


replace_grad = ReplaceGrad.apply


class ClampWithGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, min, max):
        ctx.min = min
        ctx.max = max
        ctx.save_for_backward(input)
        return input.clamp(min, max)

    @staticmethod
    def backward(ctx, grad_in):
        input, = ctx.saved_tensors
        return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None


clamp_with_grad = ClampWithGrad.apply


def spherical_dist(x, y, noise = False, noise_coeff=0.1):
    x_normed = F.normalize(x, dim=-1)
    y_normed = F.normalize(y, dim=-1)
    if noise:
        with torch.no_grad():
            noise1 = torch.empty(x_normed.shape).normal_(0,0.0422).to(x_normed).detach()*noise_coeff
            noise2 = torch.empty(y_normed.shape).normal_(0,0.0422).to(x_normed).detach()*noise_coeff

            x_normed += noise1
            y_normed += noise2
    x_normed = F.normalize(x_normed, dim=-1)
    y_normed = F.normalize(y_normed, dim=-1)

    return x_normed.sub(y_normed).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
    
def bdot(a, b):
    B = a.shape[0]
    S = a.shape[1]
    b = b.expand(B, -1)
    #print(a.shape)
    #print(b.shape)
    return torch.bmm(a.view(B, 1, S), b.view(B, S, 1)).reshape(-1)

def inner_dist(x,y):
    x_normed = F.normalize(x, dim=-1)
    y_normed = F.normalize(y, dim=-1)
    return bdot(x_normed, y_normed)

class MakeCutouts(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow
        self.noise_fac = 0.1
        self.augs = nn.Sequential(
          K.RandomHorizontalFlip(p=0.5),
          K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'),
          K.RandomPerspective(0.2, p=0.4),
          K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
          K.RandomGrayscale(p=0.1),
        )

    def set_cut_pow(self, cut_pow):
      self.cut_pow = cut_pow

    def forward(self, input, cut_pow=None, augs=True, grads=True):
        if cut_pow is None:
          cut_pow = self.cut_pow

        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])** cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
        batch = torch.cat(cutouts, dim=0)
        if grads:
          batch = clamp_with_grad(batch, 0, 1)
        if augs:
          batch = self.augs(batch)
          if self.noise_fac:
              facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
              batch = batch + facs * torch.randn_like(batch)
        return batch



def load_vqgan_model(config_path, checkpoint_path):
    config = OmegaConf.load(config_path)
    ddconfig = config.model.params.ddconfig
    model = vqgan.GumbelVQ(**config.model.params)
    model.init_from_ckpt(checkpoint_path)
    model.eval().requires_grad_(False)
    del model.loss
    return model


def size_to_fit(size, max_dim, scale_up=False):
    w, h = size
    if not scale_up and max(h, w) <= max_dim:
        return w, h
    new_w, new_h = max_dim, max_dim
    if h > w:
        new_w = round(max_dim * w / h)
    else:
        new_h = round(max_dim * h / w)
    return new_w, new_h


def fetch(path_or_url):
    if not (path_or_url.startswith('http://') or path_or_url.startswith('https://')):
        return open(path_or_url, 'rb')
    return requests.get(path_or_url, stream=True).raw

## Business End


Image path is the image you want to apply style transfer to (upload it to imgur). From text is the subject within the image you want to transfer (the thing you want the model to semantically segment to) and to text is the effect you want to apply.

For example: 
From text as house and to text as carnival would segment out a house within the image and redraw it as a carnival.

In [None]:
#@title Parameters

from_image_path = 'https://i.imgur.com/7k0YQWK.png'  #@param {type:"string"}
image_size = 640  #@param {type:"integer"}
from_text = 'gates'  #@param {type:"string"}
to_text = 'polar bears'  #@param {type:"string"}
scale_dir_by =   1.25#@param {type:"number"}
clip_model = 'ViT-B/32'  #@param ["ViT-B/32", "RN50", "RN101", "RN50x4"]
use_mask = True  #@param {type:"boolean"}
invert_mask = True  #@param {type:"boolean"}
use_wiki_art = True  #@param {type:"boolean"}
cut_pow_start = 0.3 #@param {type:"number"}
cut_pow_end =  1.0#@param {type:"number"}
cut_pow_length =  400#@param {type:"integer"}
mask_samples =  16000#@param {type:"integer"}

In [None]:
#Reset tqdm
#tqdm._instances.clear()
from google.colab import drive
import os

#Set up model and devices
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
torch.cuda.empty_cache()
sdbo= scale_dir_by
t = 0
cut_out_num=64

if use_wiki_art:
  try:
    #assume wikiart gan, edit yourself if you want.
    model = load_vqgan_model('config.yaml', 'wikiart.ckpt').to(device)
  except:
    print("Custom model not found, using default.")
    model = load_vqgan_model('vqgan_gumbel_f8_8192.yaml', 'vqgan_gumbel_f8_8192.ckpt').to(device)
else:
  model = load_vqgan_model('vqgan_gumbel_f8_8192.yaml', 'vqgan_gumbel_f8_8192.ckpt').to(device)

perceptor, preprocess = clip.load(clip_model, jit=False)
perceptor.eval().requires_grad_(False).to(device)



torch.manual_seed(0)

normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                 std=[0.26862954, 0.26130258, 0.27577711])

cut_size = perceptor.visual.input_resolution
f = 2**(model.decoder.num_resolutions - 1)
make_cutouts = MakeCutouts(cut_size, cut_out_num, cut_pow=cut_pow_start)

pil_image = Image.open(fetch(from_image_path)).convert('RGB')

source = pil_image
labels = [from_text]
texts = clip.tokenize(labels).to(device)

In [None]:
#Visualize the semantic segment we're going to use
agg = aggregate(sample(source, BoxCropper(), perceptor, preprocess, n = mask_samples), labels, perceptor)
visualise(source, agg)

#Save this as mask.png
new_p = TF.to_pil_image(agg[0][0])
new_p.save('mask.png')

In [None]:
#Properly rescale image
sideX, sideY = size_to_fit(pil_image.size, image_size, True)
toksX, toksY = sideX // f, sideY // f
sideX, sideY = toksX * f, toksY * f

from_embed = perceptor.encode_text(clip.tokenize(from_text).to(device)).float()
to_embed = perceptor.encode_text(clip.tokenize(to_text).to(device)).float()

image = TF.to_tensor(pil_image.resize((sideX, sideY), Image.LANCZOS)).to(device).unsqueeze(0)
mask_dist = None
mask_total = 0.

pil_mask = Image.open("mask.png")
#Are we using the mask we just generated?
if use_mask:
    if 'A' in pil_mask.getbands():
        pil_mask = pil_mask.getchannel('A')
    elif 'L' in pil_mask.getbands():
        pil_mask = pil_mask.getchannel('L')
    else:
        raise RuntimeError('Mask must have an alpha channel or be one channel')
    mask = TF.to_tensor(pil_mask.resize((toksX, toksY), Image.BILINEAR))
    mask = mask.to(device).unsqueeze(0)
    mask_dist = TF.to_tensor(pil_mask.resize((sideX, sideY), Image.BILINEAR)).to(device).unsqueeze(0)

    #Threshold on the average of the mask
    std, mean = torch.std_mean(mask_dist.view(-1)[torch.nonzero(mask_dist.view(-1))])
    std = std.item()
    mean = mean.item()
    print(mean + (0.5) * std)
    mask = mask.lt(mean).float()

    if invert_mask:
        mask = 1 - mask
    mask_total = mask_dist.view(-1).sum()
else:
    mask = torch.ones([], device=device)

z = model.quant_conv(model.encoder(image * 2 - 1))
z.requires_grad_()
opt = optim.Adam([z], lr=0.15)

#Draw picture
def synth(z, sample=False):
    logits = model.quantize.proj(z)
    if sample:
        one_hot = F.gumbel_softmax(logits, tau=1, hard=True, dim=1)
    else:
        one_hot = F.one_hot(logits.argmax(1), logits.shape[1]).movedim(3, 1).to(logits.dtype)
    z_q = torch.einsum('nchw,cd->ndhw', one_hot, model.quantize.embed.weight)
    return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)

#Draw picture + print status + save picture
@torch.no_grad()
def checkin(i, losses):
    losses_str = ', '.join(f'{loss.item():g}' for loss in losses)
    tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}')
    out = synth(z)
    TF.to_pil_image(out[0].cpu()).save('progress.png')
    display.display(display.Image('progress.png'))


#Optimize for prompt
def ascend_txt():
    out = synth(replace_grad(z, z * mask), sample=True)
    seed = torch.randint(2**63 - 1, [])
  
    noise_val = (1 - t) * 0.1

    #Random crops
    with torch.random.fork_rng():
        torch.manual_seed(seed)
        out_embeds = perceptor.encode_image(normalize(make_cutouts(out))).float()

    with torch.random.fork_rng():
        torch.manual_seed(seed)
        image_embeds = perceptor.encode_image(normalize(make_cutouts(image))).float()

    if mask_dist is not None:
        with torch.random.fork_rng():
            torch.manual_seed(seed)
            mask_scores = make_cutouts(mask_dist, augs=False, grads=False).view(cut_out_num, -1).sum(dim=-1) / mask_total

    result = []
    #Compare the image we started with to crops of the current image
    image_analogy = spherical_dist(out_embeds, image_embeds) * (torch.ones_like(mask_scores) - mask_scores)
    result.append(image_analogy.mean())
    #Move over a spherical geodesic that connects the "from state" to the "to state"
    word_analogy = (spherical_dist(out_embeds, to_embed, noise=False, noise_coeff=noise_val) - spherical_dist(out_embeds, from_embed, noise=False, noise_coeff=noise_val))
    result.append(word_analogy.mean() * scale_dir_by)

    return result


def train(i):
    global t
    t = min(float(i)/float(cut_pow_length),1.0) 
    cur_cut_pow = (1 - t) * cut_pow_start + t * cut_pow_end
    make_cutouts.set_cut_pow(cur_cut_pow)

    global scale_dir_by
    #scale_dir_by = clamp(1.0, sdbo * (1 - t) + t, sdbo)

    opt.zero_grad()
    lossAll = ascend_txt()
    if i % 50 == 0:
        checkin(i, lossAll)
    loss = sum(lossAll)
    loss.backward()
    opt.step()

i = 0
try:
    with tqdm() as pbar:
        while True:
            train(i)
            i += 1
            pbar.update()
except KeyboardInterrupt:
    pass
