<a href="https://colab.research.google.com/github/gabrielmuller/sentientpaint/blob/main/Sentient_Paint_text_to_image.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Sentient Paint: text to image

Generate images from text using CLIP and feature visualization techniques (Fourier basis, color correlation, transformation robustness).

[CLIP: Connecting
Text and Images](https://openai.com/blog/clip/)

[Feature Visualization](https://distill.pub/2017/feature-visualization/)

Creates images larger than the model's input size. Every epoch, random square crops are taken and resized to 224x224. CLIP computes 
 similarity values for each crop, which are then summed up into the loss value.

# How to use

*   Go to *Parameters > Prompt* section
*   Replace the prompt text with your own
*   Run everything (Ctrl + F9)

The first run will install dependencies. This takes a few minutes.

After that, images will appear in *Optimize image* as they are computed.


# Advanced usage

## High resolution images

You can generate higher resolution images by changing image resolution. Fine details are added by optimizing for small crops of the image. Because of this, the resulting image can take long to generate, and look repetitive.

For better results, try the following steps:

*   Start optimizing with a lower resolution
*   Wait until the image is stable, then stop optimization
*   Run *Image resolution and smoothness* with increased resolution and less smoothness
*   Run *Rescale image*
*   Run *Optimize image* again



## Video

An image is optimized iteratively. This means we can save the image at every step and join them as frames of a video.

Furthermore, we can apply a small transformation every frame, such as zoom, to create an infinite zoom effect.

*   Go to *New video*
*   Change `VIDEO_NAME` to some value and run the cell
*   Run *Optimize image*
*   Run *Download video*

## Style transfer

You can transfer a style to an image by setting it as the initial image, then optimizing it.



*   Click the file icon on the left sidebar
*   Click the upload file button to choose your image
*   Right click the uploaded file and choose "Copy path"
*   Go to the *Load image from file* section
*   Paste the path and run the cell
*   Run *Rescale image*
*   Run *Optimize image*



# Setup

## Install torch

In [None]:
import subprocess
from re import findall
version_out = subprocess.check_output(['nvcc', '--version']).decode()
version = findall(r'release (.*),', version_out)[0]

print('CUDA version', version)
version_map = {
    '10.0': '+cu100',
    '10.1': '+cu101',
    '10.2': '+cu102',
}

suffix = version_map[version] if version in version_map else '+cu110'

!pip install torch==1.7.1{suffix} torchvision==0.8.2{suffix} -f https://download.pytorch.org/whl/torch_stable.html ftfy regex

## Install CLIP

In [None]:
!pip install git+https://github.com/openai/CLIP.git

## Imports

In [None]:
import os
import random
import torch
import torch.fft as fft

from torchvision import transforms
from PIL import Image
from IPython.display import display

## Load CLIP model

In [None]:
import clip

DEVICE = 'cuda'

model, preprocess = clip.load('ViT-B/32', device=DEVICE)

# Width and height of model input
CLIP_SIZE = 224

# Feature visualization

## Transformation robustness

In [None]:
# Redefine CLIP normalization function so it can be backpropagated
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
                                 (0.26862954, 0.26130258, 0.27577711))

def image_transform(image_size, big_picture_focus=1.8):
  # Crop scale is random, so CLIP optimizes for both the big picture and
  # fine details.
  crop_ratio = CLIP_SIZE / min(image_size)
  scale = crop_ratio + random.random()**big_picture_focus * (1-crop_ratio)
  resize = lambda image: torch.nn.functional.interpolate(image,
                                                         scale_factor=scale,
                                                         mode='bilinear')
  return transforms.Compose([
      transforms.Lambda(clamp_image),
      transforms.Lambda(resize),
      transforms.RandomAffine(5, scale=(1.1, 1.1), resample=Image.BILINEAR),
      transforms.RandomCrop(CLIP_SIZE),
      normalize,
  ])

## Frequency weights

In [None]:
def generate_frequency_weights(image_size, smoothness):
  # `fftfreq` is not available in torch 1.7, so let's use numpy's instead
  import numpy as np

  # Build a frequency weight matrix. Define how much of each frequency the
  # gradient should have. Weight is inversely proportional to frequency.
  vertical_freqs = torch.tensor(np.fft.fftfreq(SIZE[0]),
                                device=DEVICE,
                                dtype=torch.float32).unsqueeze(1)
  horizontal_freqs = torch.tensor(np.fft.rfftfreq(SIZE[1]),
                                  device=DEVICE,
                                  dtype=torch.float32)

  # Use vertical and horizontal frequency to calculate euclidean distance
  shape = vertical_freqs.shape[0], horizontal_freqs.shape[0]
  weights = torch.square(vertical_freqs).expand(shape) + \
                torch.square(horizontal_freqs).expand(shape)
  weights = torch.sqrt(weights)

  weights = torch.pow(weights, smoothness)

  # Avoid division by zero
  weights[0, 0] = weights[0, 1]

  # Scale with the mean to compensate for the low pass
  weights = weights.mean() / weights

  print('Frequency weights computed.')
  print(weights.shape)
  return weights


def balance_frequencies(grad, shape, weights):
  # Use Fourier transform to get gradient's frequencies
  grad_fft = fft.rfftn(grad, shape)

  # Apply desired weights to frequencies
  grad_fft *= weights

  # Inverse transform back into gradient
  return fft.irfftn(grad_fft, shape)

## Color correlation

In [None]:
# Color correlation matrix precomputed using Cholesky decomposition on some
# images from the CIFAR dataset
COLOR_CORRELATION = torch.Tensor([[1.0000, 0.0000, 0.0000],
                                  [0.9224, 0.3862, 0.0000],
                                  [0.8182, 0.4037, 0.4094]]).to(DEVICE)


def correlate_colors(image):
  height = image.shape[2]
  width = image.shape[3]

  # Change shape to 3 x N matrix
  image = torch.transpose(image, 1, 3)
  image = image.reshape(-1, 3)

  # Normalize image to unit mean and standard deviation
  mean = image.mean(0, keepdim=True)
  std = image.std(0, keepdim=True)
  image = (image - mean) / std

  # Apply correlation matrix
  image = image @ COLOR_CORRELATION

  # Restore mean and deviation
  image = image * std + mean

  # Restore shape
  image = image.reshape(1, width, height, 3)
  image = torch.transpose(image, 1, 3)

  return image

## Other functions

In [None]:
# Keep image inside valid range
def clamp_image(image):
  return torch.clamp(image, 0.0, 1.0)


def display_image(image):
  with torch.no_grad():
    pil_image = transforms.ToPILImage()(clamp_image(image[0]))
    display(pil_image)


def save(image, path, index):
  with torch.no_grad():
    pil_image = transforms.ToPILImage()(clamp_image(image[0]))
    pil_image.save(path + str(index).zfill(4) + '.png')

# Parameters

## Image resolution and smoothness

Note: if you want to create high resolution images, I recommend optimizing with a lower resolution, upscaling the image (see *Change image size* section), and then continue to optimize. 

In [None]:
# Replace this with your desired image resolution (height, width).
SIZE = 480, 854

# How much should lower frequencies dominate? Replace this with a value around
# the [1.0, 2.0] range for the first pass, or less than 1.0 for a finer detail
# pass.
# A lower value means more detail but also more noise. 
# A high value is more natural but also blurrier.
SMOOTHNESS = 1.7


freq_weights = generate_frequency_weights(SIZE, SMOOTHNESS)

## Prompt

In [None]:
# ENTER YOUR PROMPT HERE
OPTIMIZE = '''Sentient Paint, a colorful surreal painting'''

# Avoid undesirable features
DEOPTIMIZE = '''text, writing, signature'''


text = clip.tokenize([OPTIMIZE, DEOPTIMIZE]).to(DEVICE)

# Run

## New video

Choose a video name to save images as frames for a video. Run *Download video* to generate a video from saved images.

In [None]:
# Enter a name for the video e.g. VIDEO_NAME = 'hello'
VIDEO_NAME = ''

if VIDEO_NAME:
  VIDEO_FILENAME = VIDEO_NAME + '.mp4'
  PATH = '/content/' + VIDEO_FILENAME + '/'

  if not os.path.exists(PATH):
    os.makedirs(PATH)

# Transformation applied for every frame in the video
def frame_transform(image, zoom_factor=1.01):
  crop_size = int(SIZE[0] / zoom_factor), int(SIZE[1] / zoom_factor)

  zoomed = transforms.CenterCrop(crop_size)(image)
  zoomed = torch.nn.functional.interpolate(zoomed, size=SIZE, mode='bilinear')
  return zoomed

## New image

In [None]:
def generate_initial_image(freq_shape, image_shape, high_freq_amount=0.2):
  # Generate random frequency and phase
  shape = 1, 3, *freq_shape
  random_fft = torch.rand(shape, device=DEVICE, dtype=torch.cfloat)
  random_fft = random_fft * freq_weights

  # Generate image from frequencies
  image = fft.irfftn(random_fft, image_shape)

  # Add some regular noise too. Seems to work well
  high_freq_noise = torch.rand(1, 3, *image_shape).to(DEVICE) * high_freq_amount
  image += high_freq_noise

  # Reduce the image's range. Make it closer to gray
  image = (image + 1.5) / 4.0

  image = clamp_image(correlate_colors(image))
  image.requires_grad = True
  return image

image = generate_initial_image(freq_weights.shape, SIZE)

epoch = 0

print('Initial image:')
display_image(image)
print(image)
print(image.shape)

## Optimize image

In [None]:
# A higher learning rate will generate an image quicker, but the image may look
# oversaturated and weird
optimizer = torch.optim.Adam([image], lr=0.03)

# How many samples (random crops) of the image we should take per batch.
# This is limited by VRAM.
N_SAMPLES_PER_BATCH = 80

# Larger images need more samples.
TARGET_N_SAMPLES = int(SIZE[0] * SIZE[1] / 8000)

N_BATCHES = max(int(TARGET_N_SAMPLES / N_SAMPLES_PER_BATCH), 1)
print(f'Optimizing {N_BATCHES * N_SAMPLES_PER_BATCH} samples per epoch')

def train(image,
          n_samples_per_batch,
          n_batches,
          big_picture_focus=1.8,
          deoptimization_weight=0.2):
  transform_image = image_transform(SIZE, big_picture_focus=big_picture_focus)

  # Gradient accumulator, a running total of each batch's gradient
  grad_acc = torch.zeros(1, 3, *SIZE, device=DEVICE)

  for _ in range(n_batches):
    samples = torch.zeros(n_samples_per_batch,
                          3,
                          CLIP_SIZE,
                          CLIP_SIZE,
                          device=DEVICE)
    for i in range(n_samples_per_batch):
      samples[i] = transform_image(image)

    # CLIP computes the similarity between each sample and text
    similarities, _ = model(samples, text)
    loss = 0

    # Add up all similarities into a loss value. The optimizer minimizes loss.
    for similarity in similarities:
      # Each sample has two similarity values: one for each prompt
      good, bad = similarity
      loss += deoptimization_weight * bad - (1 - deoptimization_weight) * good

    optimizer.zero_grad()
    loss /= n_samples_per_batch
    loss.backward()
    grad_acc += image.grad

  grad_acc /= n_batches * n_samples_per_batch
  grad_acc = balance_frequencies(grad_acc, SIZE, freq_weights)
  grad_acc = correlate_colors(grad_acc)
  image.grad = grad_acc

  optimizer.step()

  return loss.item()


for i in range(120):
  loss = train(image, N_SAMPLES_PER_BATCH, N_BATCHES)
  print('Epoch', epoch, 'Loss', loss)

  if (epoch + 1) % 20 == 0:
    display_image(image)

  if VIDEO_NAME:
    with torch.no_grad():
      image.copy_(frame_transform(image))
    save(image, PATH, epoch)
  
  epoch += 1

# Misc

## Show current image

In [None]:
display_image(image)

## Download video

In [None]:
GLOB = PATH + '*.png'
OUT = PATH + VIDEO_FILENAME

!ffmpeg -framerate 30 -y -pattern_type glob -i "{GLOB}" -vf "pad=ceil(iw/2)*2:ceil(ih/2)*2" -c:v libx264 -pix_fmt yuv420p "{OUT}"

from google.colab import files
files.download(OUT)

## Load image from file

In [None]:
# Enter path here e.g. '/content/hello.png'
path = ''

if path:
  image = transforms.ToTensor()(Image.open(path).convert('RGB'))
  image = image.unsqueeze(0).to(DEVICE)
  image.requires_grad = True
  print(image.shape)
  epoch = 0
else:
  print('No file specified')

If the output image resolution is not the same as the loaded image, run *Rescale image*.

## Rescale image

In [None]:
print('Before', image.shape)

with torch.no_grad():
  image = torch.nn.functional.interpolate(image,
                                          size=SIZE,
                                          mode='bicubic').to(DEVICE)
  image.requires_grad = True

print('After', image.shape)


## Check GPU model

In [None]:
!nvidia-smi

## Free memory

Try this if you run out of memory. If it doesn't work, reduce `N_SAMPLES_PER_BATCH` in *Optimize image*.

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

# LICENSE

The MIT License

Copyright 2021 Gabriel Müller

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

