Skip to content

Commit

Permalink
scala padding update
Browse files Browse the repository at this point in the history
  • Loading branch information
hkvision committed Aug 5, 2019
1 parent 77efcfa commit e4729f6
Showing 1 changed file with 9 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1356,11 +1356,13 @@ class PythonZooKeras[T: ClassTag](implicit ev: TensorNumeric[T]) extends PythonZ
epsilon, weightDecay)
}

def batchingWithPaddingStrategy(dataset: DataSet[JSample[T]], batchSize: Int)
def batchingWithPaddingStrategy(dataset: DataSet[JSample[T]], batchSize: Int, featureSize: Int)
: DataSet[MiniBatch[T]] = {
println("Using Feature Padding Strategy")
val paddingTensor = Tensor[T](1).fill(ev.fromType(-1.0))
val featurePaddingParam = PaddingParam[T](Some(Array.fill[Tensor[T]](13)(paddingTensor)))
val widePaddingTensor = Tensor[T](1).fill(ev.fromType(0.0))
val paddingArray = Array.fill[Tensor[T]](featureSize-1)(paddingTensor) ++ Array(widePaddingTensor)
val featurePaddingParam = PaddingParam[T](Some(paddingArray))
dataset.transform(SampleToMiniBatch(
batchSize = batchSize, featurePaddingParam = Some(featurePaddingParam)))
}
Expand All @@ -1372,10 +1374,11 @@ class PythonZooKeras[T: ClassTag](implicit ev: TensorNumeric[T]) extends PythonZ
endTrigger: Trigger,
batchSize: Int): Optimizer[T, MiniBatch[T]] = {
val sampleRDD = toJSample(trainingRdd)
val featureSize = sampleRDD.first().numFeature()

val optimizer = new InternalDistriOptimizer(
_model = model,
_dataset = batchingWithPaddingStrategy(DataSet.rdd(sampleRDD), batchSize)
_dataset = batchingWithPaddingStrategy(DataSet.rdd(sampleRDD), batchSize, featureSize)
.asInstanceOf[DistributedDataSet[MiniBatch[T]]],
_criterion = criterion
).asInstanceOf[Optimizer[T, MiniBatch[T]]]
Expand All @@ -1402,7 +1405,9 @@ class PythonZooKeras[T: ClassTag](implicit ev: TensorNumeric[T]) extends PythonZ
valRdd: JavaRDD[Sample],
vMethods: JList[ValidationMethod[T]]): Unit = {
val sampleRDD = toJSample(valRdd)
optimizer.setValidation(trigger, batchingWithPaddingStrategy(DataSet.rdd(sampleRDD), batchSize),
val featureSize = sampleRDD.first().numFeature()
optimizer.setValidation(trigger,
batchingWithPaddingStrategy(DataSet.rdd(sampleRDD), batchSize, featureSize),
vMethods.asScala.toArray)
}

Expand Down

0 comments on commit e4729f6

Please sign in to comment.