Skip to content

Commit

Permalink
WIP on self-attention
Browse files Browse the repository at this point in the history
  • Loading branch information
honnibal committed Oct 24, 2018
1 parent a91e868 commit f40f963
Showing 1 changed file with 28 additions and 14 deletions.
42 changes: 28 additions & 14 deletions thinc/neural/_classes/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,25 @@
from .model import Model


@describe.attributes(
nK=Dimension("Key width"),
nO=Dimension("Values width"),
nI=Dimension("Input width"),
nL=Dimension("Left context width"),
nR=Dimension("Right context width"),
W=Synapses("Input weights",
lambda obj: (obj.nK+obj.nK+obj.nO, obj.nI),
lambda W, ops: ops.xavier_uniform_init(W)),
d_W=Gradient("W"),
)
class SelfAttention(Model):
def __init__(self, nK=None, nO=None, nL=5, nR=5, **kwargs):
def __init__(self, nK=None, nO=None, nI=None, nL=5, nR=5, **kwargs):
Model.__init__(self, **kwargs)
self.nK = nK
self.nO = nO
self.nL = nL
self.nR = nR
Model.__init__(self, **kwargs)
self.nI = nI

def begin_update(self, X_lengths):
X, lengths = X_lengths
Expand Down Expand Up @@ -60,7 +72,7 @@ def compare(self, queries, keys, lengths):
(sum(lengths), window_size)
'''
(dotprod, dotprod_lengths), backprop_rwd = _ragged_window_dot(
self.ops, queries, keys, lengths, self.nW)
self.ops, queries, keys, lengths, self.nL, self.nR)
dotprod /= ops.xp.sqrt(self.nK)
attention = self.ops.softmax_sequences(dotprod, dotprod_lengths)

Expand All @@ -74,7 +86,7 @@ def backprop_attention(d_attention):

return attention, backprop_attention

def rescale(self, V, A, lengths, nW=None):
def rescale(self, V, A, lengths, nL, nR):
'''Perform a weighted sum of values with the attention.
Values is a ragged array of sequences, unpacked it would be
Expand All @@ -93,7 +105,7 @@ def rescale(self, V, A, lengths, nW=None):
for i, length in enumerate(lengths):
V_ = V[vidx : vidx + length]
for j in range(length):
values = V_[max(0, j-nW) : j+nW]
values = V_[max(0, j-nL) : j+nR]
attention = A[aidx : aidx + values.shape[0]]
# set row of d from ((w, d) * (w, d)).sum()
output[aidx] = (values * attention).sum(axis=0)
Expand All @@ -109,17 +121,19 @@ def backprop_rescale(d_output):
V_ = V[vidx : vidx + length]
dV_ = dV[vidx : vidx + length]
for j in range(length):
values = V_[max(0, j-nW) : j+nW]
values = V_[max(0, j-nL) : j+nR]
attention = A[aidx : aidx + values.shape[0]]

dV_[max(0, j-nW) : j+nW] += attention * d_output[aidx]
dA[aidx : aidx + values.shape[0]] += values * d_output[aidx]
aidx += 1
vidx += length
return dV, dA

return output, backprop_rescale


def _ragged_window_dot(ops, X, Y, lengths, nW):
def _ragged_window_dot(ops, X, Y, lengths, nL, nR):
'''Multiply X against context windows of Y, where X and Y are both ragged
matrices, representing concatenated sequences. We output a ragged array
where each entry is a vector with the dot product of X[i] against the
Expand Down Expand Up @@ -149,7 +163,7 @@ def _ragged_window_dot(ops, X, Y, lengths, nW):
X_ = X[start : start+length]
Y_ = Y[start : start+length]
for j in range(length):
dots, backprop = _window_dot(X_, Y_, j, nW)
dots, backprop = _window_dot(X_, Y_, j, nL, nR)
output.append(dots)
backprops.append(backprop)
start += length
Expand All @@ -169,14 +183,14 @@ def backprop_rwd(d_output):
return ops.flatten(output), out_lengths


def _window_dot(X, Y, i, nW):
start = max(0, i-nW)
end = i + nW
output = einsum('d,nwd->nw', X[i], Y[start : end])
def _window_dot(X, Y, i, nL, nR):
start = max(0, i-nL)
end = i + nR
output = einsum('d,wd->w', X[i], Y[start : end])

def backprop_window_dot(d_output):
dXi = einsum('nw,nwd->d', d_output, Y[start : end])
d_winY = einsum('nw,d->nwd', d_output, X[i])
dXi = einsum('w,wd->d', d_output, Y[start : end])
d_winY = einsum('w,d->wd', d_output, X[i])
return dXi, d_winY

return output, backprop_window_dot
Expand Down

0 comments on commit f40f963

Please sign in to comment.