Skip to content

Implementation of Switch Transformers from the paper: "Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity"

License

Notifications You must be signed in to change notification settings

kyegomez/SwitchTransformers

Repository files navigation

Multi-Modality

Switch Transformers

Switch Transformer

Implementation of Switch Transformers from the paper: "Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity" in PyTorch, Einops, and Zeta. PAPER LINK

Installation

pip install switch-transformers

Usage

import torch
from switch_transformers import SwitchTransformer

# Generate a random tensor of shape (1, 10) with values between 0 and 100
x = torch.randint(0, 100, (1, 10))

# Create an instance of the SwitchTransformer model
# num_tokens: the number of tokens in the input sequence
# dim: the dimensionality of the model
# heads: the number of attention heads
# dim_head: the dimensionality of each attention head
model = SwitchTransformer(
    num_tokens=100, dim=512, heads=8, dim_head=64
)

# Pass the input tensor through the model
out = model(x)

# Print the shape of the output tensor
print(out.shape)

Citation

@misc{fedus2022switch,
    title={Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity}, 
    author={William Fedus and Barret Zoph and Noam Shazeer},
    year={2022},
    eprint={2101.03961},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

License

MIT

About

Implementation of Switch Transformers from the paper: "Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity"

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Sponsor this project

 

Packages

No packages published