## RNN, LSTM, GRU in Pytorch

#### Import Library

In [1]:
import os
import time
import torch
from tqdm import tqdm

from torch import nn
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data.dataset import random_split

from torchtext.datasets import IMDB
from torchtext.data.utils import get_tokenizer
from torchtext.data.functional import to_map_style_dataset
from torchtext.vocab import build_vocab_from_iterator, Vectors

In [2]:
config = {'pre_trained' : 'glove', # 'glove','fasttext', None 
          'max_length': 300,
          'batch_size': 64,
          'model_type': 'gru', # 'rnn', 'lstm', 'gru','avg_not_pad', None
          'emb_dim' : 300,
          'hidden_dim':128,
          'is_bidirectional':True,
          'epoch' : 15,
          'LR': 5
          }

#### Reading Data

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

# Tokenize & Vocab setup
tokenizer = get_tokenizer('basic_english')
train_iter = IMDB(split='train')
def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), min_freq= 2, specials=["<unk>", "<pad>"])
vocab.set_default_index(vocab["<unk>"]) # This index will be returned when OOV token is queried.

num_class = 2
vocab_size = len(vocab)
idx_pad = vocab.get_stoi()['<pad>']

In [5]:
if config['pre_trained'] == 'glove':
    pretrained_vectors = Vectors(name = 'glove.6B.300d.txt', 
                                #  cache = '[my_path]',
                                 url = 'http://nlp.stanford.edu/data/glove.6B.zip')
    pretrained_emb = pretrained_vectors.get_vecs_by_tokens(vocab.get_itos(), lower_case_backup=True)

elif config['pre_trained'] == 'fasttext':
    pretrained_vectors = Vectors(name = 'wiki.simple.vec', 
                                #  cache = '[my_path]',
                                 url = 'https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.simple.vec')
    pretrained_emb = pretrained_vectors.get_vecs_by_tokens(vocab.get_itos(), lower_case_backup=True)
else:
    pass

In [6]:
# DataLoader Setup
text_pipeline = lambda x: vocab(tokenizer(x))[:config['max_length']]
label_pipeline = lambda x: {"neg":0, "pos":1}.get(x)

def collate_batch(batch):
    label_list, text_list = [], [] 
    for (_label, _text) in batch:
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        label_list.append(label_pipeline(_label))
    
    text_list = pad_sequence(text_list, batch_first= True, padding_value= idx_pad)
    label_list = torch.tensor(label_list, dtype=torch.int64)
    return text_list.to(device), label_list.to(device)

train_iter, test_iter = IMDB()
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)

num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train_, batch_size=config['batch_size'], shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=config['batch_size'], shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=collate_batch)

In [8]:
# Data Sample
print('---- Data Sample ----')
print('Input : ')
print(' '.join(vars(train_data.examples[1])['text']),'\n')
print('Label : ')
print(vars(train_data.examples[1])['label'])

---- Data Sample ----
Input : 
Homelessness (or Houselessness as George Carlin stated) has been an issue for years but never a plan to help those on the street that were once considered human who did everything from going to school, work, or vote for the matter. Most people think of the homeless as just a lost cause while worrying about things such as racism, the war on Iraq, pressuring kids to succeed, technology, the elections, inflation, or worrying if they'll be next to end up on the streets.<br /><br />But what if you were given a bet to live on the streets for a month without the luxuries you once had from a home, the entertainment sets, a bathroom, pictures on the wall, a computer, and everything you once treasure to see what it's like to be homeless? That is Goddard Bolt's lesson.<br /><br />Mel Brooks (who directs) who stars as Bolt plays a rich man who has everything in the world until deciding to make a bet with a sissy rival (Jeffery Tambor) to see if he can live in the str

#### Data Preprocessing

In [9]:
def PreProcessingText(input_sentence):
    input_sentence = input_sentence.lower() # 소문자화
    input_sentence = re.sub('<[^>]*>', repl= ' ', string = input_sentence) # "<br />" 처리
    input_sentence = re.sub('[!"#$%&\()*+,-./:;<=>?@[\\]^_`{|}~]', repl= ' ', string = input_sentence) # 특수문자 처리 ("'" 제외)
    input_sentence = re.sub('\s+', repl= ' ', string = input_sentence) # 연속된 띄어쓰기 처리
    if input_sentence:
        return input_sentence

#### Model Setup

