Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does this really have auto diff? #1

Closed
doofin opened this issue Dec 29, 2018 · 6 comments
Closed

Does this really have auto diff? #1

doofin opened this issue Dec 29, 2018 · 6 comments

Comments

@doofin
Copy link

doofin commented Dec 29, 2018

I saw that for

//============================================
// Loss functions
case class SoftmaxLoss(actual: Variable, target: Variable) extends Function {
  val x: Tensor = actual.data
  val y: Tensor = target.data.T

  val shiftedLogits: Tensor = x - ns.max(x, axis = 1)
  val z: Tensor = ns.sum(ns.exp(shiftedLogits), axis = 1)
  val logProbs: Tensor = shiftedLogits - ns.log(z)
  val n: Int = x.shape.head
  val loss: Double = -ns.sum(logProbs(ns.arange(n), y)) / n

  override def forward(): Variable = Variable(Tensor(loss), Some(this))

  override def backward(gradOutput: Variable /* not used */ ): Unit = {
    val dx = ns.exp(logProbs)
    dx(ns.arange(n), y) -= 1
    dx /= n

    actual.backward(Variable(dx))
  }
}

the backward is manually implemented,which should have been automatically derived

@koen-dejonghe
Copy link
Member

There are a couple reasons for this:

  • not all functions on Tensors are implemented on Variables.
  • the backward pass is usually faster when you write the derivation explicitly

For some functions I have implemented both, because I find autodiff more elegant than explicitly writing it. For example: Dropout.
https://github.com/botkop/scorch/blob/master/src/main/scala/scorch/nn/Dropout.scala

@doofin
Copy link
Author

doofin commented Dec 29, 2018

I tried a simple linear regression ,the loss printed out never changes.Do I have to implement the backward myself?

  def lrTest = {
    val nf1        = 2
    val nf2        = 2
    val numClasses = 2

    val fc1 = Linear(1, 1) // an affine operation: y = Wx + b
    val f: Tensor => Tensor = { t: Tensor =>
      t * 3 + ns.array(1d)
    }

    val optimizer = Adam(Seq(fc1) flatMap (_.parameters), lr = 0.0001)

    (10 to 10) foreach { i =>
      (1 to 10000) foreach { _ =>
        val xx    = ns.array(i)
        val ypred = fc1.forward(Variable(xx))
        val y     = f(xx)
        val loss = Variable(ns.mean(ns.square(ypred.data - y)))
        println(s"los  : ${loss.data}")
        println(s" y : $y ,, y pred : ${ypred.data}")
        optimizer.zeroGrad()
        loss.backward()
        optimizer.step()
      }
    }

  }

@koen-dejonghe
Copy link
Member

koen-dejonghe commented Dec 29, 2018

This is more or less what you want, I think:

    val fc1 = Linear(1, 1) // an affine operation: y = Wx + b
    val f: Tensor => Tensor = { t: Tensor =>
      t * 3 + 1.0
    }

    val optimizer = SGD(fc1.parameters, lr = 0.01)

    (1 to 10) foreach { i =>
      val x = ns.array(i)
      val vx = Variable(x)
      val y = Variable(f(x))

      val ypred = fc1(vx)

      val loss = scorch.mean((ypred - y) ** 2)

      println(s"loss: ${loss.data}, x: $x, y: ${y.data}, ypred: ${ypred.data}")

      optimizer.zeroGrad()
      loss.backward()
      optimizer.step()
}

@koen-dejonghe
Copy link
Member

koen-dejonghe commented Dec 29, 2018

or simply:

    val fc1 = Linear(1, 1) // an affine operation: y = Wx + b
    def f(v: Variable): Variable = v * 3 + 1
    val optimizer = SGD(fc1.parameters, lr = 0.01)

    (1 to 10) foreach { i =>
      val x = Variable(i)
      val y = f(x)
      val ypred = fc1(x)
      val loss = scorch.mean((ypred - y) ** 2)
      println(s"loss: ${loss.data}, x: ${x.data}, y: ${y.data}, ypred: ${ypred.data}")
      optimizer.zeroGrad()
      loss.backward()
      optimizer.step()
    }

If you want to backpropagate through a function, then you must use the functions defined in scorch, or write your own. In the latter case, you will have to define the backward method yourself.

@doofin
Copy link
Author

doofin commented Dec 31, 2018

thanks,it works!
Is auto diff in this project implemented with tape based mechanism?

@koen-dejonghe
Copy link
Member

koen-dejonghe commented Dec 31, 2018

Implicitly, I think so, yes. The README has a detailed description of how auto diff works. It's all about Variables and Functions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants