Skip to content

Latest commit

 

History

History
335 lines (261 loc) · 14.5 KB

examples.rst

File metadata and controls

335 lines (261 loc) · 14.5 KB

Code Examples

This section describes the code examples found in objax/examples.

Classification

Image

Example code available at examples/image_classification.

Logistic Regression

Train and evaluate a logistic regression model for binary classification on horses or humans dataset.

# Run command
python3 examples/image_classification/horses_or_humans_logistic.py
Code examples/image_classification/horses_or_humans_logistic.py
Data horses_or_humans from tensorflow_datasets
Network Custom single layer
Loss :pyobjax.functional.loss.sigmoid_cross_entropy_logits
Optimizer :pyobjax.optimizer.SGD
Accuracy ~77%
Hardware CPU or GPU or TPU

Digit Classification with Deep Neural Network (DNN)

Train and evaluate a DNNet model for multiclass classification on the MNIST dataset.

# Run command
python3 examples/image_classification/mnist_dnn.py
Code examples/image_classification/mnist_dnn.py

Data

MNIST from tensorflow_datasets

Network Deep Neural Net :pyobjax.zoo.dnnet.DNNet
Loss :pyobjax.functional.loss.cross_entropy_logits
Optimizer :pyobjax.optimizer.Adam
Accuracy ~98%
Hardware CPU or GPU or TPU

Techniques

Model weight averaging for improved accuracy using :pyobjax.optimizer.ExponentialMovingAverage.

Digit Classification with Convolutional Neural Network (CNN)

Train and evaluate a simple custom CNN model for multiclass classification on the MNIST dataset.

# Run command
python3 examples/image_classification/mnist_cnn.py
Code examples/image_classification/mnist_cnn.py

Data

MNIST from tensorflow_datasets

Network Custom Convolution Neural Net using :pyobjax.nn.Sequential
Loss :pyobjax.functional.loss.cross_entropy_logits_sparse
Optimizer :pyobjax.optimizer.Adam
Accuracy ~99.5%
Hardware CPU or GPU or TPU

Techniques

  • Model weight averaging for improved accuracy using :pyobjax.optimizer.ExponentialMovingAverage.
  • Regularization using extra weight decay term in loss.

Digit Classification using Differential Privacy

Train and evaluate a convNet model for MNIST dataset with differential privacy.

# Run command
python3 examples/image_classification/mnist_dp.py
# See available options with
python3 examples/image_classification/mnist_dp.py --help
Code examples/image_classification/mnist_dp.py

Data

MNIST from tensorflow_datasets

Network Custom Convolution Neural Net using :pyobjax.nn.Sequential
Loss :pyobjax.functional.loss.cross_entropy_logits

Optimizer Accuracy

:pyobjax.optimizer.SGD

Hardware GPU
Techniques * Compute differentially private gradient using :pyobjax.privacy.dpsgd.PrivateGradValues.

Image Classification on CIFAR-10 (Simple)

Train and evaluate a wide resnet model for multiclass classification on the CIFAR10 dataset.

# Run command
python3 examples/image_classification/cifar10_simple.py
Code examples/image_classification/cifar10_simple.py

Data

CIFAR10 from tf.keras.datasets

Network Wide ResNet using :pyobjax.zoo.wide_resnet.WideResNet
Loss :pyobjax.functional.loss.cross_entropy_logits_sparse
Optimizer :pyobjax.optimizer.Momentum
Accuracy ~91%
Hardware GPU or TPU

Techniques

  • Learning rate schedule.
  • Data augmentation (mirror / pixel shifts) in Numpy.
  • Regularization using extra weight decay term in loss.

Image Classification on CIFAR-10 (Advanced)

Train and evaluate convNet models for multiclass classification on the CIFAR10 dataset.

# Run command
python3 examples/image_classification/cifar10_advanced.py
# Run with custom settings
python3 examples/image_classification/cifar10_advanced.py --weight_decay=0.0001 --batch=64 --lr=0.03 --epochs=256
# See available options with
python3 examples/image_classification/cifar10_advanced.py --help
Code examples/image_classification/cifar10_advanced.py

Data

CIFAR10 from tensorflow_datasets

Network

Configurable with --arch="network" * wrn28-1, wrn28-2 using :pyobjax.zoo.wide_resnet.WideResNet * cnn32-3-max, cnn32-3-mean, cnn64-3-max, cnn64-3-mean using :pyobjax.zoo.convnet.ConvNet

Loss :pyobjax.functional.loss.cross_entropy_logits
Optimizer :pyobjax.optimizer.Momentum
Accuracy ~94%
Hardware GPU, Multi-GPU or TPU

