Skip to content

ClashLuke/jaxhelper

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 

Repository files navigation

JaxHelper

Basic tools and helpers for Jax

Getting Started

Installation

python3 -m pip install jaxhelper

Explanation

This repository contains basic helper functions I use every day.
Here are some highlights:

  • remat: function decorator to rematerialize ("activation checkpointing") hidden states during backward pass
  • softmax:
    • exp in fp32 and matmul in bf16 (-> improved convergence and speed)
    • fewer stored intermediates yet faster gradient
  • attention:

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published