In [None]:
!git clone https://github.com/rosinality/style-based-gan-pytorch.git

In [None]:
!mkdir -p ./data

In [None]:
!mkdir -p ./sample

In [None]:
!mkdir -p ./checkpoint

In [None]:
import argparse
from io import BytesIO
import multiprocessing
from functools import partial

from PIL import Image
import lmdb
from tqdm import tqdm
from torchvision import datasets
from torchvision.transforms import functional as trans_fn


def resize_and_convert(img, size, quality=100, square_crop_ratio = 0.1):
#    img = trans_fn.resize(img, size, Image.LANCZOS)
#    img = trans_fn.center_crop(img, size)
    img = trans_fn.resize(img, round(size * (1 + square_crop_ratio)), Image.LANCZOS)
    img = trans_fn.crop(img, 0 , size * square_crop_ratio / 2, size, size)
    buffer = BytesIO()
    img.save(buffer, format='jpeg', quality=quality)
    val = buffer.getvalue()

    return val


def resize_multiple(img, sizes=(8, 16, 32, 64, 128, 256, 512, 1024), quality=100):
    imgs = []

    for size in sizes:
        imgs.append(resize_and_convert(img, size, quality))

    return imgs


def resize_worker(img_file, sizes):
    i, file = img_file
    img = Image.open(file)
    img = img.convert('RGB')
    out = resize_multiple(img, sizes=sizes)

    return i, out


def prepare(transaction, dataset, n_worker, sizes=(8, 16, 32, 64, 128, 256, 512, 1024)):
    resize_fn = partial(resize_worker, sizes=sizes)

    files = sorted(dataset.imgs, key=lambda x: x[0])
    files = [(i, file) for i, (file, label) in enumerate(files)]
    total = 0

    with multiprocessing.Pool(n_worker) as pool:
        for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
            for size, img in zip(sizes, imgs):
                key = f'{size}-{str(i).zfill(5)}'.encode('utf-8')
                transaction.put(key, img)

            total += 1

        transaction.put('length'.encode('utf-8'), str(total).encode('utf-8'))

In [None]:
import matplotlib.pyplot as plt

size = 1024
square_crop_ratio = 0.1

img = Image.open("../input/naruto-english/naruto/cover-19.jpeg")
img = img.convert('RGB')
img = trans_fn.resize(img, round(size * (1 + square_crop_ratio)), Image.LANCZOS)
img = trans_fn.crop(img, 0 , size * square_crop_ratio / 2, size, size)
plt.figure()
plt.imshow(img)
plt.show()

In [None]:
input_path = "../input/naruto-english"
output_path = "./data"
n_worker = 8
imgset = datasets.ImageFolder(input_path)
with lmdb.open(output_path, map_size=1024 ** 4, readahead=False) as env:
        with env.begin(write=True) as txn:
            prepare(txn, imgset, n_worker)

In [None]:
!python ./style-based-gan-pytorch/train.py --mixing ./data --phase 50000 --max_size 512

In [None]:
!ls ./checkpoint