In [1]:
from pathlib import Path

import matplotlib.pyplot as plt
import torch
import torchinfo
import torchtext.transforms as T
from torch import nn
from torch.utils.data import DataLoader
from torchlake.common.schemas import NlpContext
from torchlake.common.utils.text import build_vocab
from torchlake.sequence_data.models import (Seq2Seq, Seq2SeqDecoder, Seq2SeqAttentionEncoder,
                                            Seq2SeqEncoder)
from torchlake.sequence_data.models.seq2seq.network import GlobalAttention, LocalAttention
from torchtext.data.utils import get_tokenizer
from torchtext.datasets import Multi30k
from tqdm import tqdm

# setting

In [2]:
data_path = Path("../../data/multi30k")
artifacts_path = Path("../../artifacts/seq2seq")

In [3]:
BATCH_SIZE = 32
context = NlpContext()

In [4]:
device = torch.device(context.device)

# data

In [5]:
SRC_LANGUAGE = 'de'
TRG_LANGUAGE = 'en'

In [6]:
tokenizers = {
    SRC_LANGUAGE: get_tokenizer('spacy', language=SRC_LANGUAGE), # de_core_news_sm
    TRG_LANGUAGE: get_tokenizer('spacy', language=TRG_LANGUAGE)  # en_web_core_sm
}



In [7]:
train_iter, val_iter, test_iter = Multi30k(
    data_path.as_posix(),
    language_pair=(SRC_LANGUAGE, TRG_LANGUAGE),
)

In [8]:
vocabs = {
    SRC_LANGUAGE: build_vocab(map(lambda x: tokenizers[SRC_LANGUAGE](x[0]), train_iter), context),
    TRG_LANGUAGE: build_vocab(map(lambda x: tokenizers[SRC_LANGUAGE](x[1]), train_iter), context),
}



In [9]:
src_transform = T.Sequential(
    T.VocabTransform(vocabs[SRC_LANGUAGE]),
    T.Truncate(context.max_seq_len - 2),
    T.AddToken(token=context.bos_idx, begin=True),
    T.AddToken(token=context.eos_idx, begin=False),
    T.ToTensor(),
    T.PadTransform(context.max_seq_len, context.padding_idx),
)

trg_transform = T.Sequential(
    T.VocabTransform(vocabs[TRG_LANGUAGE]),
    T.Truncate(context.max_seq_len - 2),
    T.AddToken(token=context.bos_idx, begin=True),
    T.AddToken(token=context.eos_idx, begin=False),
    T.ToTensor(),
    T.PadTransform(context.max_seq_len, context.padding_idx),
)

In [10]:
def collate_fn(batch):
    src_batch, tgt_batch = [], []

    for src_sample, trg_sample in batch:
        # paper p.2: reverse source sentence
        src_sample = src_sample[::-1]
        
        src_tokenizer  = tokenizers[SRC_LANGUAGE]
        src_batch.append(src_transform(src_tokenizer(src_sample.rstrip("\n"))))
        
        trg_tokenizer  = tokenizers[TRG_LANGUAGE]
        tgt_batch.append(trg_transform(trg_tokenizer(trg_sample.rstrip("\n"))))

    return torch.stack(src_batch), torch.stack(tgt_batch)

In [11]:
train_loader = DataLoader(
    train_iter,
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn,
)

In [12]:
for src, trg in train_loader:
    print(src.shape, trg.shape)
    break

torch.Size([32, 256]) torch.Size([32, 256])


In [13]:
vocab_sizes = {
    SRC_LANGUAGE: len(vocabs[SRC_LANGUAGE]),
    TRG_LANGUAGE: len(vocabs[TRG_LANGUAGE]),
}

# Model

In [14]:
NUM_LAYERS = 2
BIDIRECTIONAL = True

encoder = Seq2SeqEncoder(
    vocab_sizes[SRC_LANGUAGE],
    128,
    128,
    num_layers=NUM_LAYERS,
    bidirectional=BIDIRECTIONAL,
    context=context,
)

# encoder = Seq2SeqAttentionEncoder(
#     vocab_sizes[SRC_LANGUAGE],
#     128,
#     128,
#     num_layers=NUM_LAYERS,
#     bidirectional=BIDIRECTIONAL,
#     context=context,
# )

decoder = Seq2SeqDecoder(
    vocab_sizes[TRG_LANGUAGE],
    128,
    128,
    output_size=vocab_sizes[TRG_LANGUAGE],
    num_layers=NUM_LAYERS,
    bidirectional=BIDIRECTIONAL,
    context=context,
)

attention = None 
# attention = GlobalAttention(128, num_layers=NUM_LAYERS, bidirectional=BIDIRECTIONAL)
# attention = LocalAttention(128, num_layers=NUM_LAYERS, bidirectional=BIDIRECTIONAL)

model = Seq2Seq(encoder, decoder, attention, context=context).to(device)

In [15]:
torchinfo.summary(model)

