In [1]:
import torch
from torch import nn,optim
import matplotlib.pyplot as plt
%matplotlib auto


Using matplotlib backend: Qt5Agg


In [47]:
def sequence_mask(X,valid_len,value=0):
    maxlen=X.size(1)
    mask=torch.arange((maxlen),dtype=torch.float32,device=X.device)[None,:]<valid_len[:,None]
    X[~mask]=value
    return X

In [48]:
def masked_softmax(X,valid_lens):
    if valid_lens is None:
        return nn.functional.softmax(X,dim=-1)
    else:
        shape=X.shape
        if valid_lens.dim()==1:
            valid_lens=torch.repeat_interleave(valid_lens,shape[1])
        else:
            valid_lens=valid_lens.reshape(-1)
        X=sequence_mask(X.reshape(-1,shape[-1]),valid_lens,value=-1e6)
        return nn.functional.softmax(X.reshape(shape),dim=-1)

In [49]:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))

tensor([[[0.4692, 0.5308, 0.0000, 0.0000],
         [0.5610, 0.4390, 0.0000, 0.0000]],

        [[0.3405, 0.3220, 0.3375, 0.0000],
         [0.3044, 0.4306, 0.2650, 0.0000]]])

In [2]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
    
    def init_state(self,enc_all_outputs,*args):
        raise NotImplementedError
    
    def forward(self,X,state):
        raise NotImplementedError

In [3]:
class AttentionDecoder(Decoder):
    def __init__(self,**kwargs):
        super().__init__(**kwargs)
    
    @property
    def attention_weights(self):
        raise NotImplementedError

In [50]:
class AdditiveAttention(nn.Module):
    def __init__(self,key_size,query_size,num_hiddens,dropout,**kwargs):
        super().__init__(**kwargs)
        self.W_k=nn.Linear(key_size,num_hiddens,bias=False)
        self.W_q=nn.Linear(query_size,num_hiddens,bias=False)
        self.W_v=nn.Linear(num_hiddens,1,bias=False)
        self.dropout=nn.Dropout(dropout)
        
    def forward(self,queries,keys,values,valid_lens):
        queries,keys=self.W_q(queries),self.W_k(keys)
        features=queries.unsqueeze(2)+keys.unsqueeze(1)
        features=torch.tanh(features)
        scores=self.W_v(features).squeeze(-1)
        self.attention_weights=masked_softmax(scores,valid_lens)
        return torch.bmm(self.dropout(self.attention_weights),values)
        

In [72]:
class Seq2SeqAttentionDecoder(AttentionDecoder):
    def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):
        super().__init__(**kwargs)
        self.attention=AdditiveAttention(num_hiddens,num_hiddens,num_hiddens,dropout)
        self.embedding=nn.Embedding(vocab_size,embed_size)
        self.rnn=nn.GRU(embed_size+num_hiddens,num_hiddens,num_layers,dropout=dropout)
        self.dense=nn.Linear(num_hiddens,vocab_size)
        
    def init_state(self,enc_outputs,enc_valid_lens,*args):
        outputs,hidden_state=enc_outputs
        return (outputs.permute(1,0,2),hidden_state,enc_valid_lens)
    
    def forward(self,X,state):
        enc_outputs,hidden_state,enc_valid_lens=state
        X=self.embedding(X).permute(1,0,2)
        outputs,self._attention_weights=[],[]
        for x in X:
            query=torch.unsqueeze(hidden_state[-1],dim=1)
            context=self.attention(query,enc_outputs,enc_outputs,enc_valid_lens)
            x=torch.cat((context,torch.unsqueeze(x,dim=1)),dim=-1)
            out,hidden_state=self.rnn(x.permute(1,0,2),hidden_state)
            outputs.append(out)
            self._attention_weights.append(self.attention.attention_weights)
        outputs=self.dense(torch.cat(outputs,dim=0))
        return outputs.permute(1,0,2),[enc_outputs,hidden_state,enc_valid_lens]
    
    @property
    def attention_weights(self):
        return self._attention_weights

In [73]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, X, *args):
        raise NotImplementedError

In [74]:
class Seq2SeqEncoder(Encoder):
    def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):
        super().__init__(**kwargs)
        self.embedding=nn.Embedding(vocab_size,embed_size)
        self.rnn=nn.GRU(embed_size,num_hiddens,num_layers,dropout=dropout)
#         8,16,2,0

    def forward(self,X ,*args):
        X=self.embedding(X) #4,7,8
        X=X.permute(1,0,2) #7,4,8
        output,state=self.rnn(X) 
        #output:step,batch,num_hiddens=7,4,16
        #state :layers,batch,num_hiddnes=2,4,16
        return output,state

In [75]:
encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,num_layers=2)
encoder.eval()

Seq2SeqEncoder(
  (embedding): Embedding(10, 8)
  (rnn): GRU(8, 16, num_layers=2)
)

In [76]:
decoder=Seq2SeqAttentionDecoder(vocab_size=10,embed_size=8,num_hiddens=16,num_layers=2)
decoder.eval()

Seq2SeqAttentionDecoder(
  (attention): AdditiveAttention(
    (W_k): Linear(in_features=16, out_features=16, bias=False)
    (W_q): Linear(in_features=16, out_features=16, bias=False)
    (W_v): Linear(in_features=16, out_features=1, bias=False)
    (dropout): Dropout(p=0, inplace=False)
  )
  (embedding): Embedding(10, 8)
  (rnn): GRU(24, 16, num_layers=2)
  (dense): Linear(in_features=16, out_features=10, bias=True)
)

In [77]:
X = torch.zeros((4,7),dtype=torch.long)
state=decoder.init_state(encoder(X),None)
output,state=decoder(X,state)

In [78]:
output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape

(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))