In [1]:
import argparse



parser = argparse.ArgumentParser(conflict_handler='resolve')
parser.add = parser.add_argument
parser.add('--gpu_id', type=int, default=0)
parser.add('--image_size', type=int, default=16)
parser.add('--num_epoch', type=int, default=200)
parser.add('--val_epoch_freq', type=int, default=5)
parser.add('--batch_size', type=int, default=64)
parser.add('--num_channels', type=int, default=256)
parser.add('--data_type', type=str, default='norm')
parser.add('--target_type', type=str, default='none')
parser.add('--train_path', type=str, default='data/ecalNT_50K_e_10_100.npz')
parser.add('--val_path', type=str, default='data/ecalNT_10K_e_10_100.npz')
parser.add('--regressor_path', type=str, default='checkpoints/classifier/regressor_none')


opt, _ = parser.parse_known_args()

In [2]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '%d' % opt.gpu_id

import torch
from torch.autograd import Variable
from torch.optim import Adam
from models.dcgan import Generator, Discriminator
from src.data import Dataset
from torch.utils.data import DataLoader
from src.stats import Stats
from src.utils import calc_gradient_penalty
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
from IPython.display import clear_output

  from pandas.core import datetools


# Load and preprocess data, load stats tools

In [3]:
train_dataset = Dataset(opt)
train_loader = DataLoader(train_dataset, 
    opt.batch_size, True, num_workers=4, drop_last=True)
val_dataset = Dataset(opt, train_dataset)
val_loader = DataLoader(val_dataset, 
    opt.batch_size, True, num_workers=4, drop_last=True)
stats = Stats(opt)
stats.calc_stats(val_dataset.data, True)

# Set networks, criterions and optimizers

In [4]:
gen = Generator(opt.num_channels, opt.image_size).cuda()
dis = Discriminator(opt.num_channels, opt.image_size).cuda()
optim_g = Adam(gen.parameters(), lr=1e-4, betas=(0.5, 0.999))
optim_d = Adam(dis.parameters(), lr=1e-4, betas=(0.5, 0.999))

# Train networks

In [None]:
for e in range(opt.num_epoch):
    # train for 1 epoch
    gen.train()
    dis.train()
    train_wgan_loss = 0
    for i, (input_data, input_target) in tqdm(enumerate(train_loader, 0)):
        # train dis
        for p in dis.parameters():
            p.requires_grad = True
        optim_d.zero_grad()
        input = Variable(torch.randn(opt.batch_size, opt.num_channels)).cuda()
        fake = gen(input)
        real = Variable(input_data).cuda()
        pred_real = dis(real).mean()
        pred_fake = dis(fake.detach()).mean()
        loss_dis = pred_fake - pred_real + calc_gradient_penalty(dis, real.data, fake.data)
        loss_dis.backward()
        optim_d.step()
        train_wgan_loss += (pred_fake - pred_real).data[0]
        # train gen
        if i % 5 != 4:
            continue
        for p in dis.parameters():
            p.requires_grad = False
        optim_g.zero_grad()
        pred_fake = dis(fake).mean()
        loss_gen = -pred_fake
        loss_gen.backward()
        optim_g.step()
    train_wgan_loss /= (i+1)
    # validate after val_epoch_freq epoch
    if e % opt.val_epoch_freq != opt.val_epoch_freq-1:
        continue
    gen.eval()
    dis.eval()
    stats.train_loss.append(train_wgan_loss)
    val_wgan_loss = 0
    val_data_fake = []
    for i, (input_data, input_target) in tqdm(enumerate(val_loader, 0)):
        input = Variable(torch.randn(opt.batch_size, opt.num_channels), volatile=True).cuda()
        fake = gen(input)
        val_data_fake.append(fake.data.cpu().numpy()[:, 0])
        real = Variable(input_data).cuda()
        pred_real = dis(real).mean()
        pred_fake = dis(fake).mean()
        val_wgan_loss += (pred_fake - pred_real).data[0]
    val_wgan_loss /= (i+1)
    stats.val_loss.append(val_wgan_loss)
    val_data_fake = np.concatenate(val_data_fake, 0)
    val_data_fake = val_dataset.get_output(val_data_fake, opt.data_type)
    clear_output()
    stats.calc_stats(val_data_fake)
    f = stats.get_plot()
    f.savefig('checkpoints/generator/figures/%d.pdf' % (e+1))
    torch.save(gen.state_dict(), 'checkpoints/generator/weights/gen_%d.pkl' % (e+1))

781it [00:13, 57.79it/s]
781it [00:13, 57.89it/s]
781it [00:13, 58.20it/s]
781it [00:13, 58.35it/s]
60it [00:01, 57.73it/s]