In [7]:
# Model setup
class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, num_class, **config):
        super(TextClassificationModel, self).__init__()
        self.model_type = config['model_type']
        self.pretrained = config['pre_trained']
        self.is_bidirectional = config['is_bidirectional']
        self.embed_dim = config['emb_dim']
        self.hidden_dim = config['hidden_dim']
        
        self.embedding = nn.Embedding(vocab_size, self.embed_dim,)
        if self.pretrained:
            self.embedding = nn.Embedding(vocab_size, self.embed_dim,).from_pretrained(pretrained_emb, freeze = False)

        if self.model_type is None:
            self.fc = nn.Linear(self.embed_dim, num_class)

        elif self.model_type == 'avg_not_pad':
            self.embedding = nn.EmbeddingBag(vocab_size, self.embed_dim, sparse=True, padding_idx = idx_pad)
            if self.pretrained:
                self.embedding = self.embedding.from_pretrained(pretrained_emb, freeze = False, sparse=True)
            self.fc = nn.Linear(self.embed_dim, num_class)

        elif self.model_type in ['rnn','lstm','gru']:
            if self.model_type == 'rnn':
                self.Recurrent = nn.RNN(input_size = self.embed_dim, hidden_size = self.hidden_dim, 
                                        bidirectional = self.is_bidirectional, batch_first = True)
            elif self.model_type == 'lstm':
                self.Recurrent = nn.LSTM(input_size = self.embed_dim, hidden_size = self.hidden_dim, 
                                         bidirectional = self.is_bidirectional, batch_first = True)
            else:
                self.Recurrent = nn.GRU(input_size = self.embed_dim, hidden_size = self.hidden_dim, 
                                        bidirectional = self.is_bidirectional, batch_first = True)

            last_input_dim = self.hidden_dim * 2 if self.is_bidirectional else self.hidden_dim 
            self.fc = nn.Linear(last_input_dim, num_class)

        else:
            raise NameError('Select model_type in [rnn, lstm, gru, avg_not_pad]')

        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        if self.pretrained:
            self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text):
        embedded = self.embedding(text)
        if self.model_type is None:
            embedded = torch.mean(embedded, dim=1)
            return self.fc(embedded)
        elif self.model_type == 'avg_not_pad':
            return self.fc(embedded)
        else:
            output, _ = self.Recurrent(embedded)
            last_output = output[:,-1,:]
            return self.fc(last_output)


model = TextClassificationModel(vocab_size, num_class, **config).to(device)
model

TextClassificationModel(
  (embedding): Embedding(51718, 300)
  (Recurrent): GRU(300, 128, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=256, out_features=2, bias=True)
)

