Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 57fa328
Showing
8 changed files
with
114 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
bin/ | ||
.settings/ | ||
.cache | ||
.classpath |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
<?xml version="1.0" encoding="UTF-8"?> | ||
<projectDescription> | ||
<name>scala-nnetworks</name> | ||
<comment></comment> | ||
<projects> | ||
</projects> | ||
<buildSpec> | ||
<buildCommand> | ||
<name>org.scala-ide.sdt.core.scalabuilder</name> | ||
<arguments> | ||
</arguments> | ||
</buildCommand> | ||
</buildSpec> | ||
<natures> | ||
<nature>org.scala-ide.sdt.core.scalanature</nature> | ||
<nature>org.eclipse.jdt.core.javanature</nature> | ||
</natures> | ||
</projectDescription> |
Empty file.
Binary file not shown.
Binary file not shown.
Binary file added
BIN
+2.91 MB
lib_managed/jars/org.scalatest/scalatest_2.9.2/scalatest_2.9.2-1.8.jar
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
package nnetworks | ||
|
||
object Functions { | ||
def roundedSigmoid(v: Double): Double = math.round(1.0 / (1.0 + math.exp(-v))) | ||
} | ||
|
||
/** | ||
* layer is a two dimensional array of doubles - weights | ||
*/ | ||
class Layer(weights: Array[Array[Double]]) { | ||
|
||
/** | ||
* Fastest implementation based on http://blog.scala4java.com/2011/12/matrix-multiplication-in-scala-single.html | ||
*/ | ||
def multiThreadedIdiomatic(m1: Array[Double], m2: Array[Array[Double]]) = { | ||
val res = Array.fill[Double](m2(0).length)(0.0) | ||
for ( | ||
col <- (0 until m2(0).length).par; | ||
i <- 0 until m1.par.length | ||
) { | ||
res(col) += m1(i) * m2(i)(col) | ||
} | ||
res | ||
} | ||
|
||
def applyInputs(input: Array[Double], fun: Double => Double): Array[Double] = { | ||
multiThreadedIdiomatic(input, weights) map (fun(_)) | ||
} | ||
|
||
def applyInputs(input: Array[Double]): Array[Double] = { | ||
multiThreadedIdiomatic(input, weights) | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
package nnetworks | ||
|
||
import org.junit.runner.RunWith | ||
import org.scalatest.junit.JUnitRunner | ||
import org.scalatest.FunSuite | ||
|
||
/** | ||
* This class is a test suite for the methods in object FunSets. To run | ||
* the test suite, you can either: | ||
* - run the "test" command in the SBT console | ||
* - right-click the file in eclipse and chose "Run As" - "JUnit Test" | ||
*/ | ||
@RunWith(classOf[JUnitRunner]) | ||
class NetworkTestSuite extends FunSuite { | ||
|
||
test("Test AND network implemented") { | ||
val weights: Array[Array[Double]] = Array(Array(-30.0), Array(20), Array(20)) | ||
val layer = new Layer(weights) | ||
|
||
assert(layer.applyInputs(Array(1., 0., 0.)) === Array(-30.)) | ||
assert(layer.applyInputs(Array(1., 0., 1.)) === Array(-10.)) | ||
assert(layer.applyInputs(Array(1., 1., 0.)) === Array(-10.)) | ||
assert(layer.applyInputs(Array(1., 1., 1.)) === Array(10.)) | ||
|
||
} | ||
|
||
|
||
test("Test AND network implemented with activation function") { | ||
val weights: Array[Array[Double]] = Array(Array(-30.0), Array(20), Array(20)) | ||
val layer = new Layer(weights) | ||
|
||
assert(layer.applyInputs(Array(1., 0., 0.),Functions.roundedSigmoid ) === Array(0)) | ||
assert(layer.applyInputs(Array(1., 0., 1.), Functions.roundedSigmoid) === Array(0)) | ||
assert(layer.applyInputs(Array(1., 1., 0.), Functions.roundedSigmoid) === Array(0)) | ||
assert(layer.applyInputs(Array(1., 1., 1.), Functions.roundedSigmoid) === Array(1)) | ||
|
||
} | ||
|
||
test("Test XOR network implemented") { | ||
val weights: Array[Array[Double]] = Array(Array(-30.0, 10.), Array(20, -20.), Array(20, -10.)) | ||
val layer = new Layer(weights) | ||
|
||
assert(layer.applyInputs(Array(1., 0., 0.), Functions.roundedSigmoid) === Array(0, 1)) | ||
assert(layer.applyInputs(Array(1., 0., 1.), Functions.roundedSigmoid) === Array(0, 1)) | ||
assert(layer.applyInputs(Array(1., 1., 0.),Functions.roundedSigmoid) === Array(0, 0)) | ||
assert(layer.applyInputs(Array(1., 1., 1.), Functions.roundedSigmoid) === Array(1, 0)) | ||
|
||
} | ||
|
||
test("Test sigmoid funciton"){ | ||
assert(Functions.roundedSigmoid(-4) === 0) | ||
assert(Functions.roundedSigmoid(-1) === 0) | ||
assert(Functions.roundedSigmoid(0) === 1) | ||
assert(Functions.roundedSigmoid(1) === 1) | ||
assert(Functions.roundedSigmoid(4) === 1) | ||
} | ||
|
||
} |