## Notebook used for generating sample data for FID comparison

In [1]:
import os
from pathlib import Path
import numpy as np
import cv2

In [2]:
%pip install pytorch-fid

In [1]:
def copy_subset(src, dst, size, count=None):
    src = Path(src)
    dst = Path(dst)
    dst.mkdir(parents=True, exist_ok=True)
    objs = sorted(os.listdir(src))
    if count is not None:
        objs = np.random.choice(objs, count, replace=False)
    for o in objs:
        img = cv2.imread(str(src/o))
        img = cv2.resize(img, size)
        cv2.imwrite(str(dst/o), img)

In [2]:
def resize_all(folder, width, height):
    src = Path(folder)
    for o in os.listdir(folder):
        img = cv2.imread(str(src/o))
        img = cv2.resize(img, (width, height))
        cv2.imwrite(str(src/o), img)

In [5]:
# resize_all('../data/CelebA/val/faces/', 256, 256)

### Uncomment below to generate a random sample set of ground truth faces

In [6]:
# copy_subset('../data/CelebA/train/faces/', '../data/Samples/ground_truth_faces', (256, 256), count=16384)

### Uncomment below to generate two random ground truth datasets of paprika style

In [7]:
# copy_subset('../data/CelebA/train/paprika', '../data/Samples/ground_truth_paprika1', (256, 256), count=16384)

In [8]:
# copy_subset('../data/CelebA/train/paprika', '../data/Samples/ground_truth_paprika2', (256, 256), count=16384)

### Uncomment below to generate two random subsets of webtoon style

In [9]:
# copy_subset('../data/CelebA/train/webtoon', '../data/Samples/ground_truth_webtoon1', (256, 256), count=16384)

In [10]:
# copy_subset('../data/CelebA/train/webtoon', '../data/Samples/ground_truth_webtoon2', (256, 256), count=16384)

### Uncomment below to generate two random subsets of face v2 style

In [11]:
# copy_subset('../data/CelebA/train/face_v2', '../data/Samples/ground_truth_face_v2_1', (256, 256), count=16384)

In [12]:
# copy_subset('../data/CelebA/train/face_v2', '../data/Samples/ground_truth_face_v2_2', (256, 256), count=16384)

In [3]:
import os
import random
from pathlib import Path
from urllib.request import urlretrieve, urlcleanup
from zipfile import ZipFile

import albumentations as albm
import cv2
import matplotlib.pyplot as plt
import numpy as np
from pycocotools.coco import COCO
import torch
import torch.nn as nn
import torchvision
from tqdm.notebook import tqdm
from PIL import Image


import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np

%config InlineBackend.figure_format = 'retina'

RANDOM_SEED = 1337
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

# Flip values for slower training speed, but more determenistic results.
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

In [4]:
DEVICE = torch.device('cpu')
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
    torch.cuda.manual_seed(123)

In [5]:
def translate_subset(model, src, dst, size, count=None):
    model = model.to(DEVICE)
    src = Path(src)
    dst = Path(dst)
    dst.mkdir(parents=True, exist_ok=True)
    objs = sorted(os.listdir(src))
    if count is not None:
        objs = np.random.choice(objs, count, replace=False)
    for o in objs:
        img = cv2.imread(str(src/o))[:, :, ::-1]
        img = cv2.resize(img, size)
        imageT = torchvision.transforms.ToTensor()(Image.fromarray(img)).unsqueeze(0).to(DEVICE)
        output = np.uint8(model(imageT).squeeze(0).detach().permute(1, 2, 0).cpu().numpy() * 255.)
        cv2.imwrite(str(dst/o), output[:, :, ::-1])

In [6]:
# We also need to replace Mobilenet's ReLU6 activations with ReLU. 
# There is no noticeable difference in quality, but this will
# allow us to use CoreML for mobile inference on iOS devices.
def replace_relu6_with_relu(model):
    for name, module in reversed(model._modules.items()):
        if len(list(module.children())) > 0:
            model._modules[name] = replace_relu6_with_relu(model=module)
        if isinstance(module, nn.ReLU6):
            model._modules[name] = nn.ReLU()
    return model


