In [1]:
!pip install -U torchdata
!pip install -U spacy
!python -m spacy download en_core_web_sm
!python -m spacy download de_core_news_sm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting en-core-web-sm==3.4.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.4.0/en_core_web_sm-3.4.0-py3-none-any.whl (12.8 MB)
[K     |████████████████████████████████| 12.8 MB 5.0 MB/s 
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting de-core-news-sm==3.4.0
  Downloading https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.4.0/de_core_news_sm-3.4.0-py3-none-any.whl (14.6 MB)
[K     |████████████████████████████████| 14.6 MB 5.3 MB/s 
[38;5;2m✔ Do

In [2]:
%cd /content/drive/MyDrive/DL論文/Attention Is All You Need

/content/drive/MyDrive/DL論文/Attention Is All You Need


In [3]:
import warnings
warnings.filterwarnings('ignore')

## Data Sourcing and Processing

In [4]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k, Multi30k
from typing import Iterable, List

In [5]:
multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

In [6]:
SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'

In [7]:
# Place-holders
token_transform = {}
vocab_transform = {}

In [8]:
token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')
token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')

In [9]:
def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
    language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}

    for data_sample in data_iter:
        yield token_transform[language](data_sample[language_index[language]])

In [10]:
# Special token
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

In [11]:
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    # Training data Iterator 
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    # Create torchtext's Vocab object 
    vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter, ln),
                                                    min_freq=1,
                                                    specials=special_symbols,
                                                    special_first=True)

In [12]:
# Set UNK_IDX as the default index. This index is returned when the token is not found. 
# If not set, it throws RuntimeError when the queried token is not found in the Vocabulary. 
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
  vocab_transform[ln].set_default_index(UNK_IDX)

## Create Dataset

In [13]:
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

In [14]:
BATCH_SIZE = 128
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [15]:
def sequential_transformers(*transformers):
    def func(txt_input):
        for transform in transformers:
            txt_input = transform(txt_input)
        return txt_input
        
    return func

In [16]:
def tensor_transform(token_ids):
    return torch.cat((
        torch.tensor([BOS_IDX]),
        torch.tensor(token_ids),
        torch.tensor([EOS_IDX]),
    ))

In [17]:
text_transform = {}
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    text_transform[ln] = sequential_transformers(
        token_transform[ln],
        vocab_transform[ln],
        tensor_transform
    )

In [18]:
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))
        tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)

    return src_batch, tgt_batch

In [19]:
train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)
val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

## Set Model

In [20]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))

    return mask

In [21]:
def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len), device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)

    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [22]:
from model.seq2seq import Seq2SeqTransformer

In [23]:
SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
NUM_LAYERS = 3

In [24]:
transformer = Seq2SeqTransformer(
    num_layers=NUM_LAYERS,
    emb_size=EMB_SIZE,
    nhead=NHEAD,
    src_vocab_size=SRC_VOCAB_SIZE,
    tgt_vocab_size=TGT_VOCAB_SIZE,
    dim_ff=FFN_HID_DIM
)

transformer.to(DEVICE)
print('Setup finish')

Setup finish


In [25]:
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

In [26]:
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

## Training

In [27]:
def train_epoch(model, optimizer, train_dataloader):
    model.train()
    losses = 0
    batch_num = 0

    for src, tgt in train_dataloader:
        batch_num += 1
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)

        optimizer.zero_grad()

        tgt_out = tgt[1:, :]

        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / batch_num

In [28]:
def evaluate(model, val_dataloader):
    model.eval()
    losses = 0
    batch_num = 0

    for src, tgt in val_dataloader:
        batch_num += 1
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)

        tgt_out = tgt[1:, :]
        
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()

    return losses / batch_num

In [29]:
import time
NUM_EPOCH = 10

In [30]:
for epoch in range(1, NUM_EPOCH+1):
    start_time = time.time()
    train_loss = train_epoch(transformer, optimizer, train_dataloader)
    end_time = time.time()
    val_loss = evaluate(transformer, val_dataloader)

    print(f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, Epoch time: {(end_time -start_time):.3f}")

Epoch: 1, Train loss: 5.517, Val loss: 4.532, Epoch time: 41.133
Epoch: 2, Train loss: 4.295, Val loss: 4.038, Epoch time: 40.248
Epoch: 3, Train loss: 3.925, Val loss: 3.810, Epoch time: 39.749
Epoch: 4, Train loss: 3.701, Val loss: 3.655, Epoch time: 40.033
Epoch: 5, Train loss: 3.534, Val loss: 3.538, Epoch time: 39.917
Epoch: 6, Train loss: 3.399, Val loss: 3.454, Epoch time: 39.955
Epoch: 7, Train loss: 3.285, Val loss: 3.372, Epoch time: 39.927
Epoch: 8, Train loss: 3.182, Val loss: 3.302, Epoch time: 39.888
Epoch: 9, Train loss: 3.092, Val loss: 3.244, Epoch time: 39.939
Epoch: 10, Train loss: 3.011, Val loss: 3.201, Epoch time: 39.911
