diff --git a/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/recurrent/MemorizeSequence.java b/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/recurrent/MemorizeSequence.java index b6977c064f..28d3e39cee 100644 --- a/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/recurrent/MemorizeSequence.java +++ b/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/recurrent/MemorizeSequence.java @@ -17,6 +17,10 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ + // MemorizeSequence.java + // A simple RNN example where the network learns to memorize and reproduce a short sequence. +// Demonstrates basic RNN training and backpropagation-through-time. + package org.deeplearning4j.examples.quickstart.modeling.recurrent; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; @@ -87,6 +91,8 @@ public static void main(String[] args) { hiddenLayerBuilder.activation(Activation.TANH); listBuilder.layer(i, hiddenLayerBuilder.build()); } + + // Build a simple RNN with one recurrent layer to memorize the sequence // we need to use RnnOutputLayer for our RNN RnnOutputLayer.Builder outputLayerBuilder = new RnnOutputLayer.Builder(LossFunction.MCXENT); @@ -96,7 +102,8 @@ public static void main(String[] args) { outputLayerBuilder.nIn(HIDDEN_LAYER_WIDTH); outputLayerBuilder.nOut(LEARNSTRING_CHARS.size()); listBuilder.layer(HIDDEN_LAYER_CONT, outputLayerBuilder.build()); - + + // create network MultiLayerConfiguration conf = listBuilder.build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -122,6 +129,9 @@ public static void main(String[] args) { labels.putScalar(new int[] { 0, LEARNSTRING_CHARS_LIST.indexOf(nextChar), samplePos }, 1); samplePos++; } + + // Train the RNN to output the same sequence it receives as input + DataSet trainingData = new DataSet(input, labels); // some epochs diff --git a/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/recurrent/README.md b/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/recurrent/README.md new file mode 100644 index 0000000000..3359615375 --- /dev/null +++ b/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/recurrent/README.md @@ -0,0 +1,67 @@ +# Recurrent Neural Network (RNN) Examples – DeepLearning4J + +This folder contains simple recurrent neural network (RNN) examples using LSTM and embedding layers. +These examples demonstrate how to work with sequential data such as signals, sequences, and text-like inputs. + +--- + +## πŸ” MemorizeSequence.java +A minimal RNN example where the network learns to memorize and reproduce a fixed sequence. + +### What this example teaches +- How RNNs store information across time steps +- How backpropagation-through-time works +- How sequence learning differs from feedforward networks + +### Expected Behavior +The network eventually outputs the same sequence it was trained on. + +--- + +## πŸ”‘ RNNEmbedding.java +Demonstrates the use of an **EmbeddingLayer** followed by RNN layers. + +### Key Concepts +- Turning integer-encoded inputs into dense vectors +- Word/token embedding +- Passing embedded sequences into RNN layers + +This is a useful template for NLP-style models. + +--- + +## πŸ“Š UCISequenceClassification.java +Sequence classification on a dataset from the UCI machine learning repository. + +### What this example shows +- Loading sequential datasets +- Recurrent classification (predict a label for a whole sequence) +- Time-series preprocessing and normalization + +--- + +## βœ” How to Run Any Example + +Use the following command: + +mvn -q exec:java -Dexec.mainClass="org.deeplearning4j.examples.quickstart.modeling.recurrent." + + +Example: + + + +mvn -q exec:java -Dexec.mainClass="org.deeplearning4j.examples.quickstart.modeling.recurrent.MemorizeSequence" + + +--- + +## πŸ™Œ Why This README Helps + +This folder previously had no documentation. +This README explains: +- What each RNN example does +- What concepts it teaches +- How to run each file + +This improves clarity for beginners working with sequential models in DL4J. \ No newline at end of file diff --git a/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/recurrent/RNNEmbedding.java b/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/recurrent/RNNEmbedding.java index 13ed7711fa..7b6e0615df 100644 --- a/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/recurrent/RNNEmbedding.java +++ b/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/recurrent/RNNEmbedding.java @@ -17,6 +17,11 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ + + // RNNEmbedding.java + // Demonstrates how to use an EmbeddingLayer + RNN layers for sequence modeling. +// Useful for NLP-style integer token inputs. + package org.deeplearning4j.examples.quickstart.modeling.recurrent; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; @@ -45,6 +50,9 @@ * * @author Alex Black */ + + // Convert integer token IDs into dense embedding vectors + public class RNNEmbedding { public static void main(String[] args) { @@ -64,6 +72,10 @@ public static void main(String[] args) { } } + + // Feed embedded vectors into an LSTM/RNN to capture sequence structure + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .activation(Activation.RELU) .list() diff --git a/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/recurrent/UCISequenceClassification.java b/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/recurrent/UCISequenceClassification.java index ec79e8c893..b5cb626b2d 100644 --- a/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/recurrent/UCISequenceClassification.java +++ b/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/recurrent/UCISequenceClassification.java @@ -17,6 +17,10 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ + + // UCISequenceClassification.java +// Demonstrates sequence classification using an RNN on UCI dataset sequences. + package org.deeplearning4j.examples.quickstart.modeling.recurrent; import org.apache.commons.io.FileUtils; diff --git a/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/variationalautoencoder/README.md b/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/variationalautoencoder/README.md new file mode 100644 index 0000000000..1c6ef385ce --- /dev/null +++ b/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/variationalautoencoder/README.md @@ -0,0 +1,58 @@ +# Variational Autoencoder (VAE) Examples – DeepLearning4J + +This folder contains two examples demonstrating how to use Variational Autoencoders (VAEs) in DeepLearning4J. +VAEs are generative models that learn latent representations of data and can be used for visualization, sampling, and anomaly detection. + +--- + +## 🧩 VaeMNIST2dPlots.java +Trains a VAE on the MNIST digit dataset and visualizes the **2-dimensional latent space**. + +### What this example shows +- How to build a VAE in DL4J +- Encoding MNIST images into a 2D latent space +- Plotting how digits cluster in latent space +- How VAEs learn smooth and continuous representations + +### Why it’s useful +A 2D latent space allows easy visualization of how the model separates digits. + +--- + +## ⚠️ VaeMNISTAnomaly.java +Uses a trained VAE for **anomaly detection** on MNIST. + +### Key Concepts +- VAEs reconstruct normal data well +- They reconstruct anomalies poorly +- Reconstruction error can be used as an anomaly score + +### What this example demonstrates +- Reconstructing MNIST digits +- Computing reconstruction probability +- Detecting out-of-distribution or corrupted samples + +--- + +## βœ” How to Run Any Example + +mvn -q exec:java -Dexec.mainClass="org.deeplearning4j.examples.quickstart.modeling.variationalautoencoder." + + +Example: + + + +mvn -q exec:java -Dexec.mainClass="org.deeplearning4j.examples.quickstart.modeling.variationalautoencoder.VaeMNIST2dPlots" + + +--- + +## πŸ™Œ Why This README Helps +The VAE folder previously had no documentation. +This README explains: +- Purpose of each example +- The ML concepts involved +- How to run and understand the results + +This improves clarity for beginners working with generative models in DL4J. \ No newline at end of file diff --git a/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/variationalautoencoder/VaeMNIST2dPlots.java b/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/variationalautoencoder/VaeMNIST2dPlots.java index 496b2df909..1344b50ce4 100644 --- a/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/variationalautoencoder/VaeMNIST2dPlots.java +++ b/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/variationalautoencoder/VaeMNIST2dPlots.java @@ -17,6 +17,10 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ + // VaeMNIST2dPlots.java + // Trains a Variational Autoencoder (VAE) on MNIST and visualizes the 2D latent space. +// Shows how VAEs compress images into smooth, continuous latent variables. + package org.deeplearning4j.examples.quickstart.modeling.variationalautoencoder; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; @@ -98,10 +102,13 @@ public static void main(String[] args) throws IOException { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); + // Build a VAE with a 2-dimensional latent space for visualization + //Get the variational autoencoder layer org.deeplearning4j.nn.layers.variational.VariationalAutoencoder vae = (org.deeplearning4j.nn.layers.variational.VariationalAutoencoder) net.getLayer(0); + //Test data for plotting DataSet testdata = new MnistDataSetIterator(10000, false, rngSeed).next(); @@ -121,12 +128,15 @@ public static void main(String[] args) throws IOException { // (b) collect the reconstructions at each point in the grid net.setListeners(new PlottingListener(100, testFeatures, latentSpaceGrid, latentSpaceVsEpoch, digitsGrid)); - //Perform training + //Perform training + // Train the VAE to encode and reconstruct MNIST digits for (int i = 0; i < nEpochs; i++) { log.info("Starting epoch {} of {}",(i+1),nEpochs); net.pretrain(trainIter); //Note use of .pretrain(DataSetIterator) not fit(DataSetIterator) for unsupervised training } + // Visualize how different digit classes cluster in the latent space + //plot by default if (visualize) { //Plot MNIST test set - latent space vs. iteration (every 100 minibatches by default) diff --git a/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/variationalautoencoder/VaeMNISTAnomaly.java b/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/variationalautoencoder/VaeMNISTAnomaly.java index c5b033d301..04b966001a 100644 --- a/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/variationalautoencoder/VaeMNISTAnomaly.java +++ b/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/variationalautoencoder/VaeMNISTAnomaly.java @@ -17,6 +17,10 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ + // VaeMNISTAnomaly.java + // Demonstrates anomaly detection using a Variational Autoencoder (VAE). +// Normal digits reconstruct well, anomalies reconstruct poorly. + package org.deeplearning4j.examples.quickstart.modeling.variationalautoencoder; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; @@ -94,6 +98,9 @@ public static void main(String[] args) throws IOException { .build()) .build(); + // Load trained VAE model and MNIST test data + + MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -151,6 +158,10 @@ public int compare(Pair o1, Pair o2) { Collections.sort(list, c); } + // Compute reconstruction probability for each test image + // Low probability indicates an anomaly + + //Select the 5 best and 5 worst numbers (by reconstruction probability) for each digit List best = new ArrayList<>(50); List worst = new ArrayList<>(50);