<a href="https://colab.research.google.com/github/m3yrin/topic-aware-tag-prediction/blob/master/qiita_tag_prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Tag generation for Japanese article
Re-implementation of "Topic-Aware Neural Keyphrase Generation for Social Media Language"

Auther : @m3yrin

## Reference
* Papar
Topic-Aware Neural Keyphrase Generation for Social Media Language  
Yue Wang, Jing Li, Hou Pong Chan, Irwin King, Michael R. Lyu, Shuming Shi  
https://arxiv.org/abs/1906.03889  
ACL 2019 Long paper

* https://github.com/yuewang-cuhk/TAKG
* https://github.com/m3yrin/NTM

* Qiita data gathering
    * https://qiita.com/pocket_kyoto/items/64a5ae16f02023df883e
    "Qiitaの記事データは、機械学習のためのデータセットに向いている"

## Dataset
Qiita articles. These are mainly technical articles written in Japanese.  
https://qiita.com/

You can gather articles through Qiita API.  
https://qiita.com/api/v2/docs?locale=en

## Memo
### Some methods are not implemented
Beam search and copy mechanism are not implemented.



## Preparation

In [0]:
!git clone https://github.com/m3yrin/topic-aware-tag-prediction.git

In [0]:
!pip install janome

In [0]:
import os
os.chdir('topic-aware-tag-prediction')

### Data Gathering
In this repository, dataset is not included.  
About Qiita API, please see https://qiita.com/api/v2/docs?locale=en

Sample script `qiita_api.py` is added to this repo. If you use `qiita_api.py`, please check its code before excution. Access token is required. https://qiita.com/settings/applications?locale=en    
Usage :  
```bash
python qiita-api.py -auth_token <your_qiita_api_access_token> -data_dir ./ -start_date 2019-01-01 -end_date 2019-02-01
```

In [0]:
!python qiita-api.py -auth_token <your_qiita_api_auth_token> -data_dir ./ -start_date 2019-01-01 -end_date 2019-02-01

### Import packages

In [0]:
import random
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import pandas as pd

from tqdm import tqdm_notebook as tqdm

import gensim
from gensim import corpora, models

import janome
from janome import analyzer
from janome.charfilter import *
from janome.tokenfilter import *
from janome.tokenizer import Tokenizer

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from torch.utils.tensorboard import SummaryWriter


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device :',device)

# set random seeds
random.seed(123)
torch.manual_seed(123)
random_state = 42

### Loading Data

In [0]:
HOME_DIR = "./"
bow = pd.read_pickle(HOME_DIR+"bow.pkl")
text = pd.read_pickle(HOME_DIR+"text.pkl")
target = pd.read_pickle(HOME_DIR+"target.pkl")

In [0]:
# discard words over 400 for fast training.
text_short = [s[:400] for s in text]

In [0]:
train_X, test_X, train_B, test_B, train_Y, test_Y \
    = train_test_split(text_short, bow, target, test_size=0.2, random_state=random_state)

train_X, valid_X, train_B, valid_B, train_Y, valid_Y \
    = train_test_split(train_X, train_B, train_Y, test_size=0.2, random_state=random_state)

print("# of train, valid, test :", len(train_X),len(valid_X), len(test_X))

### Making vocab for sequence data

In [0]:
def sentence_to_ids(vocab, sentence):
    ids = [vocab.word2id.get(word, UNK) for word in sentence]
    ids += [EOS]
    return ids

class Vocab(object):
    def __init__(self, word2id={}):
        
        self.word2id = dict(word2id)
        self.id2word = {v: k for k, v in self.word2id.items()}    
        
    def build_vocab(self, sentences, min_count=1):
        word_counter = {}
        for sentence in sentences:
            for word in sentence:
                word_counter[word] = word_counter.get(word, 0) + 1

        for word, count in sorted(word_counter.items(), key=lambda x: -x[1]):
            if count < min_count:
                break
            _id = len(self.word2id)
            self.word2id.setdefault(word, _id)
            self.id2word[_id] = word 

In [0]:
# special tokens
PAD_TOKEN = '<PAD>'
BOS_TOKEN = '<S>'
EOS_TOKEN = '</S>'
UNK_TOKEN = '<UNK>'
PAD = 0
BOS = 1
EOS = 2
UNK = 3

word2id = {
    PAD_TOKEN: PAD,
    BOS_TOKEN: BOS,
    EOS_TOKEN: EOS,
    UNK_TOKEN: UNK,
    }

# minimun acceptable count for input sequence
MIN_COUNT = 3

