In [1]:
import torch
from torch import nn,optim
import matplotlib.pyplot as plt
%matplotlib auto


Using matplotlib backend: Qt5Agg


In [2]:
def sequence_mask(X,valid_len,value=0):
    maxlen=X.size(1)
    mask=torch.arange((maxlen),dtype=torch.float32,device=X.device)[None,:]<valid_len[:,None]
    X[~mask]=value
    return X

In [3]:
def masked_softmax(X,valid_lens):
    if valid_lens is None:
        return nn.functional.softmax(X,dim=-1)
    else:
        shape=X.shape
        if valid_lens.dim()==1:
            valid_lens=torch.repeat_interleave(valid_lens,shape[1])
        else:
            valid_lens=valid_lens.reshape(-1)
        X=sequence_mask(X.reshape(-1,shape[-1]),valid_lens,value=-1e6)
        return nn.functional.softmax(X.reshape(shape),dim=-1)

In [4]:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))

tensor([[[0.4361, 0.5639, 0.0000, 0.0000],
         [0.6351, 0.3649, 0.0000, 0.0000]],

        [[0.1783, 0.3847, 0.4370, 0.0000],
         [0.2127, 0.4698, 0.3175, 0.0000]]])

In [5]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
    
    def init_state(self,enc_all_outputs,*args):
        raise NotImplementedError
    
    def forward(self,X,state):
        raise NotImplementedError

In [6]:
class AttentionDecoder(Decoder):
    def __init__(self,**kwargs):
        super().__init__(**kwargs)
    
    @property
    def attention_weights(self):
        raise NotImplementedError

In [7]:
class AdditiveAttention(nn.Module):
    def __init__(self,key_size,query_size,num_hiddens,dropout,**kwargs):
        super().__init__(**kwargs)
        self.W_k=nn.Linear(key_size,num_hiddens,bias=False)
        self.W_q=nn.Linear(query_size,num_hiddens,bias=False)
        self.W_v=nn.Linear(num_hiddens,1,bias=False)
        self.dropout=nn.Dropout(dropout)
        
    def forward(self,queries,keys,values,valid_lens):
        queries,keys=self.W_q(queries),self.W_k(keys)
        features=queries.unsqueeze(2)+keys.unsqueeze(1)
        features=torch.tanh(features)
        scores=self.W_v(features).squeeze(-1)
        self.attention_weights=masked_softmax(scores,valid_lens)
        return torch.bmm(self.dropout(self.attention_weights),values)
        

In [8]:
class Seq2SeqAttentionDecoder(AttentionDecoder):
    def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):
        super().__init__(**kwargs)
        self.attention=AdditiveAttention(num_hiddens,num_hiddens,num_hiddens,dropout)
        self.embedding=nn.Embedding(vocab_size,embed_size)
        self.rnn=nn.GRU(embed_size+num_hiddens,num_hiddens,num_layers,dropout=dropout)
        self.dense=nn.Linear(num_hiddens,vocab_size)
        
    def init_state(self,enc_outputs,enc_valid_lens,*args):
        outputs,hidden_state=enc_outputs
        return (outputs.permute(1,0,2),hidden_state,enc_valid_lens)
    
    def forward(self,X,state):
        enc_outputs,hidden_state,enc_valid_lens=state
        X=self.embedding(X).permute(1,0,2)
        outputs,self._attention_weights=[],[]
        for x in X:
            query=torch.unsqueeze(hidden_state[-1],dim=1)
            context=self.attention(query,enc_outputs,enc_outputs,enc_valid_lens)
            x=torch.cat((context,torch.unsqueeze(x,dim=1)),dim=-1)
            out,hidden_state=self.rnn(x.permute(1,0,2),hidden_state)
            outputs.append(out)
            self._attention_weights.append(self.attention.attention_weights)
        outputs=self.dense(torch.cat(outputs,dim=0))
        return outputs.permute(1,0,2),[enc_outputs,hidden_state,enc_valid_lens]
    
    @property
    def attention_weights(self):
        return self._attention_weights

In [9]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, X, *args):
        raise NotImplementedError

In [10]:
class Seq2SeqEncoder(Encoder):
    def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):
        super().__init__(**kwargs)
        self.embedding=nn.Embedding(vocab_size,embed_size)
        self.rnn=nn.GRU(embed_size,num_hiddens,num_layers,dropout=dropout)
#         8,16,2,0

    def forward(self,X ,*args):
        X=self.embedding(X) #4,7,8
        X=X.permute(1,0,2) #7,4,8
        output,state=self.rnn(X) 
        #output:step,batch,num_hiddens=7,4,16
        #state :layers,batch,num_hiddnes=2,4,16
        return output,state

In [11]:
encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,num_layers=2)
encoder.eval()

Seq2SeqEncoder(
  (embedding): Embedding(10, 8)
  (rnn): GRU(8, 16, num_layers=2)
)

