Skip to content

ethanluoyc/jam

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

18 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Jam - JAX models

Jam is a collection of ML models (mostly vision models for now) implemented in Flax/Haiku. It includes model implementation, as well as pretrained weights converted from the other sources.

Jam is currently written to allow easy access to some pretrained models that provide PyTorch checkpoints. These pretrained models may be used for a variety of purposes, such as transfer learning, or as feature extractor in some vision-based RL tasks. There are preliminary examples for training some of these models from scratch but they are not yet fully tested/benchmarked.

Supported pretrained models

  1. ConvNeXt (via torchvision), flax
  2. ResNet (via torchvision), haiku and flax
  3. MVP (via https://github.com/ir413/mvp/), flax
  4. NFNet (via https://github.com/google-deepmind/deepmind-research/blob/master/nfnets), haiku and flax
  5. R3M (via https://github.com/facebookresearch/r3m/tree/main), haiku and flax

Examples

See examples.

About

Jam - JAX models

Resources

Stars

Watchers

Forks

Releases

No releases published

Languages