# build vocab
vocab_X = Vocab(word2id=word2id)
vocab_Y = Vocab(word2id=word2id)
vocab_X.build_vocab(train_X, min_count=MIN_COUNT)
vocab_Y.build_vocab(train_Y, min_count=1)

vocab_size_X = len(vocab_X.id2word)
vocab_size_Y = len(vocab_Y.id2word)
print('# of input vocab ：', vocab_size_X)
print('# of output vocab：', vocab_size_Y)

In [0]:
# Tokenize
train_X = [sentence_to_ids(vocab_X, sentence) for sentence in train_X]
train_Y = [sentence_to_ids(vocab_Y, sentence) for sentence in train_Y]
valid_X = [sentence_to_ids(vocab_X, sentence) for sentence in valid_X]
valid_Y = [sentence_to_ids(vocab_Y, sentence) for sentence in valid_Y]
test_X = [sentence_to_ids(vocab_X, sentence) for sentence in test_X]
test_Y = [sentence_to_ids(vocab_Y, sentence) for sentence in test_Y]

### Making vocab for BoW data

In [0]:
def build_bow_vocab(data, no_below=5, no_above=0.2):

    bow_dictionary = gensim.corpora.Dictionary(data)
    bow_dictionary.filter_extremes(no_below=no_below, no_above=no_above)
    
    # Re-id
    bow_dictionary.compactify()
    bow_dictionary.id2token = dict([(id, t) for t, id in bow_dictionary.token2id.items()])
    
    print("BOW dict length : %d" % len(bow_dictionary))
    
    return bow_dictionary

In [0]:
bow_vocab = build_bow_vocab(train_B)
bow_vocab_size=len(bow_vocab)

### Dataloader definition

In [0]:
class DataLoader(object):

    def __init__(self, X, B, Y, bow_vocab, batch_size, shuffle=True):
        
        self.batch_size = batch_size
        self.bow_vocab = bow_vocab
        
        self.index = 0
        self.pointer = np.array(range(len(X)))
        
        BV = [bow_vocab.doc2bow(s) for s in B]
        self.data = np.array(list(zip(X, B, BV, Y)))
        
        # counting total word number
        word_count = []
        for bow in BV:
            wc = 0
            for (i, c) in bow:
                wc += c
            word_count.append(wc)
        
        self.word_count = sum(word_count)
        self.data_size = len(X)
        
        self.shuffle = shuffle
        self.reset()

    def reset(self):
        if self.shuffle:
            self.pointer = shuffle(self.pointer)
        self.index = 0 
    
    # transform bow data into (1 x V) size vector.
    def _pad(self, batch):
        bow_vocab = len(self.bow_vocab)
        res_src_bow = np.zeros((len(batch), bow_vocab))
        
        for idx, bow in enumerate(batch):
            bow_k = [k for k, v in bow]
            bow_v = [v for k, v in bow]
            res_src_bow[idx, bow_k] = bow_v
            
        return res_src_bow
    
    def __iter__(self):
        return self

    def __next__(self):
        
        if self.index >= self.data_size:
            self.reset()
            raise StopIteration()
            
        ids = self.pointer[self.index: self.index + self.batch_size]
        seqs_X, seqs_B, seqs_BV, seqs_Y = zip(*self.data[ids])
        
        # sort for rnn
        seq_pairs = sorted(zip(seqs_X, seqs_B, seqs_BV, seqs_Y), key=lambda p: len(p[0]), reverse=True)
        seqs_X, seqs_B, seqs_BV, seqs_Y = zip(*seq_pairs)
        
        lengths_X = [len(s) for s in seqs_X]
        lengths_Y = [len(s) for s in seqs_Y]
        max_length_X = max(lengths_X)
        max_length_Y = max(lengths_Y)
        padded_X = [pad_seq(s, max_length_X) for s in seqs_X]
        padded_Y = [pad_seq(s, max_length_Y) for s in seqs_Y]
        
        padded_BV = self._pad(seqs_BV)
        
        # transposed for rnn
        batch_X = torch.tensor(padded_X, dtype=torch.long, device=device).transpose(0, 1)
        batch_Y = torch.tensor(padded_Y, dtype=torch.long, device=device).transpose(0, 1)
        
        batch_B = seqs_B
        batch_BV = torch.tensor(padded_BV, dtype=torch.float, device=device)
        
        self.index += self.batch_size
        
        return batch_X, batch_B, batch_BV, batch_Y, lengths_X


### Model definition of NTM

