This repository contains the code in both PyTorch and Jax for our paper.
Improving Transformers with Dynamically Composable Multi-Head Attention
Da Xiao, Qingye Meng, Shengping Li, Xingyuan Yuan
ICML 2024 (oral)
We propose Dynamically Composable Multi-Head Attention (DCMHA), a parameter and computation efficient attention architecture that tackles the shortcomings of Multi-Head Attention(MHA) and increases the expressive power of the model by dynamically composing attention heads. At the core of DCMHA is a Compose
function that transforms the attention score and weight matrices in an input-dependent way. DCMHA can be used as a drop-in replacement of MHA in any transformer architecture to obtain the corresponding DCFormer.
In practice, we train DCFormer on TPU for efficiency and then infer on GPU for convenience, so we open-source Jax training code and PyTorch inference code in two separate folders.
- The source code is in the
jax/
folder, supporting train DCFormer on TPU or GPU with google/MaxText. - Please refer to
jax/README.md
for details.
- The source code is in the
pytorch/
folder, supporting accelerated inference with torch.compile. - We also uploaded pretrained (DCFormer-2.8B) (DCFormer++2.8B in the paper) and (DCPythia-6.9B) to Huggingface🤗.
- Please refer to
pytorch/README.md
for details.
Synthetic tasks dataset in the paper is located at data/synthetic_dataset.jsonl
. Each line in the file contains an in-context learning sample, where words in the bracket []
are compared to calculate accuracy. Eg.
< Elizabeth has jeans. Steven has strawberries. Steven has a fox. >. Steven does not have a kind of clothing
< Karen has a taxi. Karen has a pig. Paul has a cocktail. >. Karen does not have a kind of drink
< Ruth has a sweater. Steven has a taxi. Ruth has a handgun. >. Ruth does not have a kind of [ vehicle ]
. . .