In [None]:
# Old version
class SentenceClassification(nn.Module):
    def __init__(self, **model_config):
        super(SentenceClassification, self).__init__()

        if model_config['emb_type'] == 'glove' or 'fasttext':
            # nn.Embedding을 통해 쉽게 임베딩 설정 가능
            self.emb = nn.Embedding(num_embeddings = model_config['vocab_size'],
                                    embedding_dim = model_config['emb_dim'],
                                    _weight = TEXT.vocab.vectors)
            # num_embeddings = Vocab Size, embedding_dim = 원하는 임베딩 차원을 설정 (Pre-trained Vector 사용 시, 차원과 일치 시켜야함)
            # _weight = Pre-trained Vector를 Embedding 행렬의 Initial Value로 설정, 모델 학습을 이대로 진행하면 해당 Embedding Layer의 행렬도 학습
            # nn.Embedding.from_pretrained(TEXT.vocab.vectors)은 학습 freeze 상태, 현재 옵션 아니라면 정규 분포에서 생성한 값을 Initial Value로 사용 
        else:
            self.emb = nn.Embedding(num_embeddings = model_config['vocab_size'],
                                    embedding_dim = model_config['emb_dim'])
        
        self.bidirectional = model_config['bidirectional']
        self.num_direction = 2 if model_config['bidirectional'] else 1
        self.model_type = model_config['model_type'] 

        # RNN, GRU, LSTM의 필수 옵션은 대부분 유사
        self.RNN = nn.RNN (input_size = model_config['emb_dim'],
                           hidden_size = model_config['hidden_dim'],
                           dropout = model_config['dropout'],
                           bidirectional = model_config['bidirectional'],
                           batch_first = model_config['batch_first'])
        # input_size : 입력받을 Data의 크기이므로, Embedding Dimension으로 설정
        # hidden_size : Hidden Layer의 Dimension을 설정
        # Dropout : Dropout 확률, bidirectional: 양방향 모델을 사용할 경우 설정, batch_first: Data의 제일 처음 Axis에 Batch_size가 오도록 설정
        
        self.LSTM= nn.LSTM(input_size = model_config['emb_dim'],
                           hidden_size = model_config['hidden_dim'],
                           dropout = model_config['dropout'],
                           bidirectional = model_config['bidirectional'],
                           batch_first = model_config['batch_first'])
        
        self.GRU = nn.GRU (input_size = model_config['emb_dim'],
                           hidden_size = model_config['hidden_dim'],
                           dropout = model_config['dropout'],
                           bidirectional = model_config['bidirectional'],
                           batch_first = model_config['batch_first'])
    
        self.fc = nn.Linear(model_config['hidden_dim'] * self.num_direction,
                            model_config['output_dim'])
        # 위 정보를 이용해 분류 문제룰 푸는 Task를 할 예정이므로, Class에 대한 Score를 생성하기 위해 FC layer 1개를 통과, 추가 시그모이드가 없는 이유는 추후 Loss Function에 포함돼 있기 때문
        
        self.drop = nn.Dropout(model_config['dropout'])

    def forward(self, x):
        
        emb = self.emb(x) 
        # emb : (Batch_Size, Max_Seq_Length, Emb_dim)

        if self.model_type == 'RNN':
            output, hidden = self.RNN(emb) 
        elif self.model_type == 'LSTM':
            output, (hidden, cell) = self.LSTM(emb)
        elif self.model_type == 'GRU':
            output, hidden = self.GRU(emb)
        else:
            raise NameError('Select model_type in [RNN, LSTM, GRU]')
        
        # output : (Batch_Size, Max_Seq_Length, Hidden_dim * num_direction) 
        # hidden : (num_direction, Batch_Size, Hidden_dim)
        # hidden의 경우, batch_first 옵션이 안먹는 문제가 있음
        
        last_output = output[:,-1,:]

        # last_output : (Batch_Size, Hidden_dim * num_direction)
        return self.fc(self.drop(last_output))

#### Training

In [8]:
# Training Setup
# Hyperparameters
EPOCHS = config['epoch']
LR = config['LR']  

total_accu = None

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()

    model.train()
    total_acc, total_count = 0, 0
    log_interval = 50
    start_time = time.time()

    for idx, (text, label) in tqdm(enumerate(train_dataloader)):
        # Training
        optimizer.zero_grad()
        predicted_label = model(text)
        loss = criterion(predicted_label, label)
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print(f'| epoch {epoch:3d} | {idx:5d}/{len(train_dataloader):5d} batches | accuracy {total_acc/total_count:8.3f}')                             
            total_acc, total_count = 0, 0
            start_time = time.time()
        
    # Evaluation
    model.eval()
    total_acc, total_count = 0, 0
    with torch.no_grad():
        for idx, (text, label) in enumerate(valid_dataloader):
            predicted_label = model(text)
            loss = criterion(predicted_label, label)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    accu_val =  total_acc/total_count  
     
    if total_accu is not None and total_accu > accu_val:
        scheduler.step()
    else:
        total_accu = accu_val
    print('-' * 59)
    print(f'| end of epoch {epoch:3d} | time: {time.time() - epoch_start_time:5.2f}s | valid accuracy {accu_val:8.3f}')
    print('-' * 59)

52it [00:07, 10.07it/s]

| epoch   1 |    50/  372 batches | accuracy    0.492


101it [00:12, 10.04it/s]

| epoch   1 |   100/  372 batches | accuracy    0.514


153it [00:17, 10.83it/s]

| epoch   1 |   150/  372 batches | accuracy    0.496


203it [00:22, 10.24it/s]

| epoch   1 |   200/  372 batches | accuracy    0.500


251it [00:26, 10.34it/s]

| epoch   1 |   250/  372 batches | accuracy    0.508


303it [00:31, 10.78it/s]

| epoch   1 |   300/  372 batches | accuracy    0.512


353it [00:36, 10.41it/s]

| epoch   1 |   350/  372 batches | accuracy    0.496


372it [00:37,  9.79it/s]


-----------------------------------------------------------
| end of epoch   1 | time: 38.86s | valid accuracy    0.495
-----------------------------------------------------------


52it [00:05, 10.40it/s]

| epoch   2 |    50/  372 batches | accuracy    0.502


103it [00:09, 10.23it/s]

| epoch   2 |   100/  372 batches | accuracy    0.499


151it [00:14, 10.51it/s]

