Skip to content

lucidrains/mlp-gpt-jax

main
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Code

Latest commit

 

Git stats

Files

Permalink
Failed to load latest commit information.
Type
Name
Latest commit message
Commit time
 
 
 
 
 
 
 
 
 
 
 
 
 
 

MLP GPT - Jax

A GPT, made only of MLPs, in Jax. The specific MLP to be used are gMLPs with the Spatial Gating Units.

Working Pytorch implementation

Install

$ pip install mlp-gpt-jax

Usage

from jax import random
from haiku import PRNGSequence
from mlp_gpt_jax import TransformedMLPGpt

model = TransformedMLPGpt(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 1024
)

rng = PRNGSequence(0)
seq = random.randint(next(rng), (1024,), 0, 20000)

params = model.init(next(rng), seq)
logits = model.apply(params, next(rng), seq) # (1024, 20000)

To use the tiny attention (also made autoregressive with a causal mask), just set the attn_dim to the head dimension you'd like to use. 64 was recommended in the paper

from jax import random
from haiku import PRNGSequence
from mlp_gpt_jax import TransformedMLPGpt

model = TransformedMLPGpt(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 1024,
    attn_dim = 64     # set this to 64
)

rng = PRNGSequence(0)
seq = random.randint(next(rng), (1024,), 0, 20000)

params = model.init(next(rng), seq)
logits = model.apply(params, next(rng), seq) # (1024, 20000)

Citations

@misc{liu2021pay,
    title   = {Pay Attention to MLPs}, 
    author  = {Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
    year    = {2021},
    eprint  = {2105.08050},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}