Skip to content

Commit

Permalink
allow for specifying scaling of residual branch for encoder and decod…
Browse files Browse the repository at this point in the history
…er, to pave the way for https://arxiv.org/abs/2203.00555
  • Loading branch information
lucidrains committed Mar 15, 2022
1 parent 0b01736 commit 1c57013
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 18 deletions.
58 changes: 41 additions & 17 deletions retro_pytorch/retro_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import torch
import torch.nn.functional as F
from torch import nn, einsum
Expand All @@ -23,6 +25,17 @@ def divisible_by(val, divisor):
def cast_tuple(val, num = 1):
return val if isinstance(val, tuple) else ((val,) * num)

# helper functions

class Residual(nn.Module):
def __init__(self, fn, scale_residual = 1.):
super().__init__()
self.fn = fn
self.scale_residual = scale_residual

def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x * self.scale_residual

# normalization

class RMSNorm(nn.Module):
Expand Down Expand Up @@ -260,21 +273,24 @@ def __init__(
ff_mult = 4,
ff_dropout = 0.,
final_norm = True,
cross_attn_layers = None
cross_attn_layers = None,
scale_residual = 1.
):
super().__init__()
self.layers = nn.ModuleList([])

rotary_emb_dim = max(dim_head // 2, MIN_DIM_HEAD)
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim)

residual_wrapper = partial(Residual, scale_residual = scale_residual)

for layer_num in range(1, depth + 1):
has_cross_attn = not exists(cross_attn_layers) or layer_num in cross_attn_layers

self.layers.append(nn.ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, causal = causal),
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout) if has_cross_attn else None,
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout),
residual_wrapper(Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, causal = causal)),
residual_wrapper(Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)) if has_cross_attn else None,
residual_wrapper(FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)),
]))

self.norm_out = RMSNorm(dim) if final_norm else nn.Identity()
Expand All @@ -286,12 +302,12 @@ def forward(self, x, *, mask = None, chunked_seq):
k_pos_emb = self.rotary_pos_emb(seq_len, device = device)

for attn, cross_attn, ff in self.layers:
x = attn(x, mask = mask, pos_emb = q_pos_emb) + x
x = attn(x, mask = mask, pos_emb = q_pos_emb)

if exists(cross_attn):
x = cross_attn(x, context = chunked_seq, pos_emb = (q_pos_emb, k_pos_emb)) + x
x = cross_attn(x, context = chunked_seq, pos_emb = (q_pos_emb, k_pos_emb))

x = ff(x) + x
x = ff(x)

return self.norm_out(x)

Expand All @@ -308,22 +324,26 @@ def __init__(
ff_dropout = 0.,
final_norm = True,
cross_attn_layers = None,
chunk_size = 64
chunk_size = 64,
scale_residual = 1.
):
super().__init__()
self.layers = nn.ModuleList([])

rotary_emb_dim = max(dim_head // 2, MIN_DIM_HEAD)
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim)

residual_wrapper = partial(Residual, scale_residual = scale_residual)

self.chunk_size = chunk_size

for layer_num in range(1, depth + 1):
has_cross_attn = not exists(cross_attn_layers) or layer_num in cross_attn_layers

self.layers.append(nn.ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, causal = True),
ChunkedCrossAttention(chunk_size = chunk_size, dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout) if has_cross_attn else None,
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout),
residual_wrapper(Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, causal = True)),
residual_wrapper(ChunkedCrossAttention(chunk_size = chunk_size, dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)) if has_cross_attn else None,
residual_wrapper(FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)),
]))

self.norm_out = RMSNorm(dim) if final_norm else nn.Identity()
Expand All @@ -341,17 +361,17 @@ def forward(self, x, *, context_mask = None, retrieved = None):
cross_attn_pos_emb = (cross_attn_q_pos_emb, cross_attn_k_pos_emb)

for attn, cross_attn, ff in self.layers:
x = attn(x, pos_emb = self_attn_pos_emb) + x
x = attn(x, pos_emb = self_attn_pos_emb)

if exists(cross_attn) and exists(retrieved):
x = cross_attn(
x,
context = retrieved,
context_mask = context_mask,
pos_emb = cross_attn_pos_emb
) + x
)

x = ff(x) + x
x = ff(x)

return self.norm_out(x)

Expand All @@ -376,7 +396,9 @@ def __init__(
dec_attn_dropout = 0.,
dec_ff_dropout = 0.,
chunk_size = 64,
pad_id = 0
pad_id = 0,
enc_scale_residual = 1.,
dec_scale_residual = 1.
):
super().__init__()
assert dim_head >= MIN_DIM_HEAD, f'dimension per head must be greater than {MIN_DIM_HEAD}'
Expand All @@ -396,7 +418,8 @@ def __init__(
depth = enc_depth,
attn_dropout = enc_attn_dropout,
ff_dropout = enc_ff_dropout,
cross_attn_layers = enc_cross_attn_layers
cross_attn_layers = enc_cross_attn_layers,
scale_residual = enc_scale_residual
)

self.decoder = Decoder(
Expand All @@ -405,7 +428,8 @@ def __init__(
attn_dropout = dec_attn_dropout,
ff_dropout = dec_ff_dropout,
cross_attn_layers = dec_cross_attn_layers,
chunk_size = chunk_size
chunk_size = chunk_size,
scale_residual = dec_scale_residual
)

self.to_logits = nn.Linear(dec_dim, num_tokens)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'retro-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.10',
version = '0.2.0',
license='MIT',
description = 'RETRO - Retrieval Enhanced Transformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 1c57013

Please sign in to comment.