| epoch   2 |   150/  372 batches | accuracy    0.499


203it [00:19, 10.97it/s]

| epoch   2 |   200/  372 batches | accuracy    0.504


253it [00:24, 10.82it/s]

| epoch   2 |   250/  372 batches | accuracy    0.493


303it [00:28, 10.90it/s]

| epoch   2 |   300/  372 batches | accuracy    0.507


353it [00:33, 10.97it/s]

| epoch   2 |   350/  372 batches | accuracy    0.503


372it [00:35, 10.50it/s]


-----------------------------------------------------------
| end of epoch   2 | time: 36.22s | valid accuracy    0.495
-----------------------------------------------------------


52it [00:04, 10.60it/s]

| epoch   3 |    50/  372 batches | accuracy    0.517


102it [00:09, 10.37it/s]

| epoch   3 |   100/  372 batches | accuracy    0.492


152it [00:14, 10.68it/s]

| epoch   3 |   150/  372 batches | accuracy    0.512


202it [00:19, 10.89it/s]

| epoch   3 |   200/  372 batches | accuracy    0.505


252it [00:24, 10.33it/s]

| epoch   3 |   250/  372 batches | accuracy    0.529


303it [00:29,  9.71it/s]

| epoch   3 |   300/  372 batches | accuracy    0.520


351it [00:34, 10.28it/s]

| epoch   3 |   350/  372 batches | accuracy    0.505


372it [00:36, 10.22it/s]


-----------------------------------------------------------
| end of epoch   3 | time: 37.27s | valid accuracy    0.510
-----------------------------------------------------------


52it [00:04, 10.36it/s]

| epoch   4 |    50/  372 batches | accuracy    0.510


102it [00:09, 11.19it/s]

| epoch   4 |   100/  372 batches | accuracy    0.504


152it [00:14, 10.92it/s]

| epoch   4 |   150/  372 batches | accuracy    0.504


202it [00:18, 11.55it/s]

| epoch   4 |   200/  372 batches | accuracy    0.504


252it [00:23, 10.92it/s]

| epoch   4 |   250/  372 batches | accuracy    0.509


303it [00:28, 10.13it/s]

| epoch   4 |   300/  372 batches | accuracy    0.522


351it [00:32, 10.40it/s]

| epoch   4 |   350/  372 batches | accuracy    0.497


372it [00:34, 10.65it/s]


-----------------------------------------------------------
| end of epoch   4 | time: 35.63s | valid accuracy    0.497
-----------------------------------------------------------


51it [00:04, 11.50it/s]

| epoch   5 |    50/  372 batches | accuracy    0.510


103it [00:09, 10.52it/s]

| epoch   5 |   100/  372 batches | accuracy    0.520


153it [00:13, 11.47it/s]

| epoch   5 |   150/  372 batches | accuracy    0.526


203it [00:18, 11.12it/s]

| epoch   5 |   200/  372 batches | accuracy    0.520


253it [00:22, 11.52it/s]

| epoch   5 |   250/  372 batches | accuracy    0.524


303it [00:27, 10.92it/s]

| epoch   5 |   300/  372 batches | accuracy    0.520


351it [00:31, 11.04it/s]

| epoch   5 |   350/  372 batches | accuracy    0.525


372it [00:33, 11.17it/s]


-----------------------------------------------------------
| end of epoch   5 | time: 34.05s | valid accuracy    0.520
-----------------------------------------------------------


52it [00:04, 10.56it/s]

| epoch   6 |    50/  372 batches | accuracy    0.542


102it [00:09, 10.99it/s]

| epoch   6 |   100/  372 batches | accuracy    0.535


152it [00:13, 11.56it/s]

| epoch   6 |   150/  372 batches | accuracy    0.534


202it [00:18, 11.01it/s]

| epoch   6 |   200/  372 batches | accuracy    0.532


252it [00:22, 10.45it/s]

| epoch   6 |   250/  372 batches | accuracy    0.522


302it [00:27, 11.38it/s]

| epoch   6 |   300/  372 batches | accuracy    0.517


352it [00:31, 11.20it/s]

| epoch   6 |   350/  372 batches | accuracy    0.538


372it [00:33, 11.17it/s]


-----------------------------------------------------------
| end of epoch   6 | time: 34.01s | valid accuracy    0.523
-----------------------------------------------------------


52it [00:04, 10.87it/s]

| epoch   7 |    50/  372 batches | accuracy    0.540


102it [00:09, 11.90it/s]