In [0]:
# cited : https://github.com/yuewang-cuhk/TAKG/blob/master/pykp/model.py

class NTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, topic_num,  l1_strength=0.001):
        super(NTM, self).__init__()
        self.input_dim = input_dim
        self.topic_num = topic_num
        self.fc11 = nn.Linear(self.input_dim, hidden_dim)
        self.fc12 = nn.Linear(hidden_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, topic_num)
        self.fc22 = nn.Linear(hidden_dim, topic_num)
        self.fcs = nn.Linear(self.input_dim, hidden_dim, bias=False)
        self.fcg1 = nn.Linear(topic_num, topic_num)
        self.fcg2 = nn.Linear(topic_num, topic_num)
        self.fcg3 = nn.Linear(topic_num, topic_num)
        self.fcg4 = nn.Linear(topic_num, topic_num)

        # bias disabled
        self.fcd1 = nn.Linear(topic_num, self.input_dim, bias=False)
        
    def encode(self, x):
        e1 = F.relu(self.fc11(x))
        e1 = F.relu(self.fc12(e1))
        e1 = e1.add(self.fcs(x))
        return self.fc21(e1), self.fc22(e1)

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def generate(self, h):
        g1 = torch.tanh(self.fcg1(h))
        g1 = torch.tanh(self.fcg2(g1))
        g1 = torch.tanh(self.fcg3(g1))
        g1 = torch.tanh(self.fcg4(g1))
        g1 = g1.add(h)
        return g1

    def decode(self, z):
        d1 = F.softmax(self.fcd1(z), dim=1)
        return d1

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, self.input_dim))
        z = self.reparameterize(mu, logvar)
        g = self.generate(z)
        return z, g, self.decode(g), mu, logvar

    def print_topic_words(self, vocab_dic, fn, n_top_words=10):
        beta_exp = self.fcd1.weight.data.cpu().numpy().T
        
        for k, beta_k in enumerate(beta_exp):
            topic_words = [vocab_dic[w_id] for w_id in np.argsort(beta_k)[:-n_top_words - 1:-1]]
            print('Topic {}: {}'.format(k, ' '.join(topic_words)))
            

### Model Definition of Seq2Seq

In [0]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size, padding_idx=PAD)
        self.gru = nn.GRU(hidden_size, hidden_size)

    def forward(self, seqs, input_lengths, hidden=None):
        emb = self.embedding(seqs)
        packed = pack_padded_sequence(emb, input_lengths)
        output, hidden = self.gru(packed, hidden)
        output, _ = pad_packed_sequence(output)
        return output, hidden

#### Memo
https://arxiv.org/abs/1906.03889  
Topic-Aware Neural Keyphrase Generation for Social Media Language


* decoder RNN
$$
\mathbf{s}_{j}=f_{G R U}\left(\left[\mathbf{u}_{j} ; \theta\right], \mathbf{s}_{j-1}\right)
$$

* Attention
$$
\alpha_{i j}=\frac{\exp \left(f_{\alpha}\left(\mathbf{h}_{i}, \mathbf{s}_{j}, \theta\right)\right)}{\sum_{i^{\prime}=1}^{|\mathbf{x}|} \exp \left(f_{\alpha}\left(\mathbf{h}_{i^{\prime}}, \mathbf{s}_{j}, \theta\right)\right)}, \ \mathbf{c}_{j}=\sum_{i=1}^{|\mathbf{x}|} \alpha_{i j} \mathbf{h}_{i}
$$

* Attention function

$$
f_{\alpha}(h_i, s_j , \theta) = v_{\alpha}^T tanh(W_{\alpha}[h_i; s_j ; \theta] + b_{\alpha}).
$$

* $p_{gen}$
$$
p_{g e n}={softmax}\left(\mathbf{W}_{g e n}\left[\mathbf{s}_{j} ; \mathbf{c}_{j}\right]+\mathbf{b}_{g e n}\right)
$$

* Copy mechanism (Not implemented.)
$$\lambda_{j}={sigmoid}\left(\mathbf{W}_{\lambda}\left[\mathbf{u}_{j} ; \mathbf{s}_{j} ; \mathbf{c}_{j} ; \theta\right]+\mathbf{b}_{\lambda}\right)$$
$$ 
p_{j}=\lambda_{j} \cdot p_{g e n}+\left(1-\lambda_{j}\right) \cdot \sum_{i=1}^{|\mathbf{x}|} \alpha_{i j}
$$


