/
Activation.java
152 lines (145 loc) · 5.57 KB
/
Activation.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.linalg.activations;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.activations.impl.*;
public enum Activation {
CUBE, ELU, HARDSIGMOID, HARDTANH, IDENTITY, LEAKYRELU, RATIONALTANH, RELU, RELU6,
RRELU, SIGMOID, SOFTMAX, SOFTPLUS, SOFTSIGN, TANH, RECTIFIEDTANH, SELU, SWISH,
THRESHOLDEDRELU, GELU, MISH;
/**
* Creates an instance of the activation function
*
* @return an instance of the activation function
*/
public IActivation getActivationFunction() {
switch (this) {
case CUBE:
return new ActivationCube();
case ELU:
return new ActivationELU();
case HARDSIGMOID:
return new ActivationHardSigmoid();
case HARDTANH:
return new ActivationHardTanH();
case IDENTITY:
return new ActivationIdentity();
case LEAKYRELU:
return new ActivationLReLU();
case RATIONALTANH:
return new ActivationRationalTanh();
case RECTIFIEDTANH:
return new ActivationRectifiedTanh();
case RELU:
return new ActivationReLU();
case RELU6:
return new ActivationReLU6();
case SELU:
return new ActivationSELU();
case SWISH:
return new ActivationSwish();
case RRELU:
return new ActivationRReLU();
case SIGMOID:
return new ActivationSigmoid();
case SOFTMAX:
return new ActivationSoftmax();
case SOFTPLUS:
return new ActivationSoftPlus();
case SOFTSIGN:
return new ActivationSoftSign();
case TANH:
return new ActivationTanH();
case THRESHOLDEDRELU:
return new ActivationThresholdedReLU();
case GELU:
return new ActivationGELU();
case MISH:
return new ActivationMish();
default:
throw new UnsupportedOperationException("Unknown or not supported activation function: " + this);
}
}
/**
* Returns the activation function enum value
*
* @param name the case-insensitive opName of the activation function
* @return the activation function enum value
*/
public static Activation fromString(String name) {
return Activation.valueOf(name.toUpperCase());
}
/**
* Get the Activation as a SameDiff variable
*
* @param sd SameDiff instance
* @param input Input variable to apply the activation function to
* @return SDVariable: output after applying the activation function
* @see #asSameDiff(SameDiff, SDVariable)
*/
public SDVariable asSameDiff(SameDiff sd, SDVariable input) {
return asSameDiff(null, sd, input);
}
/**
* Get the Activation as a SameDiff variable
*
* @param variableName Variable name
* @param sd SameDiff instance
* @param input Input variable to apply the activation function to
* @return SDVariable: output after applying the activation function
*/
public SDVariable asSameDiff(String variableName, SameDiff sd, SDVariable input) {
switch (this) {
case CUBE:
return sd.math().pow(variableName, input, 3.0);
case ELU:
return sd.nn().elu(variableName, input);
case HARDTANH:
return sd.nn().hardTanh(variableName, input);
case IDENTITY:
return sd.identity(variableName, input);
case LEAKYRELU:
return sd.nn().leakyRelu(variableName, input, 0.0);
case RELU:
return sd.nn().relu(variableName, input, 0.0);
case SIGMOID:
return sd.nn().sigmoid(variableName, input);
case SOFTMAX:
return sd.nn().softmax(variableName, input);
case SOFTPLUS:
return sd.nn().softplus(variableName, input);
case SOFTSIGN:
return sd.nn().softsign(variableName, input);
case TANH:
return sd.math().tanh(variableName, input);
case GELU:
return sd.nn().gelu(variableName, input);
case HARDSIGMOID:
case RATIONALTANH:
case RRELU:
case RECTIFIEDTANH:
case SELU:
case SWISH:
default:
throw new UnsupportedOperationException("Activation function not yet supported: " + this);
}
}
}