<a href="https://colab.research.google.com/github/ak9250/pytorch-small-dataset-image-generation/blob/master/SmallGan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone -b test/conditional https://github.com/khongtrunght/pytorch-small-dataset-image-generation.git

In [None]:
cd pytorch-small-dataset-image-generation/

In [None]:
!gdown https://drive.google.com/uc?id=1nAle7FCVFZdix2--ks0r5JBkFnKw8ctW

In [None]:
!unzip BigGAN_ch96_bs256x8_138k.zip

In [None]:
!mv 138k/G_ema.pth data/

Training

In [None]:
!python train.py --dataset animal --gpu 0 --pretrained ./data/G_ema.pth  --iters 50000

Testing

In [None]:
import glob
import os
import matplotlib
from PIL import Image
import numpy as np
import json

%matplotlib inline
import matplotlib.pyplot as plt

import torch
import torchvision
from models.setup_model import setup_model
from dataloaders.setup_dataloader_smallgan import setup_dataloader


def reconstruct(model,out_path,indices):
    model.eval()
    device = next(model.parameters()).device
    dataset_size = model.embeddings.weight.size()[0]
    assert type(indices)==torch.Tensor
    indices = indices.to(device)        
    embeddings = model.embeddings(indices)
    batch_size = embeddings.size()[0]
    image_tensors = model(embeddings)
    with torch.no_grad():
        torchvision.utils.save_image(
            image_tensors,
            out_path,
            nrow=int(batch_size ** 0.5),
            normalize=True,
        )
        
#see https://github.com/nogu-atsu/SmallGAN/blob/2293700dce1e2cd97e25148543532814659516bd/gen_models/ada_generator.py#L37-L53
def interpolate(model,out_path,source,dist,trncate=0.4,num=5):
    model.eval()
    device = next(model.parameters()).device
    dataset_size = model.embeddings.weight.size()[0]
    indices = torch.tensor([source,dist],device=device)
    indices = indices.to(device) 
    embeddings = model.embeddings(indices)
    embeddings = embeddings[[0]] * torch.linspace(1, 0, num,device=device)[:, None] + embeddings[[1]]* torch.linspace(0, 1, num,device=device)[:, None]
    batch_size = embeddings.size()[0]
    image_tensors = model(embeddings)
    with torch.no_grad():
        torchvision.utils.save_image(
            image_tensors,
            out_path,
            nrow=batch_size,
            normalize=True,
        )

#from https://github.com/nogu-atsu/SmallGAN/blob/2293700dce1e2cd97e25148543532814659516bd/gen_models/ada_generator.py#L37-L53        
def random(model,out_path,tmp=0.4, n=9, truncate=False):
    from scipy.stats import truncnorm
    model.eval()
    device = next(model.parameters()).device
    dataset_size = model.embeddings.weight.size()[0]
    dim_z = model.embeddings.weight.size(1)
    if truncate:
        embeddings = truncnorm(-tmp, tmp).rvs(n * dim_z).astype("float32").reshape(n, dim_z)
    else:
        embeddings = np.random.normal(0, tmp, size=(n, dim_z)).astype("float32")
    embeddings = torch.tensor(embeddings,device=device)
    batch_size = embeddings.size()[0]
    image_tensors = model(embeddings)
    with torch.no_grad():
        torchvision.utils.save_image(
            image_tensors,
            out_path,
            nrow=int(batch_size ** 0.5),
            normalize=True,
        )

In [None]:
dataloader = setup_dataloader("anime",batch_size=2)
dataset_size = len(dataloader.dataset)
exp_dir = "./experiments/train_dataset-anime_model-biggan128-ada_2019-04-26-19-11-39/"
print(json.load(open(exp_dir+"args.json")))
model = setup_model("biggan128-ada",dataset_size=50,resume=exp_dir+"checkpoint_iter500.pth.tar")
model = model.cuda()

In [None]:
reconstruct(model,out_path="./samples/anime_reconstruct.jpg",indices= torch.arange(9))
interpolate(model,out_path="./samples/anime_interpolate.jpg",source=1,dist=2)
random(model,out_path="./samples/anime_random.jpg",tmp=0.2, n=9, truncate=True)

plt.figure(figsize=(10,10))
im = Image.open("./samples/anime_reconstruct.jpg")
plt.imshow(im)
plt.show()

plt.figure(figsize=(10,10))
im = Image.open("./samples/anime_interpolate.jpg")
plt.imshow(im)
plt.show()

plt.figure(figsize=(10,10))
im = Image.open("./samples/anime_random.jpg")
plt.imshow(im)
plt.show()