Skip to content

Commit

Permalink
add enhanced recurrence from Ernie-doc paper, turned on by default
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 4, 2021
1 parent d34aab1 commit cd3c533
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 26 deletions.
46 changes: 28 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ model = CompressiveTransformer(
gru_gated_residual = True, # whether to gate the residual intersection, from 'Stabilizing Transformer for RL' paper
mogrify_gru = False, # experimental feature that adds a mogrifier for the update and residual before gating by the GRU
memory_layers = range(6, 13), # specify which layers to use long-range memory, from 'Do Transformers Need LR Memory' paper
one_head_kv = True, # share one key/value head for all queries, from Shazeers 'One Write-Head is All You Need'
ff_glu = True # use GLU variant for feedforward
)

Expand Down Expand Up @@ -91,37 +90,37 @@ sample = model.generate(prime, 4096)

```bibtex
@misc{rae2019compressive,
title={Compressive Transformers for Long-Range Sequence Modelling},
author={Jack W. Rae and Anna Potapenko and Siddhant M. Jayakumar and Timothy P. Lillicrap},
year={2019},
eprint={1911.05507},
archivePrefix={arXiv},
primaryClass={cs.LG}
title = {Compressive Transformers for Long-Range Sequence Modelling},
author = {Jack W. Rae and Anna Potapenko and Siddhant M. Jayakumar and Timothy P. Lillicrap},
year = {2019},
eprint = {1911.05507},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
```

```bibtex
@misc{parisotto2019stabilizing,
title={Stabilizing Transformers for Reinforcement Learning},
author={Emilio Parisotto and H. Francis Song and Jack W. Rae and Razvan Pascanu and Caglar Gulcehre and Siddhant M. Jayakumar and Max Jaderberg and Raphael Lopez Kaufman and Aidan Clark and Seb Noury and Matthew M. Botvinick and Nicolas Heess and Raia Hadsell},
year={2019},
eprint={1910.06764},
archivePrefix={arXiv},
primaryClass={cs.LG}
title = {Stabilizing Transformers for Reinforcement Learning},
author = {Emilio Parisotto and H. Francis Song and Jack W. Rae and Razvan Pascanu and Caglar Gulcehre and Siddhant M. Jayakumar and Max Jaderberg and Raphael Lopez Kaufman and Aidan Clark and Seb Noury and Matthew M. Botvinick and Nicolas Heess and Raia Hadsell},
year = {2019},
eprint = {1910.06764},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
```

```bibtex
@inproceedings{rae-razavi-2020-transformers,
title = "Do Transformers Need Deep Long-Range Memory?",
author = "Rae, Jack and
title = "Do Transformers Need Deep Long-Range Memory?",
author = "Rae, Jack and
Razavi, Ali",
booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
month = jul,
year = "2020",
month = jul,
year = "2020",
address = "Online",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/2020.acl-main.672"
url = "https://www.aclweb.org/anthology/2020.acl-main.672"
}
```

