In [None]:
Сравнить LSTM, RNN и GRU на задаче предсказания части речи (качество предсказания, скорость обучения, время инференса модели)

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import datetime

In [3]:
data_dir = './'
train_lang = 'en'

In [4]:
class DatasetSeq(Dataset):
    def __init__(self, data_dir, train_lang='en'):
        #open file
        with open(data_dir + train_lang + '.train', 'r', encoding="utf-8") as f:
            train = f.read().split('\n\n') 
        # delete extra tag markup
        train = [x for x in train if not '_ ' in x]
        #init vocabs of tokens for encoding { token:  id}
        self.target_vocab = {'': 0} # {p: 1, a: 2, r: 3, pu: 4}
        self.word_vocab = {'': 0} # {cat: 1, sat: 2, on: 3, mat: 4, '.': 5}
        self.char_vocab = {'': 0} # {c: 1, a: 2, t: 3, ' ': 4, s: 5}   
        # Cat sat on mat. -> [1, 2, 3, 4, 5]
        # p    a  r  p pu -> [1, 2, 3, 1, 4]
        # chars  -> [1, 2, 3, 4, 5, 2, 3, 4]

        #init encoded sequences lists (processed data)
        self.encoded_sequences = []
        self.encoded_targets = []
        self.encoded_char_sequences = []
        # n=1 because first value is padding
        n_word = 1
        n_target = 1
        n_char = 1
        for line in train:
            sequence = []
            target = []
            chars = []
            for item in line.split('\n'):
                if item != '':
                    word, label = item.split(' ')

                    if self.word_vocab.get(word) is None:
                        self.word_vocab[word] = n_word
                        n_word += 1
                    if self.target_vocab.get(label) is None:
                        self.target_vocab[label] = n_target
                        n_target += 1
                    for char in word:
                        if self.char_vocab.get(char) is None:
                            self.char_vocab[char] = n_char
                            n_char += 1
                    sequence.append(self.word_vocab[word])
                    target.append(self.target_vocab[label])
                    chars.append([self.char_vocab[char] for char in word])
            self.encoded_sequences.append(sequence)
            self.encoded_targets.append(target)
            self.encoded_char_sequences.append(chars)

    def __len__(self):
        return len(self.encoded_sequences)

    def __getitem__(self, index):
        return {
            'data': self.encoded_sequences[index], # [1, 2, 3, 4, 6] len=5
            'char': self.encoded_char_sequences[index],# [[1,2,3], [4,5], [1,2], [2,6,5,4], []] len=5
            'target': self.encoded_targets[index], # [1, 2, 3, 4, 6] len=5
        }
     

In [5]:
dataset = DatasetSeq(data_dir)

In [6]:
def collate_fn(batch):
    data = []
    target = []
    for item in batch:
        data.append(torch.as_tensor(item['data']))
        target.append(torch.as_tensor(item['target']))
    data = pad_sequence(data, batch_first=True, padding_value=0)
    target = pad_sequence(target, batch_first=True, padding_value=0)

    return {'data': data, 'target': target}
     

In [7]:
vocab_size = len(dataset.word_vocab) + 1
n_classes = len(dataset.target_vocab) + 1
n_chars = len(dataset.char_vocab) + 1

emb_dim = 400
hidden = 256
n_epochs = 5
batch_size = 128
device = torch.device('cuda')

In [8]:
class LSTMPredictor(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dim, n_classes, dropout_p=0.2):
        super().__init__()
        self.word_emb = nn.Embedding(vocab_size, emb_dim)
        self.rnn = nn.LSTM(emb_dim, hidden_dim, batch_first=True)
        self.clf = nn.Linear(hidden_dim, n_classes)
        self.do = nn.Dropout(dropout_p)
    
    def forward(self, x):
        emb = self.word_emb(x)       
        hidden, _ = self.rnn(emb)   
        pred = self.clf(self.do(hidden)) 

        return pred

In [9]:
LSTM_model = LSTMPredictor(vocab_size, emb_dim, hidden, n_classes).to(device)
LSTM_model.train()
optim = torch.optim.Adam(LSTM_model.parameters(), lr=0.0001)
loss_func = nn.CrossEntropyLoss()