Layer (type:depth-idx)                   Param #
Seq2Seq                                  --
├─Seq2SeqEncoder: 1-1                    --
│    └─LstmClassifier: 2-1               --
│    │    └─Embedding: 3-1               464,640
│    │    └─LSTM: 3-2                    659,456
│    │    └─LayerNorm: 3-3               512
├─Seq2SeqDecoder: 1-2                    --
│    └─LstmClassifier: 2-2               --
│    │    └─Embedding: 3-4               430,720
│    │    └─LSTM: 3-5                    659,456
│    │    └─LayerNorm: 3-6               512
│    │    └─Linear: 3-7                  864,805
Total params: 3,080,101
Trainable params: 3,080,101
Non-trainable params: 0

# Training

In [16]:
criterion = nn.CrossEntropyLoss(ignore_index = context.padding_idx)
optimizer = torch.optim.Adam(model.parameters())

In [17]:
epoches = 10

In [18]:
model.train()
training_loss = []
for e in range(epoches):
    running_loss = 0.0
    data_count = 0
    
    for source_sentence, target_sentence in tqdm(train_loader):
        data_count += len(source_sentence)
        optimizer.zero_grad()

        text = source_sentence.to(device)
        label = target_sentence.to(device)

        output = model(text, label)
        loss = criterion(output.transpose(-1, -2), label)
        running_loss += loss.item()

        loss.backward()
        optimizer.step()

    mean_loss = running_loss / data_count
    training_loss.append(mean_loss)
    print(f"epoch {e+1} : {mean_loss}")

32it [00:19,  1.80it/s]

In [None]:
plt.plot(training_loss)

# Evaluate

In [23]:
test_loader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

In [26]:
from torchmetrics import BLEUScore, Perplexity

In [35]:
metric = Perplexity()



In [36]:
model.eval()
running_loss = 0.0
data_count = 0
with torch.no_grad():
    for source_sentence, target_sentence in tqdm(test_loader):
        data_count += len(source_sentence)
        text = source_sentence.to(device)
        label = target_sentence.to(device)
        # label = label[1:].view(-1)

        output = model(text, label, 0)
        # output = torch.flatten(output[:, 1:], end_dim=-2)

        # loss = criterion(output.transpose(-1, -2), label)
        # running_loss += loss.item()
        metric.update(output.detach().cpu(), label.detach().cpu())

    # mean_loss = running_loss / data_count

# print(f"loss: {mean_loss}")
# print(f"perplexity: {math.exp(mean_loss)}")

8it [00:10,  1.30s/it]


In [37]:
metric.compute()

tensor(23113840.)

# Translate

In [22]:
MAX_LENGTH = 100

model.eval()
with torch.no_grad():
    for i, (source_sentence, target_sentence) in enumerate(val_iter):
        ori_source, ori_target = source_sentence, target_sentence
        source_sentence = tokenizers[SRC_LANGUAGE](source_sentence.rstrip('\n'))
        source_sentence = src_transform(source_sentence)
        source_sentence = source_sentence.unsqueeze(0)
        source_sentence = source_sentence.to(device)
        
        target_sentence = tokenizers[TRG_LANGUAGE](target_sentence.rstrip('\n'))
        target_sentence = trg_transform(target_sentence)
        target_sentence = target_sentence.unsqueeze(0)
        target_sentence = target_sentence.to(device)
        
        # outputs = [BOS_IDX]
        
        # h, c = model.encoder(source_sentence)
        
        # for _ in range(MAX_LENGTH):
            
        #     input_seq = torch.LongTensor([outputs[-1]]).to(device)
            
        #     output, _, _ = model.decoder(input_seq, h, c)
            
        #     next_token = output.argmax(1).item()
        #     outputs.append(next_token) 
            
        #     if next_token == EOS_IDX: break
        
        output = model(source_sentence, target_sentence)[0].argmax(-1)
        
        # if EOS_IDX in outputs:
        translated = vocabs[TRG_LANGUAGE].lookup_tokens(output.tolist())
        start_idx, end_idx = translated.index(context.bos_str), translated.index(context.eos_str)
        print(f'第{i+1}句')
        print('source:', ori_source.rstrip('\n'))
        print('target:', ori_target.rstrip('\n'))
        print('output:', *translated[start_idx+1:end_idx], sep=' ')
        
        if i+1==10: break

第1句
source: Eine Gruppe von Männern lädt Baumwolle auf einen Lastwagen
target: A group of men are loading cotton onto a truck
output: A group of people are loading a a a truck
第2句
source: Ein Mann schläft in einem grünen Raum auf einem Sofa.
target: A man sleeping in a green room on a couch.
output: A man sleeping in a green room with a <unk> .
第3句
source: Ein Junge mit Kopfhörern sitzt auf den Schultern einer Frau.
target: A boy wearing headphones sits on a woman's shoulders.
output: A man is headphones sits on a woman in in .
第4句
source: Zwei Männer bauen eine blaue Eisfischerhütte auf einem zugefrorenen See auf
target: Two men setting up a blue ice fishing hut on an iced over lake
output: Two men setting up a <unk> ice cream and on an <unk> over a . . .
第5句
source: Ein Mann mit beginnender Glatze, der eine rote Rettungsweste trägt, sitzt in einem kleinen Boot.
target: A balding man wearing a red life jacket is sitting in a small boat.
output: A man in wearing a blue life jacket is h

In [20]:
model_path = artifacts_path / 'seq2seq.pth'

In [21]:
torch.save(model, model_path.as_posix())

In [17]:
model = torch.load(model_path.as_posix())