In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import spacy
import numpy as np

import random
import math
import time

In [3]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [4]:
spacy_de = spacy.load("de_core_news_sm")
spacy_en = spacy.load("en_core_web_sm")

In [5]:
def tokenize_de(text):
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]

In [6]:
SRC = Field(tokenize=tokenize_de, init_token='<sos>', eos_token = "<eos>", lower=True, include_lengths=True)
TRG = Field(tokenize=tokenize_en, init_token='<sos>', eos_token = "<eos>", lower=True, include_lengths=True)

In [7]:
train_data, val_data, test_data = Multi30k.splits(exts=('.de', '.en'), fields=(SRC, TRG))

In [8]:
SRC.build_vocab(train_data, min_freq=2)
TRG.build_vocab(train_data, min_freq=2)

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
train_iter, val_iter, test_iter = BucketIterator.splits((train_data, val_data, test_data), batch_size = 128,
                                                       sort_within_batch=True,
                                                       sort_key=lambda x :len(x.src),
                                                       device = device)

### build encoder

In [10]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()
        self.emnbedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional=True)
        self.fc = nn.Linear(enc_hid_dim*2, dec_hid_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src, src_len):
        embedded = self.dropout(self.embedding(src))
        # embedded = [seq_len, batch, embed_dim]
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, src_len) # 已经是逆序排好序的,且长度给定的
        packed_outputs, hidden = self.rnn(packed_embedded)
        
        outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs)
        # outputs = [seq_len, batch, enc_hid_dim*2]
        hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)))
        # hidden = [batch, dec_hid_dim]
        return outputs, hidden

In [11]:
class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        self.atten = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias=False)
    
    def forward(self, hidden, encoder_outuputs, mask):
        """
        加入了mask这个常用的功能
        """
        batch_size = encoder_outputs.shape[1]
        src_len = encoder_outputs.shape[0]
        
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        
        encoder_outputs = encoder_outputs.permute(1, 0, 2) # seq_len, batch, enc_dim*2 -> batch, seq_len, enc_dim*2
        energy = torch.tanh(self.atten(torch.cat((hidden, encoder_outputs), dim=2)))
        # energy = [batch, seq_len, dec_hid_dim]
        attention = self.v(energy).squeeze(2)
        # attention = [batch, seq_len, 1] - > [batch, seq_len]
        
        attention = attention.masked_fill(mask==0, -1e10)
        
        return F.softmax(attention, dim=1)

### decode

In [12]:
class Encoder(nn.Module):
    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
        super().__init__()
        self.output_dim = output_dim
        self.attention = attention
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.GRU(enc_hid_dim * 2 + dec_hid_dim, dec_hid_dim)
        self.fc = nn.Linear(enc_hid_dim*2 + dec_hid_dim + emb_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, input, hidden, encoder_outputs, mask):
        # encoder_outputs = [src, batch, enc_hid_dim*2]
        input = input.unsqueeze(0)
        # input = [1, batch_size]
        embedded = self.dropout(self.embedded(input))
        a = self.attention(hidden, encoder_outputs, mask)
        a = a.unsqueeze(1)
        # a = [batch, 1, src_len]
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        weighted = torch.bmm(a, encoder_outputs)
        # batch, 1, enc_hid_dim*2
        weighted = weighted.permute(1, 0 ,2)
        rnn_input = torch.cat((embedded, weighted), dim=2)
        
        output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
        
        embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)
        
        prediction = self.fc(torch.cat((output, embedded, weighted), dim=1))
        return prediction, hidden.squeeze(0), a.squeeze(1)      
        
        

In [None]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, src_pad_idx, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.device = device
        
    def create _mask(self, src):
        mask = (src != self.src_pad_idx).permute(1, 0)
        return mask
    
    def forward(self, src, src_len, trg, teacher_forcing_ratio=0.5):
        batchs_size = src.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim
        
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        encoder_output, hidden = self.encoder(src, src_len)
        
        