In [1]:
import argparse



parser = argparse.ArgumentParser(conflict_handler='resolve')
parser.add = parser.add_argument
parser.add('--gpu_id', type=int, default=1)
parser.add('--image_size', type=int, default=16)
parser.add('--num_epoch', type=int, default=200)
parser.add('--val_epoch_freq', type=int, default=1)
parser.add('--batch_size', type=int, default=64)
parser.add('--num_channels', type=int, default=256)
parser.add('--adv_loss_type', type=str, default='wgan', help='gan|wgan')
parser.add('--num_pred', type=int, default=10, help='1|10')
parser.add('--stats_weight', type=float, default=1e-4)
parser.add('--data_type', type=str, default='norm')
parser.add('--target_type', type=str, default='none')
parser.add('--train_path', type=str, default='data/ecalNT_50K_e_10_100.npz')
parser.add('--val_path', type=str, default='data/ecalNT_10K_e_10_100.npz')
parser.add('--regressor_path', type=str, default='checkpoints/classifier/regressor_none')


opt, _ = parser.parse_known_args()

In [2]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '%d' % opt.gpu_id

import torch
import numpy as np
np.random.seed(0)
torch.manual_seed(0)

from torch.autograd import Variable
from torch.optim import Adam
from models.dcgan import Generator, Discriminator
from src.data import Dataset
from torch.utils.data import DataLoader
from src.stats import Stats
from src.utils import Loss
from matplotlib import pyplot as plt
from tqdm import tqdm
from IPython.display import clear_output

  from pandas.core import datetools


# Load and preprocess data, load stats tools

In [3]:
stats = Stats(opt)
train_dataset = Dataset(opt, stats=stats)
train_loader = DataLoader(train_dataset, 
    opt.batch_size, True, num_workers=4, drop_last=True)
val_dataset = Dataset(opt, train_dataset, stats)
val_loader = DataLoader(val_dataset, 
    opt.batch_size, True, num_workers=4, drop_last=True)
stats.calc_stats(val_dataset.data, True)

# Set networks, criterions and optimizers

In [4]:
gen = Generator(opt.num_channels, opt.image_size).cuda()
dis = Discriminator(opt.num_channels, opt.image_size, opt.num_pred).cuda()
optim_g = Adam(gen.parameters(), lr=1e-4, betas=(0.5, 0.999))
optim_d = Adam(dis.parameters(), lr=1e-4, betas=(0.5, 0.999))
crit = Loss(opt, dis, stats)

# Train networks

In [None]:
for e in range(opt.num_epoch):
    # train for 1 epoch
    gen.train()
    dis.train()
    train_dis_adv_loss = 0
    train_gen_adv_loss = 0
    train_dis_stats_loss = 0
    train_gen_stats_loss = 0
    for i, (real_data, real_stats) in tqdm(enumerate(train_loader, 0)):
        # update dis
        for p in dis.parameters():
            p.requires_grad = True
        optim_d.zero_grad()
        real = Variable(real_data.cuda()) # real data
        noise = Variable(torch.randn(opt.batch_size, opt.num_channels)).cuda()
        fake = gen(noise) # fake data
        # calc loss
        real_input = [real]
        fake_input = [fake]
        if opt.num_pred > 1:
            real_input += [Variable(real_stats.cuda())]
            fake_input += [stats.calc_stats_torch(fake[:, 0].data)]
        loss, loss_adv, loss_stats = crit(real_input, fake_input)
        train_dis_adv_loss += loss_adv.data[0]
        train_dis_stats_loss += loss_stats.data[0]
        loss.backward()
        optim_d.step()
        # update gen
        if opt.adv_loss_type == 'wgan' and (i % 5 != 4):
            continue
        for p in dis.parameters():
            p.requires_grad = False
        optim_g.zero_grad()
        loss, loss_adv, loss_stats = crit(fake_input)
        train_gen_adv_loss += loss_adv.data[0]
        train_gen_stats_loss += loss_stats.data[0]
        loss.backward()
        optim_g.step()
    train_dis_adv_loss /= (i+1)
    train_gen_adv_loss /= (i+1)
    train_dis_stats_loss /= (i+1)
    train_gen_stats_loss /= (i+1)
    # validate after val_epoch_freq epoch
    if e % opt.val_epoch_freq != opt.val_epoch_freq-1:
        continue
    gen.eval()
    dis.eval()
    stats.train_dis_adv_loss.append(train_dis_adv_loss)
    if opt.adv_loss_type != 'wgan':
        stats.train_gen_adv_loss.append(train_gen_adv_loss)
    if opt.num_pred > 1:
        stats.train_dis_stats_loss.append(train_dis_stats_loss)
        stats.train_gen_stats_loss.append(train_gen_stats_loss)
    val_dis_adv_loss = 0
    val_gen_adv_loss = 0
    val_dis_stats_loss = 0
    val_gen_stats_loss = 0
    val_data_fake = []
    for i, (real_data, real_stats) in tqdm(enumerate(val_loader, 0)):
        real = Variable(real_data.cuda())
        noise = Variable(torch.randn(opt.batch_size, opt.num_channels), volatile=True).cuda()
        fake = gen(noise)
        # calc loss
        real_input = [real]
        fake_input = [fake]
        if opt.num_pred > 1:
            real_input += [Variable(real_stats.cuda())]
            fake_input += [stats.calc_stats_torch(fake[:, 0].data)]
        loss, loss_adv, loss_stats = crit(real_input, fake_input)
        val_dis_adv_loss += loss_adv.data[0]
        val_dis_stats_loss += loss_stats.data[0]
        loss, loss_adv, loss_stats = crit(fake_input)
        val_gen_adv_loss += loss_adv.data[0]
        val_gen_stats_loss += loss_stats.data[0]
        val_data_fake.append(fake.data.cpu().numpy()[:, 0])
    val_dis_adv_loss /= (i+1)
    val_gen_adv_loss /= (i+1)
    val_dis_stats_loss /= (i+1)
    val_gen_stats_loss /= (i+1)
    if opt.adv_loss_type == 'wgan':
        stats.val_dis_adv_loss.append(val_dis_adv_loss)
    if opt.num_pred > 1:
        stats.val_dis_stats_loss.append(val_dis_stats_loss)
        stats.val_gen_stats_loss.append(val_gen_stats_loss)
    val_data_fake = np.concatenate(val_data_fake, 0)
    val_data_fake = val_dataset.get_output(val_data_fake, opt.data_type)
    clear_output()
    stats.calc_stats(val_data_fake)
    f = stats.get_plot()
    f.savefig('checkpoints/generator/figures/%d.png' % (e+1))
    torch.save(gen.state_dict(), 'checkpoints/generator/weights/gen_%d.pkl' % (e+1))

781it [00:18, 42.60it/s]
0it [00:00, ?it/s]Process Process-869:
Process Process-870:
Process Process-872:
KeyboardInterrupt
Process Process-871:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/ezakharov/miniconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/ezakharov/miniconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
KeyboardInterrupt
  File "/home/ezakharov/miniconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/ezakharov/miniconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 36, in _worker_loop
    r = index_queue.get()
  File "/home/ezakharov/miniconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
Traceback (most recent call last):
  File "/home/ezakharov/miniconda3/lib/python3.6/multiprocessing/proces

KeyboardInterrupt: 