Skip to content

Commit

Permalink
Merge pull request BVLC#686 from jeffdonahue/loss-generalization
Browse files Browse the repository at this point in the history
Loss generalization
  • Loading branch information
jeffdonahue committed Aug 13, 2014
2 parents 6ebd2eb + e7d97f5 commit 450c3d0
Show file tree
Hide file tree
Showing 87 changed files with 1,086 additions and 521 deletions.
1 change: 0 additions & 1 deletion examples/mnist/lenet_consolidated_solver.prototxt
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ net_param {
bottom: "ip2"
bottom: "label"
top: "accuracy"
include: { phase: TEST }
}
layers {
name: "loss"
Expand Down
40 changes: 20 additions & 20 deletions include/caffe/common_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class ArgMaxLayer : public Layer<Dtype> {
public:
explicit ArgMaxLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);

virtual inline LayerParameter_LayerType type() const {
Expand All @@ -39,7 +39,7 @@ class ArgMaxLayer : public Layer<Dtype> {
virtual inline int ExactNumTopBlobs() const { return 1; }

protected:
virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
Expand All @@ -58,7 +58,7 @@ class ConcatLayer : public Layer<Dtype> {
public:
explicit ConcatLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);

virtual inline LayerParameter_LayerType type() const {
Expand All @@ -68,9 +68,9 @@ class ConcatLayer : public Layer<Dtype> {
virtual inline int ExactNumTopBlobs() const { return 1; }

protected:
virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
Expand All @@ -93,7 +93,7 @@ class FlattenLayer : public Layer<Dtype> {
public:
explicit FlattenLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);

virtual inline LayerParameter_LayerType type() const {
Expand All @@ -103,9 +103,9 @@ class FlattenLayer : public Layer<Dtype> {
virtual inline int ExactNumTopBlobs() const { return 1; }

protected:
virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
Expand All @@ -122,7 +122,7 @@ class MVNLayer : public Layer<Dtype> {
public:
explicit MVNLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);

virtual inline LayerParameter_LayerType type() const {
Expand All @@ -132,9 +132,9 @@ class MVNLayer : public Layer<Dtype> {
virtual inline int ExactNumTopBlobs() const { return 1; }

protected:
virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
Expand All @@ -154,7 +154,7 @@ class SoftmaxLayer : public Layer<Dtype> {
public:
explicit SoftmaxLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);

virtual inline LayerParameter_LayerType type() const {
Expand All @@ -164,9 +164,9 @@ class SoftmaxLayer : public Layer<Dtype> {
virtual inline int ExactNumTopBlobs() const { return 1; }

protected:
virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
Expand All @@ -186,7 +186,7 @@ class SplitLayer : public Layer<Dtype> {
public:
explicit SplitLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);

virtual inline LayerParameter_LayerType type() const {
Expand All @@ -196,9 +196,9 @@ class SplitLayer : public Layer<Dtype> {
virtual inline int MinTopBlobs() const { return 1; }

protected:
virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
Expand All @@ -217,7 +217,7 @@ class SliceLayer : public Layer<Dtype> {
public:
explicit SliceLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);

virtual inline LayerParameter_LayerType type() const {
Expand All @@ -227,9 +227,9 @@ class SliceLayer : public Layer<Dtype> {
virtual inline int MinTopBlobs() const { return 2; }

protected:
virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
Expand Down
38 changes: 19 additions & 19 deletions include/caffe/data_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class DataLayer : public Layer<Dtype>, public InternalThread {
explicit DataLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual ~DataLayer();
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);

virtual inline LayerParameter_LayerType type() const {
Expand All @@ -42,9 +42,9 @@ class DataLayer : public Layer<Dtype>, public InternalThread {
virtual inline int MaxTopBlobs() const { return 2; }

protected:
virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
Expand Down Expand Up @@ -85,7 +85,7 @@ class DummyDataLayer : public Layer<Dtype> {
public:
explicit DummyDataLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);

virtual inline LayerParameter_LayerType type() const {
Expand All @@ -95,7 +95,7 @@ class DummyDataLayer : public Layer<Dtype> {
virtual inline int MinTopBlobs() const { return 1; }

protected:
virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
Expand All @@ -112,7 +112,7 @@ class HDF5DataLayer : public Layer<Dtype> {
explicit HDF5DataLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual ~HDF5DataLayer();
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);

virtual inline LayerParameter_LayerType type() const {
Expand All @@ -122,9 +122,9 @@ class HDF5DataLayer : public Layer<Dtype> {
virtual inline int ExactNumTopBlobs() const { return 2; }

protected:
virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
Expand All @@ -145,7 +145,7 @@ class HDF5OutputLayer : public Layer<Dtype> {
public:
explicit HDF5OutputLayer(const LayerParameter& param);
virtual ~HDF5OutputLayer();
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {}

virtual inline LayerParameter_LayerType type() const {
Expand All @@ -158,9 +158,9 @@ class HDF5OutputLayer : public Layer<Dtype> {
inline std::string file_name() const { return file_name_; }

protected:
virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
Expand All @@ -180,7 +180,7 @@ class ImageDataLayer : public Layer<Dtype>, public InternalThread {
explicit ImageDataLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual ~ImageDataLayer();
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);

virtual inline LayerParameter_LayerType type() const {
Expand All @@ -190,9 +190,9 @@ class ImageDataLayer : public Layer<Dtype>, public InternalThread {
virtual inline int ExactNumTopBlobs() const { return 2; }

protected:
virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
Expand Down Expand Up @@ -226,7 +226,7 @@ class MemoryDataLayer : public Layer<Dtype> {
public:
explicit MemoryDataLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);

virtual inline LayerParameter_LayerType type() const {
Expand All @@ -244,7 +244,7 @@ class MemoryDataLayer : public Layer<Dtype> {
int batch_size() { return batch_size_; }

protected:
virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
Expand All @@ -268,7 +268,7 @@ class WindowDataLayer : public Layer<Dtype>, public InternalThread {
explicit WindowDataLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual ~WindowDataLayer();
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);

virtual inline LayerParameter_LayerType type() const {
Expand All @@ -278,9 +278,9 @@ class WindowDataLayer : public Layer<Dtype>, public InternalThread {
virtual inline int ExactNumTopBlobs() const { return 2; }

protected:
virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
Expand Down

0 comments on commit 450c3d0

Please sign in to comment.