In [12]:
decoder=Seq2SeqAttentionDecoder(vocab_size=10,embed_size=8,num_hiddens=16,num_layers=2)
decoder.eval()

Seq2SeqAttentionDecoder(
  (attention): AdditiveAttention(
    (W_k): Linear(in_features=16, out_features=16, bias=False)
    (W_q): Linear(in_features=16, out_features=16, bias=False)
    (W_v): Linear(in_features=16, out_features=1, bias=False)
    (dropout): Dropout(p=0, inplace=False)
  )
  (embedding): Embedding(10, 8)
  (rnn): GRU(24, 16, num_layers=2)
  (dense): Linear(in_features=16, out_features=10, bias=True)
)

In [13]:
class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):
    def forward(self,pred,label,valid_len):
        weights=torch.ones_like(label)
        weights=sequence_mask(weights,valid_len)
        self.reduction='none'
        unweighted_loss=super().forward(pred.permute(0,2,1),label)
        weighted_loss=(unweighted_loss*weights).mean(dim=1)
        return weighted_loss
        

In [14]:
loss=MaskedSoftmaxCELoss()
loss(torch.ones(3,4,10),torch.ones((3,4),dtype=torch.long),torch.tensor([4,2,0]))

tensor([2.3026, 1.1513, 0.0000])

In [15]:
def grad_clipping(net,theta):
    if isinstance(net,nn.Module):
        params=[p for p in net.parameters() if p.requires_grad]
    else:
        params=net.params
    norm=torch.sqrt(sum(torch.sum((p.grad**2)) for p in params))
    if norm > theta:
        for param in params:
            param.grad[:] *= theta/norm

In [16]:
def train_seq2seq(net,data_iter,lr,num_epochs,tgt_vocab,device):
    def xavier_init_weights(m):
        if type(m)==nn.Linear:
            nn.init.xavier_uniform_(m.weight)
        if type(m)==nn.GRU:
            for param in m._flat_weights_names:
                if "weight" in param:
                    nn.init.xavier_uniform_(m._parameters[param])
        
    net.apply(xavier_init_weights)
    net.to(device)
    optimizer=torch.optim.Adam(net.parameters(),lr=lr)
    loss=MaskedSoftmaxCELoss()
    net.train()
    train_l,train_num=0,0
    for epoch in range(num_epochs):
        for batch in data_iter:
            optimizer.zero_grad()
            X,X_valid_len,Y,Y_valid_len=[x.to(device) for x in batch]
            bos=torch.tensor([tgt_vocab['<bos>']]*Y.shape[0],device=device).reshape(-1,1)
            dec_input=torch.cat([bos,Y[:,:-1]],1)
            Y_hat,_=net(X,dec_input,X_valid_len)
            l=loss(Y_hat,Y,Y_valid_len)
            l.sum().backward()
            grad_clipping(net,1)
            num_tokens=Y_valid_len.sum()
            optimizer.step()
            with torch.no_grad():
                train_l+=l.sum()
                train_num+=num_tokens
        if(epoch+1)%10==0:
            print("epoch : ",epoch+1," train loss : ",(train_l.item()/train_num.item()))   

In [17]:
X = torch.zeros((4,7),dtype=torch.long)
state=decoder.init_state(encoder(X),None)
output,state=decoder(X,state)

In [18]:
output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape

(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))

In [19]:
import collections

In [20]:
class Vocab:
    def __init__(self,tokens=None,min_freq=0,reserved_tokens=None):
        if tokens is None:
            tokens=[]
        if reserved_tokens is None:
            reserved_tokens=[]
        counter=count_corpus(tokens)
        self._token_freqs=sorted(counter.items(),key=lambda x:x[1],reverse=True)
        self.idx_to_token=['<unk>']+reserved_tokens
        self.token_to_idx={token:idx for idx,token in enumerate(self.idx_to_token)}
        self.idx_to_token,self.token_to_idx=[],dict()
        for token,freq in self._token_freqs:
            if freq<min_freq:
                break;
            if token not in self.token_to_idx:
                self.idx_to_token.append(token)
                self.token_to_idx[token]=len(self.idx_to_token)-1
    
    def __len__(self):
        return len(self.idx_to_token)
    
    def __getitem__(self,tokens):
        if not isinstance(tokens,(list,tuple)):
            return self.token_to_idx.get(tokens,self.unk)
        return [self.__getitem__(token) for token in tokens]
    
    def to_tokens(self,indices):
        if not isinstance(indices,(list,tuple)):
            return self.idx_to_token[indices]
        return [self.idx_to_token[index] for index in indices]
    
    @property
    def unk(self):
        return 0;
    
    @property
    def token_freqs(self):
        return self._token_freqs;
    

In [21]:
def count_corpus(tokens):
    if len(tokens)==0 or isinstance(tokens[0],list):
        tokens=[token for line in tokens for token in line]
    return collections.Counter(tokens)

