In [1]:
from collections import defaultdict

import numpy as np

import torch
import torch.nn as nn

In [10]:
device = "mps" if torch.backends.mps.is_available() else "cpu"

In [2]:
train_src_file = "../data/parallel/train.ja"
train_trg_file = "../data/parallel/train.en"

dev_src_file = "../data/parallel/dev.ja"
dev_trg_file = "../data/parallel/dev.en"

test_src_file = "../data/parallel/test.ja"
test_trg_file = "../data/parallel/test.en"

In [3]:
w2i_src = defaultdict(lambda: len(w2i_src))
w2i_trg = defaultdict(lambda: len(w2i_trg))

In [4]:
def read(source_filename, target_filename):
    with open(source_filename, "r") as f_source, open(target_filename, "r") as f_target:
        for line_src, line_target in zip(f_source, f_target):
            sent_src = [w2i_src[x] for x in line_src.strip().split() + ['</s>']]
            sent_trg = [w2i_trg[x] for x in ['<s>'] + line_target.strip().split() + ['</s>']]
            yield sent_src, sent_trg

In [5]:
# Training
train = list(read(train_src_file, train_trg_file))
unk_src = w2i_src["<unk>"]
eos_src = w2i_src["</s>"]
w2i_src = defaultdict(lambda: unk_src, w2i_src)

# Target
unk_trg = w2i_trg["<unk>"]
eos_trg = w2i_trg["</s>"]
sos_trg = w2i_trg["<s>"]
w2i_trg = defaultdict(lambda: unk_trg, w2i_trg)
i2w_trg = {v: k for k, v in w2i_trg.items()}

In [6]:
nwords_src = len(w2i_src)
nwords_trg = len(w2i_trg)

In [7]:
dev = list(read(dev_src_file, dev_trg_file))
test = list(read(test_src_file, test_trg_file))

In [8]:
EMBED_SIZE = 64
HIDDEN_SIZE = 128
BATCH_SIZE = 16

In [9]:
#Especially in early training, the model can generate basically infinitly without generating an EOS
#have a max sent size that you end at
MAX_SENT_SIZE = 50

# Encoder-Decoder LSTM

In [58]:
class ENC_DEC(nn.Module):
    def __init__(self, nwords_source, nwords_target, embedding_size_enc, embedding_size_dec, hidden_size_encoder, hidden_size_decoder):
        super(ENC_DEC, self).__init__()
        self.emb_encoder = nn.Embedding(num_embeddings=nwords_source, embedding_dim=embedding_size_enc)
        self.emb_decoder = nn.Embedding(num_embeddings=nwords_target, embedding_dim=embedding_size_dec)
        self.lstm_encoder = nn.LSTM(input_size=embedding_size_enc, hidden_size=hidden_size_encoder, batch_first=True)
        self.lstm_decoder = nn.LSTM(input_size=embedding_size_dec, hidden_size=hidden_size_decoder, batch_first=True)
        self.linear = nn.Linear(in_features=hidden_size_decoder, out_features=nwords_target)
    
    def forward(self, input_source, input_target):
        # input_source: batch_size, input_length
        input_embedding = self.emb_encoder(input_source) # Size: batch_size, input_length, embedding_size_enc
        # output_encoder => Size: batch_size, input_length, hidden_size_encoder
        output_encoder, (h_n_encoder, c_n_encoder) = self.lstm_encoder(input_embedding)
        print(output_encoder.shape)
        print(h_n_encoder.shape)
        print(c_n_encoder.shape)
        
        
        output_embedding = self.emb_decoder(input_target) # Size: batch_size, output_length, embedding_size_dec
        # output_decoder => size: batch_size, output_length, hidden_size_dec
        output_decoder, (h_n_decoder, c_n_decoder) = self.lstm_decoder(output_embedding) 
        print(output_decoder.shape)
        print(h_n_decoder.shape)
        print(c_n_decoder.shape)
        
        logits = self.linear(output_decoder) # Size: batch_size, output_length, nwords_target
        print(logits.shape)
        return logits

In [63]:
model = ENC_DEC(
            nwords_source=nwords_src, 
            nwords_target=nwords_trg, 
            embedding_size_enc=EMBED_SIZE, 
            embedding_size_dec=EMBED_SIZE, 
            hidden_size_encoder=HIDDEN_SIZE, 
            hidden_size_decoder=HIDDEN_SIZE
)

In [65]:
for sent_input, sent_target in train:
    sent_input = torch.tensor(sent_input)
    sent_target = torch.tensor(sent_target)
    model(sent_input, sent_target)
    break

torch.Size([11, 128])
torch.Size([1, 128])
torch.Size([1, 128])
torch.Size([8, 128])
torch.Size([1, 128])
torch.Size([1, 128])
torch.Size([8, 7043])
