# Start

In [None]:
! git clone https://github.com/williamyang1991/DualStyleGAN.git
%cd DualStyleGAN

/content/drive/MyDrive/nowar/DualStyleGAN


In [None]:
! pip install wandb
! pip install lmdb

Collecting wandb
  Downloading wandb-0.16.0-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.40-py3-none-any.whl (190 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.6/190.6 kB[0m [31m21.9 MB/s[0m eta [36m0:00:00[0m
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-1.35.0-py2.py3-none-any.whl (248 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m248.6/248.6 kB[0m [31m18.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting setproctitle (from wandb)
  Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)
Collecting gitdb<5,>=4.0.1 (from GitPython!=3.1.29,>=1.0.0->w

In [None]:
# !wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
!sudo unzip ninja-linux.zip -d /usr/local/bin/
!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force

Archive:  ninja-linux.zip
  inflating: /usr/local/bin/ninja    
update-alternatives: using /usr/local/bin/ninja to provide /usr/bin/ninja (ninja) in auto mode


In [None]:
! python ./model/stylegan/prepare_data.py --out ./data/nowar_soldier/lmdb/ --n_worker 4 --size 1024 ./data/nowar_soldier/images/

Make dataset of image sizes: 1024
170it [00:12, 13.61it/s]


# 1. Fine-tuning

In [None]:
! wandb login

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
import argparse
import math
import random
import os

import numpy as np
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils import data
import torch.distributed as dist
from torchvision import transforms, utils
from tqdm import tqdm
from util import data_sampler, requires_grad, accumulate, sample_data, d_logistic_loss, d_r1_loss, g_nonsaturating_loss, g_path_regularize, make_noise, mixing_noise, set_grad_none

In [None]:
try:
    import wandb

except ImportError:
    wandb = None

In [None]:
from model.stylegan.dataset import MultiResolutionDataset
from model.stylegan.distributed import (
    get_rank,
    synchronize,
    reduce_loss_dict,
    reduce_sum,
    get_world_size,
)
from model.stylegan.non_leaking import augment, AdaptiveAugment
from model.stylegan.model import Generator, Discriminator

In [None]:
# soldier style
style = 'nowar_soldier'
path = './data/nowar_soldier/lmdb/'

In [None]:
# victim style
style = 'nowar_victim'
path = './data/nowar_victim/lmdb/'

In [None]:
model_path = './checkpoint/'
local_rank = 0
size = 1024
channel_multiplier = 2
g_reg_every = 4
d_reg_every = 16
lr = 0.002
ckpt_path = './checkpoint/stylegan2-ffhq-config-f.pt'
iter = 400
batch = 4
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') # cuda 환경에서만 작동

In [None]:
# log 파일 저장할 폴더 만들기
if not os.path.exists("log/%s/"%(style)):
    os.makedirs("log/%s/"%(style))
    print("created log folder")
# model ckpt파일 저장할 폴더 만들기 (checkpoint/stylename/)
if not os.path.exists("%s/%s/"%(model_path, style)):
    os.makedirs("%s/%s/"%(model_path, style))
    print("created ckpt folder")

n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 # gpu 개수
distributed = n_gpu > 1 # gpu가 여러 개일 경우, distributed

if distributed:
    torch.cuda.set_device(local_rank)
    torch.distributed.init_process_group(backend="nccl", init_method="env://")
    synchronize()

In [None]:
latent = 512
n_mlp = 8

start_iter = 0

#if arch == 'stylegan2':
    #from model.stylegan.model import Generator, Discriminator

#elif arch == 'swagan':
    #from swagan import Generator, Discriminator


# styleGAN2 불러오기
generator = Generator(
    size, latent, n_mlp, channel_multiplier=channel_multiplier
).to(device)
discriminator = Discriminator(
    size, channel_multiplier=channel_multiplier
).to(device)
g_ema = Generator(
    size, latent, n_mlp, channel_multiplier=channel_multiplier
).to(device)
g_ema.eval()
accumulate(g_ema, generator, 0)

g_reg_ratio = g_reg_every / (g_reg_every + 1)
d_reg_ratio = d_reg_every / (d_reg_every + 1)

g_optim = optim.Adam(
    generator.parameters(),
    lr=lr * g_reg_ratio,
    betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
)
d_optim = optim.Adam(
    discriminator.parameters(),
    lr=lr * d_reg_ratio,
    betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
)

In [None]:
# ffhq 데이터로 pretrained된 stylegan 모델 불러오기
if ckpt_path is not None:
    print("load model:", ckpt_path)

    ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)

    try:
        ckpt_name = os.path.basename(ckpt_path)
        start_iter = int(os.path.splitext(ckpt_name)[0])

    except ValueError:
        pass

    generator.load_state_dict(ckpt["g"])
    discriminator.load_state_dict(ckpt["d"])
    g_ema.load_state_dict(ckpt["g_ema"])

    if "g_optim" in ckpt:
        g_optim.load_state_dict(ckpt["g_optim"])
    if "d_optim" in ckpt:
        d_optim.load_state_dict(ckpt["d_optim"])

load model: ./checkpoint/stylegan2-ffhq-config-f.pt


In [None]:
# 데이터셋 준비
if distributed:
    generator = nn.parallel.DistributedDataParallel(
        generator,
        device_ids=[local_rank],
        output_device=local_rank,
        broadcast_buffers=False,
    )

    discriminator = nn.parallel.DistributedDataParallel(
        discriminator,
        device_ids=[local_rank],
        output_device=local_rank,
        broadcast_buffers=False,
    )

transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ]
)

dataset = MultiResolutionDataset(path, transform, size)
loader = data.DataLoader(
    dataset,
    batch_size=batch,
    sampler=data_sampler(dataset, shuffle=True, distributed=distributed),
    drop_last=True,
)

In [None]:
if get_rank() == 0 and wandb is not None and wandb:
    wandb.init(project="stylegan 2")

[34m[1mwandb[0m: Currently logged in as: [33meunai9[0m. Use [1m`wandb login --relogin`[0m to force relogin


## train 함수

In [None]:
# !python -m torch.distributed.launch --nproc_per_node=1 --master_port=8765 finetune_stylegan.py --iter 600 \
# --batch 4 --ckpt ./checkpoint/stylegan2-ffhq-config-f.pt --style nowar --augment ./data/nowar/lmdb/

In [None]:
augment_p = 0
ada_target = 0.6; ada_length = 500 * 1000
n_sample = 9
mixing = 0.9
path_batch_shrink = 2; path_regularize = 2
save_every = 100
r1 = 10

In [None]:
pbar = range(iter)

if get_rank() == 0:
    pbar = tqdm(pbar, initial=start_iter, ncols=140, dynamic_ncols=False, smoothing=0.01)

mean_path_length = 0

d_loss_val = 0
r1_loss = torch.tensor(0.0, device=device)
g_loss_val = 0
path_loss = torch.tensor(0.0, device=device)
path_lengths = torch.tensor(0.0, device=device)
mean_path_length_avg = 0
loss_dict = {}

if distributed:
    g_module = generator.module
    d_module = discriminator.module

else:
    g_module = generator
    d_module = discriminator

accum = 0.5 ** (32 / (10 * 1000))
ada_aug_p = augment_p if augment_p > 0 else 0.0
r_t_stat = 0

if augment and augment_p == 0:
    ada_augment = AdaptiveAugment(ada_target, ada_length, 8, device)

sample_z = torch.randn(n_sample, latent, device=device)

  0%|                                                                                                               | 0/400 [00:00<?, ?it/s]

In [None]:
for idx in pbar:
    i = idx + start_iter

    if i > iter:
        print("Done!")

        break

    real_img = next(enumerate(loader))[1]
    real_img = real_img.to(device)

    requires_grad(generator, False)
    requires_grad(discriminator, True)

    noise = mixing_noise(batch, latent, mixing, device) #latent vector
    fake_img, _ = generator(noise) #generator로 fake img 생성

    if augment:
        real_img_aug, _ = augment(real_img, ada_aug_p)
        fake_img, _ = augment(fake_img, ada_aug_p)

    else:
        real_img_aug = real_img

    fake_pred = discriminator(fake_img) # fake img 판별
    real_pred = discriminator(real_img_aug) # real img(nowar 데이터) 판별
    d_loss = d_logistic_loss(real_pred, fake_pred) # loss 추출

    loss_dict["d"] = d_loss
    loss_dict["real_score"] = real_pred.mean()
    loss_dict["fake_score"] = fake_pred.mean()

    discriminator.zero_grad()
    d_loss.backward()
    d_optim.step() #discriminator train

    if augment and augment_p == 0:
        ada_aug_p = ada_augment.tune(real_pred)
        r_t_stat = ada_augment.r_t_stat

    d_regularize = i % d_reg_every == 0

    if d_regularize:
        real_img.requires_grad = True

        if augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)

        else:
            real_img_aug = real_img

        real_pred = discriminator(real_img_aug)
        r1_loss = d_r1_loss(real_pred, real_img)

        discriminator.zero_grad()
        (r1 / 2 * r1_loss * d_reg_every + 0 * real_pred[0]).backward()

        d_optim.step()

    loss_dict["r1"] = r1_loss

    requires_grad(generator, True)
    requires_grad(discriminator, False)

    noise = mixing_noise(batch, latent, mixing, device) # latent vector
    fake_img, _ = generator(noise) # generator로 fake img 생성

    if augment:
        fake_img, _ = augment(fake_img, ada_aug_p)

    fake_pred = discriminator(fake_img) # fake img 판별
    g_loss = g_nonsaturating_loss(fake_pred) # generator를 위한 loss

    loss_dict["g"] = g_loss

    generator.zero_grad()
    g_loss.backward()
    g_optim.step() # generator train

    g_regularize = i % g_reg_every == 0

    if g_regularize:
        path_batch_size = max(1, batch // path_batch_shrink)
        noise = mixing_noise(path_batch_size, latent, mixing, device)
        fake_img, latents = generator(noise, return_latents=True)

        path_loss, mean_path_length, path_lengths = g_path_regularize(
            fake_img, latents, mean_path_length
        )

        generator.zero_grad()
        weighted_path_loss = path_regularize * g_reg_every * path_loss

        if path_batch_shrink:
            weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

        weighted_path_loss.backward()

        g_optim.step()

        mean_path_length_avg = (
            reduce_sum(mean_path_length).item() / get_world_size()
        )

    loss_dict["path"] = path_loss
    loss_dict["path_length"] = path_lengths.mean()

    accumulate(g_ema, g_module, accum)

    loss_reduced = reduce_loss_dict(loss_dict)

    d_loss_val = loss_reduced["d"].mean().item()
    g_loss_val = loss_reduced["g"].mean().item()
    r1_val = loss_reduced["r1"].mean().item()
    path_loss_val = loss_reduced["path"].mean().item()
    real_score_val = loss_reduced["real_score"].mean().item()
    fake_score_val = loss_reduced["fake_score"].mean().item()
    path_length_val = loss_reduced["path_length"].mean().item()

    if get_rank() == 0:
        pbar.set_description(
            (
                f"iter: {i:05d}; d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}"
            )
        )

        if wandb and wandb:
            wandb.log(
                {
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                }
            )

        if i % 100 == 0 or (i+1) == iter:
            with torch.no_grad():
                g_ema.eval()
                sample, _ = g_ema([sample_z])
                sample = F.interpolate(sample,256)
                utils.save_image(
                    sample,
                    f"log/%s/finetune-%06d.jpg"%(style, i),
                    nrow=int(n_sample ** 0.5),
                    normalize=True,
                    #range=(-1, 1),
                )

        if (i+1) % save_every == 0 or (i+1) == iter:
            torch.save(
                {
                    #"g": g_module.state_dict(),
                    #"d": d_module.state_dict(),
                    "g_ema": g_ema.state_dict(),
                    #"g_optim": g_optim.state_dict(),
                    #"d_optim": d_optim.state_dict(),
                    #"args": args,
                    #"ada_aug_p": ada_aug_p,
                },
                f"%s/%s/finetune-%06d.pt"%(model_path, style, i+1),
            )

iter: 00399; d: 3.5694; g: 0.3300; r1: 0.0182; path: 0.0050; mean path: 0.1428; augment: 0.0026: 100%|████| 400/400 [27:03<00:00,  4.06s/it]