In [22]:
def preprocess_nmt(text):
    def no_space(char,prev_char):
        return char in set(',.!?') and prev_char !=' '
    
    text=text.replace('\u202f', ' ').replace('\xa0',' ').lower()
    out=[' '+char if i>0 and no_space(char,text[i-1]) else char for i,char in enumerate(text) ]
    return ''.join(out)

In [23]:
def tokenize_nmt(text,num_examples=None):
    source,target=[],[]
    for i ,line in enumerate(text.split('\n')):
        if num_examples and i > num_examples:
            break
        parts=line.split('\t')
        if len(parts)==2:
            source.append(parts[0].split(' '))
            target.append(parts[1].split(' '))
    return source,target 

In [24]:
def build_array_nmt(lines, vocab, num_steps):
#     """将机器翻译的⽂本序列转换成⼩批量"""
    lines = [vocab[l] for l in lines]
#     print("lines : ",lines)
    lines = [l + [vocab['<eos>']] for l in lines]
    array = torch.tensor([truncate_pad(l, num_steps, vocab['<pad>']) for l in lines])
    valid_len = (array != vocab['<pad>']).type(torch.int32).sum(1)
    return array, valid_len

In [25]:
def load_data_nmt(batch_size, num_steps, num_examples=600):
#     """返回翻译数据集的迭代器和词表"""
    text = preprocess_nmt(read_data_nmt())
    source, target = tokenize_nmt(text, num_examples)
#     print("source : ",source)
#     print("target : ",target)
    src_vocab = Vocab(source, min_freq=2,reserved_tokens=['<pad>', '<bos>', '<eos>'])
    tgt_vocab = Vocab(target, min_freq=2,reserved_tokens=['<pad>', '<bos>', '<eos>'])
    src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)
    tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)
    data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)
    data_iter = load_array(data_arrays, batch_size)
    return data_iter, src_vocab, tgt_vocab

In [26]:
import os

In [27]:
def read_data_nmt():
    data_dir=r'F:\study\ml\DataSet\fra-eng'
    with open(os.path.join(data_dir,'fra.txt'),'r',encoding='utf-8') as f:
        return f.read()

In [28]:
def truncate_pad(line,num_steps,padding_token):
    if len(line)>num_steps:
        return line[:num_steps]
    return line + [padding_token]*(num_steps-len(line))

In [29]:
def load_array(data_arrays, batch_size, is_train=True):
    """Construct a PyTorch data iterator.

    Defined in :numref:`sec_utils`"""
    dataset = torch.utils.data.TensorDataset(*data_arrays)
    return torch.utils.data.DataLoader(dataset, batch_size, shuffle=is_train)

In [30]:
class EncoderDecoder(nn.Module):
    """编码器-解码器架构的基类"""
    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
    def forward(self, enc_X, dec_X, *args):
        enc_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args)
        return self.decoder(dec_X, dec_state)

#### train

In [33]:
embed_size,num_hiddens,num_layers,dropout=32,32,2,0.1
batch_size,num_steps=64,10
lr,num_epochs,device=0.005,250,'cpu'

In [34]:
train_iter,src_vocab,tgt_vocab=load_data_nmt(batch_size,num_steps)

In [35]:
encoder=Seq2SeqEncoder(len(src_vocab),embed_size,num_hiddens,num_layers,dropout)
decoder=Seq2SeqAttentionDecoder(len(tgt_vocab),embed_size,num_hiddens,num_layers,dropout)

In [36]:
net=EncoderDecoder(encoder,decoder)

In [37]:
train_seq2seq(net,train_iter,lr,num_epochs,tgt_vocab,device)

epoch :  10  train loss :  0.36707769402875584
epoch :  20  train loss :  0.3036410383961757
epoch :  30  train loss :  0.25744711687369587
epoch :  40  train loss :  0.22361015594434663
epoch :  50  train loss :  0.19862289404831768
epoch :  60  train loss :  0.17908439496935316
epoch :  70  train loss :  0.16351200737410573
epoch :  80  train loss :  0.15064960891651996
epoch :  90  train loss :  0.13996249633215962
epoch :  100  train loss :  0.13089516700899845
epoch :  110  train loss :  0.12312196847435623
epoch :  120  train loss :  0.11641426066526474
epoch :  130  train loss :  0.11057046908856988
epoch :  140  train loss :  0.10540969340347642
epoch :  150  train loss :  0.10082947191249347
epoch :  160  train loss :  0.09678022835362872
epoch :  170  train loss :  0.0931576797385621
epoch :  180  train loss :  0.08987865391616676
epoch :  190  train loss :  0.08693181452567746
epoch :  200  train loss :  0.0842547712490219
epoch :  210  train loss :  0.08180369834749236
epoc

In [None]:
def predict_seq2seq(net,src_sentence,src_vocab,tgt_vocab,num_steps,device,save_attention_weights=True):
    