In [10]:
%%time
for epoch in range(n_epochs):
    dataloader = DataLoader(dataset, 
                            batch_size, 
                            shuffle=True, 
                            collate_fn=collate_fn,
                            drop_last = True,
                            )
    for i, batch in enumerate(dataloader):
        optim.zero_grad()

        predict = LSTM_model(batch['data'].to(device))
        loss = loss_func(predict.view(-1, n_classes),           
                         batch['target'].to(device).view(-1),   
                         )
        loss.backward()
        optim.step()
        if i % 100 == 0:
            print(f'epoch: {epoch}, step: {i}, loss: {loss.item()}')
LSTM_train_loss = loss.item()
print(f" loss LSTM train : {LSTM_train_loss}")

epoch: 0, step: 0, loss: 3.160245656967163
epoch: 0, step: 100, loss: 0.5738717913627625
epoch: 1, step: 0, loss: 0.648611307144165
epoch: 1, step: 100, loss: 0.671148419380188
epoch: 2, step: 0, loss: 0.4876020550727844
epoch: 2, step: 100, loss: 0.3167836368083954
epoch: 3, step: 0, loss: 0.4221790134906769
epoch: 3, step: 100, loss: 0.3681943714618683
epoch: 4, step: 0, loss: 0.29771149158477783
epoch: 4, step: 100, loss: 0.23717890679836273
 loss LSTM train : 0.25567829608917236
Wall time: 23.2 s


In [11]:
phrase = 'I was surprised that the morning had gone by so quickly'
words = phrase.split(' ')
tokens = [dataset.word_vocab[w] for w in words]

start = datetime.datetime.now()
with torch.no_grad():
    LSTM_model.eval()
    predict = LSTM_model(torch.tensor(tokens).unsqueeze(0).to(device)) 
    labels = torch.argmax(predict, dim=-1).squeeze().cpu().detach().tolist()
    LSTM_inference_time = datetime.datetime.now() - start

target_labels = list(dataset.target_vocab.keys())
print([target_labels[l] for l in labels])
print(f'LSTM_inference_time: {LSTM_inference_time}')

['PRON', 'AUX', 'VERB', 'SCONJ', 'DET', 'NOUN', 'AUX', 'VERB', 'ADP', 'ADV', 'ADV']
LSTM_inference_time: 0:00:00.002019


In [12]:
class RNNPredictor(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dim, n_classes):
        super().__init__()
        self.word_emb = nn.Embedding(vocab_size, emb_dim)
        self.rnn = nn.RNN(emb_dim, hidden_dim, batch_first=True)
        self.clf = nn.Linear(hidden_dim, n_classes)
        self.do = nn.Dropout(0.2)

    def forward(self, x):
        emb = self.word_emb(x) 
        hidden, _ = self.rnn(emb)   
        pred = self.clf(self.do(hidden)) 

        return pred

In [13]:
RNN_model = RNNPredictor(vocab_size, emb_dim, hidden, n_classes).to(device)
RNN_model.train()
optim = torch.optim.Adam(RNN_model.parameters(), lr=0.001)
loss_func = nn.CrossEntropyLoss()

In [14]:
%%time
for epoch in range(n_epochs):
    dataloader = DataLoader(dataset, 
                            batch_size, 
                            shuffle=True, 
                            collate_fn=collate_fn,
                            drop_last = True,
                            )
    for i, batch in enumerate(dataloader):
        optim.zero_grad()

        predict = RNN_model(batch['data'].to(device))
        loss = loss_func(predict.view(-1, n_classes),           
                         batch['target'].to(device).view(-1),   
                         )
        loss.backward()
        optim.step()
        if i % 100 == 0:
            print(f'epoch: {epoch}, step: {i}, loss: {loss.item()}')
RNN_train_loss = loss.item()
print(f"RNN_train_loss: {RNN_train_loss}")

epoch: 0, step: 0, loss: 3.2458600997924805
epoch: 0, step: 100, loss: 0.23593257367610931
epoch: 1, step: 0, loss: 0.21360842883586884
epoch: 1, step: 100, loss: 0.1800975650548935
epoch: 2, step: 0, loss: 0.12330935150384903
epoch: 2, step: 100, loss: 0.10707778483629227
epoch: 3, step: 0, loss: 0.11442644894123077
epoch: 3, step: 100, loss: 0.09905164688825607
epoch: 4, step: 0, loss: 0.09584056586027145
epoch: 4, step: 100, loss: 0.09146595746278763
RNN_train_loss: 0.07329050451517105
Wall time: 17.3 s


