-
Notifications
You must be signed in to change notification settings - Fork 95
/
LinearRegression.scala
73 lines (63 loc) · 2.6 KB
/
LinearRegression.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
/* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package org.platanios.tensorflow.examples
import org.platanios.tensorflow.api._
import com.typesafe.scalalogging.Logger
import org.slf4j.LoggerFactory
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
/**
* @author Emmanouil Antonios Platanios
*/
object LinearRegression {
private val logger = Logger(LoggerFactory.getLogger("Examples / Linear Regression"))
private val random = new Random()
private val weight = random.nextFloat()
def main(args: Array[String]): Unit = {
logger.info("Building linear regression model.")
val inputs = tf.placeholder[Float](Shape(-1, 1))
val outputs = tf.placeholder[Float](Shape(-1, 1))
val weights = tf.variable[Float]("weights", Shape(1, 1), tf.ZerosInitializer)
val predictions = tf.matmul(inputs, weights)
val loss = tf.sum(tf.square(tf.subtract(predictions, outputs)))
val trainOp = tf.train.AdaGrad(1.0f).minimize(loss)
logger.info("Training the linear regression model.")
val session = Session()
session.run(targets = tf.globalVariablesInitializer())
for (i <- 0 to 50) {
val trainBatch = batch(10000)
val feeds = Map(inputs -> trainBatch._1, outputs -> trainBatch._2)
val trainLoss = session.run(feeds = feeds, fetches = loss, targets = trainOp)
if (i % 1 == 0)
logger.info(s"Train loss at iteration $i = ${trainLoss.scalar} " +
s"(weight = ${session.run(fetches = weights.value).scalar})")
}
logger.info(s"Trained weight value: ${session.run(fetches = weights.value).scalar}")
logger.info(s"True weight value: $weight")
}
def batch(batchSize: Int): (Tensor[Float], Tensor[Float]) = {
val inputs = ArrayBuffer.empty[Float]
val outputs = ArrayBuffer.empty[Float]
var i = 0
while (i < batchSize) {
val input = random.nextFloat()
inputs += input
outputs += weight * input
i += 1
}
(Tensor[Float](inputs.toSeq).reshape(Shape(-1, 1)),
Tensor[Float](outputs.toSeq).reshape(Shape(-1, 1)))
}
}