This section describes the code examples found in objax/examples.
Example code available at examples/image_classification.
Train and evaluate a logistic regression model for binary classification on horses or humans dataset.
# Run command
python3 examples/image_classification/horses_or_humans_logistic.py
Code | examples/image_classification/horses_or_humans_logistic.py |
Data | horses_or_humans from tensorflow_datasets |
Network | Custom single layer |
Loss | :pyobjax.functional.loss.sigmoid_cross_entropy_logits |
Optimizer | :pyobjax.optimizer.SGD |
Accuracy | ~77% |
Hardware | CPU or GPU or TPU |
Train and evaluate a DNNet model for multiclass classification on the MNIST dataset.
# Run command
python3 examples/image_classification/mnist_dnn.py
Code | examples/image_classification/mnist_dnn.py |
Data |
MNIST from tensorflow_datasets |
Network | Deep Neural Net :pyobjax.zoo.dnnet.DNNet |
Loss | :pyobjax.functional.loss.cross_entropy_logits |
Optimizer | :pyobjax.optimizer.Adam |
Accuracy | ~98% |
Hardware | CPU or GPU or TPU |
Techniques |
Model weight averaging for improved accuracy using :py |
Train and evaluate a simple custom CNN model for multiclass classification on the MNIST dataset.
# Run command
python3 examples/image_classification/mnist_cnn.py
Code | examples/image_classification/mnist_cnn.py |
Data |
MNIST from tensorflow_datasets |
Network | Custom Convolution Neural Net using :pyobjax.nn.Sequential |
Loss | :pyobjax.functional.loss.cross_entropy_logits_sparse |
Optimizer | :pyobjax.optimizer.Adam |
Accuracy | ~99.5% |
Hardware | CPU or GPU or TPU |
Techniques |
|
Train and evaluate a convNet model for MNIST dataset with differential privacy.
# Run command
python3 examples/image_classification/mnist_dp.py
# See available options with
python3 examples/image_classification/mnist_dp.py --help
Code | examples/image_classification/mnist_dp.py |
Data |
MNIST from tensorflow_datasets |
Network | Custom Convolution Neural Net using :pyobjax.nn.Sequential |
Loss | :pyobjax.functional.loss.cross_entropy_logits |
Optimizer Accuracy |
:py |
Hardware | GPU |
Techniques | * Compute differentially private gradient using :pyobjax.privacy.dpsgd.PrivateGradValues . |
Train and evaluate a wide resnet model for multiclass classification on the CIFAR10 dataset.
# Run command
python3 examples/image_classification/cifar10_simple.py
Code | examples/image_classification/cifar10_simple.py |
Data |
CIFAR10 from tf.keras.datasets |
Network | Wide ResNet using :pyobjax.zoo.wide_resnet.WideResNet |
Loss | :pyobjax.functional.loss.cross_entropy_logits_sparse |
Optimizer | :pyobjax.optimizer.Momentum |
Accuracy | ~91% |
Hardware | GPU or TPU |
Techniques |
|
Train and evaluate convNet models for multiclass classification on the CIFAR10 dataset.
# Run command
python3 examples/image_classification/cifar10_advanced.py
# Run with custom settings
python3 examples/image_classification/cifar10_advanced.py --weight_decay=0.0001 --batch=64 --lr=0.03 --epochs=256
# See available options with
python3 examples/image_classification/cifar10_advanced.py --help
Code | examples/image_classification/cifar10_advanced.py |
Data |
|
Network |
Configurable with |
Loss | :pyobjax.functional.loss.cross_entropy_logits |
Optimizer | :pyobjax.optimizer.Momentum |
Accuracy | ~94% |
Hardware | GPU, Multi-GPU or TPU |
Techniques |
|
Train and evaluate a ResNet50 model on the ImageNet dataset. See README for additional information.
Code | examples/image_classification/imagenet_resnet50_train.py |
Data | ImageNet from tensorflow_datasets |
Network | ResNet50 |
Loss | :pyobjax.functional.loss.cross_entropy_logits_sparse |
Optimizer Accuracy |
:py |
Hardware | GPU, Multi-GPU or TPU |
Techniques |
|
Image classification using an ImageNet-pretrained VGG19 model. See README for additional information.
Code | examples/image_classification/imagenet_pretrained_vgg.py |
Techniques | Load VGG-19 model with pretrained weights and run 1000-way image classification. |
Example code available at examples/fixmatch.
Semi-supervised learning of image classification models with FixMatch.
# Run command
python3 examples/fixmatch/fixmatch.py
# Run with custom settings
python3 examples/fixmatch/fixmatch.py --dataset=cifar10.3@1000-0
# See available options with
python3 examples/fixmatch/fixmatch.py --help
Code | examples/fixmatch/fixmatch.py |
Data | CIFAR10, CIFAR100, SVHN, STL10 |
Network | Custom implementation of Wide ResNet. |
Loss | :pyobjax.functional.loss.cross_entropy_logits and :pyobjax.functional.loss.cross_entropy_logits_sparse |
Optimizer | :pyobjax.optimizer.Momentum |
Accuracy | See paper |
Hardware | GPU, Multi-GPU, TPU |
Techniques |
|
Example code is available at examples/gpt-2.
Load pretrained GPT-2 model (124M parameter) and demonstrate how to use the model to generate a text sequence. See README for additional information.
Code | examples/gpt-2/gpt2.py |
Hardware | GPU or TPU |
Techniques |
|
Example code is available at examples/text_generation.
Train and evaluate a vanilla RNN model on the Shakespeare corpus dataset. See README for additional information.
# Run command
python3 examples/text_generation/shakespeare_rnn.py
Code | examples/text_generation/shakespeare_rnn.py |
Data |
|
Network | Custom implementation of vanilla RNN. |
Loss | :pyobjax.functional.loss.cross_entropy_logits |
Optimizer | :pyobjax.optimizer.Adam |
Hardware | GPU or TPU |
Techniques |
|
Example codes available at examples/maml.
Meta-learning method MAML implementation to demonstrate computing the gradient of a gradient.
# Run command
python3 examples/maml/maml.py
Code | examples/maml/maml.py |
Data | Synthetic data |
Network | 3-layer DNNet |
Hardware | CPU or GPU or TPU |
Techniques | Gradient of gradient. |
Example code available at examples/jaxboard.
Sample usage of jaxboard. See README for additional information.
# Run command
python3 examples/jaxboard/summary.py
Code | examples/jaxboard/summary.py |
Hardware | CPU |
Usages |
|