Skip to content

Commit

Permalink
Some fixes and additional files
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Mar 21, 2019
1 parent 9b54b8f commit c9d03b1
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 105 deletions.
2 changes: 1 addition & 1 deletion ais.py
Expand Up @@ -14,7 +14,7 @@

flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single')
flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or mnist or dsprites or 2d or toy Gauss')
flags.DEFINE_string('logdir', '/mnt/nfs/yilundu/ebm_code_release/cachedir', 'location where log of experiments will be stored')
flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
flags.DEFINE_string('exp', 'default', 'name of experiments')
flags.DEFINE_integer('data_workers', 5, 'Number of different data workers to load data in parallel')
flags.DEFINE_integer('batch_size', 16, 'Size of inputs')
Expand Down
38 changes: 2 additions & 36 deletions ebm_combine.py
Expand Up @@ -15,8 +15,8 @@

flags.DEFINE_integer('batch_size', 256, 'Size of inputs')
flags.DEFINE_integer('data_workers', 4, 'Number of workers to do things')
flags.DEFINE_string('logdir', '/mnt/nfs/yilundu/ebm_code_release/cachedir', 'directory for logging')
flags.DEFINE_string('savedir', '/mnt/nfs/yilundu/ebm_code_release/cachedir', 'location where log of experiments will be stored')
flags.DEFINE_string('logdir', 'cachedir', 'directory for logging')
flags.DEFINE_string('savedir', 'cachedir', 'location where log of experiments will be stored')
flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.')
flags.DEFINE_float('step_lr', 500, 'size of gradient descent size')
flags.DEFINE_string('dsprites_path', '/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', 'path to dsprites characters')
Expand Down Expand Up @@ -53,40 +53,6 @@

FLAGS = flags.FLAGS

if FLAGS.task == 'gentest' or FLAGS.task == 'genbaseline':
# For conditioning on large shapes
# FLAGS.exp_pos = "dpos_only_1_12_new"
# FLAGS.resume_pos = 3000

# For conditioning on small shapes that are orientated
# FLAGS.exp_pos = "dpsrites_dpos_only_107"
# FLAGS.resume_pos = 4000
# FLAGS.exp_size = 'dpsrites_dsize_only_107'
# FLAGS.resume_size = 1000

# For conditioning on small shapes that are not orientated
FLAGS.exp_pos = "dpos_only_1_16_replay_batch_2"
FLAGS.resume_pos = 4000

# Super strong model
# FLAGS.exp_size = "dscale_only_1_16_replay_batch"
# FLAGS.resume_size = 10000

FLAGS.exp_size = "dscale_only_1_17_replay_batch_2"
FLAGS.resume_size = 4000

# FLAGS.exp_size = "dscale_only_1_16_replay_batch"
# FLAGS.resume_size = 10000

FLAGS.exp_shape = "dshape_only_1_19_replay_batch"
FLAGS.resume_shape = 20000
FLAGS.exp_rot = "drot_only_1_16_replay_batch"
FLAGS.resume_rot = 7000
FLAGS.step_lr = 500
FLAGS.cond_shape = True
FLAGS.cond_rot = True


class DSpritesGen(Dataset):
def __init__(self, data, latents, frac=0.0):

Expand Down
98 changes: 40 additions & 58 deletions ebm_sandbox.py
@@ -1,7 +1,6 @@
import tensorflow as tf
import math
from tqdm import tqdm
from hmc import hmc
from tensorflow.python.platform import flags
from torch.utils.data import DataLoader
import torch
Expand All @@ -10,29 +9,28 @@
from utils import optimistic_restore, set_seed
import os.path as osp
import numpy as np
from rl_algs.logger import TensorBoardOutputFormat
from baselines.logger import TensorBoardOutputFormat
from scipy.misc import imsave
import os
import sklearn.metrics as sk
from baselines.common.tf_util import initialize
from scipy.linalg import eig
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

set_seed(1)
# set_seed(1)

flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single')
flags.DEFINE_string('dataset', 'cifar10', 'omniglot or imagenet or omniglotfull or cifar10 or mnist or dsprites')
flags.DEFINE_string('logdir', '/mnt/nfs/yilundu/ebm_code_release/sandbox_cachedir', 'location where log of experiments will be stored')
flags.DEFINE_string('task', 'label', 'the task to execute (label: training on the label, anticorrupt: restore salt and pepper noise), boxcorrupt: restore empty portion of image'
flags.DEFINE_string('logdir', 'sandbox_cachedir', 'location where log of experiments will be stored')
flags.DEFINE_string('task', 'label', 'using conditional energy based models for classification'
'anticorrupt: restore salt and pepper noise),'
' boxcorrupt: restore empty portion of image'
'or crossclass: change images from one class to another'
'or cycleclass: view image change across a label'
'or nearestneighbor which returns the nearest images in the test set'
'or labelfinetune to train a model accuracy'
'or latent to traverse the latent through energy')

'or latent to traverse the latent space of an EBM through eigenvectors of the hessian (dsprites only)'
'or mixenergy to evaluate out of distribution generalization compared to other datasets')
flags.DEFINE_bool('hessian', True, 'Whether to use the hessian or the Jacobian for latent traversals')

flags.DEFINE_string('exp', 'default', 'name of experiments')
flags.DEFINE_integer('data_workers', 5, 'Number of different data workers to load data in parallel')
flags.DEFINE_integer('batch_size', 32, 'Size of inputs')
Expand All @@ -45,22 +43,25 @@
flags.DEFINE_bool('single', False, 'whether to use one sample to debug')
flags.DEFINE_bool('cclass', True, 'whether to use a conditional model (required for task label)')
flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label')
flags.DEFINE_integer('pgd', 0, 'number of steps project gradient descent to run')
flags.DEFINE_integer('lnorm', -1, 'lnorm infinity is -1, ')
flags.DEFINE_float('step_lr', 10.0, 'step size for updates on label')
flags.DEFINE_float('proj_norm', 0.0, 'Maximum change of input images')
flags.DEFINE_bool('large_model', False, 'Whether to use a large model')
flags.DEFINE_bool('larger_model', False, 'Whether to use a large model')
flags.DEFINE_bool('wider_model', False, 'Whether to use a large model')
flags.DEFINE_bool('larger_model', False, 'Whether to use a larger model')
flags.DEFINE_bool('wider_model', False, 'Whether to use a widermodel model')
flags.DEFINE_bool('svhn', False, 'Whether to test on SVHN')

# Conditions for mixenergy (outlier detection)
flags.DEFINE_bool('svhnmix', False, 'Whether to test mix on SVHN')
flags.DEFINE_bool('cifar100mix', False, 'Whether to test mix on CIFAR100')
flags.DEFINE_bool('texturemix', False, 'Whether to test mix on CIFAR100')
flags.DEFINE_bool('randommix', False, 'Whether to test mix on CIFAR100')
flags.DEFINE_bool('groupsort', False, 'Whether to test mix on CIFAR100')
flags.DEFINE_bool('hmc', False, 'Use HMC for cross class sampling')
flags.DEFINE_bool('texturemix', False, 'Whether to test mix on Textures dataset')
flags.DEFINE_bool('randommix', False, 'Whether to test mix on random dataset')

# Conditions for label task (adversarial classification)
flags.DEFINE_integer('lival', 8, 'Value of constraint for li')
flags.DEFINE_integer('l2val', 40, 'Value of constraint for l2')
flags.DEFINE_integer('pgd', 0, 'number of steps project gradient descent to run')
flags.DEFINE_integer('lnorm', -1, 'linfinity is -1, l2 norm is 2')
flags.DEFINE_bool('labelgrid', False, 'Make a grid of labels')
flags.DEFINE_bool('proj_cclass', False, 'Projection conditional')