Techniques

  • Model weight averaging for improved accuracy using :pyobjax.optimizer.ExponentialMovingAverage.
  • Parallelized on multiple GPUs using :pyobjax.Parallel.
  • Data augmentation (mirror / pixel shifts) in TensorFlow.
  • Cosine learning rate decay.
  • Regularization using extra weight decay term in loss.
  • Checkpointing, automatic resuming from latest checkpoint if training is interrupted using :pyobjax.io.Checkpoint.
  • Saving of tensorboard visualization files using :pyobjax.jaxboard.SummaryWriter.
  • Multi-loss reporting (cross-entropy, L2).
  • Reusable training loop example.

Image Classification on ImageNet

Train and evaluate a ResNet50 model on the ImageNet dataset. See README for additional information.

Code examples/image_classification/imagenet_resnet50_train.py
Data ImageNet from tensorflow_datasets
Network ResNet50
Loss :pyobjax.functional.loss.cross_entropy_logits_sparse

Optimizer Accuracy

:pyobjax.optimizer.Momentum

Hardware GPU, Multi-GPU or TPU

Techniques

  • Parallelized on multiple GPUs using :pyobjax.Parallel.
  • Data augmentation (distorted bounding box crop) in TensorFlow.
  • Linear warmup followed by multi-step learning rate decay.
  • Regularization using extra weight decay term in loss.
  • Checkpointing, automatic resuming from latest checkpoint if training is interrupted using :pyobjax.io.Checkpoint.
  • Saving of tensorboard visualization files using :pyobjax.jaxboard.SummaryWriter.

Image Classification using Pretrained VGG Network

Image classification using an ImageNet-pretrained VGG19 model. See README for additional information.

Code examples/image_classification/imagenet_pretrained_vgg.py
Techniques Load VGG-19 model with pretrained weights and run 1000-way image classification.

Semi-Supervised Learning

Example code available at examples/fixmatch.

Semi-Supervised Learning with FixMatch

Semi-supervised learning of image classification models with FixMatch.

# Run command
python3 examples/fixmatch/fixmatch.py
# Run with custom settings
python3 examples/fixmatch/fixmatch.py --dataset=cifar10.3@1000-0
# See available options with
python3 examples/fixmatch/fixmatch.py --help
Code examples/fixmatch/fixmatch.py
Data CIFAR10, CIFAR100, SVHN, STL10
Network Custom implementation of Wide ResNet.
Loss :pyobjax.functional.loss.cross_entropy_logits and :pyobjax.functional.loss.cross_entropy_logits_sparse
Optimizer :pyobjax.optimizer.Momentum
Accuracy See paper
Hardware GPU, Multi-GPU, TPU

Techniques

  • Load data from multiple data pipelines.
  • Advanced data augmentation such as RandAugment and CTAugment.
  • Stop gradient using :pyobjax.functional.stop_gradient.
  • Cosine learning rate decay.
  • Regularization using extra weight decay term in loss.

GPT-2

Example code is available at examples/gpt-2.

Generating a Text Sequence using GPT-2

Load pretrained GPT-2 model (124M parameter) and demonstrate how to use the model to generate a text sequence. See README for additional information.

Code examples/gpt-2/gpt2.py
Hardware GPU or TPU

Techniques

  • Define Transformer model.
  • Load GPT-2 model with pretrained weights and generate a sequence.

RNN

Example code is available at examples/text_generation.

Train a Vanilla RNN to Predict Characters

Train and evaluate a vanilla RNN model on the Shakespeare corpus dataset. See README for additional information.

# Run command
python3 examples/text_generation/shakespeare_rnn.py
Code examples/text_generation/shakespeare_rnn.py

Data

Shakespeare corpus from tensorflow_datasets

Network Custom implementation of vanilla RNN.
Loss :pyobjax.functional.loss.cross_entropy_logits
Optimizer :pyobjax.optimizer.Adam
Hardware GPU or TPU

Techniques

  • Model weight averaging for improved accuracy using :pyobjax.optimizer.ExponentialMovingAverage.
  • Data pipeline of sequence data for training.
  • Data processing (e.g., tokenize).
  • Clip gradients.

Optimization

Example codes available at examples/maml.

Model Agnostic Meta-Learning (MAML)

Meta-learning method MAML implementation to demonstrate computing the gradient of a gradient.

# Run command
python3 examples/maml/maml.py
Code examples/maml/maml.py
Data Synthetic data
Network 3-layer DNNet
Hardware CPU or GPU or TPU
Techniques Gradient of gradient.

Jaxboard

Example code available at examples/jaxboard.

How to Use Jaxboard

Sample usage of jaxboard. See README for additional information.

# Run command
python3 examples/jaxboard/summary.py
Code examples/jaxboard/summary.py
Hardware CPU

Usages

  • summary scalar
  • summary text
  • summary image