In [0]:
class Decoder(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.embedding = nn.Embedding(output_size, hidden_size, padding_idx=PAD)
        self.gru = nn.GRU(hidden_size * 2, hidden_size)
        self.out = nn.Linear(hidden_size * 2, output_size)
        
        self.W_a  = torch.rand((hidden_size * 3, hidden_size), dtype=torch.float,device=device , requires_grad=True)
        self.b    = torch.rand(hidden_size, dtype=torch.float, device=device, requires_grad=True)
        self.v  = torch.rand((hidden_size, 1), dtype=torch.float,device=device , requires_grad=True)
        #self.f_copy = nn.Linear(hidden_size * 4, 1)
    
    def forward(self, seqs, hidden, encoder_output, latent_topic):
        emb = self.embedding(seqs)
        emb = torch.cat((emb, latent_topic), 2)
        
        _, hidden = self.gru(emb, hidden)

        attn, a_weight = self.attention(hidden, encoder_output, latent_topic)
        
        p_gen = torch.cat((hidden, attn), 2)
        p_gen = self.out(p_gen)

        # Copy mechanism
        #lamda_pgen = torch.sigmoid(self.f_copy(torch.cat((emb, hidden, attn), 2)))
        #output = lamda_pgen * p_gen 
        #output += (1.0 - lamda_pgen) * attn

        output = p_gen
        
        return output, hidden
    
    def attention(self, u, encoder_output, latent_topic):
        
        """
        u              : embeded decoder input, (1, batch, hidden_size)
        encoder_output : encoder outputs,       (seq_len, batch, hidden_size)
        latent_topic   : topic vector from ntm, (1, batch, hidden_size)
        """
        
        seq_len = len(encoder_output)
        
        # -> (batch, seq_len, hidden_size)
        e_output = encoder_output.transpose(0,1)
        
        # state_topic : (1, batch, hidden_size * 2)
        state_topic = torch.cat((u, latent_topic), 2)
        
        # -> state_topic : (seq_len, batch, hidden_size * 2)
        state_topic = [state_topic] * seq_len
        state_topic = torch.cat(state_topic, 0)
        
        # -> state_topic : (batch, seq_len, hidden_size * 2)
        state_topic = state_topic.transpose(0,1)
        
        # state_topic_output : (batch, seq_len, hidden_size * 3)
        state_topic_output = torch.cat((state_topic, e_output), 2)
        
        # state_topic_output: (batch(i), seq_len(j), hidden_size * 3 (k))
        # self.W_a  : [hidden_size * 3(k), hidden_size(l)]
        # -> atten_weight : (batch(i), seq_len(j), hidden_size(l))
        
        atten_weight = torch.einsum('ijk,kl->ijl', state_topic_output, self.W_a)
        atten_weight = torch.tanh(atten_weight + self.b)
        
        # atten_weight : (batch(i), seq_len(j), hidden_size(k))
        # self.v : (hidden_size(k), 1(l))
        # -> atten_weight : (batch(i), seq_len(j), 1)
        atten_weight = torch.matmul(atten_weight, self.v)
        atten_weight = F.softmax(atten_weight, dim=1)
        
        # atten_weight : (batch(i), seq_len(j), 1(k))
        # e_output     : (batch(i), seq_len(j), hidden_size(l))
        # -> c : (batch(i), hidden_size(l))
        c = torch.einsum('ijk,ijl->il', atten_weight, e_output)
        
        # -> (1, batch, hidden_size)
        return c.unsqueeze(0), atten_weight

### Gathering models

In [0]:
class Model(nn.Module):
    def __init__(self, input_size_ntm, output_topic_num, hidden_size_ntm, input_size_s2s, output_size_s2s, hidden_size_s2s, l1_strength=0.001):
        super(Model, self).__init__()
        self.ntm = NTM(input_size_ntm, hidden_size_ntm, hidden_size_s2s, l1_strength=l1_strength)
        self.encoder = Encoder(input_size_s2s, hidden_size_s2s)
        self.decoder = Decoder(hidden_size_s2s, output_size_s2s)
        
    def forward(self, batch_X, batch_BV, lengths_X, max_length, mode, batch_Y=None, use_teacher_forcing=False):
        
        if mode == 'ntm':
            #ntm
            z, g, recon_batch, mu, logvar = self.ntm(batch_BV)
            
            return None, z, g, recon_batch, mu, logvar
            
        elif mode == 's2s':
            #ntm
            z, g, recon_batch, mu, logvar = self.ntm(batch_BV)
            
            # use g as laten topic vector, profibit back prop to ntm module
            latent_topic = g.detach().unsqueeze(0)
            
            # s2s
            encoder_output, encoder_hidden = self.encoder(batch_X, lengths_X)
            
            _batch_size = batch_X.size(1)
            decoder_input = torch.tensor([BOS] * _batch_size, dtype=torch.long, device=device)
            decoder_input = decoder_input.unsqueeze(0)
            decoder_hidden = encoder_hidden

            decoder_outputs = torch.zeros(max_length, _batch_size, self.decoder.output_size, device=device)

            for t in range(max_length):
                decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_output,  latent_topic)
                
                decoder_outputs[t] = decoder_output
                if use_teacher_forcing and batch_Y is not None:
                    decoder_input = batch_Y[t].unsqueeze(0)
                else:
                    decoder_input = decoder_output.max(-1)[1]
                
            return decoder_outputs, z, g, recon_batch, mu, logvar
        
        else:
            #ntm
            z, g, recon_batch, mu, logvar = self.ntm(batch_BV)
            
            # use g as laten topic vector
            latent_topic = g.unsqueeze(0)
            
            # s2s
            encoder_output, encoder_hidden = self.encoder(batch_X, lengths_X)
            
            _batch_size = batch_X.size(1)
            decoder_input = torch.tensor([BOS] * _batch_size, dtype=torch.long, device=device)
            decoder_input = decoder_input.unsqueeze(0)
            decoder_hidden = encoder_hidden

            decoder_outputs = torch.zeros(max_length, _batch_size, self.decoder.output_size, device=device)

            for t in range(max_length):
                decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_output, latent_topic)
                
                decoder_outputs[t] = decoder_output
                if use_teacher_forcing and batch_Y is not None:
                    decoder_input = batch_Y[t].unsqueeze(0)
                else:
                    decoder_input = decoder_output.max(-1)[1]
                
            return decoder_outputs, z, g, recon_batch, mu, logvar
        
    

