# Text-Guided Editing of Images (Using CLIP and StyleGAN)

In [1]:
#@title Setup (may take a few minutes)
!git clone https://github.com/khalilacheche/StyleCLIP.git

import os
os.chdir(f'./StyleCLIP')

!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# Authenticate and create the PyDrive client.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

# downloads StyleGAN's weights and facial recognition network weights
ids = ['1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT', '1N0MZSqPRJpLfP4mFQCS14ikrVSe8vQlL']
for file_id in ids:
  downloaded = drive.CreateFile({'id':file_id})
  downloaded.FetchMetadata(fetch_all=True)
  downloaded.GetContentFile(downloaded.metadata['title'])

Cloning into 'StyleCLIP'...
remote: Enumerating objects: 978, done.[K
remote: Counting objects: 100% (360/360), done.[K
remote: Compressing objects: 100% (186/186), done.[K
remote: Total 978 (delta 224), reused 253 (delta 165), pack-reused 618[K
Receiving objects: 100% (978/978), 258.22 MiB | 10.38 MiB/s, done.
Resolving deltas: 100% (369/369), done.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ftfy
  Downloading ftfy-6.1.1-py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.1/53.1 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: ftfy
Successfully installed ftfy-6.1.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-i8pdpi0y
  Running command git clone --filter=blob:none --quiet https://

In [27]:
#@title Optimization
import ipywidgets as widgets
from IPython.display import display,clear_output
import torch
from optimization.run_optimization import main
from argparse import Namespace
from torchvision.utils import make_grid
from torchvision.transforms import ToPILImage

clip_lambda_widget = widgets.FloatText(
    value=1,
    description='CLIP lambda:',
    disabled=False
)
l2_lambda_widget = widgets.FloatText(
    value=0.0004,
    description='L2 lambda:',
    disabled=False
)
loc_lambda_widget = widgets.FloatText(
    value=0.00001,
    description='Loc lambda:',
    disabled=False
)
id_lambda_widget = widgets.FloatText(
    value=0.005,
    description='ID lambda:',
    disabled=False
)
use_stylespace_widget = widgets.Checkbox(
    value=True,
    description='Stylespace',
    disabled=False
)
optimization_steps_widget = widgets.IntText(
    value=40,
    description='Optimization Steps:',
    disabled=False
)

optimization_parameters_widgets = [clip_lambda_widget, l2_lambda_widget,loc_lambda_widget,id_lambda_widget,use_stylespace_widget,optimization_steps_widget] #lambdas, use stylespace, optimization steps
optimization_parameters_box = widgets.VBox(optimization_parameters_widgets)




editing_text_widget = widgets.Text(
    value='A person with black hair',
    placeholder='Type a description of the output image',
    description='Output image desc:',
    disabled=False
)

semantic_parts_options = ["mouth","skin","eyes","nose","ears","eye_brows","hat","hair","neck"]
semantic_parts_widget = widgets.SelectMultiple(
    options=semantic_parts_options,
    value=["hair"],
    #rows=10,
    description='Semantic Parts',
    disabled=False
)
editing_parameters_widgets = [editing_text_widget,semantic_parts_widget] # text, semantic part
editing_parameters_box = widgets.VBox(editing_parameters_widgets)


use_seed_widget = widgets.Checkbox(
    value=True,
    description='Use seed?',
    disabled=False
)

seed_widget = widgets.IntText(
    value=1,
    description='Seed:',
    disabled=False
)

latent_path_widget = widgets.Text(
    value=None,
    placeholder='latens/example.pth',
    description='Latent vector path',
    disabled=False
)
input_image_parameters_widgets = [use_seed_widget,seed_widget,latent_path_widget] # use seed? seed, latent path
input_image_parameters_box = widgets.VBox(input_image_parameters_widgets)



create_video_widget = widgets.Checkbox(
    value=False,
    description='Create video?',
    disabled=False
)

export_segmentation_out_widget = widgets.Checkbox(
    value=False,
    description='Export segmentation output?',
    disabled=False
)


output_parameters_widgets = [create_video_widget,export_segmentation_out_widget] #create_video? export_segmentation_image? range of scaling, 
output_parameters_box = widgets.VBox(output_parameters_widgets)

accordion = widgets.Accordion(children=[editing_parameters_box,optimization_parameters_box,input_image_parameters_box, output_parameters_box])
accordion.set_title(0, 'Editing Parameters')
accordion.set_title(1, 'Optimization Parameters')
accordion.set_title(2, 'Input Image Parameters')
accordion.set_title(3, 'Output Parameters')

generate_button_widget = widgets.Button(
    description='Generate Image',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Click to generate',
    icon='check'
)
out = widgets.Output()


results = None

def button_action(b):
  global results
  out.clear_output()
  with out:
    print("Started optimizing...")
    
    args = {
    "description": editing_text_widget.value,
    "ckpt": "stylegan2-ffhq-config-f.pt",
    "stylegan_size": 1024,
    "lr_rampup": 0.05,
    "lr": 0.1,
    "step": optimization_steps_widget.value,
    "clip_lambda": clip_lambda_widget.value,
    "l2_lambda": l2_lambda_widget.value,
    "id_lambda": id_lambda_widget.value,
    "loc_lambda": loc_lambda_widget.value,
    'work_in_stylespace': use_stylespace_widget.value,
    "latent_path": None if use_seed_widget.value else latent_path_widget.value,
    "truncation": 0.7,
    "save_intermediate_image_every": 1 if create_video_widget.value else 20,
    "results_dir": "results",
    "ir_se50_weights": "model_ir_se50.pth",
    "segmentation_model": "face_segmentation",
    "semantic_parts":semantic_parts_widget.value,
    "export_segmentation_image": export_segmentation_out_widget.value
    }
    if (use_seed_widget.value):
      torch.manual_seed(seed_widget.value)
    
    
    
    # run the optimization
    results = main(Namespace(**args))
    orig_res_imgs = torch.cat(
            [
                results["orig_image"],
                results["gen_image"],
            ]
    )
    result_image = ToPILImage()(make_grid(orig_res_imgs.detach().cpu(), normalize=True, scale_each=True, range=(-1, 1), padding=0))
    h, w = result_image.size
    display(result_image.resize((h // 2, w // 2)))
    
generate_button_widget.on_click(button_action)


In [28]:
display(accordion,generate_button_widget,out)

Accordion(children=(VBox(children=(Text(value='A person with black hair', description='Output image desc:', pl…

Button(description='Generate Image', icon='check', style=ButtonStyle(), tooltip='Click to generate')

Output()

In [36]:
#@title Scaling the latent vector
import numpy as np

def prepare_images_for_display(imgs):
  concatenated = torch.cat(imgs)
  result_image = ToPILImage()(make_grid(concatenated.detach().cpu(), normalize=True, scale_each=True, range=(-1, 1), padding=0))
  h, w = result_image.size
  im=result_image.resize((h // 2, w // 2))
  return im
def get_images_scaled(generator,latent_init,latent_dir,min_scale,max_scale,num_scales):
  imgs = []
  for scale in np.linspace (min_scale,max_scale, num_scales):
    is_stylespace = use_stylespace_widget.value
    new_latent = (
            [(latent_init[c] + scale * latent_dir[c]) for c in range(len(latent_init))]
            if (is_stylespace)
            else latent_init + scale * latent_dir
        )
    result_gen, _ = generator(
                    [new_latent],
                    input_is_latent=True,
                    randomize_noise=False,
                    input_is_stylespace=is_stylespace,
                )
    imgs.append(result_gen["image"])
  return imgs


def create_image_output_editing_form(results):
  num_scaling_steps_widget = widgets.IntSlider(
      value=3,
      min=3,
      max=9,
      step=2,
      description='Number of images',
      disabled=False,
      continuous_update=False,
      orientation='horizontal',
      readout=True,
      readout_format='d'
  )


  scaling_range_value_widget = widgets.FloatRangeSlider(
      value=[-1, 1],
      min=-2.0,
      max=2.0,
      step=0.001,
      description='Scaling range:',
      disabled=False,
      continuous_update=True,
      orientation='horizontal',
      readout=True,
      readout_format='.2f',
  )
  scale_button_widget =  widgets.Button(
    description='Generate Image',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Click to generate',
    icon='check'
)
  v_box = widgets.VBox([num_scaling_steps_widget,scaling_range_value_widget,scale_button_widget])
  out_scale = widgets.Output()
  def scale_button_action(b):
    if (results is None):
      return
    out_scale.clear_output(True)
    with out_scale:
      min_scale, max_scale = scaling_range_value_widget.value
      num_scales = num_scaling_steps_widget.value
      imgs = get_images_scaled(results["generator"],results["orig_latent"],results["latent_dir"],min_scale,max_scale,num_scales)
      display(prepare_images_for_display(imgs))
      torch.cuda.empty_cache()

  scale_button_widget.on_click(scale_button_action)
  display(v_box,out_scale)
  with out_scale:
    scale_button_action(None)



  

In [37]:
create_image_output_editing_form(results)

VBox(children=(IntSlider(value=3, continuous_update=False, description='Number of images', max=9, min=3, step=…

Output()

In [7]:
#@title Create and Download Video

#!ffmpeg -r 15 -i results/%05d.jpg -c:v libx264 -vf fps=25 -pix_fmt yuv420p out.mp4
#from google.colab import files
#files.download('out.mp4')