In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import nltk
from torchtext.data import Field,BucketIterator, TabularDataset
from model import Encoder, Decoder

In [4]:
TEXT = Field(tokenize=nltk.word_tokenize,use_vocab=True,lower=True, include_lengths=True, batch_first=True)

In [8]:
BATCH_SIZE = 32

In [5]:
train_data = TabularDataset(
                                   path="train_data.txt", # 데이터가 있는 root 경로
                                   format='tsv', # \t로 구분
                                   fields=[('input',TEXT),('target',TEXT)])

TEXT.build_vocab(train_data,min_freq=2)

train_loader = BucketIterator(train_data,batch_size=BATCH_SIZE, device=-1, # device -1 : cpu, device 0 : 남는 gpu
    sort_key=lambda x: len(x.input),sort_within_batch=True,repeat=False,shuffle=True)

In [7]:
len(train_data)

6973

In [10]:
HIDDEN = 200
EMBED = 100
VOCAB = len(TEXT.vocab)
LR = 0.001

encoder = Encoder(VOCAB,EMBED,HIDDEN,bidirec=True)
decoder = Decoder(VOCAB,EMBED,HIDDEN*2)
decoder.embedding = encoder.embedding

loss_function = nn.CrossEntropyLoss(ignore_index=TEXT.vocab.stoi['<pad>'])
enc_optim = optim.Adam(encoder.parameters(),lr=LR)
dec_optim = optim.Adam(decoder.parameters(),lr=LR)

### Sanity Check 

In [16]:
for i,batch in enumerate(train_loader):
    inputs,lengths = batch.input
    targets,_ = batch.target
    decoding_start = Variable(torch.LongTensor([TEXT.vocab.stoi['<s>']]*targets.size(0))).unsqueeze(1)
    
    encoder.zero_grad()
    decoder.zero_grad()
    output,hidden = encoder(inputs,lengths.tolist())
    score = decoder(decoding_start,hidden,targets.size(1),output,lengths)
    
    loss = loss_function(score,targets.view(-1))
    print(loss)
    loss.backward()
    enc_optim.step()
    dec_optim.step()
    if i==5:
        break

Variable containing:
 10.5353
[torch.FloatTensor of size 1]

Variable containing:
 10.5345
[torch.FloatTensor of size 1]

Variable containing:
 10.5326
[torch.FloatTensor of size 1]

Variable containing:
 10.5294
[torch.FloatTensor of size 1]

Variable containing:
 10.5257
[torch.FloatTensor of size 1]

Variable containing:
 10.5266
[torch.FloatTensor of size 1]



* Pointer Supervision (make trainset)
* Self-Critic (mixed object function) with REINFORCE
* Beam search
* ROUGE