### AUX functions

In [0]:
def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.kaiming_uniform(m.weight)

def pad_seq(seq, max_length):
    res = seq + [PAD for i in range(max_length - len(seq))]
    return res    

mce = nn.CrossEntropyLoss(size_average=False, ignore_index=PAD)
def masked_cross_entropy(logits, target):
    return mce(logits.view(-1, logits.size(-1)), target.view(-1))

def l1_penalty(para):
    return nn.L1Loss()(para, torch.zeros_like(para))

def check_sparsity(para, sparsity_threshold=1e-3):
    num_weights = para.shape[0] * para.shape[1]
    num_zero = (para.abs() < sparsity_threshold).sum().float()
    return num_zero / float(num_weights)

def update_sparsity_l1(model, sparsity_target):
    
    cur_sparsity = check_sparsity(model.ntm.fcd1.weight.data)
    cur_l1 = model.ntm.l1_strength
    
    diff = sparsity_target - cur_sparsity
    cur_l1.mul_(2.0 ** diff)

## Train

In [0]:
# Tensorboard

%load_ext tensorboard
%tensorboard --logdir runs

### 1st Training (NTM)

In [0]:
def compute_loss_ntm(batch_X, batch_BV, batch_Y, lengths_X, model, optimizer=None, is_train=True, l1_strength = 1e5):
    model.train(is_train)
    
    # dummy
    use_teacher_forcing = is_train and (random.random() < teacher_forcing_rate)
    max_length = batch_Y.size(0)

    # norm bow vector
    batch_BV_norm = F.normalize(batch_BV)
    
    # forward all model
    _, z, g, recon_batch, mu, logvar = model(batch_X, batch_BV_norm, lengths_X, max_length, 'ntm', batch_Y, use_teacher_forcing)
    
    # loss for ntm
    bce = F.binary_cross_entropy(recon_batch, batch_BV, size_average=False)
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    loss_ntm = bce + kld
    
    # add l1 penalty for sparsity of ntm decoder FC weight
    loss_ntm += l1_strength * l1_penalty(model.ntm.fcd1.weight)

    # sum up losses
    loss = loss_ntm
    
    if is_train:
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    return loss.item(), bce.item(), _

In [0]:
# Training parameter
batch_size = 128
num_epochs_ntm = 100
lr = 1e-3
target_sparsity=0.85
teacher_forcing_rate = 0.0
l1_strength = 1e6

# Model parameter
model_args = {
    'input_size_ntm'  : bow_vocab_size,
    'output_topic_num': 256,
    'hidden_size_ntm' : 256,
    'input_size_s2s'  : vocab_size_X, 
    'output_size_s2s' : vocab_size_Y,
    'hidden_size_s2s' : 256
}

