Skip to content

Commit

Permalink
some cleaning up
Browse files Browse the repository at this point in the history
  • Loading branch information
lmzintgraf committed Jul 9, 2019
1 parent 7102d53 commit 65cae00
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 104 deletions.
5 changes: 2 additions & 3 deletions regression/cavia.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def run(args, log_interval=5000, rerun=False):

# visualise results
if args.task == 'celeba':
tasks_celebA.visualise(task_family_train, task_family_test, copy.deepcopy(logger.best_valid_model),
args, i_iter)
task_family_train.visualise(task_family_train, task_family_test, copy.deepcopy(logger.best_valid_model),
args, i_iter)

# print current results
logger.print_info(i_iter, start_time)
Expand All @@ -177,7 +177,6 @@ def run(args, log_interval=5000, rerun=False):


def eval_cavia(args, model, task_family, num_updates, n_tasks=100, return_gradnorm=False):

# get the task family
input_range = task_family.get_input_range().to(args.device)

Expand Down
33 changes: 16 additions & 17 deletions regression/maml.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def run(args, log_interval=5000, rerun=False):
task_family_valid = tasks_sine.RegressionTasksSinusoidal()
task_family_test = tasks_sine.RegressionTasksSinusoidal()
elif args.task == 'celeba':
task_family_train = tasks_celebA.CelebADataset('train')
task_family_valid = tasks_celebA.CelebADataset('valid')
task_family_test = tasks_celebA.CelebADataset('test')
task_family_train = tasks_celebA.CelebADataset('train', args.device)
task_family_valid = tasks_celebA.CelebADataset('valid', args.device)
task_family_test = tasks_celebA.CelebADataset('test', args.device)
else:
raise NotImplementedError

Expand Down Expand Up @@ -90,11 +90,8 @@ def run(args, log_interval=5000, rerun=False):

for _ in range(args.num_inner_updates):

# copy the current model (we use this to compute the inner-loop update)
model_outer = model_inner

# forward through network
outputs = model_outer(train_inputs)
# make prediction using the current model
outputs = model_inner(train_inputs)

# get targets
targets = target_functions[t](train_inputs)
Expand All @@ -104,23 +101,25 @@ def run(args, log_interval=5000, rerun=False):
# compute loss for current task
loss_task = F.mse_loss(outputs, targets)

# update private parts of network and keep correct computation graph
params = [w for w in model_outer.weights] + [b for b in model_outer.biases] + [model_outer.task_context]
# compute the gradient wrt current model
params = [w for w in model_inner.weights] + [b for b in model_inner.biases] + [model_inner.task_context]
grads = torch.autograd.grad(loss_task, params, create_graph=True, retain_graph=True)

# make an update on the inner model using the current model (to build up computation graph)
for i in range(len(model_inner.weights)):
if not args.first_order:
model_inner.weights[i] = model_outer.weights[i] - args.lr_inner * grads[i]
model_inner.weights[i] = model_inner.weights[i] - args.lr_inner * grads[i]
else:
model_inner.weights[i] = model_outer.weights[i] - args.lr_inner * grads[i].detach()
model_inner.weights[i] = model_inner.weights[i] - args.lr_inner * grads[i].detach()
for j in range(len(model_inner.biases)):
if not args.first_order:
model_inner.biases[j] = model_outer.biases[j] - args.lr_inner * grads[i + j + 1]
model_inner.biases[j] = model_inner.biases[j] - args.lr_inner * grads[i + j + 1]
else:
model_inner.biases[j] = model_outer.biases[j] - args.lr_inner * grads[i + j + 1].detach()
model_inner.biases[j] = model_inner.biases[j] - args.lr_inner * grads[i + j + 1].detach()
if not args.first_order:
model_inner.task_context = model_outer.task_context - args.lr_inner * grads[i + j + 2]
model_inner.task_context = model_inner.task_context - args.lr_inner * grads[i + j + 2]
else:
model_inner.task_context = model_outer.task_context - args.lr_inner * grads[i + j + 2].detach()
model_inner.task_context = model_inner.task_context - args.lr_inner * grads[i + j + 2].detach()

