## 환경 설정

In [None]:
import torch
from torchvision import transforms
import torchvision
import PIL
from PIL import Image
import matplotlib.pyplot as plt
import pathlib
import os
import pickle
import warnings
import copy
import numpy as np
import math
warnings.filterwarnings('ignore')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd /content/drive/MyDrive/StyleGAN2/StyleGAN2-ada_Toonify

/content/drive/MyDrive/StyleGAN2/StyleGAN2-ada_Toonify


In [None]:
!pip install ninja

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ninja
  Downloading ninja-1.11.1-py2.py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (145 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m146.0/146.0 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ninja
Successfully installed ninja-1.11.1


In [None]:
import ninja

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"device: {device}")

device: cuda


## Network blending

In [None]:
model_list=list(pathlib.Path('/content/drive/MyDrive/StyleGAN2/StyleGAN2-ada_Toonify/pretrained').glob('*.pkl'))
model_list

[PosixPath('/content/drive/MyDrive/StyleGAN2/StyleGAN2-ada_Toonify/pretrained/ffhq.pkl'),
 PosixPath('/content/drive/MyDrive/StyleGAN2/StyleGAN2-ada_Toonify/pretrained/metfaces.pkl')]

In [None]:
with open(model_list[0], 'rb') as f:
    G1 = pickle.load(f)['G_ema'].requires_grad_(False).to(device)#.cuda()  # torch.nn.Module ##FFHQ model
    
with open(model_list[1], 'rb') as f:
    G2 = pickle.load(f)['G_ema'].requires_grad_(False).to(device)#.cuda()  # torch.nn.Module ##metfaces model

In [None]:
def get_conv_names(model,max_dim):
    val = 1 + int(np.log2(max_dim / 4))  # 9 for 1024, 8 for 512
    resolutions = [4 * 2 ** x for x in range(val)]
    names = [x[0] for x in list(model.named_parameters())] #fine tuned networks
    level_names = [["conv0", "const"],["conv1", "torgb"]]
    position=0
    conv_names=[]
    for res in resolutions:
        rootname = f"synthesis.b{res}."
        for level, level_suffixes in enumerate(level_names): #conv0,const -> level 1 // conv1,torgb -> level 2
            for suffix in level_suffixes:
                searchname = rootname+suffix
                matches = [x for x in names if x.startswith(searchname)]
                info_tuples = [(name, f"b{res}",level,position) for name in matches]
                conv_names.extend(info_tuples)
            position+=1
    return conv_names

In [None]:
def blend(G1, G2, resolution, level, network_size=1024, blend_width=None, verbose=True):

  model1_names=get_conv_names(G1,1024)
  model2_names=get_conv_names(G2,1024)

  assert model1_names==model2_names

  output_model = copy.deepcopy(G1)

  short_names = [(x[1:3]) for x in model1_names]
  full_names = [(x[0]) for x in model1_names]
  mid_point_idx = short_names.index((f'b{resolution}', level))
  mid_point_pos = model1_names[mid_point_idx][3]

  ys=[]
  for name, resolution, level, position in model1_names:
    x = position - mid_point_pos 
    if blend_width: #blend_width = None : hard blend / = float : soft blend(logistic)
        exponent = -x/blend_width #blend_width 작을수록 G2영향력 증가 
        y = 1 / (1 + math.exp(exponent))
    else:
        y = 1 if x > 1 else 0
    ys.append(y)
    if verbose:
        print(f"Blending {name} by {y}")
  #position이 mid_point보다 클수록 exponent 작아짐 -> y가 커짐 -> G2의 영향력 증가 = mid_point기준으로 해상도 큰 부분 G2로 & 해상도 작은 부분 G1으로
  new_model_state_dict = output_model.state_dict()
  for name, y in zip(full_names,ys):
    new_model_state_dict[name] = G2.state_dict()[name]*y+G1.state_dict()[name]*(1-y)

  output_model.load_state_dict(new_model_state_dict)

  return output_model

In [None]:
def get_image(model, z=None, label=0, truncation_psi=0.7, noise_mode="const", w=None, is_w=False):
    if is_w:
        img = model.synthesis(w.unsqueeze(0), noise_mode=noise_mode)
    else:
        img = model(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
    img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    imgfile = PIL.Image.fromarray(img[0].cpu().numpy(), "RGB")

    return imgfile

In [None]:
resolutions=[64, 128]

blended_models={}
for resolution in resolutions:
  blended_models[f'b{resolution}']=blend(G1,G2,resolution,level=0,network_size=512,blend_width=0.7,verbose=False)

In [None]:
file_num_start=0
file_num_end=300

In [None]:
c=None
m=blended_models['b128']
for num in range(file_num_start,file_num_end):
  z = torch.randn([1, G1.z_dim],device=device)
  ws = G1.mapping(z,c,truncation_psi=0.5, truncation_cutoff=8)
  img_ori = G1.synthesis(ws, noise_mode='const', force_fp32=True)
  img_ori=img_ori.to('cpu')
  img_ori = np.array(img_ori[0].permute(1,2,0))

  normalized_data = (np.array(img_ori) - np.min(img_ori)) / (np.max(img_ori) - np.min(img_ori))
  normalized_data = (normalized_data * 255).astype(np.uint8)
  img_ori = Image.fromarray(normalized_data)
  img_ori=img_ori.resize((512,512))

  assert ws.shape[1:] == (m.num_ws, m.w_dim)
  img_trans=get_image(model=m,w=ws[0],is_w=True)
  img_trans=img_trans.resize((512,512))

  concatenated_image = Image.new('RGB', (1024, 512))
  concatenated_image.paste(img_ori, (0, 0))
  concatenated_image.paste(img_trans, (512, 0))

  concatenated_image.save(f'/content/drive/MyDrive/StyleGAN2/StyleGAN2-ada_Toonify/pair_dataset/metface/{num}.jpg')

  plt.imshow(concatenated_image)
  plt.show()

Output hidden; open in https://colab.research.google.com to view.