<a href="https://colab.research.google.com/github/heugyu/notebook/blob/master/GRU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import os 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtext import data
from torchtext import datasets

In [0]:
BATCH_SIZE = 64
LR = 0.001
EPOCHS = 40
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")

In [5]:
TEXT = data.Field(sequential=True, batch_first=True, lower=True)
LABEL = data.Field(sequential=False, batch_first=True)

trainset, testset = datasets.IMDB.splits(TEXT, LABEL)

TEXT.build_vocab(trainset, min_freq=5)
LABEL.build_vocab(trainset)

trainset, valset = trainset.split(split_ratio=0.8)
train_iter, val_iter, test_iter = data.BucketIterator.splits(
    (trainset, valset, testset),
    batch_size = BATCH_SIZE,
    shuffle=True,
    repeat=False
)

vocab_size = len(TEXT.vocab)
n_classes = 2
print(f'train : {len(train_iter)}  test : {len(test_iter)}  vocab : {vocab_size}  classes : {n_classes}')

train : 313  test : 391  vocab : 46159  classes : 2


In [8]:
class BasicGRU(nn.Module):
    def __init__(self, n_layers, hidden_dim, n_vocab, embed_dim, n_classes, dropout_p=0.2):
        super(BasicGRU, self).__init__()
        print('building basic gru model ...')
        self.n_layers = n_layers
        self.emded = nn.Embedding(n_vocab, embed_dim)
        self.hidden_dim = hidden_dim
        self.dropout = nn.Dropout(dropout_p)
        self.gru = nn.GRU(
            embed_dim,
            self.hidden_dim,
            num_layers=self.n_layers,
            batch_first=True
        )
        self.out = nn.Linear(self.hidden_dim, n_classes)

    def forward(self, x):
        x = self.emded(x)
        h_0 = self._init_state(batch_size=x.size(0))
        x, _ = self.gru(x, h_0)
        h_t = x[:, -1, :]
        self.dropout(h_t)
        logit = self.out(h_t)
        return logit

    def _init_state(self, batch_size=1):
        weight = next(self.parameters()).data
        return weight.new(self.n_layers, batch_size, self.hidden_dim).zero_()

def train(model, optimizer, train_iter):
    model.train()
    for b, batch in enumerate(train_iter):
        x, y = batch.text.to(DEVICE), batch.label.to(DEVICE)
        y.data.sub_(1)
        optimizer.zero_grad()

        logit = model(x)
        loss = F.cross_entropy(logit, y)
        loss.backward()
        optimizer.step()

def evaluate(model, val_iter):
    model.eval()
    corrects, total_loss = 0, 0
    for batch in val_iter:
        x, y = batch.text.to(DEVICE), batch.label.to(DEVICE)
        y.data.sub_(1)
        logit = model(x)
        loss = F.cross_entropy(logit, y, reduction='sum')
        total_loss += loss.item()
        corrects += (logit.max(1)[1].view(y.size()).data == y.data).sum()
    size = len(val_iter.dataset)
    avg_loss = total_loss / size
    avg_accuracy = 100.0 * corrects / size
    return avg_loss, avg_accuracy

model = BasicGRU(1, 256, vocab_size, 128, n_classes, 0.5).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

best_val_loss = None
for e in range(1, EPOCHS+1):
    train(model, optimizer, train_iter)
    val_loss, val_accuray = evaluate(model, val_iter)

    print(f'에폭 : {e}  검증오차 : {val_loss}   검증정확도 : {val_accuray}')

test_loss, test_acc = evaluate(model, test_iter)
print(f'테스트 오차 : {test_loss}   테스트 정학도 : {test_acc}')

building basic gru model ...
에폭 : 1  검증오차 : 0.6935031635284424   검증정확도 : 50.63999938964844
에폭 : 2  검증오차 : 0.6955553000450134   검증정확도 : 51.68000030517578
에폭 : 3  검증오차 : 0.7093723979949951   검증정확도 : 49.65999984741211
에폭 : 4  검증오차 : 0.5235813743114471   검증정확도 : 76.15999603271484
에폭 : 5  검증오차 : 0.39257103457450865   검증정확도 : 83.31999969482422
에폭 : 6  검증오차 : 0.31888863639831544   검증정확도 : 86.68000030517578
에폭 : 7  검증오차 : 0.32119335041046143   검증정확도 : 87.05999755859375
에폭 : 8  검증오차 : 0.3444256635665894   검증정확도 : 86.87999725341797
에폭 : 9  검증오차 : 0.40229130001068114   검증정확도 : 86.83999633789062
에폭 : 10  검증오차 : 0.41135893411636354   검증정확도 : 86.07999420166016
에폭 : 11  검증오차 : 0.4544870388031006   검증정확도 : 86.83999633789062
에폭 : 12  검증오차 : 0.39552955169677734   검증정확도 : 87.18000030517578
에폭 : 13  검증오차 : 0.42498252353668214   검증정확도 : 86.75999450683594
에폭 : 14  검증오차 : 0.46009779634475706   검증정확도 : 86.73999786376953
에폭 : 15  검증오차 : 0.499126589012146   검증정확도 : 86.5999984741211
에폭 : 16  검증오차 : 0.46410519771