Skip to content

mcbal/spin-model-transformers

Repository files navigation

Spin-model transformers

Install

pip install -e .[dev]
pre-commit install
pre-commit run --all-files

Examples

import jax
from spin_model_transformers import SpinTransformer


key = jax.random.PRNGKey(2666)
x_key, mod_key = jax.random.split(key)

x = jax.random.normal(x_key, shape=(1, 256, 512))
transformer = SpinTransformer(depth=6, dim=512, num_heads=1, beta=1.0, key=mod_key)

out = jax.vmap(transformer)(x)  # (1, 256, 512)