# Text-Guided Editing of Images (Using CLIP and StyleGAN)

In [1]:
#@title Setup (may take a few minutes)
!git clone https://github.com/khalilacheche/StyleCLIP.git

import os
os.chdir(f'./StyleCLIP')

!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

!pip install wandb

from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# Authenticate and create the PyDrive client.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

# downloads StyleGAN's weights and facial recognition network weights
ids = ['1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT', '1N0MZSqPRJpLfP4mFQCS14ikrVSe8vQlL']
for file_id in ids:
  downloaded = drive.CreateFile({'id':file_id})
  downloaded.FetchMetadata(fetch_all=True)
  downloaded.GetContentFile(downloaded.metadata['title'])

Cloning into 'StyleCLIP'...
remote: Enumerating objects: 885, done.[K
remote: Counting objects: 100% (267/267), done.[K
remote: Compressing objects: 100% (119/119), done.[K
remote: Total 885 (delta 169), reused 196 (delta 140), pack-reused 618[K
Receiving objects: 100% (885/885), 241.88 MiB | 20.84 MiB/s, done.
Resolving deltas: 100% (314/314), done.
Updating files: 100% (260/260), done.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ftfy
  Downloading ftfy-6.1.1-py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.1/53.1 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: ftfy
Successfully installed ftfy-6.1.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-j0x2su8n
  Running command git clo

In [None]:
!wandb login

In [81]:
experiment_type = 'edit' #@param ['edit', 'free_generation']

semantic_part = "hair" #@param ["mouth","skin","eyes","nose","ears","eye_brows","hat","hair","neck"]

description = 'A person with ### hair' #@param {type:"string"}

latent_path = None #@param {type:"string"}

optimization_steps = 40 #@param {type:"number"}

clip_lambda = 1 #@param {type:"number"}

l2_lambda = 0.004 #@param {type:"number"}

loc_lambda = 0.00001 #@param {type:"number"}

id_lambda = 0.005 #@param {type:"number"}

stylespace = False #@param {type:"boolean"}

create_video = True #@param {type:"boolean"}

export_segmentation_image = True #@param {type:"boolean"}


In [None]:
colors = ['black', 'blue', 'red', 'blonde', 'gray', 'purple', 'brown']

In [82]:
use_seed = True #@param {type:"boolean"}

seed = 1 #@param {type: "number"}

In [83]:
#@title Additional Arguments
args = {
    "description": description,
    "ckpt": "stylegan2-ffhq-config-f.pt",
    "stylegan_size": 1024,
    "lr_rampup": 0.05,
    "lr": 0.1,
    "step": optimization_steps,
    "mode": experiment_type,
    "clip_lambda": clip_lambda,
    "l2_lambda": l2_lambda,
    "id_lambda": id_lambda,
    "loc_lambda": loc_lambda,
    'work_in_stylespace': stylespace,
    "latent_path": latent_path,
    "truncation": 0.7,
    "save_intermediate_image_every": 1 if create_video else 20,
    "results_dir": "results",
    "ir_se50_weights": "model_ir_se50.pth",
    "semantic_part":semantic_part,
    "export_segmentation_image": export_segmentation_image,
    "seed": seed
}

In [88]:
from optimization.run_optimization import main
from argparse import Namespace
from criteria.clip_loss import CLIPLoss
import clip
from PIL import Image
import torch
import torchvision.transforms as transforms
import wandb


wandb.init(
    # set the wandb project where this run will be logged
    project="clip_colors",
    
    # track hyperparameters and run metadata
    config=args
)


clip_loss_evals = {}
for color in colors:
  torch.manual_seed(seed)
  args['description'] = description.replace('###', color)
  result = main(Namespace(**args))
  img_path = 'results/00039.jpg'
  image = Image.open(img_path)

  wandb.log({color + ' image': wandb.Image('results/00039.jpg')})

  transform = transforms.ToTensor()
  img_tensor = transform(image).unsqueeze(0).cuda()
  clip_loss = CLIPLoss(Namespace(**args))
  clip_losses = {}
  for eval_color in colors:
    desc = description.replace('###', eval_color)
    text_inputs = torch.cat([clip.tokenize(desc)]).cuda()
    c_loss = clip_loss(img_tensor, text_inputs)
    clip_losses[eval_color] = c_loss.item()
  clip_loss_evals[color] = clip_losses

  data = [[key, val] for (key, val) in clip_losses.items()]
  table = wandb.Table(data=data, columns = ["color", "loss"])
  wandb.log({color + "_clip_loss" : wandb.plot.bar(table, "color",
                               "loss", title="Clip losses for image " + color)})

Loading ResNet ArcFace
Loading Segmentation Models


loss: 0.9204, loc_loss: 7124.3711;: 100%|██████████| 40/40 [00:33<00:00,  1.18it/s]


In [None]:
#@title Visualize Result
from torchvision.utils import make_grid
from torchvision.transforms import ToPILImage
result_image = ToPILImage()(make_grid(result.detach().cpu(), normalize=True, scale_each=True, range=(-1, 1), padding=0))
h, w = result_image.size
result_image.resize((h // 2, w // 2))