| epoch   7 |   100/  372 batches | accuracy    0.521


152it [00:13, 11.10it/s]

| epoch   7 |   150/  372 batches | accuracy    0.534


202it [00:18, 10.83it/s]

| epoch   7 |   200/  372 batches | accuracy    0.524


252it [00:22, 11.53it/s]

| epoch   7 |   250/  372 batches | accuracy    0.525


302it [00:27, 11.39it/s]

| epoch   7 |   300/  372 batches | accuracy    0.535


352it [00:31, 10.50it/s]

| epoch   7 |   350/  372 batches | accuracy    0.533


372it [00:33, 11.09it/s]


-----------------------------------------------------------
| end of epoch   7 | time: 34.21s | valid accuracy    0.518
-----------------------------------------------------------


52it [00:04, 11.11it/s]

| epoch   8 |    50/  372 batches | accuracy    0.541


102it [00:09, 10.97it/s]

| epoch   8 |   100/  372 batches | accuracy    0.535


152it [00:13, 11.51it/s]

| epoch   8 |   150/  372 batches | accuracy    0.526


202it [00:18, 10.72it/s]

| epoch   8 |   200/  372 batches | accuracy    0.550


252it [00:23, 10.75it/s]

| epoch   8 |   250/  372 batches | accuracy    0.532


302it [00:27, 11.67it/s]

| epoch   8 |   300/  372 batches | accuracy    0.530


352it [00:32, 11.72it/s]

| epoch   8 |   350/  372 batches | accuracy    0.540


372it [00:33, 10.99it/s]


-----------------------------------------------------------
| end of epoch   8 | time: 34.59s | valid accuracy    0.517
-----------------------------------------------------------


51it [00:04, 11.66it/s]

| epoch   9 |    50/  372 batches | accuracy    0.535


103it [00:09, 11.56it/s]

| epoch   9 |   100/  372 batches | accuracy    0.552


153it [00:13, 10.95it/s]

| epoch   9 |   150/  372 batches | accuracy    0.546


203it [00:18, 10.94it/s]

| epoch   9 |   200/  372 batches | accuracy    0.556


252it [00:23,  9.61it/s]

| epoch   9 |   250/  372 batches | accuracy    0.531


302it [00:27, 10.35it/s]

| epoch   9 |   300/  372 batches | accuracy    0.537


352it [00:32, 10.83it/s]

| epoch   9 |   350/  372 batches | accuracy    0.559


372it [00:34, 10.89it/s]


-----------------------------------------------------------
| end of epoch   9 | time: 34.96s | valid accuracy    0.517
-----------------------------------------------------------


53it [00:05, 10.64it/s]

| epoch  10 |    50/  372 batches | accuracy    0.544


102it [00:10, 10.10it/s]

| epoch  10 |   100/  372 batches | accuracy    0.538


152it [00:14, 11.60it/s]

| epoch  10 |   150/  372 batches | accuracy    0.547


203it [00:19, 10.65it/s]

| epoch  10 |   200/  372 batches | accuracy    0.536


253it [00:23, 10.68it/s]

| epoch  10 |   250/  372 batches | accuracy    0.562


302it [00:28, 10.46it/s]

| epoch  10 |   300/  372 batches | accuracy    0.541


352it [00:33, 10.42it/s]

| epoch  10 |   350/  372 batches | accuracy    0.551


372it [00:35, 10.52it/s]


-----------------------------------------------------------
| end of epoch  10 | time: 36.12s | valid accuracy    0.517
-----------------------------------------------------------


52it [00:04, 10.44it/s]

| epoch  11 |    50/  372 batches | accuracy    0.544


102it [00:09, 10.38it/s]

| epoch  11 |   100/  372 batches | accuracy    0.540


152it [00:14, 11.08it/s]

| epoch  11 |   150/  372 batches | accuracy    0.539


202it [00:19, 10.92it/s]

| epoch  11 |   200/  372 batches | accuracy    0.551


252it [00:23, 10.62it/s]

| epoch  11 |   250/  372 batches | accuracy    0.547


302it [00:28, 10.28it/s]

| epoch  11 |   300/  372 batches | accuracy    0.541


352it [00:33, 10.29it/s]

| epoch  11 |   350/  372 batches | accuracy    0.559


372it [00:35, 10.57it/s]


-----------------------------------------------------------
| end of epoch  11 | time: 36.00s | valid accuracy    0.517
-----------------------------------------------------------


52it [00:04, 10.28it/s]

