In [1]:
from collections import defaultdict
import math
import time
import random
import torch
import torch.nn as nn
from torch.autograd import Variable

import torch.nn.functional as F

import numpy as np

In [2]:
# Functions to read in the corpus
# NOTE: We are using data from the Penn Treebank, which is already converted
#       into an easy-to-use format with "<unk>" symbols. If we were using other
#       data we would have to do pre-processing and consider how to choose
#       unknown words, etc.
w2i = defaultdict(lambda: len(w2i))
S = w2i["<s>"]
UNK = w2i["<unk>"]
def read_dataset(filename):
  with open(filename, "r") as f:
    for line in f:
      yield [w2i[x] for x in line.strip().split(" ")]

# Read in the data
train = list(read_dataset("../data/ptb/train.txt"))
w2i = defaultdict(lambda: UNK, w2i)
dev = list(read_dataset("../data/ptb/valid.txt"))
i2w = {v: k for k, v in w2i.items()}
n_words = len(w2i)

In [3]:
n_words

10000

In [30]:
class CBOW(nn.Module):
    
    def __init__(self, n_words, embed_size, hidden_size, dropout_rate, window_size):
        super(CBOW,self).__init__()
        self.embedding = nn.Embedding(n_words, embed_size)
        
        self.cbow = nn.Sequential(
            nn.Linear(embed_size, hidden_size), nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size, n_words),
        )
        return
    
    def _convert_word_idx_to_variable(self, word_idxs):
        var = Variable(torch.LongTensor(word_idxs))
        return var

    def forward(self, word_idxs):
        if not type(word_idxs) == torch.Tensor:
            word_idxs = self._convert_word_idx_to_variable(word_idxs)
        
        embed = self.embedding(word_idxs) #[N, context_size(2*window_size), embed_size]
        embed = embed.sum(1) # [N, embed_size]
        logit = self.cbow(embed)
        return logit

In [31]:
embed_size = 128
hidden_size = 128
dropout_rate = 0.3
window_size = 2

model = CBOW(
    n_words=n_words, 
    embed_size=embed_size,
    hidden_size=hidden_size, 
    dropout_rate=dropout_rate, 
    window_size=window_size,
    )

lr = 0.0001
optimizer = torch.optim.Adam(params = model.parameters(), lr = lr)

In [32]:
def convert_word_idx_to_variable(word_idxs):
    var = Variable(torch.LongTensor(word_idxs))
    return var

In [33]:
def cal_logits(sents):
    
    if not type(sents) == torch.Tensor:
        sents = convert_word_idx_to_variable(sents)
        
    logit = model(sents)
    return logit

In [71]:
def cal_loss(sents):
    context_history = []
    target_word_history = []
    
    for sent in sents:
        for i, target_word in enumerate(sent):
            context = [S] * window_size + [target_word] + [S] * window_size
            
            for j, w in enumerate(range(-window_size, window_size+1)):
                if i+w <0  or i+w > len(sent)-1:
                    pass
                else:
                    context[j] = sent[i+w] # 중심단어 기준에서 w개 전(후) 단어
                    
                    # 앞뒤 N개 단어를 context_hitory에
                    # 중간의 target 단어를 target_word_history에 저장
            context_history.append([context[0:window_size] + context[window_size+1:]])
            target_word_history.append(target_word)

            context_var = convert_word_idx_to_variable(context_history)
            context_var = context_var.view(context_var.size(0), -1)
            target_var = convert_word_idx_to_variable(target_word_history)

    logits = cal_logits(context_var)
    cost = F.cross_entropy(logits, target_var, reduction = 'sum')
    
    return cost

# Training

In [None]:
epochs = 100
batch_size = 100 # mini-batch GD

# 모델을 학습모드로 변경
model.train()

for epoch in range(epochs):
    random.shuffle(train)
    train_loss = 0
    
    start_time = time.time()
    
    for start_idx, end_idx in zip(range(0, len(train), batch_size),
                                  range(batch_size, len(train)+1, batch_size)):
        sents = train[start_idx : end_idx]
        my_loss = cal_loss(sents)
        train_loss += my_loss
        
        optimizer.zero_grad()
        my_loss.backward()
        optimizer.step()
        
    if (epoch+1)%10 == 0 :
        print('%s epoch 학습 완료. 경과 시간 : %s'%(epoch+1, start-time - time.time()))
        print('Train Loss : %s'%train_loss)