Skip to content

Commit

Permalink
Metal changes
Browse files Browse the repository at this point in the history
After several changes the metal functions are fully set up now. It is
in a separate class that has a series of mathematical parallel compute
functions to use. Some minor tests (can’t write proper unit tests for
Metal?) done to validate functionality appear to work. More testing is
needed.
  • Loading branch information
jordenhill committed Mar 3, 2016
1 parent 63c0189 commit ef3cac5
Show file tree
Hide file tree
Showing 15 changed files with 379 additions and 1,728 deletions.
59 changes: 13 additions & 46 deletions Birdbrain.xcodeproj/project.pbxproj
Expand Up @@ -7,19 +7,16 @@
objects = {

/* Begin PBXBuildFile section */
BC00903E1BBB22700086F758 /* Birdbrain.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = BC0090331BBB22700086F758 /* Birdbrain.framework */; };
BC0090431BBB22710086F758 /* FeedforwardNeuralNetworldTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = BC0090421BBB22710086F758 /* FeedforwardNeuralNetworldTests.swift */; };
BC00904E1BBB23390086F758 /* Accelerate.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = BC00904D1BBB23390086F758 /* Accelerate.framework */; };
BC0090501BBB238F0086F758 /* FeedforwardNeuralNetwork.swift in Sources */ = {isa = PBXBuildFile; fileRef = BC00904F1BBB238F0086F758 /* FeedforwardNeuralNetwork.swift */; };
BC0090521BBB24170086F758 /* Math.swift in Sources */ = {isa = PBXBuildFile; fileRef = BC0090511BBB24170086F758 /* Math.swift */; };
BC6C9E8A1C8158C000E83E50 /* word2vec.c in Sources */ = {isa = PBXBuildFile; fileRef = BC6C9E891C8158C000E83E50 /* word2vec.c */; };
BC0621F61C882E0600AB7D5E /* Birdbrain.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = BC0090331BBB22700086F758 /* Birdbrain.framework */; };
BC0621F91C8841AB00AB7D5E /* MetalFunctions.metal in Sources */ = {isa = PBXBuildFile; fileRef = BC0621F71C8841AB00AB7D5E /* MetalFunctions.metal */; };
BC0621FA1C8841AB00AB7D5E /* MetalDevice.swift in Sources */ = {isa = PBXBuildFile; fileRef = BC0621F81C8841AB00AB7D5E /* MetalDevice.swift */; };
BC886A2B1C14B19300A28840 /* Metal.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = BC32C8F21BC79EE600D92D94 /* Metal.framework */; };
BCD993761C7EE29B005A93ED /* Birdbrain-Bridging-Header.h in Headers */ = {isa = PBXBuildFile; fileRef = BCD993751C7EE29B005A93ED /* Birdbrain-Bridging-Header.h */; };
BCD993871C7EE8F9005A93ED /* Word2Vec.h in Headers */ = {isa = PBXBuildFile; fileRef = BCD993861C7EE8F9005A93ED /* Word2Vec.h */; };
BCDFE1501C068C6200BA1A65 /* BirdbrainMathTest.swift in Sources */ = {isa = PBXBuildFile; fileRef = BCDFE14F1C068C6200BA1A65 /* BirdbrainMathTest.swift */; };
BCED1B1B1C87C74700C7DD86 /* MetalDevice.swift in Sources */ = {isa = PBXBuildFile; fileRef = BCED1B1A1C87C74700C7DD86 /* MetalDevice.swift */; };
BCED1B1D1C87C75700C7DD86 /* MetalFunctions.metal in Sources */ = {isa = PBXBuildFile; fileRef = BCED1B1C1C87C75700C7DD86 /* MetalFunctions.metal */; };
BCEDA4411C7CC5CC0086EF9A /* FileReader.swift in Sources */ = {isa = PBXBuildFile; fileRef = BCEDA4401C7CC5CC0086EF9A /* FileReader.swift */; };
BCF7F4DE1C0EA3F500AAFBE0 /* RecurrentNeuralNetwork.swift in Sources */ = {isa = PBXBuildFile; fileRef = BCF7F4DD1C0EA3F500AAFBE0 /* RecurrentNeuralNetwork.swift */; };
BCF7F4E01C10340800AAFBE0 /* LSTMNetwork.swift in Sources */ = {isa = PBXBuildFile; fileRef = BCF7F4DF1C10340800AAFBE0 /* LSTMNetwork.swift */; };
/* End PBXBuildFile section */
Expand All @@ -43,19 +40,12 @@
BC00904D1BBB23390086F758 /* Accelerate.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Accelerate.framework; path = System/Library/Frameworks/Accelerate.framework; sourceTree = SDKROOT; };
BC00904F1BBB238F0086F758 /* FeedforwardNeuralNetwork.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FeedforwardNeuralNetwork.swift; sourceTree = "<group>"; };
BC0090511BBB24170086F758 /* Math.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Math.swift; sourceTree = "<group>"; };
BC0621F71C8841AB00AB7D5E /* MetalFunctions.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = MetalFunctions.metal; sourceTree = "<group>"; };
BC0621F81C8841AB00AB7D5E /* MetalDevice.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = MetalDevice.swift; sourceTree = "<group>"; };
BC32C8F21BC79EE600D92D94 /* Metal.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Metal.framework; path = System/Library/Frameworks/Metal.framework; sourceTree = SDKROOT; };
BC6C9E821C81534C00E83E50 /* word-analogy.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = "word-analogy.c"; sourceTree = "<group>"; };
BC6C9E871C81579400E83E50 /* distance.c */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.c; path = distance.c; sourceTree = "<group>"; };
BC6C9E881C8157BE00E83E50 /* compute-accuracy.c */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.c; path = "compute-accuracy.c"; sourceTree = "<group>"; };
BC6C9E891C8158C000E83E50 /* word2vec.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = word2vec.c; sourceTree = "<group>"; };
BC9744BE1C80CA030034473F /* word2phrase.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = word2phrase.c; sourceTree = "<group>"; };
BCD993751C7EE29B005A93ED /* Birdbrain-Bridging-Header.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "Birdbrain-Bridging-Header.h"; sourceTree = "<group>"; };
BCD993861C7EE8F9005A93ED /* Word2Vec.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = Word2Vec.h; sourceTree = "<group>"; };
BCDFE14F1C068C6200BA1A65 /* BirdbrainMathTest.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = BirdbrainMathTest.swift; sourceTree = "<group>"; };
BCED1B1A1C87C74700C7DD86 /* MetalDevice.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = MetalDevice.swift; sourceTree = "<group>"; };
BCED1B1C1C87C75700C7DD86 /* MetalFunctions.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = MetalFunctions.metal; sourceTree = "<group>"; };
BCEDA43F1C7CB4AA0086EF9A /* MyPlayground.playground */ = {isa = PBXFileReference; lastKnownFileType = file.playground; path = MyPlayground.playground; sourceTree = "<group>"; };
BCEDA4401C7CC5CC0086EF9A /* FileReader.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FileReader.swift; sourceTree = "<group>"; };
BCF7F4DD1C0EA3F500AAFBE0 /* RecurrentNeuralNetwork.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = RecurrentNeuralNetwork.swift; sourceTree = "<group>"; };
BCF7F4DF1C10340800AAFBE0 /* LSTMNetwork.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = LSTMNetwork.swift; sourceTree = "<group>"; };
/* End PBXFileReference section */
Expand All @@ -67,14 +57,14 @@
files = (
BC886A2B1C14B19300A28840 /* Metal.framework in Frameworks */,
BC00904E1BBB23390086F758 /* Accelerate.framework in Frameworks */,
BC0621F61C882E0600AB7D5E /* Birdbrain.framework in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
BC00903A1BBB22700086F758 /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
BC00903E1BBB22700086F758 /* Birdbrain.framework in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
Expand Down Expand Up @@ -104,8 +94,6 @@
BC0090351BBB22700086F758 /* Birdbrain */ = {
isa = PBXGroup;
children = (
BC6C9E861C81560F00E83E50 /* word-analogy.c */,
BCD993791C7EE454005A93ED /* Word2Vec */,
BCEDA43F1C7CB4AA0086EF9A /* MyPlayground.playground */,
BCF7F4E71C1142F300AAFBE0 /* Other */,
BCF7F4E61C113F9500AAFBE0 /* Metal */,
Expand All @@ -126,26 +114,6 @@
path = BirdbrainTests;
sourceTree = "<group>";
};
BC6C9E861C81560F00E83E50 /* word-analogy.c */ = {
isa = PBXGroup;
children = (
);
name = "word-analogy.c";
sourceTree = "<group>";
};
BCD993791C7EE454005A93ED /* Word2Vec */ = {
isa = PBXGroup;
children = (
BC6C9E881C8157BE00E83E50 /* compute-accuracy.c */,
BC6C9E871C81579400E83E50 /* distance.c */,
BC9744BE1C80CA030034473F /* word2phrase.c */,
BC6C9E821C81534C00E83E50 /* word-analogy.c */,
BC6C9E891C8158C000E83E50 /* word2vec.c */,
BCD993861C7EE8F9005A93ED /* Word2Vec.h */,
);
name = Word2Vec;
sourceTree = "<group>";
};
BCF7F4E41C113F6200AAFBE0 /* Math */ = {
isa = PBXGroup;
children = (
Expand All @@ -167,8 +135,8 @@
BCF7F4E61C113F9500AAFBE0 /* Metal */ = {
isa = PBXGroup;
children = (
BCED1B1A1C87C74700C7DD86 /* MetalDevice.swift */,
BCED1B1C1C87C75700C7DD86 /* MetalFunctions.metal */,
BC0621F71C8841AB00AB7D5E /* MetalFunctions.metal */,
BC0621F81C8841AB00AB7D5E /* MetalDevice.swift */,
);
name = Metal;
sourceTree = "<group>";
Expand All @@ -177,7 +145,6 @@
isa = PBXGroup;
children = (
BC0090381BBB22700086F758 /* Info.plist */,
BCEDA4401C7CC5CC0086EF9A /* FileReader.swift */,
);
name = Other;
sourceTree = "<group>";
Expand All @@ -189,7 +156,6 @@
isa = PBXHeadersBuildPhase;
buildActionMask = 2147483647;
files = (
BCD993871C7EE8F9005A93ED /* Word2Vec.h in Headers */,
BCD993761C7EE29B005A93ED /* Birdbrain-Bridging-Header.h in Headers */,
);
runOnlyForDeploymentPostprocessing = 0;
Expand Down Expand Up @@ -239,6 +205,9 @@
BC00902A1BBB22700086F758 /* Project object */ = {
isa = PBXProject;
attributes = {
KnownAssetTags = (
New,
);
LastSwiftUpdateCheck = 0730;
LastUpgradeCheck = 0710;
ORGANIZATIONNAME = "Jorden Hill";
Expand Down Expand Up @@ -292,13 +261,11 @@
buildActionMask = 2147483647;
files = (
BCF7F4E01C10340800AAFBE0 /* LSTMNetwork.swift in Sources */,
BC0621F91C8841AB00AB7D5E /* MetalFunctions.metal in Sources */,
BC0090501BBB238F0086F758 /* FeedforwardNeuralNetwork.swift in Sources */,
BCED1B1B1C87C74700C7DD86 /* MetalDevice.swift in Sources */,
BCF7F4DE1C0EA3F500AAFBE0 /* RecurrentNeuralNetwork.swift in Sources */,
BCEDA4411C7CC5CC0086EF9A /* FileReader.swift in Sources */,
BCED1B1D1C87C75700C7DD86 /* MetalFunctions.metal in Sources */,
BC0090521BBB24170086F758 /* Math.swift in Sources */,
BC6C9E8A1C8158C000E83E50 /* word2vec.c in Sources */,
BC0621FA1C8841AB00AB7D5E /* MetalDevice.swift in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
Expand Down
Binary file not shown.
4 changes: 1 addition & 3 deletions Birdbrain/Birdbrain-Bridging-Header.h
@@ -1,3 +1 @@
//Bridging header used to connect the swift code to the C-based Word2Vec program.

#import "Word2Vec.h"
//Bridging header used to connect the swift code to the C-based Word2Vec program.
94 changes: 0 additions & 94 deletions Birdbrain/FileReader.swift

This file was deleted.

52 changes: 41 additions & 11 deletions Birdbrain/LSTMNetwork.swift
Expand Up @@ -19,10 +19,13 @@ public class LSTMNetwork {
var woh: [Float]
var inputDim: Int
var memCellCount: Int
var GPU: MetalDevice!
var useMetal: Bool

public init(inputDim: Int, memCellCount: Int) {
public init(inputDim: Int, useMetal: Bool, memCellCount: Int) {
self.inputDim = inputDim
self.memCellCount = memCellCount
self.useMetal = useMetal
wgx = (1...inputDim * memCellCount).map{_ in initRand(inputDim)}
wix = (1...inputDim * memCellCount).map{_ in initRand(inputDim)}
wfx = (1...inputDim * memCellCount).map{_ in initRand(inputDim)}
Expand All @@ -31,9 +34,13 @@ public class LSTMNetwork {
wfh = (1...memCellCount * memCellCount).map{_ in initRand(inputDim)}
wih = (1...memCellCount * memCellCount).map{_ in initRand(inputDim)}
woh = (1...memCellCount * memCellCount).map{_ in initRand(inputDim)}

if (useMetal) {
GPU = MetalDevice()
}
}

public func feedforward(input: [[Float]], useMetal: Bool)
public func feedforward(input: [[Float]])
-> ([[Float]], [[Float]], [[Float]]) {

if (useMetal) {
Expand All @@ -44,7 +51,6 @@ public class LSTMNetwork {
}

private func GPUCompute(input: [[Float]]) -> ([[Float]], [[Float]], [[Float]]) {
let GPU = MetalDevice()
let T = input.count
let start: [Float] = (1...memCellCount).map{_ in 0.0}
var g: [[Float]] = (1...T).map{_ in (1...memCellCount).map{_ in 0.0} }
Expand All @@ -53,7 +59,7 @@ public class LSTMNetwork {
var o: [[Float]] = (1...T).map{_ in (1...memCellCount).map{_ in 0.0} }
var s: [[Float]] = (1...T).map{_ in (1...memCellCount).map{_ in 0.0} }
var h: [[Float]] = (1...T).map{_ in (1...memCellCount).map{_ in 0.0} }
var y: [[Float]] = (1...T).map{_ in (1...memCellCount).map{_ in 0.0} }
var p: [[Float]] = (1...T).map{_ in (1...memCellCount).map{_ in 0.0} }

g[0] = GPU.tanh(GPU.add(GPU.mvMul(wgx, m: memCellCount, n: inputDim, vector: input[0]),
y: GPU.mvMul(wgh, m: memCellCount, n: memCellCount, vector: start)))
Expand All @@ -65,7 +71,7 @@ public class LSTMNetwork {
y: GPU.mvMul(woh, m: memCellCount, n: memCellCount, vector: start)))
s[0] = add(GPU.mul(g[0], y: i[0]), y: mul(s[0], y: f[0]))
h[0] = mul(s[0], y: o[0])
y[0] = softmax(s[0])
p[0] = softmax(s[0])

for t in 1..<T {
g[t] = GPU.tanh(GPU.add(GPU.mvMul(wgx, m: memCellCount, n: inputDim, vector: input[t]),
Expand All @@ -78,9 +84,9 @@ public class LSTMNetwork {
y: GPU.mvMul(woh, m: memCellCount, n: memCellCount, vector: h[t - 1])))
s[t] = GPU.add(GPU.mul(g[t], y: i[t]), y: mul(s[t - 1], y: f[t]))
h[t] = GPU.mul(s[t], y: o[t])
y[t] = softmax(s[t])
p[t] = softmax(s[t])
}
return (s, y, h)
return (s, h, p)
}

private func CPUCompute(input: [[Float]]) -> ([[Float]], [[Float]], [[Float]]) {
Expand All @@ -92,7 +98,7 @@ public class LSTMNetwork {
var o: [[Float]] = (1...T).map{_ in (1...memCellCount).map{_ in 0.0} }
var s: [[Float]] = (1...T).map{_ in (1...memCellCount).map{_ in 0.0} }
var h: [[Float]] = (1...T).map{_ in (1...memCellCount).map{_ in 0.0} }
var y: [[Float]] = (1...T).map{_ in (1...memCellCount).map{_ in 0.0} }
var p: [[Float]] = (1...T).map{_ in (1...memCellCount).map{_ in 0.0} }

g[0] = tanh(add(mvMul(wgx, m: memCellCount, n: inputDim, x: input[0]),
y: mvMul(wgh, m: memCellCount, n: memCellCount, x: start)))
Expand All @@ -104,7 +110,7 @@ public class LSTMNetwork {
y: mvMul(woh, m: memCellCount, n: memCellCount, x: start)))
s[0] = add(mul(g[0], y: i[0]), y: mul(s[0], y: f[0]))
h[0] = mul(s[0], y: o[0])
y[0] = softmax(s[0])
p[0] = softmax(s[0])

for t in 1..<T {
g[t] = tanh(add(mvMul(wgx, m: memCellCount, n: inputDim, x: input[t]),
Expand All @@ -117,8 +123,32 @@ public class LSTMNetwork {
y: mvMul(woh, m: memCellCount, n: memCellCount, x: h[t - 1])))
s[t] = add(mul(g[t], y: i[t]), y: mul(s[t - 1], y: f[t]))
h[t] = mul(s[t], y: o[t])
y[t] = softmax(s[t])
p[t] = softmax(s[t])
}
return (s, h, p)
}

/**Calculates and returns the loss of the LSTM network.
- Parameter input: The input to the network.
- Parameter target: The target output, given as an array containing arrays of expected indexes.
- numExamples: The number of examples used for training.
- Returns: The loss of the network given the outputs and inputs.
*/
public func calculateLoss(input: [[Float]], target: [[Int]], numExamples: Int) -> Float {
let (_, _, p) = feedforward(input)
var L = [Float]()
var y = [[Float]](count: p.count, repeatedValue: [Float](count: p[0].count, repeatedValue: 0.0))

for i in 0..<target.count {
for k in target[i] {
y[i][k] = 1.0
}
}
return (s, y, h)

for i in 0..<target.count {
L.append(sum(mul(y[i], y: log(p[i]))))
}

return sum(L) / Float(numExamples)
}
}

0 comments on commit ef3cac5

Please sign in to comment.