train_dataloader = DataLoader(train_X, train_B, train_Y, bow_vocab, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_X, valid_B, valid_Y, bow_vocab, batch_size=batch_size, shuffle=False)

model = Model(**model_args).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

model.apply(init_weights)

In [0]:
# Tensorboard
writer = SummaryWriter()

for epoch in range(1, num_epochs_ntm+1):
    train_loss = 0.
    train_bce  = 0.
    valid_loss = 0.
    valid_bce  = 0.
    
    # train loop
    for batch in train_dataloader:
        batch_X, _, batch_BV, batch_Y, lengths_X = batch
        loss, bce, _ = compute_loss_ntm(batch_X, batch_BV, batch_Y, lengths_X, model, optimizer, is_train=True, l1_strength = l1_strength)
        train_loss += loss
        train_bce  += bce
        
    # validation loop
    for batch in valid_dataloader:
        batch_X, _, batch_BV, batch_Y, lengths_X = batch
        loss, bce, _ = compute_loss_ntm(batch_X, batch_BV, batch_Y, lengths_X, model, optimizer, is_train=False, l1_strength = l1_strength)
        valid_loss += loss
        valid_bce  += bce
        
    # calc ave total loss
    train_loss = np.sum(train_loss) / len(train_dataloader.data)
    valid_loss = np.sum(valid_loss) / len(valid_dataloader.data)
    
    # calc ntm weight sparsity
    sparsity = check_sparsity(model.ntm.fcd1.weight.data)
    
    # calc perplexity
    train_bce = train_bce / train_dataloader.word_count
    valid_bce = valid_bce / valid_dataloader.word_count
    train_ppl = np.exp(train_bce)
    valid_ppl = np.exp(valid_bce)
    
    print('Epoch {:4d} | loss(train/valid) {:5.2f} / {:5.2f}, ppl(train/valid) {:5.2f} / {:5.2f}, sparsity(current/target) {:.3f} / {:.3f},'.format(
        epoch, train_loss, valid_loss,train_ppl, valid_ppl,float(sparsity.cpu()), target_sparsity))
    
    # tensorboard
    writer.add_scalars('ntm/loss',{'train_loss': train_loss,'valid_loss': valid_loss},epoch)
    writer.add_scalars('ntm/perplexity',{'train_ppl': train_ppl,'valid_ppl': valid_ppl},epoch)
    writer.add_scalars('ntm/sparsity',{'sparsity': float(sparsity.cpu())},epoch)
    
    print('-'*80)

writer.close()

In [0]:
ckpt = model.state_dict()
torch.save(ckpt, HOME_DIR + 'model_ntm.pt')

### 2nd Training (Seq2seq)

In [0]:
# Training parameter
batch_size = 128
num_epochs_s2s = 10
lr = 1e-3
target_sparsity=0.85
teacher_forcing_rate = 0.4
l1_strength = 1e6

# Model parameter
model_args = {
    'input_size_ntm'  : bow_vocab_size,
    'output_topic_num': 256,
    'hidden_size_ntm' : 256,
    'input_size_s2s'  : vocab_size_X, 
    'output_size_s2s' : vocab_size_Y,
    'hidden_size_s2s' : 256
}


train_dataloader = DataLoader(train_X, train_B, train_Y, bow_vocab, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_X, valid_B, valid_Y, bow_vocab, batch_size=batch_size, shuffle=False)

model = Model(**model_args).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

ckpt = torch.load(HOME_DIR + 'model_ntm.pt')
model.load_state_dict(ckpt)

In [0]:
def compute_loss_s2s(batch_X, batch_BV, batch_Y, lengths_X, model, optimizer=None, is_train=True, gamma = 1.0, l1_strength = 1e6):
    model.train(is_train)

    # set params for s2s
    use_teacher_forcing = is_train and (random.random() < teacher_forcing_rate)
    max_length = batch_Y.size(0)
    
    # norm bow vector
    batch_BV_norm = F.normalize(batch_BV)
    
    # forward all model
    pred_Y, z, g, recon_batch, mu, logvar = model(batch_X, batch_BV_norm, lengths_X, max_length, 's2s', batch_Y, use_teacher_forcing)
    
    # loss for s2s
    loss_s2s = masked_cross_entropy(pred_Y.contiguous(), batch_Y.contiguous())
    
    # loss for ntm
    bce = F.binary_cross_entropy(recon_batch, batch_BV, size_average=False)
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    #loss_ntm = bce + kld + model.ntm.l1_strength * l1_penalty(model.ntm.fcd1.weight)
    loss_ntm = bce + kld

    # sum up losses. gamma is a weight for loss_s2s
    loss = loss_s2s * gamma
    
    if is_train:
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    pred = pred_Y.max(dim=-1)[1].data.cpu().numpy().T.tolist()

    return loss.item(), bce.item(), pred

