Releases
v0.1.7
Compare
Sorry, something went wrong.
No results found
Major Features
New Training Framework (braintools.trainer)
PyTorch Lightning-like training API for JAX-based neural network training with comprehensive features:
LightningModule : Base class for defining training models with training_step(), validation_step(), and configure_optimizers() hooks
Trainer : Orchestration class for managing training loops, epochs, and device placement
TrainOutput/EvalOutput : Structured output types for training and evaluation results
Callbacks System
10+ built-in callbacks for customizing training behavior:
ModelCheckpoint: Automatic model saving based on monitored metrics
EarlyStopping: Stop training when metrics plateau
LearningRateMonitor: Track and log learning rate changes
GradientClipCallback: Gradient clipping for training stability
Timer: Track training time
RichProgressBar / TQDMProgressBar: Visual progress indicators
LambdaCallback / PrintCallback: Custom callback utilities
Logging Backends
6 pluggable logging backends :
TensorBoardLogger: TensorBoard integration
WandBLogger: Weights & Biases integration
CSVLogger: Simple CSV file logging
NeptuneLogger: Neptune.ai integration
MLFlowLogger: MLFlow integration
CompositeLogger: Combine multiple loggers
Data Loading Utilities
JAX-compatible data loading with distributed support:
DataLoader / DistributedDataLoader: Efficient batch loading
Dataset, ArrayDataset, DictDataset, IterableDataset: Dataset abstractions
Sampler, RandomSampler, SequentialSampler, BatchSampler, DistributedSampler: Sampling strategies
Distributed Training
Multi-device and multi-host training strategies :
SingleDeviceStrategy: Single device execution
DataParallelStrategy: Data parallelism across devices
ShardedDataParallelStrategy / FullyShardedDataParallelStrategy: Memory-efficient sharded training
AutoStrategy: Automatic strategy selection
all_reduce, broadcast: Distributed communication primitives
Checkpointing
Comprehensive checkpoint management :
CheckpointManager: Manage multiple checkpoints with retention policies
save_checkpoint / load_checkpoint: Save and restore model states
find_checkpoint / list_checkpoints: Checkpoint discovery utilities
Progress Bar System
Multiple progress bar implementations :
SimpleProgressBar: Basic text-based progress
TQDMProgressBarWrapper: TQDM-based progress
RichProgressBarWrapper: Rich library-based progress
Improvements
API Documentation
Enhanced module documentation : All public modules now include comprehensive docstrings with examples, parameter descriptions, and usage guidelines directly in __init__.py files
Reorganized imports : Cleaner and more consistent import structure across all modules
Breaking Changes
Removed braintools.param Module
The entire braintools.param module has been removed , including:
Data containers (Data)
Parameter wrappers (Param, Const)
State containers (ArrayHidden, ArrayParam)
Regularization classes (GaussianReg, L1Reg, L2Reg)
All transform classes (SigmoidT, SoftplusT, AffineT, etc.)
Utility functions (get_param(), get_size())
Users relying on these features should migrate to alternative implementations or pin to version 0.1.6
You can’t perform that action at this time.