Skip to content

Commit

Permalink
tftraininghelper evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
hkvision committed Aug 16, 2019
1 parent 6c2ea24 commit 08befe2
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
10 changes: 9 additions & 1 deletion pyzoo/zoo/pipeline/api/net/tf_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from bigdl.nn.criterion import Criterion
from bigdl.nn.layer import Layer
from bigdl.util.common import to_list, JavaValue
from bigdl.util.common import to_list, JavaValue, callBigDlFunc
from bigdl.optim.optimizer import EveryEpoch, MaxEpoch, SeveralIteration
from zoo.pipeline.api.keras.engine.topology import to_bigdl_metric
from zoo.pipeline.api.keras.optimizers import DistriOptimizer
Expand Down Expand Up @@ -54,6 +54,12 @@ def __init__(self, path, configProto):
byte_arr = None
super(TFTrainingHelper, self).__init__(None, "float", path, byte_arr)

def evaluate(self, dataset, batch_size, val_methods):
return callBigDlFunc(self.bigdl_type,
"tfEvaluate",
self.value,
dataset, batch_size, val_methods)


class TFOptimizer:
def __init__(self, loss, optim_method, sess=None, dataset=None, inputs=None,
Expand Down Expand Up @@ -182,10 +188,12 @@ def to_floats(vs):

if val_outputs is not None and val_labels is not None:
val_rdd = self.dataset.get_validation_data()
self.val_rdd = val_rdd
if val_rdd is not None:
val_method = [TFValidationMethod(m, len(val_outputs), len(val_labels))
for m in to_list(val_method)]
training_rdd = sample_rdd
self.val_method = val_method

elif val_split != 0.0:
training_rdd, val_rdd = sample_rdd.randomSplit([1 - val_split, val_split])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1414,4 +1414,23 @@ class PythonZooKeras[T: ClassTag](implicit ev: TensorNumeric[T]) extends PythonZ
def createEpochStep(stepSize: Int, gamma: Double): SGD.EpochStep = {
SGD.EpochStep(stepSize, gamma)
}


def tfEvaluate(model: AbstractModule[Activity, Activity, T],
valRDD: JavaRDD[Sample],
batchSize: Int,
valMethods: JList[ValidationMethod[T]])
: JList[EvaluatedResult] = {
val sampleRDD = toJSample(valRDD)
val featureSize = sampleRDD.first().numFeature()
val dataSet = batchingWithPaddingStrategy(DataSet.rdd(sampleRDD), batchSize, featureSize)
val rdd = dataSet.toDistributed().data(train = false)
val resultArray = model.evaluate(rdd,
valMethods.asScala.toArray)
val testResultArray = resultArray.map { result =>
EvaluatedResult(result._1.result()._1, result._1.result()._2,
result._2.toString())
}
testResultArray.toList.asJava
}
}

0 comments on commit 08befe2

Please sign in to comment.