Skip to content

Latest commit

 

History

History
335 lines (261 loc) · 14.5 KB

examples.rst

File metadata and controls

335 lines (261 loc) · 14.5 KB

Code Examples

This section describes the code examples found in objax/examples.

Classification

Image

Example code available at examples/image_classification.

Logistic Regression

Train and evaluate a logistic regression model for binary classification on horses or humans dataset.

# Run command
python3 examples/image_classification/horses_or_humans_logistic.py
Code examples/image_classification/horses_or_humans_logistic.py
Data horses_or_humans from tensorflow_datasets
Network Custom single layer
Loss :py:func:`objax.functional.loss.sigmoid_cross_entropy_logits`
Optimizer :py:class:`objax.optimizer.SGD`
Accuracy ~77%
Hardware CPU or GPU or TPU

Digit Classification with Deep Neural Network (DNN)

Train and evaluate a DNNet model for multiclass classification on the MNIST dataset.

# Run command
python3 examples/image_classification/mnist_dnn.py
Code examples/image_classification/mnist_dnn.py
Data MNIST from tensorflow_datasets
Network Deep Neural Net :py:class:`objax.zoo.dnnet.DNNet`
Loss :py:func:`objax.functional.loss.cross_entropy_logits`
Optimizer :py:class:`objax.optimizer.Adam`
Accuracy ~98%
Hardware CPU or GPU or TPU
Techniques Model weight averaging for improved accuracy using :py:class:`objax.optimizer.ExponentialMovingAverage`.

Digit Classification with Convolutional Neural Network (CNN)

Train and evaluate a simple custom CNN model for multiclass classification on the MNIST dataset.

# Run command
python3 examples/image_classification/mnist_cnn.py
Code examples/image_classification/mnist_cnn.py
Data MNIST from tensorflow_datasets
Network Custom Convolution Neural Net using :py:class:`objax.nn.Sequential`
Loss :py:func:`objax.functional.loss.cross_entropy_logits_sparse`
Optimizer :py:class:`objax.optimizer.Adam`
Accuracy ~99.5%
Hardware CPU or GPU or TPU
Techniques

Digit Classification using Differential Privacy

Train and evaluate a convNet model for MNIST dataset with differential privacy.

# Run command
python3 examples/image_classification/mnist_dp.py
# See available options with
python3 examples/image_classification/mnist_dp.py --help
Code examples/image_classification/mnist_dp.py
Data MNIST from tensorflow_datasets
Network Custom Convolution Neural Net using :py:class:`objax.nn.Sequential`
Loss :py:func:`objax.functional.loss.cross_entropy_logits`
Optimizer :py:class:`objax.optimizer.SGD`
Accuracy  
Hardware GPU
Techniques

Image Classification on CIFAR-10 (Simple)

Train and evaluate a wide resnet model for multiclass classification on the CIFAR10 dataset.

# Run command
python3 examples/image_classification/cifar10_simple.py
Code examples/image_classification/cifar10_simple.py
Data CIFAR10 from tf.keras.datasets
Network Wide ResNet using :py:class:`objax.zoo.wide_resnet.WideResNet`
Loss :py:func:`objax.functional.loss.cross_entropy_logits_sparse`
Optimizer :py:class:`objax.optimizer.Momentum`
Accuracy ~91%
Hardware GPU or TPU
Techniques
  • Learning rate schedule.
  • Data augmentation (mirror / pixel shifts) in Numpy.
  • Regularization using extra weight decay term in loss.

Image Classification on CIFAR-10 (Advanced)

Train and evaluate convNet models for multiclass classification on the CIFAR10 dataset.