# Conditions on which models to use
flags.DEFINE_bool('cond_pos', True, 'whether to condition on position')
Expand Down Expand Up @@ -342,15 +343,23 @@ def boxcorrupt(test_dataloader, dataloader, weights, model, target_vars, logdir,


def crossclass(dataloader, weights, model, target_vars, logdir, sess):
X, Y_GT, X_mods = target_vars['X'], target_vars['Y_GT'], target_vars['X_mods']
X, Y_GT, X_mods, X_final = target_vars['X'], target_vars['Y_GT'], target_vars['X_mods'], target_vars['X_final']
for data_corrupt, data, label_gt in tqdm(dataloader):
data, label_gt = data.numpy(), label_gt.numpy()
data_corrupt = data.copy()
data_corrupt[1:] = data_corrupt[0:-1]
data_corrupt[0] = data[-1]

feed_dict = {X: data_corrupt, Y_GT: label_gt}
data_mods = sess.run(X_mods, feed_dict)
data_mods = []
data_mod = data_corrupt

for i in range(10):
data_mods.append(data_mod)

feed_dict = {X: data_mod, Y_GT: label_gt}
data_mod = sess.run(X_final, feed_dict)



data_corrupt, data = rescale_im(data_corrupt), rescale_im(data)

Expand Down Expand Up @@ -798,33 +807,6 @@ def construct_steps(weights, X, Y_GT, model, target_vars):
target_vars['X_mods'] = X_mods


def construct_hmc_steps(weights, X, Y_GT, model, target_vars):
n = 50
scale_fac = 1.0

# if FLAGS.task == 'cycleclass':
# scale_fac = 10.0

X_mods = []
X = tf.identity(X)

for i in range(FLAGS.num_steps):
p = tf.random_normal(tf.shape(X), mean=0.0, stddev=0.0001)
for j in range(10):
energy_noise = model.forward(X, weights, label=Y_GT, reuse=True)
x_grad = tf.gradients(energy_noise, [X])[0]
p = p - FLAGS.step_lr * x_grad / 2
X = X - FLAGS.step_lr * p

if i % n == (n-1):
X_mods.append(X)

print("Constructing step {}".format(i))

target_vars['X_final'] = X
target_vars['X_mods'] = X_mods


def nearest_neighbor(dataset, sess, target_vars, logdir):
X = target_vars['X']
Y_GT = target_vars['Y_GT']
Expand Down Expand Up @@ -886,7 +868,10 @@ def main():
elif FLAGS.larger_model:
model = ResNet32Larger(num_filters=hidden_dim)
elif FLAGS.wider_model:
model = ResNet32Wider(num_filters=196, train=False)
if FLAGS.dataset == 'imagenet':
model = ResNet32Wider(num_filters=196, train=False)
else:
model = ResNet32Wider(num_filters=256, train=False)
else:
model = ResNet32(num_filters=hidden_dim)

Expand Down Expand Up @@ -925,20 +910,17 @@ def main():
if FLAGS.task == 'label':
construct_label(weights, X, Y, Y_GT, model, target_vars)
elif FLAGS.task == 'labelfinetune':
construct_finetune_label(weights, X, Y, Y_GT, model, target_vars)
construct_finetune_label(weights, X, Y, Y_GT, model, target_vars, )
elif FLAGS.task == 'energyeval' or FLAGS.task == 'mixenergy':
construct_energy(weights, X, Y, Y_GT, model, target_vars)
elif FLAGS.task == 'anticorrupt' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'crossclass' or FLAGS.task == 'cycleclass' or FLAGS.task == 'democlass' or FLAGS.task == 'nearestneighbor':
if FLAGS.hmc:
construct_hmc_steps(weights, X, Y_GT, model, target_vars)
else:
construct_steps(weights, X, Y_GT, model, target_vars)
construct_steps(weights, X, Y_GT, model, target_vars)
elif FLAGS.task == 'latent':
construct_latent(weights, X, Y_GT, model, target_vars)

sess.run(tf.global_variables_initializer())
saver = loader = tf.train.Saver(max_to_keep=10)
savedir = osp.join('/mnt/nfs/yilundu/ebm_code_release/cachedir', FLAGS.exp)
savedir = osp.join('cachedir', FLAGS.exp)
logdir = osp.join(FLAGS.logdir, FLAGS.exp)
if not osp.exists(logdir):
os.makedirs(logdir)
Expand All @@ -948,7 +930,7 @@ def main():
model_file = osp.join(savedir, 'model_{}'.format(FLAGS.resume_iter))
resume_itr = FLAGS.resume_iter

if FLAGS.task == 'label' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'labelfinetune' or FLAGS.task == "energyeval":
if FLAGS.task == 'label' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'labelfinetune' or FLAGS.task == "energyeval" or FLAGS.task == "crossclass" or FLAGS.task == "mixenergy":
optimistic_restore(sess, model_file)
# saver.restore(sess, model_file)
else:
Expand All @@ -971,7 +953,7 @@ def main():
else:
label(dataloader, test_dataloader, target_vars, sess)
elif FLAGS.task == 'labelfinetune':
labelfinetune(dataloader, test_dataloader, target_vars, sess, savedir, saver)
labelfinetune(dataloader, test_dataloader, target_vars, sess, savedir, saver, l1val=FLAGS.lival, l2val=FLAGS.l2val)
elif FLAGS.task == 'energyeval':
energyeval(dataloader, test_dataloader, target_vars, sess)
elif FLAGS.task == 'mixenergy':
Expand Down
4 changes: 2 additions & 2 deletions hmc.py
Expand Up @@ -97,8 +97,8 @@ def hmc(initial_x,
Step-size in Hamiltonian simulation
num_steps : int
Number of steps to take in Hamiltonian simulation
log_posterior : str
Log posterior (unnormalized) for the target distribution
neg_log_posterior : str
Negative log posterior (unnormalized) for the target distribution
Returns
-------
Expand Down
7 changes: 4 additions & 3 deletions imagenet_demo.py
Expand Up @@ -6,7 +6,7 @@
import imageio


flags.DEFINE_string('logdir', '/mnt/nfs/yilundu/ebm_code_release/cachedir', 'location where log of experiments will be stored')
flags.DEFINE_string('logdir', '../cachedir', 'location where log of experiments will be stored')
flags.DEFINE_integer('num_steps', 200, 'num of steps for conditional imagenet sampling')
flags.DEFINE_float('step_lr', 170., 'number of steps to run')
flags.DEFINE_integer('batch_size', 16, 'number of steps to run')
Expand Down Expand Up @@ -63,9 +63,10 @@ def rescale_im(im):
model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
saver.restore(sess, model_file)

# lx = list(range(16))
lx = [0, 14, 37, 145, 108, 963, 242, 624, 238, 323, 527, 748, 985, 973, 974, 979]
ls = np.random.permutation(1000)[:16]
ims = []

# What to initialize sampling with.
x_mod = np.random.uniform(0, 1, size=(FLAGS.batch_size, 128, 128, 3))
labels = np.eye(1000)[lx]

Expand Down
16 changes: 16 additions & 0 deletions requirements.txt
@@ -0,0 +1,16 @@
scipy==0.19.1
horovod==0.16.0
torch==0.3.1
scikit_image==0.13.0
tensorflow==1.12.0
torchvision==0.2.0
six==1.11.0
imageio==2.4.1
tqdm==4.20.0
matplotlib==1.5.3
mpi4py==3.0.0
numpy==1.14.0
Pillow==5.4.1
baselines==0.1.5
skimage==0.0
scikit_learn==0.20.3
4 changes: 2 additions & 2 deletions test_inception.py
Expand Up @@ -18,7 +18,7 @@
from inception import get_inception_score
from fid import get_fid_score

flags.DEFINE_string('logdir', '/mnt/nfs/yilundu/ebm_code_release/cachedir', 'location where log of experiments will be stored')
flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
flags.DEFINE_string('exp', 'default', 'name of experiments')
flags.DEFINE_bool('cclass', False, 'whether to condition on class')

Expand Down Expand Up @@ -216,7 +216,7 @@ def compute_inception(sess, target_vars):

images.extend(list(ims))

saveim = osp.join('/mnt/nfs/yilundu/ebm_code_release/sandbox_cachedir', FLAGS.exp, "test{}.png".format(FLAGS.idx))
saveim = osp.join('sandbox_cachedir', FLAGS.exp, "test{}.png".format(FLAGS.idx))

ims = ims[:100]

Expand Down
6 changes: 3 additions & 3 deletions train.py
Expand Up @@ -48,12 +48,12 @@
'Number of different data workers to load data in parallel')

# General Experiment Settings
flags.DEFINE_string('logdir', '/mnt/nfs/yilundu/ebm_code_release/cachedir',
flags.DEFINE_string('logdir', 'cachedir',
'location where log of experiments will be stored')
flags.DEFINE_string('exp', 'default', 'name of experiments')
flags.DEFINE_integer('log_interval', 10, 'log outputs every so many batches')
flags.DEFINE_integer('save_interval', 2000,'save outputs every so many batches')
flags.DEFINE_integer('test_interval', 2000,'evaluate outputs every so many batches')
flags.DEFINE_integer('save_interval', 1000,'save outputs every so many batches')
flags.DEFINE_integer('test_interval', 1000,'evaluate outputs every so many batches')
flags.DEFINE_integer('resume_iter', -1, 'iteration to resume training from')
flags.DEFINE_bool('train', True, 'whether to train or test')
flags.DEFINE_integer('epoch_num', 10000, 'Number of Epochs to train on')
Expand Down

0 comments on commit c9d03b1

Please sign in to comment.