| epoch  12 |    50/  372 batches | accuracy    0.550


102it [00:09, 10.90it/s]

| epoch  12 |   100/  372 batches | accuracy    0.552


152it [00:14, 10.65it/s]

| epoch  12 |   150/  372 batches | accuracy    0.544


202it [00:19, 10.53it/s]

| epoch  12 |   200/  372 batches | accuracy    0.553


252it [00:24, 10.65it/s]

| epoch  12 |   250/  372 batches | accuracy    0.537


302it [00:28, 10.40it/s]

| epoch  12 |   300/  372 batches | accuracy    0.547


353it [00:33, 10.07it/s]

| epoch  12 |   350/  372 batches | accuracy    0.536


372it [00:35, 10.51it/s]


-----------------------------------------------------------
| end of epoch  12 | time: 36.22s | valid accuracy    0.517
-----------------------------------------------------------


52it [00:05, 10.41it/s]

| epoch  13 |    50/  372 batches | accuracy    0.531


102it [00:09, 10.35it/s]

| epoch  13 |   100/  372 batches | accuracy    0.553


152it [00:14, 10.20it/s]

| epoch  13 |   150/  372 batches | accuracy    0.549


202it [00:19, 10.48it/s]

| epoch  13 |   200/  372 batches | accuracy    0.541


252it [00:24, 10.20it/s]

| epoch  13 |   250/  372 batches | accuracy    0.557


302it [00:29, 10.31it/s]

| epoch  13 |   300/  372 batches | accuracy    0.537


352it [00:34, 10.15it/s]

| epoch  13 |   350/  372 batches | accuracy    0.561


372it [00:36, 10.26it/s]


-----------------------------------------------------------
| end of epoch  13 | time: 37.09s | valid accuracy    0.517
-----------------------------------------------------------


53it [00:05, 10.89it/s]

| epoch  14 |    50/  372 batches | accuracy    0.528


103it [00:09, 10.44it/s]

| epoch  14 |   100/  372 batches | accuracy    0.541


153it [00:14, 10.40it/s]

| epoch  14 |   150/  372 batches | accuracy    0.548


203it [00:19, 10.98it/s]

| epoch  14 |   200/  372 batches | accuracy    0.543


253it [00:24, 10.25it/s]

| epoch  14 |   250/  372 batches | accuracy    0.552


302it [00:28, 10.77it/s]

| epoch  14 |   300/  372 batches | accuracy    0.562


352it [00:33, 10.95it/s]

| epoch  14 |   350/  372 batches | accuracy    0.539


372it [00:35, 10.57it/s]


-----------------------------------------------------------
| end of epoch  14 | time: 35.95s | valid accuracy    0.517
-----------------------------------------------------------


53it [00:05, 10.22it/s]

| epoch  15 |    50/  372 batches | accuracy    0.541


103it [00:09, 10.73it/s]

| epoch  15 |   100/  372 batches | accuracy    0.547


153it [00:14, 10.74it/s]

| epoch  15 |   150/  372 batches | accuracy    0.557


202it [00:19, 10.48it/s]

| epoch  15 |   200/  372 batches | accuracy    0.536


252it [00:24, 10.86it/s]

| epoch  15 |   250/  372 batches | accuracy    0.532


302it [00:28, 10.78it/s]

| epoch  15 |   300/  372 batches | accuracy    0.541


352it [00:33, 10.79it/s]

| epoch  15 |   350/  372 batches | accuracy    0.564


372it [00:35, 10.43it/s]


-----------------------------------------------------------
| end of epoch  15 | time: 36.41s | valid accuracy    0.517
-----------------------------------------------------------


#### Test

In [9]:
print('Checking the results of test dataset.')
model.eval()
total_acc, total_count = 0, 0
with torch.no_grad():
    for idx, (text, label) in enumerate(test_dataloader):
        predicted_label = model(text, )
        loss = criterion(predicted_label, label)
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
accu_test =  total_acc/total_count
print('test accuracy {:8.3f}'.format(accu_test))

Checking the results of test dataset.
test accuracy    0.508


In [10]:
def predict(text, text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text))
        text = pad_sequence([text], batch_first=True, padding_value=idx_pad)
        output = model(text)
        return output.argmax(1).item()

ex_text_str = "It was very bad movie"

model = model.to("cpu")
label_dict = {0:'neg', 1:'pos'}
print(f"This is a {label_dict.get(predict(ex_text_str, text_pipeline))} comment")

This is a neg comment