Expand Down Expand Up @@ -152,3 +151,14 @@ sample = model.generate(prime, 4096)
url = {https://arxiv.org/abs/1909.11942}
}
```

```bibtex
@misc{ding2021erniedoc,
title = {ERNIE-Doc: A Retrospective Long-Document Modeling Transformer},
author = {Siyu Ding and Junyuan Shang and Shuohuan Wang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang},
year = {2021},
eprint = {2012.15688},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
```
36 changes: 30 additions & 6 deletions compressive_transformer_pytorch/compressive_transformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def forward(self, x, **kwargs):
# attention.

class SelfAttention(nn.Module):
def __init__(self, dim, seq_len, mem_len, cmem_len, cmem_ratio = 4, heads = 8, attn_dropout = 0., dropout = 0., reconstruction_attn_dropout = 0., one_kv_head = False):
def __init__(self, dim, seq_len, mem_len, cmem_len, cmem_ratio = 4, heads = 8, attn_dropout = 0., dropout = 0., reconstruction_attn_dropout = 0.):
super().__init__()
assert (dim % heads) == 0, 'dimension must be divisible by the number of heads'

Expand All @@ -182,9 +182,7 @@ def __init__(self, dim, seq_len, mem_len, cmem_len, cmem_ratio = 4, heads = 8, a
self.compress_mem_fn = ConvCompress(dim, cmem_ratio)

self.to_q = nn.Linear(dim, dim, bias = False)

kv_dim = self.dim_head if one_kv_head else dim
self.to_kv = nn.Linear(dim, kv_dim * 2, bias = False)
self.to_kv = nn.Linear(dim, dim * 2, bias = False)
self.to_out = nn.Linear(dim, dim)

self.attn_dropout = nn.Dropout(attn_dropout)
Expand Down Expand Up @@ -291,7 +289,28 @@ def forward(self, x, memories = None, pos_emb = None, input_mask = None, calc_me
# transformer

class CompressiveTransformer(nn.Module):
def __init__(self, num_tokens, dim, seq_len, depth, emb_dim = None, memory_layers = None, mem_len = None, cmem_len = None, cmem_ratio = 4, heads = 8, gru_gated_residual = True, mogrify_gru = False, attn_dropout = 0., ff_glu = False, ff_dropout = 0., attn_layer_dropout = 0., reconstruction_attn_dropout = 0., reconstruction_loss_weight = 1., one_kv_head = False):
def __init__(
self,
num_tokens,
dim,
seq_len,
depth,
emb_dim = None,
memory_layers = None,
enhanced_recurrence = True,
mem_len = None,
cmem_len = None,
cmem_ratio = 4,
heads = 8,
gru_gated_residual = True,
mogrify_gru = False,
attn_dropout = 0.,
ff_glu = False,
ff_dropout = 0.,
attn_layer_dropout = 0.,
reconstruction_attn_dropout = 0.,
reconstruction_loss_weight = 1.
):
super().__init__()
emb_dim = default(emb_dim, dim)
mem_len = default(mem_len, seq_len)
Expand All @@ -306,6 +325,7 @@ def __init__(self, num_tokens, dim, seq_len, depth, emb_dim = None, memory_layer

self.depth = depth
self.memory_layers = list(memory_layers)
self.enhanced_recurrence = enhanced_recurrence

self.token_emb = nn.Embedding(num_tokens, emb_dim)
self.to_model_dim = nn.Identity() if emb_dim == dim else nn.Linear(emb_dim, dim)
Expand All @@ -320,7 +340,7 @@ def __init__(self, num_tokens, dim, seq_len, depth, emb_dim = None, memory_layer

wrapper = partial(GRUGating, dim, mogrify = mogrify_gru) if gru_gated_residual else Residual

self.attn_layers = nn.ModuleList([wrapper(PreNorm(dim, SelfAttention(dim, seq_len, mem_len, cmem_len, cmem_ratio, heads, dropout = attn_layer_dropout, attn_dropout = attn_dropout, reconstruction_attn_dropout = reconstruction_attn_dropout, one_kv_head = one_kv_head))) for _ in range(depth)])
self.attn_layers = nn.ModuleList([wrapper(PreNorm(dim, SelfAttention(dim, seq_len, mem_len, cmem_len, cmem_ratio, heads, dropout = attn_layer_dropout, attn_dropout = attn_dropout, reconstruction_attn_dropout = reconstruction_attn_dropout))) for _ in range(depth)])
self.ff_layers = nn.ModuleList([wrapper(PreNorm(dim, FeedForward(dim, dropout = ff_dropout, glu = ff_glu))) for _ in range(depth)])

self.reconstruction_loss_weight = reconstruction_loss_weight
Expand All @@ -347,6 +367,10 @@ def forward(self, x, memories = None, mask = None):
next_cmem = []
aux_loss = torch.tensor(0., requires_grad = True, **to(x))

if self.enhanced_recurrence:
mem = torch.roll(mem, -1, 0)
cmem = torch.roll(cmem, -1, 0)

mem_iter, cmem_iter = map(iterate_tensor, (mem, cmem))

for ind, (attn, ff) in enumerate(zip(self.attn_layers, self.ff_layers)):
Expand Down
9 changes: 7 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
setup(
name = 'compressive-transformer-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.3.21',
version = '0.4.0',
license='MIT',
description = 'Implementation of Compressive Transformer in Pytorch',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
url = 'https://github.com/lucidrains/compressive-transformer-pytorch',
keywords = ['attention', 'artificial intelligence', 'transformer', 'deep learning'],
keywords = [
'attention',
'artificial intelligence',
'transformer',
'deep learning'
],
install_requires=[
'torch',
'mogrifier'
Expand Down

0 comments on commit cd3c533

Please sign in to comment.