In [23]:
from model import Generator
## Import package
import torch
import torchvision
from torch import optim
import math
import clip
from tqdm import tqdm
import os


In [24]:
g_ema = Generator(1024, 512, 8)
g_ema.load_state_dict(torch.load('/home/zhaoxiang/CLIP_AD/StyleCLIP/pretrained_models/stylegan2-ffhq-config-f.pt')["g_ema"], strict=False)

print(g_ema)

Generator(
  (style): Sequential(
    (0): PixelNorm()
    (1): EqualLinear(512, 512)
    (2): EqualLinear(512, 512)
    (3): EqualLinear(512, 512)
    (4): EqualLinear(512, 512)
    (5): EqualLinear(512, 512)
    (6): EqualLinear(512, 512)
    (7): EqualLinear(512, 512)
    (8): EqualLinear(512, 512)
  )
  (input): ConstantInput()
  (conv1): StyledConv(
    (conv): ModulatedConv2d(512, 512, 3, upsample=False, downsample=False)
    (noise): NoiseInjection()
    (activate): FusedLeakyReLU()
  )
  (to_rgb1): ToRGB(
    (conv): ModulatedConv2d(512, 3, 1, upsample=False, downsample=False)
  )
  (convs): ModuleList(
    (0): StyledConv(
      (conv): ModulatedConv2d(512, 512, 3, upsample=True, downsample=False)
      (noise): NoiseInjection()
      (activate): FusedLeakyReLU()
    )
    (1): StyledConv(
      (conv): ModulatedConv2d(512, 512, 3, upsample=False, downsample=False)
      (noise): NoiseInjection()
      (activate): FusedLeakyReLU()
    )
    (2): StyledConv(
      (conv): Modul

In [25]:
# Clip loss
class CLIPLoss(torch.nn.Module):

    def __init__(self):
        super(CLIPLoss, self).__init__()
        self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
        self.upsample = torch.nn.Upsample(scale_factor=7)
        self.avg_pool = torch.nn.AvgPool2d(kernel_size=1024 // 32)          # stylegan size = 1024

    def forward(self, image, text):
        image = self.avg_pool(self.upsample(image))
        similarity = 1 - self.model(image, text)[0] / 100
        return similarity

In [26]:
# Parameteres
step = 300
lr = 0.1
l2_lambda = 0.008

work_in_stylespace = False

In [27]:
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
    lr_ramp = min(1, (1 - t) / rampdown)
    lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
    lr_ramp = lr_ramp * min(1, t / rampup)

    return initial_lr * lr_ramp

In [28]:
description = 'A person with blue hair'
text_inputs = torch.cat([clip.tokenize(description)]).cuda()

## Initialize the latend code
# g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False)
g_ema.eval()
g_ema = g_ema.cuda()
mean_latent = g_ema.mean_latent(4096)
latent_code_init = mean_latent.detach().clone().repeat(1, 18, 1)
latent = latent_code_init.detach().clone()
latent.requires_grad = True

latent_code_init_not_trunc = torch.randn(1, 512).cuda()
with torch.no_grad():
    _, latent_code_init, _ = g_ema([latent_code_init_not_trunc], return_latents=True,
                                truncation=0.7, truncation_latent=mean_latent)
    
    img_orig, _ = g_ema([latent_code_init], input_is_latent=True, randomize_noise=False)


In [29]:
# Loss
clip_loss = CLIPLoss()
optimizer = optim.Adam([latent], lr=lr)

In [30]:
# Image Editting loop
pbar = tqdm(range(step))

for i in pbar:
    t = i / step
    lr = get_lr(t, lr)
    optimizer.param_groups[0]["lr"] = lr

    img_gen, _ = g_ema([latent], input_is_latent=True, randomize_noise=False, input_is_stylespace=work_in_stylespace)

    c_loss = clip_loss(img_gen, text_inputs)
    
    l2_loss = ((latent_code_init - latent) ** 2).sum()
    loss = c_loss + l2_lambda * l2_loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    pbar.set_description(
        (
            f"loss: {loss.item():.4f};"
        )
    )
    with torch.no_grad():
        img_gen, _ = g_ema([latent], input_is_latent=True, randomize_noise=False, input_is_stylespace=work_in_stylespace)

    torchvision.utils.save_image(img_gen, f"/home/zhaoxiang/CLIP_AD/output/styleGAN/{str(i).zfill(5)}.jpg", normalize=True, range=(-1, 1))

final_result = torch.cat([img_orig, img_gen])

torchvision.utils.save_image(final_result.detach().cpu(), os.path.join('/home/zhaoxiang/CLIP_AD/output/styleGAN' ,"final_result.jpg"), normalize=True, scale_each=True, range=(-1, 1))

loss: 8.9297;:   2%|▏         | 7/300 [00:19<11:15,  2.31s/it]