In [1]:
import argparse
import os
import time
import numpy as np

import torch
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as tforms
from torchvision.utils import save_image

import torch.utils.data as data
from torch.utils.data import Dataset

import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 300

from PIL import Image
import os.path
import errno
import codecs

import lib.layers as layers
import lib.utils as utils
import lib.multiscale_parallel as multiscale_parallel
import lib.modules as modules
import lib.thops as thops

from train_misc import standard_normal_logprob
from train_misc import set_cnf_options, count_nfe, count_parameters, count_total_time, count_nfe_gate
from train_misc import add_spectral_norm, spectral_norm_power_iteration
from train_misc import create_regularization_fns, get_regularization, append_regularization_to_log

In [None]:
# go fast boi!!
torch.backends.cudnn.benchmark = True
SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams']
GATES = ["cnn1", "cnn2", "rnn"]

class Options:
    def __init__(self):
        self.data = "cifar10"
        self.dims = "64,64,64"
        self.strides = "1,1,1,1"
        self.num_blocks = 2
        
        self.conv = True
        --layer_type concat --multiscale True --rademacher True --batch_size 900 --test_batch_size 500
        self.layer_type = "concat"
        self.seed = 0
        
        
        
        
        
        self.divergence_fn = "approximate"
        self.nonlinearity = "softplus"
        self.solver = "dopri5"
        self.atol = 10e-5
        
        


parser = argparse.ArgumentParser("Continuous Normalizing Flow")
parser.add_argument("--data", choices=["mnist", "colormnist", "svhn", "cifar10", 'lsun_church'], type=str, default="mnist")
parser.add_argument("--dims", type=str, default="8,32,32,8")
parser.add_argument("--strides", type=str, default="2,2,1,-2,-2")
parser.add_argument("--num_blocks", type=int, default=1, help='Number of stacked CNFs.')

parser.add_argument("--conv", type=eval, default=True, choices=[True, False])
parser.add_argument(
    "--layer_type", type=str, default="ignore",
    choices=["ignore", "concat", "concat_v2", "squash", "concatsquash", "concatcoord", "hyper", "blend"]
)
parser.add_argument("--divergence_fn", type=str, default="approximate", choices=["brute_force", "approximate"])
parser.add_argument(
    "--nonlinearity", type=str, default="softplus", choices=["tanh", "relu", "softplus", "elu", "swish"]
)

parser.add_argument("--seed", type=int, default=0)

parser.add_argument('--solver', type=str, default='dopri5', choices=SOLVERS)
parser.add_argument('--atol', type=float, default=1e-5)
parser.add_argument('--rtol', type=float, default=1e-5)
parser.add_argument("--step_size", type=float, default=None, help="Optional fixed step size.")

parser.add_argument('--gate', type=str, default='cnn1', choices=GATES)
parser.add_argument('--scale', type=float, default=1.0)
parser.add_argument('--scale_fac', type=float, default=1.0)
parser.add_argument('--scale_std', type=float, default=1.0)
parser.add_argument('--eta', default=0.1, type=float,
                        help='tuning parameter that allows us to trade-off the competing goals of' 
                                'minimizing the prediction loss and maximizing the gate rewards ')
parser.add_argument('--rl-weight', default=0.01, type=float,
                        help='rl weight')

parser.add_argument('--gamma', default=0.99, type=float,
                        help='discount factor, default: (0.99)')

parser.add_argument('--test_solver', type=str, default=None, choices=SOLVERS + [None])
parser.add_argument('--test_atol', type=float, default=None)
parser.add_argument('--test_rtol', type=float, default=None)

parser.add_argument("--imagesize", type=int, default=None)
parser.add_argument("--alpha", type=float, default=1e-6)
parser.add_argument('--time_length', type=float, default=1.0)
parser.add_argument('--train_T', type=eval, default=True)

parser.add_argument("--num_epochs", type=int, default=500)
parser.add_argument("--batch_size", type=int, default=200)
parser.add_argument(
    "--batch_size_schedule", type=str, default="", help="Increases the batchsize at every given epoch, dash separated."
)
parser.add_argument("--test_batch_size", type=int, default=200)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--warmup_iters", type=float, default=1000)
parser.add_argument("--weight_decay", type=float, default=0.0)
parser.add_argument("--spectral_norm_niter", type=int, default=10)
parser.add_argument("--weight_y", type=float, default=0.5)
parser.add_argument("--annealing_std", type=eval, default=False, choices=[True, False])
parser.add_argument("--y_class", type=int, default=10)
parser.add_argument("--y_color", type=int, default=10)

parser.add_argument("--add_noise", type=eval, default=True, choices=[True, False])
parser.add_argument("--batch_norm", type=eval, default=False, choices=[True, False])
parser.add_argument('--residual', type=eval, default=False, choices=[True, False])
parser.add_argument('--autoencode', type=eval, default=False, choices=[True, False])
parser.add_argument('--rademacher', type=eval, default=True, choices=[True, False])
parser.add_argument('--spectral_norm', type=eval, default=False, choices=[True, False])
parser.add_argument('--multiscale', type=eval, default=False, choices=[True, False])
parser.add_argument('--parallel', type=eval, default=False, choices=[True, False])
parser.add_argument('--conditional', type=eval, default=False, choices=[True, False])
parser.add_argument('--controlled_tol', type=eval, default=False, choices=[True, False])
parser.add_argument("--train_mode", choices=["semisup", "sup", "unsup"], type=str, default="semisup")
parser.add_argument("--condition_ratio", type=float, default=0.5)
parser.add_argument("--dropout_rate", type=float, default=0.0)


# Regularizations
parser.add_argument('--l1int', type=float, default=None, help="int_t ||f||_1")
parser.add_argument('--l2int', type=float, default=None, help="int_t ||f||_2")
parser.add_argument('--dl2int', type=float, default=None, help="int_t ||f^T df/dt||_2")
parser.add_argument('--JFrobint', type=float, default=None, help="int_t ||df/dx||_F")
parser.add_argument('--JdiagFrobint', type=float, default=None, help="int_t ||df_i/dx_i||_F")
parser.add_argument('--JoffdiagFrobint', type=float, default=None, help="int_t ||df/dx - df_i/dx_i||_F")

parser.add_argument("--time_penalty", type=float, default=0, help="Regularization on the end_time.")
parser.add_argument(
    "--max_grad_norm", type=float, default=1e10,
    help="Max norm of graidents (default is just stupidly high to avoid any clipping)"
)

parser.add_argument("--begin_epoch", type=int, default=1)
parser.add_argument("--resume", type=str, default=None)
parser.add_argument("--save", type=str, default="experiments/cnf")
parser.add_argument("--val_freq", type=int, default=1)
parser.add_argument("--log_freq", type=int, default=1)

args = parser.parse_args()