-
Notifications
You must be signed in to change notification settings - Fork 200
/
BackpropWeights.cpp
145 lines (125 loc) · 4.88 KB
/
BackpropWeights.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
// Copyright Hugh Perkins 2014,2015 hughperkins at gmail
//
// This Source Code Form is subject to the terms of the Mozilla Public License,
// v. 2.0. If a copy of the MPL was not distributed with this file, You can
// obtain one at http://mozilla.org/MPL/2.0/.
#include <algorithm>
#include "util/StatefulTimer.h"
#include "util/stringhelper.h"
#include "BackpropWeights.h"
#include "BackpropWeightsCpu.h"
#include "BackpropWeightsNaive.h"
#include "BackpropWeightsScratch.h"
#include "BackpropWeightsScratchLarge.h"
#include "BackpropWeightsIm2Col.h"
#include "BackpropWeightsFsword73.h"
#include "BackpropWeightsFsword73_BatchSize.h"
#include "BackpropWeightsAuto.h"
using namespace std;
#undef STATIC
#define STATIC
#undef VIRTUAL
#define VIRTUAL
BackpropWeights::BackpropWeights(EasyCL *cl, LayerDimensions layerDimensions) :
cl(cl),
dim(layerDimensions),
debug(false) {
}
STATIC BackpropWeights *BackpropWeights::instance(EasyCL *cl, LayerDimensions dim) {
// return new BackpropWeightsFsword73(cl, dim);
return new BackpropWeightsAuto(cl, dim);
// if(dim.inputSize - dim.filterSize < 4) {
// return new BackpropWeightsNaive(cl, dim);
// }
// if(square(dim.filterSize) <= cl->getMaxWorkgroupSize()
// && dim.inputSize <= 32) { // if inputimagesize too big, we run out of local memory
// return new BackpropWeightsScratch(cl, dim);
// } else if(square(dim.filterSize) <= cl->getMaxWorkgroupSize()) {
// return new BackpropWeightsScratchLarge(cl, dim);
// } else {
// return new BackpropWeightsNaive(cl, dim);
// }
}
STATIC int BackpropWeights::getNumImplementations() {
return 7;
}
STATIC bool BackpropWeights::plausiblyOptimal(int index, int batchSize, LayerDimensions dim) {
if(index == 0) {
return false;
}
if(index >= 7) {
return false;
}
return true;
}
STATIC BackpropWeights *BackpropWeights::instanceForTest(EasyCL *cl, LayerDimensions layerDimensions) {
return new BackpropWeightsScratchLarge(cl, layerDimensions);
}
STATIC BackpropWeights *BackpropWeights::instanceSpecific(int idx, EasyCL *cl, LayerDimensions layerDimensions) {
if(idx == -1) {
return new BackpropWeightsAuto(cl, layerDimensions);
}
if(idx == 0) {
return new BackpropWeightsNaive(cl, layerDimensions);
}
if(idx == 1) {
return new BackpropWeightsScratch(cl, layerDimensions);
}
if(idx == 2) {
return new BackpropWeightsScratchLarge(cl, layerDimensions);
}
if(idx == 3) {
return new BackpropWeightsIm2Col(cl, layerDimensions);
}
if(idx == 4) {
cout << "fword73 kernel" << endl;
return new BackpropWeightsFsword73(cl, layerDimensions);
}
if (idx == 5) {
cout << "fword73 kernel BatchSize" << endl;
return new BackpropWeightsFsword73_BatchSize(cl, layerDimensions);
}
if (idx == 6) {
return new BackpropWeightsCpu(cl, layerDimensions);
}
throw std::runtime_error("BackpropWeights::instanceSpecific doesnt handle idx " + toString(idx));
}
VIRTUAL void BackpropWeights::calcGradWeights(int batchSize, float *gradOutput, float *inputs, float *gradWeights, float *gradBias) {
StatefulTimer::timeCheck("BackpropWeights::backprop begin");
// const float learningMultiplier = learningRate / batchSize / sqrt(dim.outputSize * dim.outputSize);
int outputNumElements = batchSize * dim.outputCubeSize;
CLWrapper *gradOutputWrapper = cl->wrap(outputNumElements, gradOutput);
gradOutputWrapper->copyToDevice();
int inputNumElements = batchSize * dim.inputCubeSize;
CLWrapper *inputDataWrapper = cl->wrap(inputNumElements, inputs);
inputDataWrapper->copyToDevice();
CLWrapper *gradWeightsWrapper = 0;
int gradWeightsSize = debug ? std::max(10000, dim.filtersSize) : dim.filtersSize;
gradWeightsWrapper = cl->wrap(gradWeightsSize, gradWeights);
gradWeightsWrapper->copyToDevice();
CLWrapper *gradBiasWrapper = 0;
if(dim.biased) {
gradBiasWrapper = cl->wrap(dim.numFilters, gradBias);
gradBiasWrapper->copyToDevice();
}
StatefulTimer::timeCheck("BackpropWeights::backprop after copied to device");
calcGradWeights(batchSize, gradOutputWrapper, inputDataWrapper, gradWeightsWrapper, gradBiasWrapper);
StatefulTimer::timeCheck("BackpropWeights::backprop after call backprop");
gradWeightsWrapper->copyToHost();
if(dim.biased) {
gradBiasWrapper->copyToHost();
}
StatefulTimer::timeCheck("BackpropWeights::backprop after copytohost");
delete gradOutputWrapper;
delete inputDataWrapper;
delete gradWeightsWrapper;
if(dim.biased) {
delete gradBiasWrapper;
}
}
float BackpropWeights::learningRateToMultiplier(int batchSize) {
// float multiplier = rate / batchSize / sqrt(dim.outputSize);
// float multiplier = rate;
// std::cout << "rate " << rate << " multiplier " << multiplier << std::endl;
return 1.0f;
}