Skip to content

Commit

Permalink
Merge pull request #661 from RobAltena/master
Browse files Browse the repository at this point in the history
Fixes #659 cleans up all compiler warnings.
  • Loading branch information
Adam Gibson committed Apr 17, 2018
2 parents ecf2f2e + e5c5956 commit 70975a5
Showing 1 changed file with 22 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.net.URISyntaxException;
import java.net.URL;
import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -85,8 +84,8 @@ public class NeuralStyleTransfer {
* https://harishnarayanan.org/writing/artistic-style-transfer/
* Other Values(5,100): http://www.chioka.in/tensorflow-implementation-neural-algorithm-of-artistic-style
*/
public static final double ALPHA = 0.025;
public static final double BETA = 5.0;
private static final double ALPHA = 0.025;
private static final double BETA = 5.0;

private static final double LEARNING_RATE = 2;
private static final double NOISE_RATION = 0.1;
Expand All @@ -97,43 +96,36 @@ public class NeuralStyleTransfer {
private static final int SAVE_IMAGE_CHECKPOINT = 5;
private static final String OUTPUT_PATH = "/styletransfer/out/";

public static final int HEIGHT = 224;
public static final int WIDTH = 224;
public static final int CHANNELS = 3;
private static final int HEIGHT = 224;
private static final int WIDTH = 224;
private static final int CHANNELS = 3;
private static final DataNormalization IMAGE_PRE_PROCESSOR = new VGG16ImagePreProcessor();
private static final NativeImageLoader LOADER = new NativeImageLoader(HEIGHT, WIDTH, CHANNELS);


public static void main(String[] args) throws IOException, URISyntaxException {
public static void main(String[] args) throws IOException {
new NeuralStyleTransfer().transferStyle();
}

public void transferStyle() throws IOException, URISyntaxException {
private void transferStyle() throws IOException {

ComputationGraph vgg16FineTune = loadModel();

INDArray content = loadImage(CONTENT_FILE);

INDArray style = loadImage(STYLE_FILE);

INDArray combination = createCombinationImage();

Map<String, INDArray> activationsContentMap = vgg16FineTune.feedForward(content, true);

Map<String, INDArray> activationsStyleMap = vgg16FineTune.feedForward(style, true);
HashMap<String, INDArray> activationsStyleGramMap = buildStyleGramValues(activationsStyleMap);

AdamUpdater adamUpdater = createADAMUpdater();
for (int iteration = 0; iteration < ITERATIONS; iteration++) {
log.info("iteration " + iteration);
Map<String, INDArray> activationsCombMap = vgg16FineTune.feedForward(combination, true);

INDArray styleBackProb = backPropagateStyles(vgg16FineTune, activationsStyleGramMap, activationsCombMap);
INDArray[] input = new INDArray[] { combination };
Map<String, INDArray> activationsCombMap = vgg16FineTune.feedForward(input, true, false);

INDArray styleBackProb = backPropagateStyles(vgg16FineTune, activationsStyleGramMap, activationsCombMap);
INDArray backPropContent = backPropagateContent(vgg16FineTune, activationsContentMap, activationsCombMap);

INDArray backPropAllValues = backPropContent.muli(ALPHA).addi(styleBackProb.muli(BETA));

adamUpdater.applyUpdater(backPropAllValues, iteration, 0);
combination.subi(backPropAllValues);

Expand All @@ -147,7 +139,7 @@ public void transferStyle() throws IOException, URISyntaxException {
}

private INDArray backPropagateStyles(ComputationGraph vgg16FineTune, HashMap<String, INDArray> activationsStyleGramMap, Map<String, INDArray> activationsCombMap) {
INDArray styleBackProb = Nd4j.zeros(new int[]{1, CHANNELS, HEIGHT, WIDTH});
INDArray styleBackProb = Nd4j.zeros(1, CHANNELS, HEIGHT, WIDTH);
for (String styleLayer : STYLE_LAYERS) {
String[] split = styleLayer.split(",");
String styleLayerName = split[0];
Expand Down Expand Up @@ -199,11 +191,8 @@ private INDArray loadImage(String contentFile) throws IOException {
return content;
}

/**
/*
* Since style activation are not changing we are saving some computation by calculating style grams only once
*
* @param activationsStyle
* @return
*/
private HashMap<String, INDArray> buildStyleGramValues(Map<String, INDArray> activationsStyle) {
HashMap<String, INDArray> styleGramValuesMap = new HashMap<>();
Expand All @@ -227,7 +216,7 @@ private int findLayerIndex(String styleLayerName) {
return index;
}

public double totalLoss(Map<String, INDArray> activationsStyleMap, Map<String, INDArray> activationsCombMap, Map<String, INDArray> activationsContentMap) {
private double totalLoss(Map<String, INDArray> activationsStyleMap, Map<String, INDArray> activationsCombMap, Map<String, INDArray> activationsContentMap) {
Double stylesLoss = allStyleLayersLoss(activationsStyleMap, activationsCombMap);
return ALPHA * contentLoss(activationsCombMap.get(CONTENT_LAYER_NAME).dup(), activationsContentMap.get(CONTENT_LAYER_NAME).dup()) + BETA * stylesLoss;
}
Expand All @@ -253,7 +242,7 @@ private Double allStyleLayersLoss(Map<String, INDArray> activationsStyleMap, Map
* @return Weighted content loss component
*/

public double contentLoss(INDArray combActivations, INDArray contentActivations) {
private double contentLoss(INDArray combActivations, INDArray contentActivations) {
return sumOfSquaredErrors(contentActivations, combActivations) / (4.0 * (CHANNELS) * (WIDTH) * (HEIGHT));
}

Expand All @@ -268,7 +257,7 @@ public double contentLoss(INDArray combActivations, INDArray contentActivations)
* @param combination Activations from intermediate layer of CNN for combination image input
* @return Loss contribution from this comparison
*/
public double styleLoss(INDArray style, INDArray combination) {
private double styleLoss(INDArray style, INDArray combination) {
INDArray s = gramMatrix(style);
INDArray c = gramMatrix(combination);
int[] shape = style.shape();
Expand Down Expand Up @@ -296,7 +285,7 @@ private INDArray backPropagate(ComputationGraph vgg16FineTune, INDArray dLdANext
* @param b Another tensor
* @return Sum of squared errors: scalar
*/
public double sumOfSquaredErrors(INDArray a, INDArray b) {
private double sumOfSquaredErrors(INDArray a, INDArray b) {
INDArray diff = a.sub(b); // difference
INDArray squares = Transforms.pow(diff, 2); // element-wise squaring
return squares.sumNumber().doubleValue();
Expand All @@ -311,7 +300,7 @@ public double sumOfSquaredErrors(INDArray a, INDArray b) {
* @param combActivations Features at same layer from current combo image
* @return Derivatives of content loss w.r.t. combo features
*/
public INDArray derivativeLossContentInLayer(INDArray contentActivations, INDArray combActivations) {
private INDArray derivativeLossContentInLayer(INDArray contentActivations, INDArray combActivations) {

combActivations = combActivations.dup();
contentActivations = contentActivations.dup();
Expand All @@ -338,10 +327,9 @@ public INDArray derivativeLossContentInLayer(INDArray contentActivations, INDArr
* @param x Tensor to get Gram matrix of
* @return Resulting Gram matrix
*/
public INDArray gramMatrix(INDArray x) {
private INDArray gramMatrix(INDArray x) {
INDArray flattened = flatten(x);
INDArray gram = flattened.mmul(flattened.transpose());
return gram;
return flattened.mmul(flattened.transpose());
}

private INDArray flatten(INDArray x) {
Expand All @@ -359,7 +347,7 @@ private INDArray flatten(INDArray x) {
* @param comboFeatures Intermediate activations of one layer for combo image input
* @return Derivative of style error matrix for the layer w.r.t. combo image
*/
public INDArray derivativeLossStyleInLayer(INDArray styleGramFeatures, INDArray comboFeatures) {
private INDArray derivativeLossStyleInLayer(INDArray styleGramFeatures, INDArray comboFeatures) {

comboFeatures = comboFeatures.dup();
double N = comboFeatures.shape()[0];
Expand Down Expand Up @@ -393,14 +381,13 @@ private ComputationGraph loadModel() throws IOException {
return vgg16;
}

private BufferedImage saveImage(INDArray combination, int iteration) throws IOException {
private void saveImage(INDArray combination, int iteration) throws IOException {
IMAGE_PRE_PROCESSOR.revertFeatures(combination);

BufferedImage output = imageFromINDArray(combination);
URL resource = getClass().getResource(OUTPUT_PATH);
File file = new File(resource.getPath() + "/iteration" + iteration + ".jpg");
ImageIO.write(output, "jpg", file);
return output;
}

/**
Expand All @@ -412,7 +399,7 @@ private BufferedImage saveImage(INDArray combination, int iteration) throws IOEx
* @param array INDArray containing an image
* @return BufferedImage
*/
public BufferedImage imageFromINDArray(INDArray array) {
private BufferedImage imageFromINDArray(INDArray array) {
int[] shape = array.shape();

int height = shape[2];
Expand Down

0 comments on commit 70975a5

Please sign in to comment.