In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator
import spacy
import numpy as np

import random
import math
import time


In [3]:
seed = 123
def randomSeed(SEED):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
randomSeed(seed)

In [4]:
spacy_de = spacy.load("de_core_news_sm")
spacy_en = spacy.load("en_core_web_sm")

In [5]:
def tokenize_de(text):
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]

In [9]:
SRC = Field(tokenize=tokenize_de, init_token="<sos>", eos_token = "<eos>", lower=True)

TRG = Field(tokenize=tokenize_en, init_token='<sos>', eos_token="<eos>", lower=True)

In [10]:
train_data, val_data, test_data = Multi30k.splits(exts=('.de', '.en'), fields=(SRC, TRG))

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [12]:
train_iter, val_iter, test_iter = BucketIterator.splits((train_data, val_data, test_data), batch_size=128, device=device)

### Building the Model

In [18]:
class Encoder(nn.Module):
    """
    双向GRU,最后返回hidden要拼接最后一层的前向和后向,然后再进行一个全连接+tanh变换函数即可
    """
    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional=True)
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
        self.dropout = nn.Dropout(droout)
    
    def forward(self, src):
        # src = [src_len, batch_size]
        embedded = self.dropout(self.embedding(src))
        outputs, hidden = self.rnn(embedded)
        hidden = torch.tanh(self.fc(torch.cat(hidden[-2,:,:], hidden[-1,:,:], dim=1)))
        return outputs, hidden

### attention 层的构造

In [None]:
class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()