# Run command
python3 examples/image_classification/cifar10_advanced.py
# Run with custom settings
python3 examples/image_classification/cifar10_advanced.py --weight_decay=0.0001 --batch=64 --lr=0.03 --epochs=256
# See available options with
python3 examples/image_classification/cifar10_advanced.py --help
Code examples/image_classification/cifar10_advanced.py
Data CIFAR10 from tensorflow_datasets
Network Configurable with --arch="network" * wrn28-1, wrn28-2 using :py:class:`objax.zoo.wide_resnet.WideResNet` * cnn32-3-max, cnn32-3-mean, cnn64-3-max, cnn64-3-mean using :py:class:`objax.zoo.convnet.ConvNet`
Loss :py:func:`objax.functional.loss.cross_entropy_logits`
Optimizer :py:class:`objax.optimizer.Momentum`
Accuracy ~94%
Hardware GPU, Multi-GPU or TPU
Techniques

Image Classification on ImageNet

Train and evaluate a ResNet50 model on the ImageNet dataset. See README for additional information.

Code examples/image_classification/imagenet_resnet50_train.py
Data ImageNet from tensorflow_datasets
Network ResNet50
Loss :py:func:`objax.functional.loss.cross_entropy_logits_sparse`
Optimizer :py:class:`objax.optimizer.Momentum`
Accuracy  
Hardware GPU, Multi-GPU or TPU
Techniques

Image Classification using Pretrained VGG Network

Image classification using an ImageNet-pretrained VGG19 model. See README for additional information.

Code examples/image_classification/imagenet_pretrained_vgg.py
Techniques Load VGG-19 model with pretrained weights and run 1000-way image classification.

Semi-Supervised Learning

Example code available at examples/fixmatch.

Semi-Supervised Learning with FixMatch

Semi-supervised learning of image classification models with FixMatch.

# Run command
python3 examples/fixmatch/fixmatch.py
# Run with custom settings
python3 examples/fixmatch/fixmatch.py --dataset=cifar10.3@1000-0
# See available options with
python3 examples/fixmatch/fixmatch.py --help
Code examples/fixmatch/fixmatch.py
Data CIFAR10, CIFAR100, SVHN, STL10
Network Custom implementation of Wide ResNet.
Loss :py:func:`objax.functional.loss.cross_entropy_logits` and :py:func:`objax.functional.loss.cross_entropy_logits_sparse`
Optimizer :py:class:`objax.optimizer.Momentum`
Accuracy See paper
Hardware GPU, Multi-GPU, TPU
Techniques

GPT-2

Example code is available at examples/gpt-2.

Generating a Text Sequence using GPT-2

Load pretrained GPT-2 model (124M parameter) and demonstrate how to use the model to generate a text sequence. See README for additional information.

Code examples/gpt-2/gpt2.py
Hardware GPU or TPU
Techniques
  • Define Transformer model.
  • Load GPT-2 model with pretrained weights and generate a sequence.

RNN

Example code is available at examples/text_generation.

Train a Vanilla RNN to Predict Characters

Train and evaluate a vanilla RNN model on the Shakespeare corpus dataset. See README for additional information.

# Run command
python3 examples/text_generation/shakespeare_rnn.py
Code examples/text_generation/shakespeare_rnn.py
Data Shakespeare corpus from tensorflow_datasets
Network Custom implementation of vanilla RNN.
Loss :py:func:`objax.functional.loss.cross_entropy_logits`
Optimizer :py:class:`objax.optimizer.Adam`
Hardware GPU or TPU
Techniques

Optimization

Example codes available at examples/maml.

Model Agnostic Meta-Learning (MAML)

Meta-learning method MAML implementation to demonstrate computing the gradient of a gradient.

# Run command
python3 examples/maml/maml.py
Code examples/maml/maml.py
Data Synthetic data
Network 3-layer DNNet
Hardware CPU or GPU or TPU
Techniques Gradient of gradient.

Jaxboard

Example code available at examples/jaxboard.

How to Use Jaxboard

Sample usage of jaxboard. See README for additional information.

# Run command
python3 examples/jaxboard/summary.py
Code examples/jaxboard/summary.py
Hardware CPU
Usages
  • summary scalar
  • summary text
  • summary image