Skip to content

Commit

Permalink
force modules to be products so that their fields can be traversed
Browse files Browse the repository at this point in the history
  • Loading branch information
ctongfei committed Jun 10, 2019
1 parent 1885f1a commit ce4dc08
Show file tree
Hide file tree
Showing 24 changed files with 184 additions and 110 deletions.
5 changes: 5 additions & 0 deletions diff/src/main/scala/nexus/diff/AnyOp.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
package nexus.diff

import shapeless.Nat

/**
* Basic trait for all operators, regardless of its arity.
*
* @tparam Y Output type
* @author Tongfei Chen
*/
trait AnyOp[Y] {

type Arity <: Nat

/** The arity of this operator. */
def arity: Int

Expand Down
8 changes: 4 additions & 4 deletions diff/src/main/scala/nexus/diff/Assignment.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ package nexus.diff
import shapeless._

/**
* Represents an assignment to a symbolic expression, which takes the form `D[X] := X`, where `D` is a differentiable box.
* Represents an assignment to a symbolic expression, which takes the form `F[X] := X`, where `F` is a computation box.
* @author Tongfei Chen
* @since 0.1.0
*/
trait Assignment[D[_]] extends BoxValuePair[D, Id] {
trait Assignment[F[_]] extends BoxValuePair[F, Id] {

override def toString = s"$expr := $value"

Expand All @@ -16,14 +16,14 @@ trait Assignment[D[_]] extends BoxValuePair[D, Id] {
object Assignment {

/** Creates an assignment for an symbolic expression. */
def apply[D[_], X](x: D[X], v: X): Assignment[D] = new Assignment[D] {
def apply[F[_], X](x: F[X], v: X): Assignment[F] = new Assignment[F] {
type Data = X
val expr = x
val value = v
}

// Use of dependent types in the return type
def unapply[D[_]](a: Assignment[D]): Option[(D[a.Data], a.Data)] =
def unapply[F[_]](a: Assignment[F]): Option[(F[a.Data], a.Data)] =
Some((a.expr, a.value))

}
1 change: 1 addition & 0 deletions diff/src/main/scala/nexus/diff/Batch.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ trait BatchTensor[T[_], E, U] extends Batch[T[U]] {
implicit def T: IsTensorK[T, E]

def batchSize = T.sizeOfDim(underlying, 0)

override def apply(i: Int) = {
val v: T[Uh] = T.sliceAlong(underlying, BatchDim, i)
v.asInstanceOf[T[U]] // this cast is safe
Expand Down
24 changes: 0 additions & 24 deletions diff/src/main/scala/nexus/diff/CompactTensorSequence.scala

This file was deleted.

11 changes: 10 additions & 1 deletion diff/src/main/scala/nexus/diff/Func.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,17 @@ package nexus.diff

import nexus._
import cats.arrow._
import shapeless.{HList, HNil}

trait Func0[Y] {
trait HFunc[Xs <: HList, Y] {
def apply[F[_]: Algebra](xs: F[Xs]): F[Y]
}

trait Func0[Y] { self =>
def apply[F[_]: Algebra](): F[Y]
def asHFunc: HFunc[HNil, Y] = new HFunc[HNil, Y] {
def apply[F[_] : Algebra](xs: F[HNil]) = self.apply()
}
}

trait Func1[X, Y] extends (X => Y) {
Expand All @@ -25,6 +33,7 @@ trait Func3[X1, X2, X3, Y] extends ((X1, X2, X3) => Y) {

object Func1 {

/** `Func1` forms a category. */
implicit object Category extends Category[Func1] {
def id[A] = new Func1[A, A] {
def apply[F[_]: Algebra](x: F[A]) = x
Expand Down
9 changes: 9 additions & 0 deletions diff/src/main/scala/nexus/diff/HasParameters.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package nexus.diff

import nexus.diff.io.Serializer

/**
* Base trait for anything that can be serialized (as a map of keys to parameters).
* These include modules and optimizers.
*
* @author Tongfei Chen
* @since 0.1.0
*/
Expand All @@ -13,4 +16,10 @@ trait HasParameters {

def loadFromParameterMap(m: Map[String, Param[_]]): Unit

def save(filename: String, format: Serializer): Unit =
format.save(parameterMap, filename)

def load(filename: String, format: Serializer): Unit =
this.loadFromParameterMap(format.load(filename))

}
51 changes: 32 additions & 19 deletions diff/src/main/scala/nexus/diff/Module.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
package nexus.diff


trait AnyModule extends HasParameters {

def parameters: Set[Param[_]]

def parameterMap =
import shapeless.{HList, Nat}

/**
* Base trait for all modules.
* Modules serve as basic building blocks of neural network that may contain parameters.
*
* A module should always be implemented as an instance of the trait `Product`, preferably as a '''case class''':
* This enables recursive traversal over its components to get all the parameters.
*/
trait AnyModule extends HasParameters { self: Product =>

/**
* Returns the set of all parameters in this module.
*/
def parameters: Set[Param[_]] = this.productIterator.flatMap {
case p: Param[_] => Set(p)
case p: AnyModule => p.parameters
case _ => Set()
}.toSet

def parameterMap: Map[String, Param[_]] =
parameters.view.map(p => p.name -> p).toMap

def loadFromParameterMap(m: Map[String, Param[_]]): Unit =
Expand All @@ -14,23 +29,21 @@ trait AnyModule extends HasParameters {

}

trait Module0[Y] extends Func0[Y] with AnyModule
trait Module0[Y] extends Func0[Y] with AnyModule { self: Product => }

trait Module1[X, Y] extends Func1[X, Y] with AnyModule { self =>

def >>[Z](that: Module1[Y, Z]): Module1[X, Z] = new Module1[X, Z] {

def parameters = self.parameters union (that match {
case that: Module1[Y, Z] => that.parameters
case _ => Set()
})
def apply[F[_]: Algebra](x: F[X]): F[Z] = that(self(x))
}
trait Module1[X, Y] extends Func1[X, Y] with AnyModule { self: Product =>

def >>[Z](that: Module1[Y, Z]): Module1[X, Z] = Module1.Composed(this, that)
def >>[Z](that: PolyModule1)(implicit p: that.P[Y, Z]): Module1[X, Z] = self >> that.ground(p)

}

trait Module2[X1, X2, Y] extends Func2[X1, X2, Y] with AnyModule
object Module1 {
case class Composed[X, Y, Z](f: Module1[X, Y], g: Module1[Y, Z]) extends Module1[X, Z] {
def apply[F[_]: Algebra](x: F[X]): F[Z] = g(f(x))
}
}

trait Module2[X1, X2, Y] extends Func2[X1, X2, Y] with AnyModule { self: Product => }

trait Module3[X1, X2, X3, Y] extends Func3[X1, X2, X3, Y] with AnyModule
trait Module3[X1, X2, X3, Y] extends Func3[X1, X2, X3, Y] with AnyModule { self: Product => }
6 changes: 6 additions & 0 deletions diff/src/main/scala/nexus/diff/Op.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package nexus.diff

import nexus.diff.exception._
import shapeless.HList
import shapeless.Nat._

trait Op0[Y] extends Func0[Y] with AnyOp[Y] {
type Arity = _0
final def arity = 0
def forward(): Y
def apply[F[_]]()(implicit F: Algebra[F]): F[Y] = F.app0(this)
Expand All @@ -15,6 +18,7 @@ trait Op0[Y] extends Func0[Y] with AnyOp[Y] {
*/
trait Op1[X, Y] extends Func1[X, Y] with AnyOp[Y] { op =>

type Arity = _1
final def arity = 1

/** Applies this operation to a symbolic expression. */
Expand Down Expand Up @@ -52,6 +56,7 @@ object Op1 {
*/
trait Op2[X1, X2, Y] extends Func2[X1, X2, Y] with AnyOp[Y] {

type Arity = _2
final def arity = 2

/** Applies this operation to two symbolic expressions. */
Expand Down Expand Up @@ -109,6 +114,7 @@ object Op2 {
*/
trait Op3[X1, X2, X3, Y] extends Func3[X1, X2, X3, Y] with AnyOp[Y] {

type Arity = _3
final def arity = 3

def forward(x1: X1, x2: X2, x3: X3): Y
Expand Down
6 changes: 3 additions & 3 deletions diff/src/main/scala/nexus/diff/PolyModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package nexus.diff
*/
trait PolyModule1 extends PolyFunc1 {

trait P[X, Y] extends Module1[X, Y]
trait P[X, Y] extends Module1[X, Y] with Product

def ground[X, Y](implicit p: P[X, Y]) = p
}
Expand All @@ -17,7 +17,7 @@ trait PolyModule1 extends PolyFunc1 {
*/
trait PolyModule2 extends PolyFunc2 {

trait P[X1, X2, Y] extends Module2[X1, X2, Y]
trait P[X1, X2, Y] extends Module2[X1, X2, Y] with Product

def ground[X1, X2, Y](implicit p: P[X1, X2, Y]) = p

Expand All @@ -30,7 +30,7 @@ trait PolyModule2 extends PolyFunc2 {
*/
trait PolyModule3 extends PolyFunc3 {

trait P[X1, X2, X3, Y] extends Module3[X1, X2, X3, Y]
trait P[X1, X2, X3, Y] extends Module3[X1, X2, X3, Y] with Product

def ground[X1, X2, X3, Y](implicit p: P[X1, X2, X3, Y]) = p

Expand Down
19 changes: 9 additions & 10 deletions diff/src/main/scala/nexus/diff/Unroll.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,23 @@ import cats._
/**
* Typeclass witnessing a structure `S[_]` can be unrolled under the computation box `F[_]`.
* For example, a traced sequence `Traced[Seq[A]]` can be unrolled to `Seq[Traced[A]]`.
*
* This is a more specific abstraction than [[cats.Traverse]].
* @author Tongfei Chen
*/
trait Unroll[S[_], F[_]] {
trait Unroll[F[_], S[_]] {

def unroll[X](ds: F[S[X]]): S[F[X]]
def traverse[A, B](fa: F[A])(f: A *=> S[B]): S[F[B]]

def unroll[A](fsa: F[S[A]]): S[F[A]]

}

object Unroll {

implicit def UnrollAnyId[S[_]]: Unroll[S, Id] = new Unroll[S, Id] {
def unroll[X](fsx: S[X]): S[X] = fsx
}

implicit object UnrollIndexedSeqTraced extends Unroll[IndexedSeq, Traced] {
def unroll[X](fsx: Traced[IndexedSeq[X]]) = {
IndexedSeq.tabulate(fsx.value.length)(i => ???)
}
implicit def unrollAnyId[S[_]]: Unroll[Id, S] = new Unroll[Id, S] {
def traverse[A, B](a: A)(f: A *=> S[B]) = f(a)
def unroll[A](sa: S[A]) = sa
}

}
1 change: 1 addition & 0 deletions diff/src/main/scala/nexus/diff/execution/Forward.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ trait Forward[F[_]] extends (F ~> Id) {
def values: BoxMap[F, Id]

def backward: Backward[F]
// TODO: def backward[G[_]](implicit G: Backward[G]): Backward[F, G]

}
4 changes: 2 additions & 2 deletions diff/src/main/scala/nexus/diff/io/Serializer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import nexus.diff._
*/
trait Serializer {

def save(module: HasParameters, f: String): Unit
def load(module: HasParameters, f: String): Unit
def save(map: Map[String, Param[_]], f: String): Unit
def load(f: String): Map[String, Param[_]]

}
8 changes: 3 additions & 5 deletions diff/src/main/scala/nexus/diff/modules/Affine.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,14 @@ import nexus.diff.util._
* @author Tongfei Chen
* @since 0.1.0
*/
class Affine[T[_], R, X <: Dim, Y <: Dim] private(
val W: Param[T[(Y, X)]],
val b: Param[T[Y]]
case class Affine[T[_], R, X <: Dim, Y <: Dim] private(
W: Param[T[(Y, X)]],
b: Param[T[Y]]
)
(implicit T: IsRealTensorK[T, R])
extends Module1[T[X], T[Y]]
{

def parameters = Set(W, b)

/** The linear transformation matrix of this layer. */
def weight = W

Expand Down
11 changes: 4 additions & 7 deletions diff/src/main/scala/nexus/diff/modules/Convolution1D.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,14 @@ import nexus.diff.util._
* @author Tongfei Chen
* @since 0.1.0
*/
class Convolution1D[T[_], R, W <: Dim, X <: Dim, Y <: Dim] private(
val window: Int,
val stride: Int,
val kernel: Param[T[(Y, X)]]
case class Convolution1D[T[_], R, W <: Dim, X <: Dim, Y <: Dim] private(
window: Int,
stride: Int,
kernel: Param[T[(Y, X)]]
)
(implicit T: IsRealTensorK[T, R])
extends Module1[T[(W, X)], T[(W, Y)]]
{

def parameters = Set(kernel)

def apply[F[_] : Algebra](x: F[T[(W, X)]]) = ???
}

Expand Down
11 changes: 6 additions & 5 deletions diff/src/main/scala/nexus/diff/modules/CosineSimilarity.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ import nexus.diff.syntax._
* @author Tongfei Chen
*/
object CosineSimilarity extends PolyModule2 {
implicit def cosineSimilarityF[T[_], R, I](implicit T: IsRealTensorK[T, R]): P[T[I], T[I], R] =
new P[T[I], T[I], R] {
def apply[F[_] : Algebra](x1: F[T[I]], x2: F[T[I]]) = Div(Dot(x1, x2), Mul(L2Norm(x1), L2Norm(x2)))
def parameters = Set()
}

case class CosineSimilarityF[T[_], R, I]()(implicit T: IsRealTensorK[T, R]) extends P[T[I], T[I], R] {
def apply[F[_] : Algebra](x1: F[T[I]], x2: F[T[I]]) = Div(Dot(x1, x2), Mul(L2Norm(x1), L2Norm(x2)))
}

implicit def cosineSimilarityF[T[_], R, I](implicit T: IsRealTensorK[T, R]): P[T[I], T[I], R] = CosineSimilarityF()
}
7 changes: 2 additions & 5 deletions diff/src/main/scala/nexus/diff/modules/Linear.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,14 @@ import nexus.diff.util._
* @author Tongfei Chen
* @since 0.1.0
*/
class Linear[T[_], R, X <: Dim, Y <: Dim] private(
val weight: Param[T[(Y, X)]]
)(implicit T: IsRealTensorK[T, R])
case class Linear[T[_], R, X <: Dim, Y <: Dim] private(weight: Param[T[(Y, X)]])(implicit T: IsRealTensorK[T, R])
extends Module1[T[X], T[Y]]
{

type Input = X

type Output = Y

def parameters = Set(weight)

def apply[F[_] : Algebra](x: F[T[X]]) = MVMul(weight.as, x)
}

Expand All @@ -47,4 +43,5 @@ object Linear {
from[T, R, X, Y](weight)
}


}
Loading

0 comments on commit ce4dc08

Please sign in to comment.