# ------------ compute meta-gradient on test loss of current task ------------

Expand Down Expand Up @@ -192,7 +191,7 @@ def run(args, log_interval=5000, rerun=False):

# visualise results
if args.task == 'celeba':
tasks_celebA.visualise(task_family_train, task_family_test, copy.copy(logger.best_valid_model),
task_family_train.visualise(task_family_train, task_family_test, copy.copy(logger.best_valid_model),
args, i_iter)

# print current results
Expand Down
165 changes: 81 additions & 84 deletions regression/tasks_celebA.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def __init__(self, mode, device):

self.device = device

if os.path.isdir('./data/CelebA/'):
data_root = './data/CelebA/'
if os.path.isdir('/home/scratch/luiraf/work/data/celeba/'):
data_root = '/home/scratch/luiraf/work/data/celeba/'
else:
raise FileNotFoundError('Can\'t find celebrity faces.')

Expand Down Expand Up @@ -135,85 +135,82 @@ def get_labels(self):
test_imgs.append(row[0])
return train_imgs, valid_imgs, test_imgs


def visualise(task_family_train, task_family_test, model, config, i_iter):
plt.figure(figsize=(14, 14))

for i, img_path in enumerate(task_family_train.image_files[:6] + task_family_test.image_files[:6]):

# randomly pick image
img = task_family_train.get_image(img_path)
# get target function
target_func = task_family_train.get_target_function(img)
# pick data points for training
pixel_inputs = task_family_train.sample_inputs(config['k_shot_eval'], config['order_pixels'])
pixel_targets = target_func(pixel_inputs)
# update model
if config['method'] == 'cavia':
model.reset_context_params()
for _ in range(config['num_inner_updates']):
pixel_pred = model(pixel_inputs)
loss = F.mse_loss(pixel_pred, pixel_targets)
grad = torch.autograd.grad(loss, model.context_params, create_graph=not config['first_order'])[0]
model.context_params = model.context_params - config['lr_inner'] * grad
elif config['method'] == 'maml':
for _ in range(config['num_inner_updates']):
pixel_pred = model(pixel_inputs)
loss = F.mse_loss(pixel_pred, pixel_targets)
params = [w for w in model.weights] + [b for b in model.biases] + [model.task_context]
grads = torch.autograd.grad(loss, params)
for k in range(len(model.weights)):
model.weights[k] = model.weights[k] - config['lr_inner'] * grads[k].detach()
for j in range(len(model.biases)):
model.biases[j] = model.biases[j] - config['lr_inner'] * grads[k + j + 1].detach()
model.task_context = model.task_context - config['lr_inner'] * grads[k + j + 2].detach()
else:
raise NotImplementedError('Wtf you doin')

# plot context
plt.subplot(6, 6, (i % 6) * 6 + 1 + int(i > 5) * 3)
# img = (img + 1) / 2
plt.imshow(img)
plt.xticks([])
plt.yticks([])

# context
plt.subplot(6, 6, (i % 6) * 6 + 2 + int(i > 5) * 3)
img_copy = copy.copy(img) * 0
# de-normalise coordinates
pixel_inputs *= 32
pixel_inputs = pixel_inputs.long()
img_copy[pixel_inputs[:, 0], pixel_inputs[:, 1]] = img[pixel_inputs[:, 0], pixel_inputs[:, 1]]
plt.imshow(img_copy)
plt.xticks([])
plt.yticks([])

if i == 0:
plt.title('TRAIN', fontsize=20)
if i == 6:
plt.title('TEST', fontsize=20)

# predict
plt.subplot(6, 6, (i % 6) * 6 + 3 + int(i > 5) * 3)
input_range = task_family_train.get_input_range()
img_pred = model(input_range).view(task_family_train.img_size).cpu().detach().numpy()
# img_pred = (img_pred + 1) / 2
img_pred[img_pred < 0] = 0
img_pred[img_pred > 1] = 1
plt.imshow(img_pred)
plt.xticks([])
plt.yticks([])

if not os.path.isdir('{}/celeba_result_plots/'.format(self.code_root)):
os.mkdir('{}/celeba_result_plots/'.format(self.code_root))

plt.tight_layout()
plt.savefig('{}/celeba_result_plots/{}_c{}_k{}_o{}_u{}_lr{}_{}'.format(self.code_root,
config['method'],
config['num_context_params'],
config['k_meta_train'],
config['order_pixels'],
config['num_inner_updates'],
int(10 * config['lr_inner']),
i_iter))
plt.close()
def visualise(self, task_family_train, task_family_test, model, args, i_iter):
plt.figure(figsize=(14, 14))

for i, img_path in enumerate(task_family_train.image_files[:6] + task_family_test.image_files[:6]):

# randomly pick image
img = task_family_train.get_image(img_path)
# get target function
target_func = task_family_train.get_target_function(img)
# pick data points for training
pixel_inputs = task_family_train.sample_inputs(args.k_shot_eval, args.use_ordered_pixels)
pixel_targets = target_func(pixel_inputs)
# update model
if not args.maml:
model.reset_context_params()
for _ in range(args.num_inner_updates):
pixel_pred = model(pixel_inputs)
loss = F.mse_loss(pixel_pred, pixel_targets)
grad = torch.autograd.grad(loss, model.context_params, create_graph=not args.first_order)[0]
model.context_params = model.context_params - args.lr_inner * grad
else:
for _ in range(args.num_inner_updates):
pixel_pred = model(pixel_inputs)
loss = F.mse_loss(pixel_pred, pixel_targets)
params = [w for w in model.weights] + [b for b in model.biases] + [model.task_context]
grads = torch.autograd.grad(loss, params)
for k in range(len(model.weights)):
model.weights[k] = model.weights[k] - args.lr_inner * grads[k].detach()
for j in range(len(model.biases)):
model.biases[j] = model.biases[j] - args.lr_inner * grads[k + j + 1].detach()
model.task_context = model.task_context - args.lr_inner * grads[k + j + 2].detach()

# plot context
plt.subplot(6, 6, (i % 6) * 6 + 1 + int(i > 5) * 3)
# img = (img + 1) / 2
plt.imshow(img)
plt.xticks([])
plt.yticks([])

# context
plt.subplot(6, 6, (i % 6) * 6 + 2 + int(i > 5) * 3)
img_copy = copy.copy(img) * 0
# de-normalise coordinates
pixel_inputs *= 32
pixel_inputs = pixel_inputs.long()
img_copy[pixel_inputs[:, 0], pixel_inputs[:, 1]] = img[pixel_inputs[:, 0], pixel_inputs[:, 1]]
plt.imshow(img_copy)
plt.xticks([])
plt.yticks([])

if i == 0:
plt.title('TRAIN', fontsize=20)
if i == 6:
plt.title('TEST', fontsize=20)

# predict
plt.subplot(6, 6, (i % 6) * 6 + 3 + int(i > 5) * 3)
input_range = task_family_train.get_input_range()
img_pred = model(input_range).view(task_family_train.img_size).cpu().detach().numpy()
# img_pred = (img_pred + 1) / 2
img_pred[img_pred < 0] = 0
img_pred[img_pred > 1] = 1
plt.imshow(img_pred)
plt.xticks([])
plt.yticks([])

if not os.path.isdir('{}/celeba_result_plots/'.format(self.code_root)):
os.mkdir('{}/celeba_result_plots/'.format(self.code_root))

plt.tight_layout()
plt.savefig('{}/celeba_result_plots/{}_c{}_k{}_o{}_u{}_lr{}_{}'.format(self.code_root,
int(args.maml),
args.num_context_params,
args.k_meta_train,
args.use_ordered_pixels,
args.num_inner_updates,
int(10 * args.lr_inner),
i_iter))
plt.close()

0 comments on commit 65cae00

Please sign in to comment.