class AnimeNet(nn.Module):
    def __init__(self):
        super().__init__()
        mobilenet = torchvision.models.mobilenet_v2(width_mult=0.5)

        # We reuse state dict from mobilenet v2 width width_mult == 1.0.
        # This is not the optimal way to use pretrained models, but in this case
        # it gives us good initialization for faster convergence.
        state_dict = torchvision.models.mobilenet_v2(pretrained=True).state_dict()
        target_dict = mobilenet.state_dict()
        for k in target_dict.keys():
            if len(target_dict[k].size()) == 0:
                continue
            state_dict[k] = state_dict[k][:target_dict[k].size(0)]
            if len(state_dict[k].size()) > 1:
                state_dict[k] = state_dict[k][:, :target_dict[k].size(1)]

        mobilenet.load_state_dict(state_dict)

        weight = mobilenet.features[0][0].weight.detach()
        # mobilenet.features[0][0].weight = nn.Parameter(data=weight / 255.)

        mobilenet = replace_relu6_with_relu(mobilenet)

        self.features = mobilenet.features[:-2]
        self.upscale0 = nn.Sequential(
            nn.Conv2d(80, 48, 1, 1, 0, bias=False),
            nn.BatchNorm2d(48),
            nn.ReLU()
        )
        self.upscale1 = nn.Sequential(
            nn.Conv2d(48, 16, 3, 1, 1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU()
        )
        self.upscale2 = nn.Sequential(
            nn.Conv2d(16, 16, 3, 1, 1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU()
        )
        self.upscale3 = nn.Sequential(
            nn.Conv2d(16, 8, 3, 1, 1, bias=False),
            nn.BatchNorm2d(8),
            nn.ReLU()
        )
        self.upscale4 = nn.Sequential(
            nn.Conv2d(8, 4, 3, 1, 1, bias=False),
            nn.BatchNorm2d(4),
            nn.ReLU()
        )
        self.upscale5 = nn.Conv2d(4, 3, 3, 1, 1, bias=True)

    def forward(self, x):
        out = x
        skip_outs = []
        for i in range(len(self.features)):
            out = self.features[i](out)
            if i in {1, 3, 6, 13}:
                skip_outs.append(out)
        out = self.upscale0(out)
        out = nn.functional.interpolate(out, scale_factor=2, mode='nearest')
        out = self.upscale1(out + skip_outs[3])
        out = nn.functional.interpolate(out, scale_factor=2, mode='nearest')
        out = self.upscale2(out + skip_outs[2])
        out = nn.functional.interpolate(out, scale_factor=2, mode='nearest')
        out = self.upscale3(out + skip_outs[1])
        out = nn.functional.interpolate(out, scale_factor=2, mode='nearest')
        out = self.upscale4(out + skip_outs[0])
        out = nn.functional.interpolate(out, scale_factor=2, mode='nearest')
        out = self.upscale5(out)
        return torch.sigmoid(out)

In [None]:
class WideAnimeNet(nn.Module):
    def __init__(self):
        super().__init__()
        mobilenet = torchvision.models.mobilenet_v2(width_mult=0.75)

        # We reuse state dict from mobilenet v2 width width_mult == 1.0.
        # This is not the optimal way to use pretrained models, but in this case
        # it gives us good initialization for faster convergence.
        state_dict = torchvision.models.mobilenet_v2(pretrained=True).state_dict()
        target_dict = mobilenet.state_dict()
        for k in target_dict.keys():
            if len(target_dict[k].size()) == 0:
                continue
            state_dict[k] = state_dict[k][:target_dict[k].size(0)]
            if len(state_dict[k].size()) > 1:
                state_dict[k] = state_dict[k][:, :target_dict[k].size(1)]

        mobilenet.load_state_dict(state_dict)

        weight = mobilenet.features[0][0].weight.detach()

        mobilenet = replace_relu6_with_relu(mobilenet)

        self.features = mobilenet.features[:-2]
        self.upscale0 = nn.Sequential(
            nn.Conv2d(120, 72, 1, 1, 0, bias=False),
            nn.BatchNorm2d(72),
            nn.ReLU()
        )
        self.upscale1 = nn.Sequential(
            nn.Conv2d(72, 24, 3, 1, 1, bias=False),
            nn.BatchNorm2d(24),
            nn.ReLU()
        )
        self.upscale2 = nn.Sequential(
            nn.Conv2d(24, 24, 3, 1, 1, bias=False),
            nn.BatchNorm2d(24),
            nn.ReLU()
        )
        self.upscale3 = nn.Sequential(
            nn.Conv2d(24, 16, 3, 1, 1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU()
        )
        self.upscale4 = nn.Sequential(
            nn.Conv2d(16, 4, 3, 1, 1, bias=False),
            nn.BatchNorm2d(4),
            nn.ReLU()
        )
        self.upscale5 = nn.Conv2d(4, 3, 3, 1, 1, bias=True)

    def forward(self, x):
        out = x
        skip_outs = []
        for i in range(len(self.features)):
            out = self.features[i](out)
            if i in {1, 3, 6, 13}:
                skip_outs.append(out)
        out = self.upscale0(out)
        out = nn.functional.interpolate(out, scale_factor=2, mode='nearest')
        out = self.upscale1(out + skip_outs[3])
        out = nn.functional.interpolate(out, scale_factor=2, mode='nearest')
        out = self.upscale2(out + skip_outs[2])
        out = nn.functional.interpolate(out, scale_factor=2, mode='nearest')
        out = self.upscale3(out + skip_outs[1])
        out = nn.functional.interpolate(out, scale_factor=2, mode='nearest')
        out = self.upscale4(out + skip_outs[0])
        out = nn.functional.interpolate(out, scale_factor=2, mode='nearest')
        out = self.upscale5(out)
        return torch.sigmoid(out)


In [7]:
model = AnimeNet()

In [18]:
PAPRIKA = 'paprika.pth'
FACE_V2 = 'face_v2.pth'
WEBTOON = 'webtoon.pth'

In [9]:
# model.load_state_dict(torch.load('mobilenetv2_256_sobel2_anime.pth'))
model.load_state_dict(torch.load('face_v2_2.pth'))

<All keys matched successfully>

Generate a dataset from our trained model for evaluation (using FID) against ground truth

In [10]:
translate_subset(model, '../data/Samples/ground_truth_faces', '../data/Samples/face_v2_3', (256, 256))