From f9ccaa53f34e566f0da65b4c9fe79966db28cf6b Mon Sep 17 00:00:00 2001 From: "Yu, Chong" Date: Tue, 31 Oct 2017 20:31:23 +0800 Subject: [PATCH] Add the APIs,scripts and models to support DCGAN. --- data/Celeb-A/celebA.txt | 10 + data/Celeb-A/crop_celebA.py | 59 +++ include/caffe/net.hpp | 3 + include/caffe/solver.hpp | 1 + .../dcgan/data.prototxt | 17 + .../dcgan/discriminator.prototxt | 394 ++++++++++++++++++ .../intel_optimized_models/dcgan/generate.py | 36 ++ .../dcgan/generator.prototxt | 390 +++++++++++++++++ .../dcgan/solver_template.prototxt | 19 + models/intel_optimized_models/dcgan/train.py | 141 +++++++ python/caffe/_caffe.cpp | 12 +- python/caffe/pycaffe.py | 7 + 12 files changed, 1087 insertions(+), 2 deletions(-) create mode 100644 data/Celeb-A/celebA.txt create mode 100644 data/Celeb-A/crop_celebA.py create mode 100644 models/intel_optimized_models/dcgan/data.prototxt create mode 100644 models/intel_optimized_models/dcgan/discriminator.prototxt create mode 100644 models/intel_optimized_models/dcgan/generate.py create mode 100644 models/intel_optimized_models/dcgan/generator.prototxt create mode 100644 models/intel_optimized_models/dcgan/solver_template.prototxt create mode 100644 models/intel_optimized_models/dcgan/train.py diff --git a/data/Celeb-A/celebA.txt b/data/Celeb-A/celebA.txt new file mode 100644 index 000000000..ee543e2c4 --- /dev/null +++ b/data/Celeb-A/celebA.txt @@ -0,0 +1,10 @@ +/Celeb-A_Cropped/000001.jpg 1 +/Celeb-A_Cropped/000002.jpg 1 +/Celeb-A_Cropped/000003.jpg 1 +/Celeb-A_Cropped/000004.jpg 1 +/Celeb-A_Cropped/000005.jpg 1 +/Celeb-A_Cropped/000006.jpg 1 +/Celeb-A_Cropped/000007.jpg 1 +/Celeb-A_Cropped/000008.jpg 1 +/Celeb-A_Cropped/000009.jpg 1 +/Celeb-A_Cropped/000010.jpg 1 \ No newline at end of file diff --git a/data/Celeb-A/crop_celebA.py b/data/Celeb-A/crop_celebA.py new file mode 100644 index 000000000..96959b96c --- /dev/null +++ b/data/Celeb-A/crop_celebA.py @@ -0,0 +1,59 @@ +from PIL import Image +import os +import sys + +print "" +print "Prepare Celeb-A Dataset! (1. Crop the images. 2. Generate a train list file.)" +print "" +print "-------------------------------------------------------------------------------" + +current_path = os.getcwd() +celebA_path = "" +celebA_cropped_path = "" +print "The current path containing this python file is: " + current_path +if len(sys.argv) == 1: + print "Please give the path of original Celeb-A dataset!" + exit(0) +elif len(sys.argv) > 1: + print "The path of original Celeb-A dataset is: " + str(sys.argv[1]) + celebA_path = sys.argv[1] + celebA_cropped_path = os.path.dirname(celebA_path) + os.sep + "Cropped" #To avoid crop the generated images again if this parameter is not provided + if len(sys.argv) > 2: + print "The path of cropped Celeb-A dataset will be: " + str(sys.argv[2]) + celebA_cropped_path = sys.argv[2] + else: + print "The path of cropped Celeb-A dataset will be defult, set as: " + celebA_cropped_path + +if os.path.exists(celebA_cropped_path): + print "The path of cropped Celeb-A dataset exists." +else: + print "The path of cropped Celeb-A dataset doesn't exist! I will create it now!" + os.makedirs(celebA_cropped_path) +print "-------------------------------------------------------------------------------" + +training_list_file = os.path.join(celebA_cropped_path, "celebA.txt") +list_file = open(training_list_file, 'w') +total_image_num = 0 +x1, y1 = 30, 40 +cropped_box = (x1, y1, x1 + 138, y1 + 138) + +for parent,dirnames,filenames in os.walk(celebA_path): + for filename in filenames: + if filename.endswith(".jpg"): + total_image_num += 1 + #print "parent is:" + parent + #print "filename is:" + filename + image_path_and_name = os.path.join(parent,filename) + print "the full name of the file is: " + image_path_and_name + input_image = Image.open(image_path_and_name) + #input_image.show() + cropped_image = input_image.crop(cropped_box) + #cropped_image.show() + scaled_cropped_image = cropped_image.resize((64, 64)) + #scaled_cropped_image.show() + save_result_image_path_and_name = os.path.join(celebA_cropped_path,filename) + scaled_cropped_image.save(save_result_image_path_and_name, 'jpeg') + list_file.writelines(save_result_image_path_and_name) + list_file.writelines(" 1" + "\n") #Must add label to list file +print "There are " + str(total_image_num) + " images are finished with cropping and scaling operations!" +list_file.close() \ No newline at end of file diff --git a/include/caffe/net.hpp b/include/caffe/net.hpp index 362b28de7..4b6fa0090 100644 --- a/include/caffe/net.hpp +++ b/include/caffe/net.hpp @@ -326,6 +326,9 @@ class Net { /// @brief return whether NetState state meets NetStateRule rule static bool StateMeetsRule(const NetState& state, const NetStateRule& rule, const string& layer_name); + inline const map& blob_names_index() const { + return blob_names_index_; + } protected: // Helpers for Init. diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 9b97c3c0b..a9ebdabc7 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -113,6 +113,7 @@ class Solver { } int iter() { return iter_; } void set_iter(int value) { iter_ = value; } + void increment_iter() { iter_++; } // Invoked at specific points during an iteration class Callback { diff --git a/models/intel_optimized_models/dcgan/data.prototxt b/models/intel_optimized_models/dcgan/data.prototxt new file mode 100644 index 000000000..d81132b70 --- /dev/null +++ b/models/intel_optimized_models/dcgan/data.prototxt @@ -0,0 +1,17 @@ + +layer { + name: "data" + type: "ImageData" + top: "data" + top: "label" + transform_param { + mirror: true + mean_value: 104 + mean_value: 117 + mean_value: 123 + } + image_data_param { + source: "data/Celeb-A/celebA.txt" + batch_size: 64 + } +} diff --git a/models/intel_optimized_models/dcgan/discriminator.prototxt b/models/intel_optimized_models/dcgan/discriminator.prototxt new file mode 100644 index 000000000..638962153 --- /dev/null +++ b/models/intel_optimized_models/dcgan/discriminator.prototxt @@ -0,0 +1,394 @@ +force_backward: true +input: "data" +input_shape { + dim: 64 + dim: 3 + dim: 64 + dim: 64 +} + +input: "label" +input_shape { + dim: 64 + dim: 1 +} + + +# Discriminator net + +layer { + name: "Dconv1" + type: "Convolution" + bottom: "data" + top: "Dconv1" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 0 + decay_mult: 0 + } + convolution_param { + num_output: 64 + kernel_size: 4 + stride: 2 + pad: 1 + weight_filler { + type: "gaussian" + std: 0.02 + } + bias_filler { + type: "constant" + value: 0 + } +# engine: CAFFE + } +} +layer { + name: "Drelu1" + type: "ReLU" + bottom: "Dconv1" + top: "Dconv1" + relu_param { + negative_slope: 0.2 + } + +} + +layer { + name: "Dconv2" + type: "Convolution" + bottom: "Dconv1" + top: "Dconv2" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 0 + decay_mult: 0 + } + convolution_param { + num_output: 128 + kernel_size: 4 + stride: 2 + pad: 1 + weight_filler { + type: "gaussian" + std: 0.02 + } + bias_filler { + type: "constant" + value: 0 + } +# engine: CAFFE + } +} +layer { + name: "Dconv2_BN" + type: "BatchNorm" include { phase: TRAIN} + bottom: "Dconv2" + top: "Dconv2_BN" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + batch_norm_param { + use_global_stats: false + moving_average_fraction: 0.95 + } +} +layer { + name: "Dconv2_BN" + type: "BatchNorm" include { phase: TEST} + bottom: "Dconv2" + top: "Dconv2_BN" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + batch_norm_param { + use_global_stats: true + moving_average_fraction: 0.95 + } +} + +layer { + name: "Drelu2" + type: "ReLU" + bottom: "Dconv2_BN" + top: "Dconv2_BN" + relu_param { + negative_slope: 0.2 + } + +} +layer { + name: "Dconv3" + type: "Convolution" + bottom: "Dconv2_BN" + top: "Dconv3" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 0 + decay_mult: 0 + } + convolution_param { + num_output: 256 + kernel_size: 4 + stride: 2 + pad: 1 + weight_filler { + type: "gaussian" + std: 0.02 + } + bias_filler { + type: "constant" + value: 0 + } +# engine: CAFFE + } +} +layer { + name: "Dconv3_BN" + type: "BatchNorm" include { phase: TRAIN} + bottom: "Dconv3" + top: "Dconv3_BN" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + batch_norm_param { + use_global_stats: false + moving_average_fraction: 0.95 + } +} +layer { + name: "Dconv3_BN" + type: "BatchNorm" include { phase: TEST} + bottom: "Dconv3" + top: "Dconv3_BN" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + batch_norm_param { + use_global_stats: true + moving_average_fraction: 0.95 + } +} + +layer { + name: "Drelu3" + type: "ReLU" + bottom: "Dconv3_BN" + top: "Dconv3_BN" + relu_param { + negative_slope: 0.2 + } + +} +layer { + name: "Dconv4" + type: "Convolution" + bottom: "Dconv3_BN" + top: "Dconv4" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 0 + decay_mult: 0 + } + convolution_param { + num_output: 512 + kernel_size: 4 + stride: 2 + pad: 1 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } +# engine: CAFFE + } +} +layer { + name: "Dconv4_BN" + type: "BatchNorm" include { phase: TRAIN} + bottom: "Dconv4" + top: "Dconv4_BN" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + batch_norm_param { + use_global_stats: false + moving_average_fraction: 0.95 + } +} +layer { + name: "Dconv4_BN" + type: "BatchNorm" include { phase: TEST} + bottom: "Dconv4" + top: "Dconv4_BN" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + batch_norm_param { + use_global_stats: true + moving_average_fraction: 0.95 + } +} + +layer { + name: "Drelu4" + type: "ReLU" + bottom: "Dconv4_BN" + top: "Dconv4_BN" + relu_param { + negative_slope: 0.2 + } + +} +layer { + name: "Dconv5" + type: "Convolution" + bottom: "Dconv4_BN" + top: "Dconv5" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 0 + decay_mult: 0 + } + convolution_param { + num_output: 512 + kernel_size: 4 + stride: 1 + pad: 0 + weight_filler { + type: "gaussian" + std: 0.02 + } + bias_filler { + type: "constant" + value: 0 + } +# engine: CAFFE + } +} +#layer { +# name: "relu5" +# type: "ReLU" +# bottom: "Dconv5" +# top: "Dconv5" +#} +# +#layer { +# name: "Dpool5" +# type: "Pooling" +# bottom: "Dconv5" +# top: "Dpool5" +# pooling_param { +# pool: AVE +# kernel_size: 11 +# stride: 11 +# } +#} + +layer { + name: "Dfc7" + type: "InnerProduct" + bottom: "Dconv5" + top: "Dfc7" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 0 + decay_mult: 0 + } + inner_product_param { + num_output: 1 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } +} + + +layer { + name: "discr_loss" + type: "SigmoidCrossEntropyLoss" + bottom: "Dfc7" + bottom: "label" + top: "discr_loss" + loss_weight: 1 +} + diff --git a/models/intel_optimized_models/dcgan/generate.py b/models/intel_optimized_models/dcgan/generate.py new file mode 100644 index 000000000..2a3a47fcc --- /dev/null +++ b/models/intel_optimized_models/dcgan/generate.py @@ -0,0 +1,36 @@ +import caffe +import numpy as np +import sys +import cv2 +import scipy.io +import scipy.misc + +nz = 100 +img_size = 64 +batch_size = 64 + +gen_net = caffe.Net(sys.argv[1], sys.argv[2], caffe.TEST) + +# Fix the seed to debug +np.random.seed(0) +gen_net.blobs['feat'].data[...] = np.random.normal(0, 1, (batch_size, nz)).astype(np.float32) + +gen_net.forward_simple() + +generated_img = gen_net.blobs['generated'].data + +print generated_img.shape + +print generated_img[0].transpose(1,2,0) +max_val, min_val = np.max(generated_img[0]), np.min(generated_img[0]) + +# Concat all images into a big 8*8 image +flatten_img = ((generated_img.transpose((0,2,3,1)))[:] - min_val) / (max_val-min_val) +print flatten_img.shape +#print flatten_img.reshape(2, 2, 64, 64, 3).shape +scipy.misc.imsave('test1.png', flatten_img.reshape(8,8,img_size,img_size,3).swapaxes(1,2).reshape(8*img_size,8*img_size, 3)) +#cv2.imshow('test1', flatten_img.reshape(8,8,img_size,img_size,3).swapaxes(1,2).reshape(8*img_size,8*img_size, 3)) +#cv2.waitKey() + +#cv2.imshow('test', ((generated_img.transpose((0,2,3,1)))[2] - min_val) / (max_val-min_val)) +#cv2.waitKey() diff --git a/models/intel_optimized_models/dcgan/generator.prototxt b/models/intel_optimized_models/dcgan/generator.prototxt new file mode 100644 index 000000000..92f1dc3de --- /dev/null +++ b/models/intel_optimized_models/dcgan/generator.prototxt @@ -0,0 +1,390 @@ +force_backward: true +input: "feat" +input_shape { + dim: 64 + dim: 100 +} +layer { + name: "reshape" + type: "Reshape" + bottom: "feat" + top: "reshape_defc5" + reshape_param { + shape { + dim: 64 + dim: 100 + dim: 1 + dim: 1 + } + } +} +layer { + name: "deconv5" + type: "Deconvolution" + bottom: "reshape_defc5" + top: "deconv5" + param { + lr_mult: 1 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + convolution_param { + num_output: 512 + pad: 0 + kernel_size: 4 + stride: 1 + weight_filler { + type: "gaussian" + std: 0.02 + } + bias_filler { + type: "constant" + } +# engine: CAFFE + } +} + +layer { + name: "deconv5_BN" + type: "BatchNorm" include { phase: TRAIN} + bottom: "deconv5" + top: "deconv5_BN" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + batch_norm_param { + use_global_stats: false + moving_average_fraction: 0.95 + } +} +layer { + name: "deconv5_BN" + type: "BatchNorm" include { phase: TEST} + bottom: "deconv5" + top: "deconv5_BN" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + batch_norm_param { + use_global_stats: true + moving_average_fraction: 0.95 + } +} + +layer { + name: "relu_deconv5" + type: "ReLU" + bottom: "deconv5_BN" + top: "deconv5_BN" +} +layer { + name: "conv5_1" + type: "Deconvolution" + bottom: "deconv5_BN" + top: "conv5_1" + param { + lr_mult: 1 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + convolution_param { + num_output: 256 + pad: 1 + kernel_size: 4 + stride: 2 + weight_filler { + type: "gaussian" + std: 0.02 + } + bias_filler { + type: "constant" + } +# engine: CAFFE + } +} +layer { + name: "deconv5_1_BN" + type: "BatchNorm" include { phase: TRAIN} + bottom: "conv5_1" + top: "deconv5_1_BN" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + batch_norm_param { + use_global_stats: false + moving_average_fraction: 0.95 + } +} +layer { + name: "deconv5_1_BN" + type: "BatchNorm" include { phase: TEST} + bottom: "conv5_1" + top: "deconv5_1_BN" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + batch_norm_param { + use_global_stats: true + moving_average_fraction: 0.95 + } +} + +layer { + name: "relu_conv5_1" + type: "ReLU" + bottom: "deconv5_1_BN" + top: "deconv5_1_BN" +} +layer { + name: "deconv4" + type: "Deconvolution" + bottom: "deconv5_1_BN" + top: "deconv4" + param { + lr_mult: 1 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + convolution_param { + num_output: 128 + pad: 1 + kernel_size: 4 + stride: 2 + weight_filler { + type: "gaussian" + std: 0.02 + } + bias_filler { + type: "constant" + } +# engine: CAFFE + } +} +layer { + name: "deconv4_BN" + type: "BatchNorm" include { phase: TRAIN} + bottom: "deconv4" + top: "deconv4_BN" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + batch_norm_param { + use_global_stats: false + moving_average_fraction: 0.95 + } +} +layer { + name: "deconv4_BN" + type: "BatchNorm" include { phase: TEST} + bottom: "deconv4" + top: "deconv4_BN" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + batch_norm_param { + use_global_stats: true + moving_average_fraction: 0.95 + } +} + +layer { + name: "relu_deconv4" + type: "ReLU" + bottom: "deconv4_BN" + top: "deconv4_BN" +} +layer { + name: "conv4_1" + type: "Deconvolution" + bottom: "deconv4_BN" + top: "conv4_1" + param { + lr_mult: 1 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + convolution_param { + num_output: 64 + pad: 1 + kernel_size: 4 + stride: 2 + weight_filler { + type: "gaussian" + std: 0.02 + } + bias_filler { + type: "constant" + } +# engine: CAFFE + } +} + +layer { + name: "deconv4_1_BN" + type: "BatchNorm" include { phase: TRAIN} + bottom: "conv4_1" + top: "deconv4_1_BN" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + batch_norm_param { + use_global_stats: false + moving_average_fraction: 0.95 + } +} +layer { + name: "deconv4_1_BN" + type: "BatchNorm" include { phase: TEST} + bottom: "conv4_1" + top: "deconv4_1_BN" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + batch_norm_param { + use_global_stats: true + moving_average_fraction: 0.95 + } +} + +layer { + name: "relu_conv4_1" + type: "ReLU" + bottom: "deconv4_1_BN" + top: "deconv4_1_BN" +} +layer { + name: "generated" + type: "Deconvolution" + bottom: "deconv4_1_BN" + top: "generated" + param { + lr_mult: 1 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + convolution_param { + num_output: 3 + pad: 1 + kernel_size: 4 + stride: 2 + weight_filler { + type: "gaussian" + std: 0.02 + } + bias_filler { + type: "constant" + } +# engine: CAFFE + } +} + +#layer { +# name: "deconv3_relu" +# type: "TanH" +# bottom: "deconv3" +# top: "deconv3_relu" +#} + +#layer { +# name: "deconv0_crop" +# type: "CropSimple" +# bottom: "deconv3" +# top: "deconv0_crop" +# crop_param { +# crop_height: 64 +# crop_width: 64 +# } +#} +#layer { +# name: "generated" +# type: "Eltwise" +# bottom: "deconv3" +# top: "generated" +#} diff --git a/models/intel_optimized_models/dcgan/solver_template.prototxt b/models/intel_optimized_models/dcgan/solver_template.prototxt new file mode 100644 index 000000000..b81d44f50 --- /dev/null +++ b/models/intel_optimized_models/dcgan/solver_template.prototxt @@ -0,0 +1,19 @@ +net: "@NET@.prototxt" +display: 0 +# time_per_iter: 1 +base_lr: 0.0002 +momentum: 0.5 +momentum2: 0.999 +weight_decay: 0.0004 +lr_policy: "multistep" +gamma: 0.5 +stepvalue: 6000 +stepvalue: 10000 +stepvalue: 140000 +stepvalue: 180000 +stepvalue: 220000 +max_iter: 250000 +solver_mode: CPU +solver_type: ADAM +device_id: 0 + diff --git a/models/intel_optimized_models/dcgan/train.py b/models/intel_optimized_models/dcgan/train.py new file mode 100644 index 000000000..7f2a07460 --- /dev/null +++ b/models/intel_optimized_models/dcgan/train.py @@ -0,0 +1,141 @@ +import caffe +import numpy as np +import time +import os +import sys + +if len(sys.argv) == 1: + start_snapshot = 0 + +nz = 100 # latent vector dimension +image_size = 64 # image size +max_iter = int(1e6) # maximum number of iterations +display_every = 20 # show losses every so many iterations +snapshot_every = 1000 # snapshot every so many iterations +snapshot_folder = 'snapshots_test' # where to save the snapshots (and load from) + +feat_shape = (nz,) +im_size = (3,image_size,image_size) +batch_size = 64 +snapshot_at_iter = -1 +snapshot_at_iter_file = 'snapshot_at_iter.txt' + +sub_nets = ('generator', 'discriminator', 'data') + +if not os.path.exists(snapshot_folder): + os.makedirs(snapshot_folder) + +#make solvers +with open ("solver_template.prototxt", "r") as myfile: + solver_template=myfile.read() + +for curr_net in sub_nets: + with open("solver_%s.prototxt" % curr_net, "w") as myfile: + myfile.write(solver_template.replace('@NET@', curr_net)) + +#initialize the nets +generator = caffe.AdamSolver('solver_generator.prototxt') +discriminator = caffe.AdamSolver('solver_discriminator.prototxt') +data_reader = caffe.AdamSolver('solver_data.prototxt') + + +#load from snapshot +if start_snapshot: + curr_snapshot_folder = snapshot_folder +'/' + str(start_snapshot) + print >> sys.stderr, '\n === Starting from snapshot ' + curr_snapshot_folder + ' ===\n' + generator_caffemodel = curr_snapshot_folder +'/' + 'generator.caffemodel' + if os.path.isfile(generator_caffemodel): + generator.net.copy_from(generator_caffemodel) + else: + raise Exception('File %s does not exist' % generator_caffemodel) + discriminator_caffemodel = curr_snapshot_folder +'/' + 'discriminator.caffemodel' + if os.path.isfile(discriminator_caffemodel): + discriminator.net.copy_from(discriminator_caffemodel) + else: + raise Exception('File %s does not exist' % discriminator_caffemodel) + +#read weights of losses +discr_loss_weight = discriminator.net._blob_loss_weights[discriminator.net._blob_names_index['discr_loss']] + +train_discr = True +train_gen = True + +#do training +start = time.time() +for it in range(start_snapshot,max_iter): + # read the data + data_reader.net.forward_simple() + # feed the data to the generator and run it + generator.net.blobs['feat'].data[...] = np.random.normal(0, 1, (64, nz)).astype(np.float32) + generator.net.forward_simple() + generated_img = generator.net.blobs['generated'].data + # run the discriminator on real data + discriminator.net.blobs['data'].data[...] = data_reader.net.blobs['data'].data + discriminator.net.blobs['label'].data[...] = np.ones((batch_size,1), dtype='float32') + discriminator.net.forward_simple() + discr_real_loss = np.copy(discriminator.net.blobs['discr_loss'].data) + if train_discr: + discriminator.increment_iter() + discriminator.net.clear_param_diffs() + discriminator.net.backward_simple() + + # run the discriminator on generated data + discriminator.net.blobs['data'].data[...] = generated_img + discriminator.net.blobs['label'].data[...] = np.zeros((batch_size,1), dtype='float32') + discriminator.net.forward_simple() + discr_fake_loss = np.copy(discriminator.net.blobs['discr_loss'].data) + if train_discr: + discriminator.net.backward_simple() + discriminator.apply_update() + + # run the discriminator on generated data with opposite labels, to get the gradient for the generator + discriminator.net.blobs['label'].data[...] = np.ones((batch_size,1), dtype='float32') + discriminator.net.forward_simple() + discr_fake_for_generator_loss = np.copy(discriminator.net.blobs['discr_loss'].data) + if train_gen: + generator.increment_iter() + generator.net.clear_param_diffs() + discriminator.net.backward_simple() + + generator.net.blobs['generated'].diff[...] = discriminator.net.blobs['data'].diff + generator.net.backward_simple() + generator.apply_update() + + + #display + if it % display_every == 0: + print >> sys.stderr, "[%s] Iteration %d: %f seconds" % (time.strftime("%c"), it, time.time()-start) + print >> sys.stderr, " discr real loss: %e * %e = %f" % (discr_real_loss, discr_loss_weight, discr_real_loss*discr_loss_weight) + print >> sys.stderr, " discr fake loss: %e * %e = %f" % (discr_fake_loss, discr_loss_weight, discr_fake_loss*discr_loss_weight) + print >> sys.stderr, " discr fake loss for generator: %e * %e = %f" % (discr_fake_for_generator_loss, discr_loss_weight, discr_fake_for_generator_loss*discr_loss_weight) + start = time.time() + if os.path.isfile(snapshot_at_iter_file): + with open (snapshot_at_iter_file, "r") as myfile: + snapshot_at_iter = int(myfile.read()) + + #snapshot + if it % snapshot_every == 0 or it == snapshot_at_iter: + curr_snapshot_folder = snapshot_folder +'/' + str(it) + print >> sys.stderr, '\n === Saving snapshot to ' + curr_snapshot_folder + ' ===\n' + if not os.path.exists(curr_snapshot_folder): + os.makedirs(curr_snapshot_folder) + generator_caffemodel = curr_snapshot_folder + '/' + 'generator.caffemodel' + generator.net.save(generator_caffemodel) + discriminator_caffemodel = curr_snapshot_folder + '/' + 'discriminator.caffemodel' + discriminator.net.save(discriminator_caffemodel) + + #switch optimizing discriminator and generator, so that neither of them overfits too much + discr_loss_ratio = (discr_real_loss + discr_fake_loss) / discr_fake_for_generator_loss + if discr_loss_ratio < 1e-1 and train_discr: + train_discr = False + train_gen = True + print >> sys.stderr, "<<< real_loss=%e, fake_loss=%e, fake_loss_for_generator=%e, train_discr=%d, train_gen=%d >>>" % (discr_real_loss, discr_fake_loss, discr_fake_for_generator_loss, train_discr, train_gen) + if discr_loss_ratio > 5e-1 and not train_discr: + train_discr = True + train_gen = True + print >> sys.stderr, " <<< real_loss=%e, fake_loss=%e, fake_loss_for_generator=%e, train_discr=%d, train_gen=%d >>>" % (discr_real_loss, discr_fake_loss, discr_fake_for_generator_loss, train_discr, train_gen) + if discr_loss_ratio > 1e1 and train_gen: + train_gen = False + train_discr = True + print >> sys.stderr, "<<< real_loss=%e, fake_loss=%e, fake_loss_for_generator=%e, train_discr=%d, train_gen=%d >>>" % (discr_real_loss, discr_fake_loss, discr_fake_for_generator_loss, train_discr, train_gen) + diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index 3b02f509b..afafe71d8 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -44,6 +44,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include #include #include +#include #include // these need to be included after boost on OS X @@ -397,7 +398,9 @@ BOOST_PYTHON_MODULE(_caffe) { bp::with_custodian_and_ward<1, 2, bp::with_custodian_and_ward<1, 3> >()) .def("save", &Net_Save) .def("save_hdf5", &Net_SaveHDF5) - .def("load_hdf5", &Net_LoadHDF5); + .def("load_hdf5", &Net_LoadHDF5) + .add_property("_blob_names_index", bp::make_function(&Net::blob_names_index, + bp::return_value_policy())); BP_REGISTER_SHARED_PTR_TO_PYTHON(Net); bp::class_, shared_ptr >, boost::noncopyable>( @@ -431,6 +434,7 @@ BOOST_PYTHON_MODULE(_caffe) { bp::class_("LayerParameter", bp::no_init); + void (Solver::*apply_update_function_pointer)(void) = &Solver::ApplyUpdate; bp::class_, shared_ptr >, boost::noncopyable>( "Solver", bp::no_init) .add_property("net", &Solver::net) @@ -442,7 +446,9 @@ BOOST_PYTHON_MODULE(_caffe) { &Solver::Solve), SolveOverloads()) .def("step", &Solver::Step) .def("restore", &Solver::Restore) - .def("snapshot", &Solver::Snapshot); + .def("snapshot", &Solver::Snapshot) + .def("apply_update", apply_update_function_pointer) + .def("increment_iter", &Solver::increment_iter); BP_REGISTER_SHARED_PTR_TO_PYTHON(Solver); bp::class_, bp::bases >, @@ -466,6 +472,8 @@ BOOST_PYTHON_MODULE(_caffe) { bp::def("get_solver", &GetSolverFromFile, bp::return_value_policy()); + bp::class_ >("MapStringInt") + .def(bp::map_indexing_suite >() ); // vector wrappers for all the vector types we use bp::class_ > > >("BlobVec") diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py index bc606148d..9587446cf 100755 --- a/python/caffe/pycaffe.py +++ b/python/caffe/pycaffe.py @@ -217,6 +217,11 @@ def _Net_backward(self, diffs=None, start=None, end=None, **kwargs): # Unpack diffs to extract return {out: self.blobs[out].diff for out in outputs} +def _Net_forward_simple(self): + self._forward(0, len(self.layers) - 1) + +def _Net_backward_simple(self): + self._backward(len(self.layers) - 1, 0) def _Net_forward_all(self, blobs=None, **kwargs): """ @@ -371,6 +376,8 @@ def get_id_name(self): Net.params = _Net_params Net.forward = _Net_forward Net.backward = _Net_backward +Net.forward_simple = _Net_forward_simple +Net.backward_simple = _Net_backward_simple Net.forward_all = _Net_forward_all Net.forward_backward_all = _Net_forward_backward_all Net.set_input_arrays = _Net_set_input_arrays