Skip to content

Commit

Permalink
add mish activation
Browse files Browse the repository at this point in the history
  • Loading branch information
mikowals committed Oct 14, 2019
1 parent 1bc52ce commit 13fa56a
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 22 deletions.
11 changes: 9 additions & 2 deletions Sources/PreactResNet/PreactResNet.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import TensorFlow
import LayersDataFormat

@differentiable
func mish<Scalar: TensorFlowFloatingPoint>(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
return input * tanh(softplus(input))
}

extension Tensor where Scalar: TensorFlowFloatingPoint {
@differentiable
func l2Loss() -> Tensor<Scalar> {
Expand Down Expand Up @@ -30,6 +35,7 @@ func makeStrides(stride: Int, dataFormat: Raw.DataFormat) -> (Int, Int, Int, Int
}
return strides
}

struct WeightNormConv2D<Scalar: TensorFlowFloatingPoint>: Layer {
var filter: Tensor<Scalar> {
didSet { filter = filter.weightNormalized() }
Expand Down Expand Up @@ -137,7 +143,7 @@ struct PreactConv2D<Scalar: TensorFlowFloatingPoint>: Layer {

@differentiable
func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
let tmp = relu(input + bias1) + bias2
let tmp = mish(input + bias1) + bias2
return tmp.convolved2DDF(withFilter: filter * g,
strides: makeStrides(stride: stride, dataFormat: dataFormat),
padding: .same,
Expand Down Expand Up @@ -344,7 +350,7 @@ public struct PreactResNet<Scalar: TensorFlowFloatingPoint>: Layer {
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar>{
var tmp = conv1(input) * multiplier1 + bias1
tmp = blocks.differentiableReduce(tmp) {last, layer in layer(last)}
tmp = relu(tmp * multiplier2 + bias2)
tmp = mish(tmp * multiplier2 + bias2)
let squeezingAxes = dataFormat == .nchw ? [2, 3] : [1, 2]
tmp = tmp.mean(squeezingAxes: squeezingAxes)
tmp = dense1(tmp)
Expand All @@ -363,6 +369,7 @@ public struct PreactResNet<Scalar: TensorFlowFloatingPoint>: Layer {
bias2 = newValue.bias2
dense1.replaceParameters(newValue.dense1)
}
// No longer used as filters and weights use didSet to maintain normalized weights.
public mutating func projectUnitNorm() {
conv1.filter = conv1.filter.weightNormalized()
for ii in 0 ..< blocks.count {
Expand Down
84 changes: 64 additions & 20 deletions Tests/PreactResNetTests/PreactResNetTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,85 @@ import TensorFlow
@testable import PreactResNet

final class PreactResNetTests: XCTestCase {
/*func testGetSetParameters() {
func testGetSetParameters() {
var model = PreactResNet<Float>(dataFormat: Raw.DataFormat.nhwc)
var parameters = model.differentiableVectorView
let originalParameters = parameters
var count = 0
for kp in parameters.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
count += parameters[keyPath: kp].shape.contiguousSize
parameters[keyPath: kp] = Tensor<Float>(repeating: 0.02, shape: parameters[keyPath: kp].shape)
parameters[keyPath: kp] = Tensor<Float>(repeating: 0.02,
shape: parameters[keyPath: kp].shape)
}
XCTAssertEqual(count, 7352745)
model.replaceParameters(parameters)
parameters = model.differentiableVectorView
var sum = Tensor<Float>(0)
for kp in parameters.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
sum += parameters[keyPath: kp].sum()
XCTAssertNotEqual(parameters[keyPath: kp], originalParameters[keyPath: kp])
}
// This is an example of a functional test case.
// Use XCTAssert and related functions to verify your tests produce the correct
// results.
XCTAssertEqual(count, 7352745)
XCTAssertEqual(sum, Tensor<Float>(7352745 * 2))
}*/
}

func testWeightNormalization(){
var conv = WeightNormConv2D<Float>(filter: Tensor(orthogonal: [3, 3, 2, 1]),
g: Tensor(zeros: [2]))
print ("norm: ", conv.filter.l2Norm(alongAxes: [0,1,2]))
XCTAssertTrue((conv.filter.l2Norm(alongAxes: [0, 1, 2]) .== 1).all(),
"initial weights normalized")
func testWeightNormConv2D(){
var conv = WeightNormConv2D<Float>(filter: Tensor(orthogonal: [3, 3, 32, 64]),
g: Tensor(zeros: [64]))
XCTAssertTrue(
conv.filter.l2Norm(alongAxes: [0, 1, 2]).isAlmostEqual(to: Tensor(ones: [1, 1, 1, 64]),
tolerance: 1e-6),
"initial weights normalized"
)
var direction = conv.differentiableVectorView
direction.filter = Tensor<Float>(repeating: 0.2, shape: [3, 3, 2, 1])
direction.filter = Tensor<Float>(repeating: 0.2, shape: [3, 3, 32, 64])
conv.move(along: direction)
XCTAssertTrue((conv.filter.l2Norm(alongAxes: [0, 1, 2]) .== 1).all(),
"updated weights normalized")
XCTAssertTrue(
conv.filter.l2Norm(alongAxes: [0, 1, 2]).isAlmostEqual(to: Tensor(ones: [1, 1, 1, 64]),
tolerance: 1e-6),
"updated weights normalized"
)
}

func testWeightNormDense(){
var dense = WeightNormDense<Float>(weight: Tensor(orthogonal: [64, 10]),
bias: Tensor(zeros: [10]),
g: Tensor(zeros: [10]))
XCTAssertTrue(
dense.weight.l2Norm(alongAxes: [0]).isAlmostEqual(to: Tensor(ones: [1, 10]),
tolerance: 1e-6),
"initial weights normalized"
)
var direction = dense.differentiableVectorView
direction.weight = Tensor<Float>(repeating: 0.2, shape: [64, 10])
dense.move(along: direction)
XCTAssertTrue(
dense.weight.l2Norm(alongAxes: [0]).isAlmostEqual(to: Tensor(ones: [1, 10]),
tolerance: 1e-6),
"updated weights normalized"
)
}

func testWeightPreactConv2D(){
var preact = PreactConv2D<Float>(filter: Tensor(orthogonal: [3, 3, 32, 64]),
bias1: Tensor(zeros: [64]),
bias2: Tensor(zeros: [64]),
g: Tensor(zeros: [64]))
XCTAssertTrue(
preact.filter.l2Norm(alongAxes: [0, 1, 2]).isAlmostEqual(to: Tensor(ones: [1, 1, 1, 64]),
tolerance: 1e-6),
"initial weights normalized"
)
var direction = preact.differentiableVectorView
direction.filter = Tensor<Float>(repeating: 0.2, shape: [3, 3, 32, 64])
preact.move(along: direction)
XCTAssertTrue(
preact.filter.l2Norm(alongAxes: [0, 1, 2]).isAlmostEqual(to: Tensor(ones: [1, 1, 1, 64]),
tolerance: 1e-6),
"updated weights normalized"
)
}

static var allTests = [
("testWeightNormalization", testWeightNormalization),
("testGetSetParameters", testGetSetParameters),
("testWeightNormConv2D", testWeightNormConv2D),
("testWeightPreactConv2D", testWeightPreactConv2D),
("testWeightNormDense", testWeightNormDense),
]
}

0 comments on commit 13fa56a

Please sign in to comment.