diff --git a/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/convolution/LeNetMNIST.java b/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/convolution/LeNetMNIST.java index 49e36ba073..3b0858d79a 100644 --- a/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/convolution/LeNetMNIST.java +++ b/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/convolution/LeNetMNIST.java @@ -17,6 +17,12 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ + // LeNetMNIST.java + // Classic LeNet-style CNN for MNIST digit classification. + // Expected accuracy: 98–99%. + // MNIST images: 28x28 grayscale (1 channel). + + package org.deeplearning4j.examples.quickstart.modeling.convolution; import org.apache.commons.io.FilenameUtils; @@ -64,6 +70,10 @@ public static void main(String[] args) throws Exception { */ log.info("Build model...."); + // Build LeNet CNN architecture: + // Conv → Pool → Conv → Pool → Dense → Output + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(seed) .l2(0.0005) @@ -116,7 +126,9 @@ row vector format (i.e., 1x784 vectors), hence the "convolutionalFlat" input typ MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); - + + // Evaluate the model at the end of each epoch using EvaluativeListener + log.info("Train model..."); model.setListeners(new ScoreIterationListener(10), new EvaluativeListener(mnistTest, 1, InvocationType.EPOCH_END)); //Print score every 10 iterations and evaluate on test set every epoch model.fit(mnistTrain, nEpochs); diff --git a/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/convolution/LeNetMNISTReLu.java b/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/convolution/LeNetMNISTReLu.java index 42990301ab..2667e3b0d8 100644 --- a/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/convolution/LeNetMNISTReLu.java +++ b/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/convolution/LeNetMNISTReLu.java @@ -17,6 +17,10 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ + // LeNetMNISTReLu.java + // Same as LeNetMNIST but uses ReLU activation instead of Tanh. + + package org.deeplearning4j.examples.quickstart.modeling.convolution; import org.datavec.api.io.labels.ParentPathLabelGenerator; diff --git a/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/convolution/README.md b/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/convolution/README.md new file mode 100644 index 0000000000..8db4a7631e --- /dev/null +++ b/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/convolution/README.md @@ -0,0 +1,57 @@ +# Convolutional Neural Network (CNN) Examples – DeepLearning4J + +This folder contains convolutional neural network (CNN) examples implemented using DeepLearning4J. + +## 🧠 LeNetMNIST.java +Trains a LeNet-style CNN on the MNIST handwritten digit dataset. + +### ✔ What this example demonstrates +- Loading the MNIST dataset +- Building a classic LeNet CNN architecture +- Training the network +- Evaluating accuracy + +### ✔ Expected Accuracy +**98%–99%** after 1–2 epochs. + +--- + +## 📦 How to Run + +Build the project: + +mvn clean package + +Run the example: + +mvn -q exec:java -Dexec.mainClass="org.deeplearning4j.examples.quickstart.modeling.convolution.LeNetMNIST" + +Run the ReLU variant: + +mvn -q exec:java -Dexec.mainClass="org.deeplearning4j.examples.quickstart.modeling.convolution.LeNetMNISTReLu" + +--- + +## 🧱 LeNet Architecture Breakdown +- Convolution layer (20 filters) +- Subsampling (2×2) +- Convolution layer (50 filters) +- Subsampling +- Dense layer +- Output layer (Softmax, 10 classes) + +--- + +## 📂 Other Files +| File | Description | +|------|-------------| +| **LeNetMNISTReLu.java** | LeNet variant with ReLU activation | +| **CenterLossLeNetMNIST.java** | LeNet with center loss | +| **CIFARClassifier.java** | CIFAR-10 image classifier | +| **Conv1DUCISequenceClassifier.java** | 1D CNN example for sequences | + +--- + +## 🙌 Why This Documentation Helps +These CNN examples previously had no explanation, run instructions, or architecture summary. +This README improves clarity for new users and first-time contributors.