In [None]:
import matplotlib.pyplot as plt
# %matplotlib inline 
import numpy as np
import torch

import torch.nn.functional as F
from torchvision.transforms import Compose, Resize, Normalize, ToTensor
from torch.utils import data
import gc

import wandb

from tqdm import tqdm_notebook as tqdm
from IPython.display import clear_output

In [None]:
from src.data import DefaultDataset
from src.data import LoaderSampler

from src.resnet import ResNet_D
from src.unet import UNet

from src.tools import fig2data, fig2img

from src.tools import freeze, unfreeze, weights_init_D, plot_images, plot_random_images

In [None]:
assert torch.cuda.is_available()
torch.cuda.set_device(f'cuda:0')

In [None]:
# ! bash download.sh celeba-hq-dataset

In [None]:
# ! kaggle datasets download reitanaka/alignedanimefaces
# ! unzip /home/sudakovcom/Desktop/diffusion/NOT/NeuralOptimalTransport/datasets/alignedanimefaces.zip

## Preprocessing of aligned anime faces
Cropping & aligning with celeba (rescaled) faces

In [None]:
from PIL import Image
import os
from tqdm import tqdm_notebook

In [None]:
def center_crop(im, size):
    left = int(im.size[0]/2-size/2)
    upper = int(im.size[1]/2-size/2)
    right = left + size
    lower = upper + size
    
    return im.crop((left, upper,right,lower))

def noncenter_crop(im, size, shift=(0,0)):
    left = int(im.size[0]/2-size/2) + shift[0]
    upper = int(im.size[1]/2-size/2) + shift[1]
    right = left + size
    lower = upper + size
    
    return im.crop((left, upper,right,lower))

In [None]:
# path = '/home/sudakovcom/safebooru_jpeg'
# files = os.listdir(path)

In [None]:
def preprocess_anime_face(path_in_out):
    in_path, out_path = path_in_out
    im = Image.open(in_path).resize((512,512))
    im = noncenter_crop(im, 256, (0, -14)).resize((128, 128))
    im.save(out_path)

In [None]:
# in_paths = [os.path.join(path, file) for file in files]

# out_path = '/home/sudakovcom/Desktop/diffusion/NOT/NeuralOptimalTransport/datasets/anime_faces'
# out_names = [os.path.join(out_path, f'{i}.png') for i in range(len(files))]

# if not os.path.exists(out_path):
#     os.makedirs(out_path)

In [None]:
# from multiprocessing import Pool
# import time

# start = time.time()
# with Pool(64) as p:
#     p.map(preprocess_anime_face, list(zip(in_paths, out_names)))
# end = time.time()
# print(end-start)

# Calculate statistics for metrics

In [None]:
T_ITERS = 10
# f_LR, T_LR = 1e-4, 1e-4
f_LR, T_LR = 1e-3, 1e-3
IMG_SIZE = 128
BATCH_SIZE = 64
PLOT_INTERVAL = 20
COST = 'mse' # Mean Squared Error
CPKT_INTERVAL = 1000
MAX_STEPS = 100001
SEED = 0x000000

In [None]:
transform = Compose([Resize((IMG_SIZE, IMG_SIZE)), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

dataset_celeba = DefaultDataset('/home/sudakovcom/Desktop/diffusion/NOT/datasets/celeba_hq/train/female', transform=transform)
dataset_anime = DefaultDataset('/home/sudakovcom/Desktop/diffusion/NOT/datasets/anime_faces', transform=transform)

dataloader_celeba = data.DataLoader(dataset=dataset_celeba, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True, drop_last=True)
dataloader_anime = data.DataLoader(dataset=dataset_anime, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True, drop_last=True)

sampler_celeba = LoaderSampler(dataloader_celeba, device='cuda')
sampler_anime = LoaderSampler(dataloader_anime, device='cuda')

print(len(dataset_celeba), len(dataset_anime))

In [None]:
from src.tools import get_loader_stats, calculate_frechet_distance, get_pushed_loader_stats

In [None]:
mu_celeba, sigma_celeba = get_loader_stats(sampler_celeba.loader)
mu_anime, sigma_anime = get_loader_stats(sampler_anime.loader)

np.save('/home/sudakovcom/Desktop/diffusion/NOT/stats/mu_celeba.npy', mu_celeba)
np.save('/home/sudakovcom/Desktop/diffusion/NOT/stats/sigma_celeba.npy', sigma_celeba)
np.save('/home/sudakovcom/Desktop/diffusion/NOT/stats/mu_anime.npy', mu_anime)
np.save('/home/sudakovcom/Desktop/diffusion/NOT/stats/sigma_anime.npy', sigma_anime)