/
Dropout.java
216 lines (190 loc) · 8.63 KB
/
Dropout.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nn.conf.dropout;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.layers.HelperUtils;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
import org.nd4j.linalg.exception.ND4JOpProfilerException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;
@Data
@JsonIgnoreProperties({"mask", "helper", "helperCountFail", "initializedHelper"})
@EqualsAndHashCode(exclude = {"mask", "helper", "helperCountFail", "initializedHelper"})
@Slf4j
public class Dropout implements IDropout {
/**
* When using CuDNN and an error is encountered, should fallback to the non-CuDNN implementatation be allowed?
* If set to false, an exception in CuDNN will be propagated back to the user. If false, the built-in
* (non-CuDNN) implementation for LSTM/GravesLSTM will be used
*
*/
@Getter
@Setter
protected boolean helperAllowFallback = true;
private double p;
private ISchedule pSchedule;
private transient INDArray mask;
private transient DropoutHelper helper;
private boolean initializedHelper = false;
private int helperCountFail = 0;
public final static String CUDNN_DROPOUT_HELPER_CLASS_NAME = "org.deeplearning4j.cuda.dropout.CudnnDropoutHelper";
/**
* @param activationRetainProbability Probability of retaining an activation - see {@link Dropout} javadoc
*/
public Dropout(double activationRetainProbability) {
this(activationRetainProbability, null);
if(activationRetainProbability < 0.0){
throw new IllegalArgumentException("Activation retain probability must be > 0. Got: " + activationRetainProbability);
}
if(activationRetainProbability == 0.0){
throw new IllegalArgumentException("Invalid probability value: Dropout with 0.0 probability of retaining "
+ "activations is not supported");
}
}
/**
* @param activationRetainProbabilitySchedule Schedule for probability of retaining an activation - see {@link Dropout} javadoc
*/
public Dropout(ISchedule activationRetainProbabilitySchedule){
this(Double.NaN, activationRetainProbabilitySchedule);
}
/**
* When using a helper (CuDNN or MKLDNN in some cases) and an error is encountered, should fallback to the non-helper implementation be allowed?
* If set to false, an exception in the helper will be propagated back to the user. If false, the built-in
* (non-helper) implementation for Dropout will be used
*
* @param allowFallback Whether fallback to non-helper implementation should be used
*/
public Dropout helperAllowFallback(boolean allowFallback) {
this.setHelperAllowFallback(allowFallback);
return this;
}
protected Dropout(@JsonProperty("p") double activationRetainProbability, @JsonProperty("pSchedule") ISchedule activationRetainProbabilitySchedule) {
this.p = activationRetainProbability;
this.pSchedule = activationRetainProbabilitySchedule;
}
/**
* Initialize the CuDNN dropout helper, if possible
*/
protected void initializeHelper(DataType dataType){
helper = HelperUtils.createHelper(CUDNN_DROPOUT_HELPER_CLASS_NAME,
"", DropoutHelper.class, "dropout-helper", dataType
);
initializedHelper = helper != null;
}
@Override
public INDArray applyDropout(INDArray inputActivations, INDArray output, int iteration, int epoch, LayerWorkspaceMgr workspaceMgr) {
Preconditions.checkState(output.dataType().isFPType(), "Output array must be a floating point type, got %s for array of shape %ndShape",
output.dataType(), output);
double currP;
if(pSchedule != null){
currP = pSchedule.valueAt(iteration, epoch);
} else {
currP = p;
}
if(!initializedHelper){
initializeHelper(output.dataType());
}
if(helper != null && (helperCountFail == 0 || !isHelperAllowFallback())){
boolean helperWorked = false;
try {
helper.applyDropout(inputActivations, output, p);
helperWorked = true;
}catch (ND4JOpProfilerException e){
throw e; //NaN panic etc for debugging
} catch (Exception e){
if(e.getMessage().contains("Failed to allocate")){
//This is a memory exception - don't fallback to built-in implementation
throw e;
}
if(isHelperAllowFallback()){
helperCountFail++;
log.warn("CuDNN execution failed - falling back on built-in implementation",e);
} else {
throw new RuntimeException("Error during Dropout CuDNN helper forward pass - helperAllowFallback() is set to false", e);
}
}
if(helperWorked)
return output;
}
INDArray inputCast = inputActivations;
if(inputCast != output && inputCast.dataType() != output.dataType()){
inputCast = inputCast.castTo(output.dataType());
}
mask = workspaceMgr.createUninitialized(ArrayType.INPUT, output.dataType(), output.shape(), output.ordering()).assign(1.0);
Nd4j.getExecutioner().exec(new DropOutInverted(mask, mask, currP));
Nd4j.getExecutioner().exec(new MulOp(inputCast, mask, output));
return output;
}
@Override
public INDArray backprop(INDArray gradAtOutput, INDArray gradAtInput, int iteration, int epoch) {
if(helper != null && (helperCountFail == 0 || !isHelperAllowFallback())){
boolean helperWorked = false;
try {
helper.backprop(gradAtOutput, gradAtInput);
helperWorked = true;
}catch (ND4JOpProfilerException e){
throw e; //NaN panic etc for debugging
} catch (Exception e){
if(e.getMessage().contains("Failed to allocate")){
//This is a memory exception - don't fallback to built-in implementation
throw e;
}
if(isHelperAllowFallback()){
helperCountFail++;
log.warn("CuDNN execution failed - falling back on built-in implementation",e);
} else {
throw new RuntimeException("Error during Dropout CuDNN helper backprop - helperAllowFallback() is set to false", e);
}
}
if(helperWorked)
return gradAtInput;
}
Preconditions.checkState(mask != null, "Cannot perform backprop: Dropout mask array is absent (already cleared?)");
//dL/dx = dL/dz * dz/dx, with z=0 or x/p
//Mask already contains either 0 or 1/p, so just muli
INDArray m = mask;
if(m.dataType() != gradAtInput.dataType()){
m = m.castTo(gradAtInput.dataType());
}
Nd4j.getExecutioner().exec(new MulOp(gradAtOutput, m, gradAtInput));
mask = null;
return gradAtInput;
}
@Override
public void clear() {
mask = null;
}
@Override
public Dropout clone() {
return new Dropout(p, pSchedule == null ? null : pSchedule.clone());
}
}