In [1]:
#!/usr/bin/env python
import sys

import torch
import torch.nn
import torch.optim
from torch.nn.functional import avg_pool2d, interpolate
from torch.autograd import Variable
import numpy as np
import tqdm

import config as c
import opts


In [2]:
#opts.parse(sys.argv)
config_str = ""
config_str += "==="*30 + "\n"
config_str += "Config options:\n\n"


In [3]:
for v in dir(c):
    if v[0]=='_': continue
    s=eval('c.%s'%(v))
    config_str += "  {:25}\t{}\n".format(v,s)

config_str += "==="*30 + "\n"

print(config_str)


Config options:

  add_image_noise          	0.15
  batch_size               	512
  betas                    	(0.9, 0.999)
  checkpoint_on_error      	True
  checkpoint_save_interval 	360
  checkpoint_save_overwrite	True
  clamping                 	1.9
  colorize                 	False
  cond_net_file            	
  cond_width               	64
  data_mean                	0.0
  data_std                 	1.0
  decay_by                 	0.01
  do_fwd                   	True
  do_rev                   	False
  fc_dropout               	0.0
  filename                 	output/saving_mnist_cinn.pt
  img_dims                 	(28, 28)
  init_scale               	0.03
  internal_width           	256
  internal_width_conv      	64
  live_visualization       	False
  load_file                	
  loss_names               	['L', 'L_rev']
  lr                       	0.0001
  n_blocks                 	7
  n_blocks_conv            	3
  n_epochs                 	1440
  n_its_per_epoch          	65536


In [4]:
#c.colorize=True
c.colorize

False

In [5]:
import model
import data
import viz
import losses

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Epoch		L		L_rev


In [6]:
if c.colorize:
    import cond_net

class dummy_loss(object):
    def item(self):
        return 1.


In [7]:

if c.load_file:
    model.load(c.load_file)


In [8]:
def sample_outputs(sigma):
    return sigma * torch.cuda.FloatTensor(c.batch_size, c.output_dim).normal_()



In [9]:
c.colorize

False

In [10]:
if c.colorize:
    cond_tensor = torch.zeros(c.batch_size, model.cond_size, *c.img_dims).cuda()

    def make_cond(mask, cond_features):
        cond_tensor[:, 0] = mask[:, 0]
        cond_tensor[:, 1:] = cond_features.view(c.batch_size, -1, 1, 1).expand(-1, -1, *c.img_dims)
        return cond_tensor

else:
    cond_tensor = torch.zeros(c.batch_size, model.cond_size).cuda()
    def make_cond(labels):
        cond_tensor.zero_()
        cond_tensor.scatter_(1, labels.view(-1,1), 1.)
        return cond_tensor

    test_labels = torch.LongTensor((list(range(10))*(c.batch_size//10 + 1))[:c.batch_size]).cuda()
    test_cond = make_cond(test_labels).clone()

try:
    for i_epoch in range(-c.pre_low_lr, c.n_epochs):

        loss_history = []
        data_iter = iter(data.train_loader)

        if i_epoch < 0:
            for param_group in model.optim.param_groups:
                param_group['lr'] = c.lr * 2e-2

        for i_batch, data_tuple in tqdm.tqdm(enumerate(data_iter),
                                              total=min(len(data.train_loader), c.n_its_per_epoch),
                                              leave=False,
                                              mininterval=1.,
                                              disable=(not c.progress_bar),
                                              ncols=83):

            if c.colorize:
                x, labels, masks = data_tuple
                #print()
                #print(x.shape, labels.shape, masks.shape, cond_tensor.shape)
                #torch.Size([512, 3, 28, 28]) torch.Size([512]) torch.Size([512, 1, 28, 28]) torch.Size([512, 65])
                x, labels, masks  = x.cuda(), labels.cuda(), masks.cuda()
                x += c.add_image_noise * torch.cuda.FloatTensor(x.shape).normal_()
                with torch.no_grad():
                    cond_features = cond_net.model.features(masks)
                    cond = make_cond(masks, cond_features)

            else:
                x, labels = data_tuple
                x, labels = x.cuda(), labels.cuda()
                x += c.add_image_noise * torch.cuda.FloatTensor(x.shape).normal_()

                cond = make_cond(labels.cuda())

            output = model.model(x, cond)

            if c.do_fwd:
                zz = torch.sum(output**2, dim=1)
                jac = model.model.log_jacobian(run_forward=False)

                neg_log_likeli = 0.5 * zz - jac

                l = torch.mean(neg_log_likeli)
                l.backward(retain_graph=c.do_rev)
            else:
                l = dummy_loss()

            if c.do_rev:
                samples_noisy = sample_outputs(c.latent_noise) + output.data

                x_rec = model.model(samples_noisy, rev=True)
                l_rev = torch.mean( (x-x_rec)**2 )
                l_rev.backward()
            else:
                l_rev = dummy_loss()

            model.optim_step()
            loss_history.append([l.item(), l_rev.item()])

            if i_batch+1 >= c.n_its_per_epoch:
                # somehow the data loader workers don't shut down automatically
                try:
                    data_iter._shutdown_workers()
                except:
                    pass

                break

        model.weight_scheduler.step()

        epoch_losses = np.mean(np.array(loss_history), axis=0)
        epoch_losses[0] = min(epoch_losses[0], 0)

        if i_epoch > 1 - c.pre_low_lr:
            viz.show_loss(epoch_losses, logscale=False)
            output_orig = output.cpu()
            viz.show_hist(output_orig)

        with torch.no_grad():
            samples = sample_outputs(c.sampling_temperature)

            if not c.colorize:
                cond = test_cond

            rev_imgs = model.model(samples, cond, rev=True)
            ims = [rev_imgs]

        viz.show_imgs(*list(data.unnormalize(i) for i in ims))

        model.model.zero_grad()

        if (i_epoch % c.checkpoint_save_interval) == 0:
            model.save(c.filename + '_checkpoint_%.4i' % (i_epoch * (1-c.checkpoint_save_overwrite)))

    model.save(c.filename)

except:
    if c.checkpoint_on_error:
        model.save(c.filename + '_ABORT')

    raise


  0%|                                                      | 0/117 [00:00<?, ?it/s]

000		0.0000		1.0000                                                              


  0%|                                                      | 0/117 [00:00<?, ?it/s]

001		0.0000		1.0000                                                              


  0%|                                                      | 0/117 [00:00<?, ?it/s]

002		-29.5028		1.0000                                                            


                                                                                   

003		-141.5633		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

004		-312.4865		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

005		-400.3436		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

006		-460.5282		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

007		-504.2522		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

008		-539.1761		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

009		-567.5555		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

010		-591.5207		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

011		-611.9165		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

012		-629.7160		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

013		-644.7872		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

014		-657.7227		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

015		-668.8634		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

016		-678.0444		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

017		-686.0782		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

018		-692.1334		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

019		-697.7762		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

020		-702.3009		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

021		-706.0868		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

022		-709.5644		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

023		-712.4119		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

024		-714.9066		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

025		-717.0569		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

026		-719.2235		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

027		-721.0599		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

028		-723.0547		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

029		-724.7388		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

030		-726.5212		1.0000                                                           


  0%|                                                      | 0/117 [00:00<?, ?it/s]

031		-728.4769		1.0000                                                           


                                                                                   

032		-730.3992		1.0000                                                           


 85%|█████████████████████████████████████▌      | 100/117 [00:07<00:01, 15.33it/s]