Skip to content

Commit

Permalink
Ad rotary
Browse files Browse the repository at this point in the history
  • Loading branch information
bratao committed May 2, 2021
1 parent 7c84205 commit 1c614c3
Showing 1 changed file with 68 additions and 8 deletions.
76 changes: 68 additions & 8 deletions sru/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import torch.nn.functional as F
from torch import Tensor
from torch.nn.utils.rnn import PackedSequence
from einops import rearrange, repeat, reduce

from sru.ops import (elementwise_recurrence_cpu,
elementwise_recurrence_gpu,
elementwise_recurrence_naive)



class SRUCell(nn.Module):
"""
A single SRU layer as per `LSTMCell`, `GRUCell` in Pytorch.
Expand Down Expand Up @@ -686,6 +688,34 @@ def forward(self,
return output


class RotaryEmbedding(nn.Module):

def __init__(self, dim, base=10000):
super().__init__()
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
self.seq_len_cached = None
self.emb_cached = None

def forward(self, x, seq_dim=1):
seq_len = x.shape[seq_dim]
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.emb_cached = emb[None, :, :]
return self.emb_cached

def rotate_half(x):
x = rearrange(x, '... (j d) -> ... j d', j = 2)
x1, x2 = x.unbind(dim = -2)
return torch.cat((-x2, x1), dim = -1)

def apply_rotary_pos_emb(q, k, freqs):
q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k))
return q, k

class SRUppAttention(nn.Module):
"""
Self-attention module used in SRU++ module.
Expand Down Expand Up @@ -746,6 +776,8 @@ def __init__(self,
self.alpha = nn.Parameter(torch.Tensor([float(rezero_init_alpha)])) # type: ignore
self.normalize_after = normalize_after
self.layer_norm: Optional[nn.Module] = None


if layer_norm:
self.layer_norm = nn.LayerNorm(proj_features)

Expand Down Expand Up @@ -774,7 +806,8 @@ def forward(self,
mask_pad: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
memory: Optional[Tensor] = None,
memory_mask_pad: Optional[Tensor] = None) -> Tensor:
memory_mask_pad: Optional[Tensor] = None,
rotary_pos_emb=None) -> Tensor:
"""The forward method of SRU++ attention.
"""

Expand All @@ -786,6 +819,8 @@ def forward(self,
head_dim = proj_dim // num_heads
scaling = float(head_dim) ** -0.5



# concat memory and input as the key-value block when provided
if memory is not None:
if memory.dim() != 3 or list(memory.size()[-2:]) != [bsz, in_dim]:
Expand All @@ -812,12 +847,20 @@ def forward(self,
z = layer_norm(z)
q = z


# query, key, value
k, v = self.linear2(z).chunk(2, dim=-1)
q = q.contiguous().view(tgt_len, -1, head_dim).transpose(0, 1)
k = k.contiguous().view(src_len, -1, head_dim).transpose(0, 1)
v = v.contiguous().view(src_len, -1, head_dim).transpose(0, 1)

if rotary_pos_emb is not None:
l = rotary_pos_emb.shape[-1]
(ql, qr), (kl, kr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k))
ql, kl = apply_rotary_pos_emb(ql, kl, rotary_pos_emb)
q = torch.cat((ql, qr), dim=-1)
k = torch.cat((kl, kr), dim=-1)

# (bsz * num_heads, tgt_len, src_len)
q = q * scaling
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
Expand Down Expand Up @@ -881,7 +924,9 @@ def forward(self,
mask_pad: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
memory: Optional[Tensor] = None,
memory_mask_pad: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
memory_mask_pad: Optional[Tensor] = None,
rotary_pos_emb=None
) -> Tuple[Tensor, Tensor]:
"""The forward method of the SRU++ layer.
"""

Expand Down Expand Up @@ -918,10 +963,17 @@ def forward(self,
# compute U
# U is (length, batch_size, output_size * num_matrices)
transform_module = self.transform_module
U = transform_module(input, mask_pad=mask_pad,
attn_mask=attn_mask,
memory=memory,
memory_mask_pad=memory_mask_pad)
if isinstance(transform_module, SRUppAttention):
U = transform_module(input, mask_pad=mask_pad,
attn_mask=attn_mask,
memory=memory,
memory_mask_pad=memory_mask_pad,
rotary_pos_emb=rotary_pos_emb)
else:
U = transform_module(input, mask_pad=mask_pad,
attn_mask=attn_mask,
memory=memory,
memory_mask_pad=memory_mask_pad)
V = self.weight_c

# apply elementwise recurrence to get hidden states h and c
Expand Down Expand Up @@ -1027,6 +1079,9 @@ def __init__(self,
else:
first_layer_input_size = input_size

rotary_emb_dim = 32
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim)

for i in range(num_layers):
# create the i-th SRU layer
in_features = first_layer_input_size if i == 0 else self.output_size
Expand Down Expand Up @@ -1156,7 +1211,7 @@ def forward(self,

if input_size != self.input_size:
raise ValueError("Input has size (*, *, {}) but expect a last dimension of {}".format(
input_size, self.input_size
input_size, self.input_to_hidden
))

if c0 is None:
Expand Down Expand Up @@ -1195,14 +1250,19 @@ def forward(self,
lstc = []
i = 0
x = x.contiguous()

rotary_pos_emb = self.rotary_pos_emb(input, 0)
#rotary_pos_emb = None

for rnn in self.rnn_lst:
prev_inputs.append(x)
memory_i = memory[i] if memory is not None else None
h, c = rnn(x, c0_[i],
mask_pad=mask_pad,
attn_mask=attn_mask,
memory=memory_i,
memory_mask_pad=memory_mask_pad)
memory_mask_pad=memory_mask_pad,
rotary_pos_emb=rotary_pos_emb)
x = h
lstc.append(c)
i += 1
Expand Down

0 comments on commit 1c614c3

Please sign in to comment.