In [None]:
!pip install --upgrade torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
!pip install einops ninja
#!pip install --upgrade https://download.pytorch.org/whl/nightly/cu111/torch-1.11.0.dev20211012%2Bcu111-cp37-cp37m-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu111/torchvision-0.12.0.dev20211012%2Bcu111-cp37-cp37m-linux_x86_64.whl
!git clone https://github.com/NVlabs/stylegan3
!git clone https://github.com/openai/CLIP
!pip install -e ./CLIP
!pip install einops ninja

import sys
sys.path.append('./CLIP')
sys.path.append('./stylegan3')

import tensorflow
import io
import os, time
import pickle
import shutil
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import requests
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import clip
from tqdm.notebook import tqdm
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from IPython.display import display
from einops import rearrange

In [None]:
!nvidia-smi

print(torch.__version__)

In [None]:
# Additional image that will be treated like a text prompt. If it's a zip, we take an average of all the images / experimental
# image_prompt_file = "https://images.pexels.com/photos/614810/pexels-photo-614810.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=1" #@param {type: "string"}
image_prompt_file = "https://img.freepik.com/free-photo/front-view-beautiful-man_23-2148780802.jpg?w=740&t=st=1704837609~exp=1704838209~hmac=a4133bfc49105a9a3e107708666534aefc6fc78b3c55d44526380a192f6f3405" #@param {type: "string"}
images_prompt_weight =  1 #@param {type: "number"}

# Speed at which to try approximating the text. Too fast seems to give strange results. Maximum is 100.
speed = 10  #@param {type: "number"}

# How many steps to run. Each step generates one frame.
steps = 100 #@param {type: "number"}

# Change the seed to generate variations of the same prompt
seed = -1 #@param {type: "number"}

# We haven't completely understood which parameters influence the generation of this model. Changing the learning rate could help (between 0 and 100)
learning_rate = 10

model_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhq-1024x1024.pkl'

output_path = "/content/output"

social = False

smoothing = (100.0-speed)/100.0

import sys
import torch
device = torch.device('cuda:0')
print('Using device:', device, file=sys.stderr)
torch.manual_seed(seed)

# Define necessary functions
import requests
from io import BytesIO

def fetch(url_or_path):
    if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
        r = requests.get(url_or_path)
        r.raise_for_status()
        fd = io.BytesIO()
        fd.write(r.content)
        fd.seek(0)
        return fd
    return open(url_or_path, 'rb')

def fetch_model(url_or_path):
    basename = os.path.basename(url_or_path)
    if os.path.exists(basename):
        return basename
    else:
        !wget -N '{url_or_path}'
        return basename

def norm1(prompt):
    "Normalize to the unit sphere."
    return prompt / prompt.square().sum(dim=-1,keepdim=True).sqrt()

def spherical_dist_loss(x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)

class MakeCutouts(torch.nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
        return torch.cat(cutouts)

make_cutouts = MakeCutouts(224, 32, 0.5)

def embed_image(image):
  n = image.shape[0]
  cutouts = make_cutouts(image)
  embeds = clip_model.embed_cutout(cutouts)
  embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n)
  return embeds


def embed_img_file(path):
  image = Image.open(path).convert('RGB')
  return embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0)

def embed_img_url(url):
  response = requests.get(url)

  # Check if the request was successful (status code 200)
  if response.status_code == 200:
    # Open the image using PIL
    image = Image.open(BytesIO(response.content)).convert('RGB')
    return embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0)
  else:
    print(f"Failed to download image. Status code: {response.status_code}")
    return None

def embed_url(url):
  image = Image.open(fetch(url)).convert('RGB')
  return embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0)

class CLIP(object):
  def __init__(self):
    clip_model = "ViT-B/32"
    self.model, _ = clip.load(clip_model)
    self.model = self.model.requires_grad_(False)
    self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                          std=[0.26862954, 0.26130258, 0.27577711])

  @torch.no_grad()
  def embed_text(self, prompt):
      "Normalized clip text embedding."
      return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float())

  def embed_cutout(self, image):
      "Normalized clip image embedding."
      return norm1(self.model.encode_image(self.normalize(image)))

clip_model = CLIP()

# Load stylegan model