In [0]:
# Tensorboard
writer = SummaryWriter()

for epoch in range(1, num_epochs_s2s+1):
    train_loss = 0.
    valid_loss = 0.
    
    # train loop
    for batch in train_dataloader:
        batch_X, _, batch_BV, batch_Y, lengths_X = batch
        loss, _, pred = compute_loss_s2s(batch_X, batch_BV, batch_Y, lengths_X, model, optimizer, is_train=True, gamma=200)
        train_loss += loss
    
    # validation loop
    for batch in valid_dataloader:
        batch_X, _, batch_BV, batch_Y, lengths_X = batch
        loss, _, pred = compute_loss_s2s(batch_X, batch_BV, batch_Y, lengths_X, model, optimizer, is_train=False, gamma=200)
        valid_loss += loss
        
    # calc ave total loss
    train_loss = np.sum(train_loss) / len(train_dataloader.data)
    valid_loss = np.sum(valid_loss) / len(valid_dataloader.data)
    
    print('Epoch {:4d} | loss(train/valid) {:5.2f} / {:5.2f}'.format(epoch, train_loss, valid_loss))
    
    # tensorboard
    writer.add_scalars('s2s/loss',{'train_loss': train_loss,'valid_loss': valid_loss},epoch)

    print('-'*80)
    
writer.close()

In [0]:
ckpt = model.state_dict()
torch.save(ckpt, HOME_DIR + 'model_ntm_s2s.pt')

### 3rd Training (Joint)

In [0]:
def compute_loss(batch_X, batch_BV, batch_Y, lengths_X, model, optimizer=None, is_train=True, gamma = 1.0, l1_strength = 1e4):
    model.train(is_train)

    # set params for s2s
    use_teacher_forcing = is_train and (random.random() < teacher_forcing_rate)
    max_length = batch_Y.size(0)
    
    # norm bow vector
    batch_BV_norm = F.normalize(batch_BV)
    
    # forward all model
    pred_Y, z, g, recon_batch, mu, logvar = model(batch_X, batch_BV_norm, lengths_X, max_length, '', batch_Y, use_teacher_forcing)
    
    # loss for s2s
    loss_s2s = masked_cross_entropy(pred_Y.contiguous(), batch_Y.contiguous())
    
    # loss for ntm
    bce = F.binary_cross_entropy(recon_batch, batch_BV, size_average=False)
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    #loss_ntm = bce + kld + model.ntm.l1_strength * l1_penalty(model.ntm.fcd1.weight)
    loss_ntm = bce + kld + l1_strength * l1_penalty(model.ntm.fcd1.weight)

    # sum up losses
    loss = loss_ntm + loss_s2s * gamma
    
    if is_train:
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    pred = pred_Y.max(dim=-1)[1].data.cpu().numpy().T.tolist()

    return loss.item(), bce.item(), pred

In [0]:
# Training parameter
batch_size = 256
num_epochs_all = 20
lr = 1e-5
target_sparsity=0.85
teacher_forcing_rate = 0.4
l1_strength = 1e6

# Model parameter
model_args = {
    'input_size_ntm'  : bow_vocab_size,
    'output_topic_num': 256,
    'hidden_size_ntm' : 256,
    'input_size_s2s'  : vocab_size_X, 
    'output_size_s2s' : vocab_size_Y,
    'hidden_size_s2s' : 256
}

train_dataloader = DataLoader(train_X, train_B, train_Y, bow_vocab, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_X, valid_B, valid_Y, bow_vocab, batch_size=batch_size, shuffle=False)

model = Model(**model_args).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

ckpt = torch.load(HOME_DIR + 'model_ntm_s2s.pt')
model.load_state_dict(ckpt)

In [0]:
# Tensorboard
writer = SummaryWriter()

