Skip to content

Commit

Permalink
Add Spark VAE reconstruction error (for VAE + loss function models)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed Jan 16, 2017
1 parent e684335 commit c12e989
Show file tree
Hide file tree
Showing 7 changed files with 284 additions and 98 deletions.
@@ -1,53 +1,18 @@
/*
*
* * Copyright 2016 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/

package org.deeplearning4j.spark.impl.common.score;

import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.spark.impl.multilayer.scoring.ScoreExamplesFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;

/**
* Function to calculate the reconstruction probability for a variational autoencoder.<br>
* Function to calculate the scores (reconstruction probability or log probability) for a variational autoencoder.<br>
* Note that scoring is batched for computational efficiency.<br>
*
* @param <K> Type of key, associated with each example. Used to keep track of which score belongs to which example
* @author Alex Black
*/
public abstract class BaseVaeReconstructionProbWithKeyFunction<K> implements PairFlatMapFunction<Iterator<Tuple2<K, INDArray>>, K, Double> {
public abstract class BaseVaeReconstructionProbWithKeyFunction<K> extends BaseVaeScoreWithKeyFunction<K> {

protected static Logger log = LoggerFactory.getLogger(BaseVaeReconstructionProbWithKeyFunction.class);

protected final Broadcast<INDArray> params;
protected final Broadcast<String> jsonConfig;
private final int batchSize;
private final boolean useLogProbability;
private final int numSamples;

Expand All @@ -59,69 +24,18 @@ public abstract class BaseVaeReconstructionProbWithKeyFunction<K> implements Pai
* @param numSamples Number of samples to use when calling {@link VariationalAutoencoder#reconstructionLogProbability(INDArray, int)}
*/
public BaseVaeReconstructionProbWithKeyFunction(Broadcast<INDArray> params, Broadcast<String> jsonConfig, boolean useLogProbability,
int batchSize, int numSamples) {
this.params = params;
this.jsonConfig = jsonConfig;
int batchSize, int numSamples){
super(params, jsonConfig, batchSize);
this.useLogProbability = useLogProbability;
this.batchSize = batchSize;
this.numSamples = numSamples;
}

public abstract VariationalAutoencoder getVaeLayer();


@Override
public Iterable<Tuple2<K, Double>> call(Iterator<Tuple2<K, INDArray>> iterator) throws Exception {
if (!iterator.hasNext()) {
return Collections.emptyList();
}

VariationalAutoencoder vae = getVaeLayer();

List<Tuple2<K, Double>> ret = new ArrayList<>();

List<INDArray> collect = new ArrayList<>(batchSize);
List<K> collectKey = new ArrayList<>(batchSize);
int totalCount = 0;
while (iterator.hasNext()) {
collect.clear();
collectKey.clear();
int nExamples = 0;
while (iterator.hasNext() && nExamples < batchSize) {
Tuple2<K, INDArray> t2 = iterator.next();
INDArray features = t2._2();
int n = features.size(0);
if (n != 1) throw new IllegalStateException("Cannot score examples with one key per data set if "
+ "data set contains more than 1 example (numExamples: " + n + ")");
collect.add(features);
collectKey.add(t2._1());
nExamples += n;
}
totalCount += nExamples;

INDArray toScore = Nd4j.vstack(collect);

INDArray scores;
if(useLogProbability){
scores = vae.reconstructionLogProbability(toScore, numSamples);
} else {
scores = vae.reconstructionProbability(toScore, numSamples);
}

double[] doubleScores = scores.data().asDouble();

for (int i = 0; i < doubleScores.length; i++) {
ret.add(new Tuple2<>(collectKey.get(i), doubleScores[i]));
}
public INDArray computeScore(VariationalAutoencoder vae, INDArray toScore) {
if(useLogProbability){
return vae.reconstructionLogProbability(toScore, numSamples);
} else {
return vae.reconstructionProbability(toScore, numSamples);
}

if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();

if (log.isDebugEnabled()) {
log.debug("Scored {} examples ", totalCount);
}

return ret;
}
}
@@ -0,0 +1,116 @@
/*
*
* * Copyright 2016 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/

package org.deeplearning4j.spark.impl.common.score;

import lombok.extern.slf4j.Slf4j;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;

/**
* Function to calculate the scores (reconstruction probability, reconstruction error) for a variational autoencoder.<br>
* Note that scoring is batched for computational efficiency.<br>
*
* @param <K> Type of key, associated with each example. Used to keep track of which score belongs to which example
* @author Alex Black
*/
@Slf4j
public abstract class BaseVaeScoreWithKeyFunction<K> implements PairFlatMapFunction<Iterator<Tuple2<K, INDArray>>, K, Double> {

protected final Broadcast<INDArray> params;
protected final Broadcast<String> jsonConfig;
private final int batchSize;


/**
* @param params MultiLayerNetwork parameters
* @param jsonConfig MultiLayerConfiguration, as json
* @param batchSize Batch size to use when scoring
*/
public BaseVaeScoreWithKeyFunction(Broadcast<INDArray> params, Broadcast<String> jsonConfig, int batchSize) {
this.params = params;
this.jsonConfig = jsonConfig;
this.batchSize = batchSize;
}

public abstract VariationalAutoencoder getVaeLayer();

public abstract INDArray computeScore(VariationalAutoencoder vae, INDArray toScore);


@Override
public Iterable<Tuple2<K, Double>> call(Iterator<Tuple2<K, INDArray>> iterator) throws Exception {
if (!iterator.hasNext()) {
return Collections.emptyList();
}

VariationalAutoencoder vae = getVaeLayer();

List<Tuple2<K, Double>> ret = new ArrayList<>();

List<INDArray> collect = new ArrayList<>(batchSize);
List<K> collectKey = new ArrayList<>(batchSize);
int totalCount = 0;
while (iterator.hasNext()) {
collect.clear();
collectKey.clear();
int nExamples = 0;
while (iterator.hasNext() && nExamples < batchSize) {
Tuple2<K, INDArray> t2 = iterator.next();
INDArray features = t2._2();
int n = features.size(0);
if (n != 1) throw new IllegalStateException("Cannot score examples with one key per data set if "
+ "data set contains more than 1 example (numExamples: " + n + ")");
collect.add(features);
collectKey.add(t2._1());
nExamples += n;
}
totalCount += nExamples;

INDArray toScore = Nd4j.vstack(collect);
INDArray scores = computeScore(vae, toScore);

double[] doubleScores = scores.data().asDouble();

for (int i = 0; i < doubleScores.length; i++) {
ret.add(new Tuple2<>(collectKey.get(i), doubleScores[i]));
}
}

if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();

if (log.isDebugEnabled()) {
log.debug("Scored {} examples ", totalCount);
}

return ret;
}
}
@@ -0,0 +1,54 @@
package org.deeplearning4j.spark.impl.graph.scoring;

import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunction;
import org.deeplearning4j.spark.impl.common.score.BaseVaeScoreWithKeyFunction;
import org.nd4j.linalg.api.ndarray.INDArray;

/**
* Function to calculate the reconstruction error for a variational autoencoder, that is the first layer in a
* ComputationGraph.<br>
* Note that the VAE must be using a loss function, not a {@link org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution}<br>
* Also note that scoring is batched for computational efficiency.<br>
*
* @author Alex Black
* @see CGVaeReconstructionProbWithKeyFunction
*/
public class CGVaeReconstructionErrorWithKeyFunction<K> extends BaseVaeScoreWithKeyFunction<K> {


/**
* @param params MultiLayerNetwork parameters
* @param jsonConfig MultiLayerConfiguration, as json
* @param batchSize Batch size to use when scoring
*/
public CGVaeReconstructionErrorWithKeyFunction(Broadcast<INDArray> params, Broadcast<String> jsonConfig, int batchSize) {
super(params, jsonConfig, batchSize);
}

@Override
public VariationalAutoencoder getVaeLayer() {
ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson((String)jsonConfig.getValue()));
network.init();
INDArray val = ((INDArray)params.value()).unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException("Network did not have same number of parameters as the broadcasted set parameters");
network.setParams(val);

Layer l = network.getLayer(0);
if (!(l instanceof VariationalAutoencoder)) {
throw new RuntimeException("Cannot use CGVaeReconstructionErrorWithKeyFunction on network that doesn't have a VAE "
+ "layer as layer 0. Layer type: " + l.getClass());
}
return (VariationalAutoencoder)l;
}

@Override
public INDArray computeScore(VariationalAutoencoder vae, INDArray toScore) {
return vae.reconstructionError(toScore);
}
}
Expand Up @@ -3,11 +3,10 @@
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunction;
import org.deeplearning4j.spark.impl.common.score.BaseVaeScoreWithKeyFunction;
import org.nd4j.linalg.api.ndarray.INDArray;

/**
Expand Down
@@ -0,0 +1,53 @@
package org.deeplearning4j.spark.impl.multilayer.scoring;

import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunction;
import org.deeplearning4j.spark.impl.common.score.BaseVaeScoreWithKeyFunction;
import org.nd4j.linalg.api.ndarray.INDArray;

/**
* Function to calculate the reconstruction error for a variational autoencoder, that is the first layer in a
* MultiLayerNetwork.<br>
* Note that the VAE must be using a loss function, not a {@link org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution}<br>
* Also note that scoring is batched for computational efficiency.<br>
*
* @author Alex Black
* @see VaeReconstructionProbWithKeyFunction
*/
public class VaeReconstructionErrorWithKeyFunction<K> extends BaseVaeScoreWithKeyFunction<K> {

/**
* @param params MultiLayerNetwork parameters
* @param jsonConfig MultiLayerConfiguration, as json
* @param batchSize Batch size to use when scoring
*/
public VaeReconstructionErrorWithKeyFunction(Broadcast<INDArray> params, Broadcast<String> jsonConfig, int batchSize) {
super(params, jsonConfig, batchSize);
}

@Override
public VariationalAutoencoder getVaeLayer() {
MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String)jsonConfig.getValue()));
network.init();
INDArray val = ((INDArray)params.value()).unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException("Network did not have same number of parameters as the broadcasted set parameters");
network.setParameters(val);

Layer l = network.getLayer(0);
if (!(l instanceof VariationalAutoencoder)) {
throw new RuntimeException("Cannot use VaeReconstructionErrorWithKeyFunction on network that doesn't have a VAE "
+ "layer as layer 0. Layer type: " + l.getClass());
}
return (VariationalAutoencoder)l;
}

@Override
public INDArray computeScore(VariationalAutoencoder vae, INDArray toScore) {
return vae.reconstructionError(toScore);
}
}
Expand Up @@ -6,6 +6,7 @@
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunction;
import org.deeplearning4j.spark.impl.common.score.BaseVaeScoreWithKeyFunction;
import org.nd4j.linalg.api.ndarray.INDArray;

/**
Expand Down

0 comments on commit c12e989

Please sign in to comment.