In [15]:
phrase = 'I was surprised that the morning had gone by so quickly'
words = phrase.split(' ')
tokens = [dataset.word_vocab[w] for w in words]

start = datetime.datetime.now()
with torch.no_grad():
    RNN_model.eval()
    predict = RNN_model(torch.tensor(tokens).unsqueeze(0).to(device)) 
    labels = torch.argmax(predict, dim=-1).squeeze().cpu().detach().tolist()
    RNN_inference_time = datetime.datetime.now() - start

target_labels = list(dataset.target_vocab.keys())
print([target_labels[l] for l in labels])
print(f'RNN_inference_time: {RNN_inference_time}')

['PRON', 'AUX', 'VERB', 'SCONJ', 'DET', 'NOUN', 'AUX', 'VERB', 'ADP', 'ADV', 'ADV']
RNN_inference_time: 0:00:00.001968


In [16]:
class GRUPredictor(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dim, n_classes):
        super().__init__()
        self.word_emb = nn.Embedding(vocab_size, emb_dim)
        self.rnn = nn.GRU(emb_dim, hidden_dim, batch_first=True)
        self.clf = nn.Linear(hidden_dim, n_classes)
        self.do = nn.Dropout(0.1)

    def forward(self, x):
        emb = self.word_emb(x)
        hidden, _ = self.rnn(emb)   
        pred = self.clf(self.do(hidden))

        return pred

In [17]:
GRU_model = GRUPredictor(vocab_size, emb_dim, hidden, n_classes).to(device)
GRU_model.train()
optim = torch.optim.Adam(GRU_model.parameters(), lr=0.001)
loss_func = nn.CrossEntropyLoss()

In [18]:
%%time
for epoch in range(n_epochs):
    dataloader = DataLoader(dataset, 
                            batch_size, 
                            shuffle=True, 
                            collate_fn=collate_fn,
                            drop_last = True,
                            )
    for i, batch in enumerate(dataloader):
        optim.zero_grad()

        predict = GRU_model(batch['data'].to(device))
        loss = loss_func(predict.view(-1, n_classes),           
                         batch['target'].to(device).view(-1),   
                         )
        loss.backward()
        optim.step()
        if i % 100 == 0:
            print(f'epoch: {epoch}, step: {i}, loss: {loss.item()}')

GRU_train_loss = loss.item()

print(f"GRU_train_loss: {GRU_train_loss}")

epoch: 0, step: 0, loss: 2.8637948036193848
epoch: 0, step: 100, loss: 0.106620654463768
epoch: 1, step: 0, loss: 0.1675812005996704
epoch: 1, step: 100, loss: 0.14275148510932922
epoch: 2, step: 0, loss: 0.11082882434129715
epoch: 2, step: 100, loss: 0.06928535550832748
epoch: 3, step: 0, loss: 0.06799425184726715
epoch: 3, step: 100, loss: 0.09315257519483566
epoch: 4, step: 0, loss: 0.020670292899012566
epoch: 4, step: 100, loss: 0.059901539236307144
GRU_train_loss: 0.04385297745466232
Wall time: 21 s


In [19]:
phrase = 'I was surprised that the morning had gone by so quickly'
words = phrase.split(' ')
tokens = [dataset.word_vocab[w] for w in words]

start = datetime.datetime.now()
with torch.no_grad():
    GRU_model.eval()
    predict = GRU_model(torch.tensor(tokens).unsqueeze(0).to(device)) 
    labels = torch.argmax(predict, dim=-1).squeeze().cpu().detach().tolist()
    GRU_inference_time = datetime.datetime.now() - start

target_labels = list(dataset.target_vocab.keys())
print([target_labels[l] for l in labels])
print(f'GRU_inference_time: {GRU_inference_time}')

['PRON', 'AUX', 'ADJ', 'SCONJ', 'DET', 'NOUN', 'AUX', 'VERB', 'ADP', 'ADV', 'ADV']
GRU_inference_time: 0:00:00.002999


In [None]:
Лучший Loss - GRU, лучшее время - RNN, лучшее время инференса - RNN