## 1. Import Libs

In [1]:
import torch
from torchvision import transforms
from torch.autograd import Variable
from dataset import DatasetFromFolder
from model import Generator
import utils
import argparse
import os

## 2. Setting hyperparameters

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', required=False, default='white2yellow_tiger', help='input dataset')
parser.add_argument('--batch_size', type=int, default=1, help='test batch size')
parser.add_argument('--ngf', type=int, default=32)
parser.add_argument('--num_resnet', type=int, default=6, help='number of resnet blocks in generator')
parser.add_argument('--input_size', type=int, default=256, help='input size')
params = parser.parse_args([])
print(params)

# Directories
data_dir = 'data/' + params.dataset + '/'
save_dir = params.dataset + '_test_results/'
model_dir = params.dataset + '_model/'

if not os.path.exists(save_dir):
    os.mkdir(save_dir)
if not os.path.exists(model_dir):
    os.mkdir(model_dir)

Namespace(batch_size=1, dataset='white2yellow_tiger', input_size=256, ngf=32, num_resnet=6)


## 3. Load Dataset
### 3.1 Preprocessing

In [3]:
transform = transforms.Compose([
    transforms.Scale((params.input_size,params.input_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

  "please use transforms.Resize instead.")


### 3.2 Test data

In [4]:
test_data_A = DatasetFromFolder(data_dir, subfolder='testA', transform=transform)
test_data_loader_A = torch.utils.data.DataLoader(
    dataset=test_data_A, batch_size=params.batch_size, shuffle=False)

test_data_B = DatasetFromFolder(data_dir, subfolder='testB', transform=transform)
test_data_loader_B = torch.utils.data.DataLoader(
    dataset=test_data_B, batch_size=params.batch_size, shuffle=False)

## 4. Load Models

In [5]:
G_A = Generator(3, params.ngf, 3, params.num_resnet)
G_B = Generator(3, params.ngf, 3, params.num_resnet)
G_A.cuda()
G_B.cuda()
G_A.load_state_dict(torch.load(model_dir + 'generator_A_param.pkl'))
G_B.load_state_dict(torch.load(model_dir + 'generator_B_param.pkl'))

<All keys matched successfully>

## 5. Model Test

In [6]:
for k, real_A in enumerate(test_data_loader_A):
    # input image data
    real_A = Variable(real_A.cuda())
    
    # A --> B --> A
    fake_B = G_A(real_A)
    recon_A = G_B(fake_B)
    
    # Show result for test data
    utils.plot_test_result(real_A, fake_B, recon_A, k, save=True, save_dir=save_dir + 'AtoB/')
    # only one test image
    utils.save_singleimages(fake_B, k, save=True, save_dir=save_dir + 'AtoB_one/')

    print('%d images are generated.' % (k + 1))

for k, real_B in enumerate(test_data_loader_B):

    # input image data
    real_B = Variable(real_B.cuda())

    # B -> A -> B
    fake_A = G_B(real_B)
    recon_B = G_A(fake_A)

    # Show result for test data
    utils.plot_test_result(real_B, fake_A, recon_B, k, save=True, save_dir=save_dir + 'BtoA/')
    # only one test image
    utils.save_singleimages(fake_A, k, save=True, save_dir=save_dir + 'BtoA_one/')
    
    print('%d images are generated.' % (k + 1))

1 images are generated.
2 images are generated.
3 images are generated.
4 images are generated.
5 images are generated.
6 images are generated.
7 images are generated.
8 images are generated.
9 images are generated.
10 images are generated.
1 images are generated.
2 images are generated.
3 images are generated.
4 images are generated.
5 images are generated.
6 images are generated.
7 images are generated.
8 images are generated.
9 images are generated.
10 images are generated.
