In [63]:
import torch
import torch.nn as nn
from torch import optim
from tqdm import tqdm_notebook

from data_loader import get_dataset_loader
from model import Encoder, AttnDecoder

In [70]:
class Config():        
    
    # Data 
    #csv_file = 'data/complete_df.csv'
    csv_file = 'data/toy_df.csv'
    vocab_file = 'crawling/Reviews_csv/vocab.txt'
    tag_vocab = 'crawling/tags_txt/tag_vocab.txt'
    rating_dict = {'불만':0, '추천안함':0,
                    '보통':1,
                    '추천':2, '만족':2,
                    '적극추천':3}
    category = 'subcat'    # 'subcat' or 'category'
        
    def add_dataset_info(self, dataset):
        """data loading이후에 결정되는 것들"""
        self.rating_size = len(dataset.rating2idx)
        self.category_size = len(dataset.category2idx)
        self.tag_size = len(dataset.tag2idx)
        self.output_size = len(dataset.word2idx)
        self.padding_idx = dataset.word2idx['PAD']  # 0
        self.SOS_token = dataset.word2idx['SOS']    # 1
        self.EOS_token = dataset.word2idx['EOS']    # 2
    # Encoder
    # pretrained = False 
    attribute_size = 64
    
    # Decoder
    hidden_size = 512 
    num_layers = 2
    num_attr = 3 # for attention!
    
    # training
    batch_size = 50
    dropout = 0.2
    num_steps = 100
    print_every = 1

In [79]:
def train(encoder, decoder, dataloader, loss_fn, optimizer, config, verbose=False):
    encoder.train()
    decoder.train()
    
    
    def splitHidden(encoder_output):
        h_0 = encoder_output.view(config.num_layers, config.batch_size,                               config.hidden_size)
        c_0 = torch.zeros_like(h_0) 
        return (h_0, c_0)
    
    for t in tqdm_notebook(range(config.num_steps)):
        optimizer.zero_grad()
        data_iter = iter(dataloader)    # batching할 때 어떻게 할지 수정 필요
        
        rating_tensor, category_tensor, tag_tensor, target_tensor = next(data_iter)
        target_length = target_tensor.size(-1)
        
        attrs, encoder_output = encoder(rating_tensor, category_tensor, tag_tensor)
        decoder_hidden = splitHidden(encoder_output)        
        decoder_input = config.SOS_token * torch.ones((config.batch_size,1)).long() 
        
        decoder_outputs = []
        for idx in range(target_length): 
            decoder_output, decoder_hidden, attention_weights = \
                                        decoder(decoder_input, decoder_hidden, attrs)            
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.detach().view(config.batch_size, 1)
            decoder_outputs.append(decoder_output)

        decoder_outputs = torch.cat(decoder_outputs, 1).view(config.batch_size*target_length, -1)
        target_tensor = target_tensor.view(-1)
        loss = loss_fn(decoder_outputs, target_tensor) 
        num_actual_token = torch.sum(target_tensor != encoder.config.padding_idx).item()
        loss /= num_actual_token
        
        if verbose==True and t % config.print_every == 0:
            print("loss at %d step: %f" % (t, loss))
            
        loss.backward()
        optimizer.step()     

### Get data & set config

In [77]:
config = Config()
dataset, dataloader = get_dataset_loader(config.csv_file, config.vocab_file, config.tag_vocab, config.rating_dict,                         config.category, config.batch_size)
config.add_dataset_info(dataset)

### Instantiate model and start training

In [78]:
encoder = Encoder(config)
decoder = AttnDecoder(config)

params = list(encoder.parameters()) + list(decoder.parameters())
loss_fn = nn.NLLLoss(size_average=False, ignore_index=config.padding_idx)
#optimizer = optim.Adam(params, lr=0.001)
#optimizer = optim.SGD(params, lr=1)
optimizer = optim.RMSprop(params, lr=0.002, alpha=0.95)

train(encoder, decoder, dataloader, loss_fn, optimizer, config, verbose=True)

HBox(children=(IntProgress(value=0), HTML(value='')))

loss at 0 step: 10.102799
loss at 1 step: 9.550422
loss at 2 step: 8.512295
loss at 3 step: 9.510113
loss at 4 step: 7.948709
loss at 5 step: 7.725322
loss at 6 step: 6.978649
loss at 7 step: 6.665123
loss at 8 step: 6.229571
loss at 9 step: 5.961671
loss at 10 step: 6.047873
loss at 11 step: 5.925772
loss at 12 step: 6.324553
loss at 13 step: 5.960098
loss at 14 step: 5.971672
loss at 15 step: 5.657265
loss at 16 step: 5.632519
loss at 17 step: 5.652637
loss at 18 step: 5.376629
loss at 19 step: 5.681334
loss at 20 step: 5.432980
loss at 21 step: 5.555317
loss at 22 step: 5.258498
loss at 23 step: 5.551787
loss at 24 step: 5.542890
loss at 25 step: 5.481921
loss at 26 step: 5.254322
loss at 27 step: 5.465711
loss at 28 step: 5.257216
loss at 29 step: 5.189991
loss at 30 step: 5.476954
loss at 31 step: 5.233053
loss at 32 step: 5.107705
loss at 33 step: 5.202115
loss at 34 step: 5.221690
loss at 35 step: 4.952003
loss at 36 step: 4.953732
loss at 37 step: 5.004900
loss at 38 step: 5.17