In [1]:
import torch
from torch import nn
from matplotlib import pyplot as plt
import numpy as np
from tqdm.notebook import tqdm

In [2]:
from pathlib import Path
from urllib.request import urlopen
import linecache
from itertools import count

class En2DeDataset(torch.utils.data.Dataset):
    def __init__(self, folder_path, transform=None, download=False, train=True):
        self.path = Path(folder_path)
        self.train = train
        self.transform = transform
        self.train_en_url = 'https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/train.en'
        self.train_de_url = 'https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/train.de'

        self.test_en_url = 'https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2015.en'
        self.test_de_url = 'https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2015.de'
        
        self.length = 4_468_841 if train else 2_170
        if download:
            self.__download()

    def __download(self):
        self.path.mkdir(parents=True, exist_ok=True)
        if self.train:
            files = (('train.en', self.train_en_url), ('train.de', self.train_de_url))
        else:
            files = (('test.en', self.test_en_url), ('test.de', self.test_de_url))

        for file, url in files:
            with urlopen(url) as webfile:
                localpath = self.path / file
                if localpath.exists():
                    localpath.unlink()
                with localpath.open("wb+") as localfile:
                    for i in tqdm(range(self.length)):
                        line = webfile.readline()
                        localfile.write(line)
                    assert(not line)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        files = ('train.en', 'train.de') if self.train else ('test.en', 'test.de')
        line_path = self.path / files[0]
        label_path = self.path / files[1]
        if not line_path.exists() or not label_path.exists():
            raise FileNotFoundError('Set download to True to download the dataset')
        
        line = linecache.getline(str(line_path.absolute()), idx)
        label = linecache.getline(str(label_path.absolute()), idx)

        if self.transform:
            line = self.transform(line)
            label = self.transform(label)
        return line, label


train_dataset = En2DeDataset('./downloads', train=True)
test_dataset = En2DeDataset('./downloads', train=False)

train_dtld = torch.utils.data.DataLoader(train_dataset, shuffle=True, batch_size=16)
test_dtld = torch.utils.data.DataLoader(test_dataset, shuffle=True)

In [3]:
from transformers.model import MultiHeadAttention, Transformer
model = MultiHeadAttention()

In [4]:
BS = 16

q = torch.ones(BS, 512)
k = torch.ones(BS, 512)
v = torch.ones(BS, 512)

model(q, k, v).shape

torch.Size([16, 8, 64]) torch.Size([16, 64, 8])


torch.Size([16, 512])

In [5]:
Transformer().encoder

Sequential(
  (0): ResidualConnectionLayer(
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (sublayer): MultiHeadAttention(
      (attention): ScaledDotProductAttention()
      (lq): Linear(in_features=512, out_features=512, bias=True)
      (lk): Linear(in_features=512, out_features=512, bias=True)
      (lv): Linear(in_features=512, out_features=512, bias=True)
      (l): Linear(in_features=512, out_features=512, bias=True)
    )
  )
  (1): ResidualConnectionLayer(
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (sublayer): FeedForward(
      (l1): Linear(in_features=512, out_features=2048, bias=True)
      (l2): Linear(in_features=2048, out_features=512, bias=True)
    )
  )
  (2): ResidualConnectionLayer(
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (sublayer): MultiHeadAttention(
      (attention): ScaledDotProductAttention()
      (lq): Linear(in_features=512, out_features=512, bias=True)
      (lk): Linear(in_f