for epoch in range(1, num_epochs_all+1):
    train_loss = 0.
    train_bce  = 0.
    valid_loss = 0.
    valid_bce  = 0.
    
    # train loop
    for batch in train_dataloader:
        batch_X, _, batch_BV, batch_Y, lengths_X = batch
        loss, bce, pred = compute_loss(batch_X, batch_BV, batch_Y, lengths_X, model, optimizer, is_train=True, gamma=200, l1_strength = l1_strength)
        train_loss += loss
        train_bce  += bce
        
    # validation loop
    for batch in valid_dataloader:
        batch_X, _, batch_BV, batch_Y, lengths_X = batch
        loss, bce, pred = compute_loss(batch_X, batch_BV, batch_Y, lengths_X, model, optimizer, is_train=False, gamma=200, l1_strength = l1_strength)
        valid_loss += loss
        valid_bce  += bce
        
    # calc ave total loss
    train_loss = np.sum(train_loss) / len(train_dataloader.data)
    valid_loss = np.sum(valid_loss) / len(valid_dataloader.data)
    
    # calc ntm weight sparsity
    sparsity = check_sparsity(model.ntm.fcd1.weight.data)
    
    # calc perplexity
    train_bce = train_bce / train_dataloader.word_count
    valid_bce = valid_bce / valid_dataloader.word_count
    train_ppl = np.exp(train_bce)
    valid_ppl = np.exp(valid_bce)
    
    print('Epoch {:4d} | loss(train/valid) {:5.2f} / {:5.2f}, ppl(train/valid) {:5.2f} / {:5.2f}, sparsity(current/target) {:.3f} / {:.3f}'.format(
        epoch, train_loss, valid_loss,train_ppl, valid_ppl,float(sparsity.cpu()), target_sparsity))
    
    # tensorboard
    writer.add_scalars('joint/loss',{'train_loss': train_loss,'valid_loss': valid_loss},epoch)
    writer.add_scalars('joint/perplexity',{'train_ppl': train_ppl,'valid_ppl': valid_ppl},epoch)
    writer.add_scalars('joint/sparsity',{'sparsity': float(sparsity.cpu())},epoch)
    
    print('-'*80)
    
writer.close()

In [0]:
ckpt = model.state_dict()
torch.save(ckpt, HOME_DIR + 'model_ntm_s2s_joint.pt')

## Result

In [0]:
ckpt = torch.load(HOME_DIR + 'model_ntm_s2s_joint.pt')
model.load_state_dict(ckpt)

In [0]:
def ids_to_sentence(vocab, ids):
    return [vocab.id2word[_id] for _id in ids]

def trim_eos(ids):
    if EOS in ids:
        return ids[:ids.index(EOS)]
    else:
        return ids

### Loss

In [0]:
test_dataloader = DataLoader(test_X, test_B, test_Y, bow_vocab, batch_size=256,shuffle=False)

test_loss = 0.
test_bce  = 0.
test_hyps = []
# validation loop
for batch in test_dataloader:
    batch_X, _, batch_BV, batch_Y, lengths_X = batch
    loss, bce, pred = compute_loss(batch_X, batch_BV, batch_Y, lengths_X, model, optimizer, is_train=False, gamma=200, l1_strength = l1_strength)
    test_loss += loss
    test_bce  += bce
 
# calc ave total loss
valid_loss = np.sum(test_loss) / len(test_dataloader.data)

# calc ntm weight sparsity
sparsity = check_sparsity(model.ntm.fcd1.weight.data)

# calc perplexity
test_bce = test_bce / test_dataloader.word_count
test_ppl = np.exp(test_bce)


print('Loss(test) {:5.2f}, PPL(test) {:5.2f}, sparsity(Test) {:.3f}'.format(
    test_loss, test_ppl, float(sparsity.cpu()), ))


#### Test prediction

In [0]:
test_dataloader = DataLoader(test_X, test_B, test_Y, bow_vocab, batch_size=1,shuffle=False)

In [0]:
model.eval()

batch_X, batch_B, batch_BV , batch_Y, lengths_X = next(test_dataloader)
sentence_X = ' '.join(ids_to_sentence(vocab_X, batch_X.data.cpu().numpy()[:-1, 0]))
sentence_Y = ' '.join(ids_to_sentence(vocab_Y, batch_Y.data.cpu().numpy()[:-1, 0]))

print('src: {}'.format(sentence_X))
print('tgt: {}'.format(sentence_Y))

output, z, g, recon_batch, mu, logvar  = model(batch_X, batch_BV, lengths_X, 5, '')
output = output.max(dim=-1)[1].view(-1).data.cpu().tolist()
output_sentence = ' '.join(set(ids_to_sentence(vocab_Y, trim_eos(output))))
print('out: {}'.format(output_sentence))