In [1]:
import argparse
import os
import numpy as np
import math
import itertools
import time
import datetime
import sys
from PIL import Image
import glob

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

from models import *
from kis_dataset import *

import torch.nn as nn
import torch.nn.functional as F
import torch

In [15]:
parser = argparse.ArgumentParser()
parser.add_argument('--mode', type=str, default='gray2color', help='gray2color | color2color ')
parser.add_argument("--epoch", type=int, default=0, help="トレーニングを開始するためのエポック")
parser.add_argument("--n_epochs", type=int, default=120, help="トレーニングエポック数")
parser.add_argument("--dataset_name", type=str, default="test", help="データセットの名前")
parser.add_argument("--saved_name", type=str ,default='40', help="epoch to start training from, or last")
parser.add_argument("--batch_size", type=int, default=16, help="バッチサイズ")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: 学習速度")
parser.add_argument("--b1", type=float, default=0.5, help="adam: 勾配の減衰")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=4, help="CPUスレッド数")
parser.add_argument("--img_height", type=int, default=256, help="画像サイズの高さ")
parser.add_argument("--img_width", type=int, default=256, help="画像サイズの横")
parser.add_argument("--channels", type=int, default=3, help="画像チャネル数")
parser.add_argument("--latent_dim", type=int, default=8, help="潜伏次元数（ノイズベクトル？）")
parser.add_argument("--sample_interval", type=int, default=500, help="interval between saving generator samples")
parser.add_argument("--checkpoint_interval", type=int, default=60, help="interval between model checkpoints")
parser.add_argument("--lambda_pixel", type=float, default=10, help="pixelwise loss weight")
parser.add_argument("--lambda_latent", type=float, default=0.5, help="latent loss weight")
parser.add_argument("--lambda_kl", type=float, default=0.01, help="kullback-leibler loss weight")
opt = parser.parse_args(args=[])
print(opt)

#存在しなければフォルダを作成する
# os.makedirs("pix_system_ver.2/results/%s"%opt.dataset_name, exist_ok=True)

cuda = True if torch.cuda.is_available() else False

input_shape = (opt.channels, opt.img_height, opt.img_width)

transforms_ = [
    transforms.Resize((opt.img_height, opt.img_width), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
]

generator = Generator(opt.latent_dim, input_shape)
encoder = Encoder(opt.latent_dim, input_shape)


Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
# #################cudaが使用可能なら変換する#########################

if cuda:
    print("CUDAが使えます")
    generator = generator.cuda()

# #########################################################


generator.load_state_dict(torch.load("pix_system_ver.1/model/%s/generator_%s.pth" % (opt.dataset_name,opt.saved_name), map_location='cpu'))
# print("generatorを生成しました")
# encoder.load_state_dict(torch.load("pix_system_ver.1/model/%s/encoder_%s.pth" % (opt.dataset_name,opt.saved_name), map_location='cpu'))
# # print("encoderを生成しました")
    


# ###############画像の読み込み###############################

input_image = KIS_make_datapath_list(opt)

dataloader = DataLoader(
    KIS_ImageDataset(input_image, opt, transforms_=transforms_),
    batch_size=opt.batch_size,
    shuffle=False,
    num_workers=opt.n_cpu
)

# print(dataloader)
# ################################################################

# for i in range(2):
#     sampled_z = Variable(Tensor(np.random.normal(0, 1, (opt.latent_dim, opt.latent_dim))))
#     print(sampled_z)



for i, batch in enumerate(dataloader):
        generator.eval()
        print(batch.size())
        real_A = Variable(batch.type(Tensor))
        # real_A.size(0)
        sampled_z = Variable(Tensor(np.random.normal(0, 1, (real_A.size(0), opt.latent_dim))))
        print(sampled_z.size())
        fake_B = generator(real_A,sampled_z)
        save_image(fake_B, "pix_system_ver.1/results/No"+str(i+1)+".png".format(opt.dataset_name, i), nrow=3)
print("生成しました")

Namespace(mode='gray2color', epoch=0, n_epochs=120, dataset_name='test', saved_name='40', batch_size=16, lr=0.0002, b1=0.5, b2=0.999, n_cpu=4, img_height=256, img_width=256, channels=3, latent_dim=8, sample_interval=500, checkpoint_interval=60, lambda_pixel=10, lambda_latent=0.5, lambda_kl=0.01)
CUDAが使えます
torch.Size([1, 3, 256, 256])
torch.Size([1, 8])
生成しました