network_url = model_url

with open(fetch_model(network_url), 'rb') as fp:
  G = pickle.load(fp)['G_ema'].to(device)

# Fix the coordinate grid to w_avg
shift = G.synthesis.input.affine(G.mapping.w_avg.unsqueeze(0))
G.synthesis.input.affine.bias.data.add_(shift.squeeze(0))
G.synthesis.input.affine.weight.data.zero_()

zs = torch.randn([10000, G.mapping.z_dim], device=device)
w_stds = G.mapping(zs, None).std(0)

# Run Settings
from glob import glob

#target = embed_img_file(image_prompt_file)
target = embed_img_url(image_prompt_file)
prompts = [(images_prompt_weight, target)]

# Actually do the run
from IPython.display import display


tf = Compose([
  Resize(224),
  lambda x: torch.clamp((x+1)/2,min=0,max=1),
  ])

def run():
  torch.manual_seed(seed)
  timestring = time.strftime('%Y%m%d%H%M%S')

  # Init
  # Method 1: sample 32 inits and choose the one closest to prompt

  with torch.no_grad():
    qs = []
    losses = []
    for _ in range(8):
      q = (G.mapping(torch.randn([4,G.mapping.z_dim], device=device), None, truncation_psi=0.7) - G.mapping.w_avg) / w_stds
      images = G.synthesis(q * w_stds + G.mapping.w_avg)
      embeds = embed_image(images.add(1).div(2))
      loss = 0
      for (w, t) in prompts:
        loss += w * spherical_dist_loss(embeds, t).mean(0)
      i = torch.argmin(loss)
      qs.append(q[i])
      losses.append(loss[i])
    qs = torch.stack(qs)
    losses = torch.stack(losses)
    print(losses)
    print(losses.shape, qs.shape)
    i = torch.argmin(losses)
    q = qs[i].unsqueeze(0).requires_grad_()

  # Method 2: Random init depending only on the seed.

  # Sampling loop
  q_ema = q
  opt = torch.optim.AdamW([q], lr=learning_rate/250.0, betas=(0.0,0.999))
  loop = tqdm(range(steps))
  for i in loop:
    opt.zero_grad()
    w = q * w_stds
    image = G.synthesis(w + G.mapping.w_avg, noise_mode='const')
    embed = embed_image(image.add(1).div(2))
    loss = 0
    for (w, t) in prompts:
      loss += w * spherical_dist_loss(embed, t).mean(0)
    print(loss)
    loss.backward()
    opt.step()
    loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item())

    q_ema = q_ema * smoothing + q * (1-smoothing)
    latent = q_ema * w_stds + G.mapping.w_avg
    image = G.synthesis(latent, noise_mode='const')

    #if i % 10 == 0:
    #  display(TF.to_pil_image(tf(image)[0]))
    pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0,1))
    os.makedirs(output_path, exist_ok=True)
    os.makedirs("/tmp/ffmpeg", exist_ok=True)
    if i % 5 == 0:
      pil_image.save(f'{output_path}/output_{i:04}.jpg')
      display(pil_image)
    pil_image.save(f'/tmp/ffmpeg/output_{i:04}.jpg')
  return latent

try:
  latent = run()
  torch.save(latent, f"{output_path}/latent.pt")
except KeyboardInterrupt:
  pass

out_file=output_path+"/video.mp4"


last_frame=!ls -t /tmp/ffmpeg/*.jpg | head -1
last_frame = last_frame[0]

# Copy last frame to start and duplicate at end so it sticks around longer
!cp -v $last_frame /tmp/ffmpeg/0000.jpg



encoding_options = "-c:v libx264 -crf 20 -preset slow -vf format=yuv420p -c:a aac -movflags +faststart"

!ffmpeg  -r 10 -i /tmp/ffmpeg/%*.jpg -y {encoding_options} /tmp/vid_no_audio.mp4
!ffmpeg -i /tmp/vid_no_audio.mp4 -f lavfi -i anullsrc -c:v copy -c:a aac -shortest -y "$out_file"

print("Written", out_file)
!sleep 2
!rm -r /tmp/ffmpeg

import os.path
if not os.path.exists(out_file):
  raise Exception("Expected output file does not exist.")