In [2]:
from torch.utils.data import Dataset, DataLoader
import os
import scipy.sparse
import numpy as np
from collections import defaultdict
import torch.nn as nn
import torch
from tqdm import tqdm, trange

In [3]:
class TextData(Dataset):
    def __init__(self, data_dir, partition='train'):
        self.partition = partition
        self.texts_en, self.bow_matrix_en, self.vocab_en, self.word2id_en, self.id2word_en = self.read_data(data_dir, lang='en')
        self.texts_cn, self.bow_matrix_cn, self.vocab_cn, self.word2id_cn, self.id2word_cn = self.read_data(data_dir, lang='cn')
        
        self.size_en = len(self.texts_en)
        self.size_cn = len(self.texts_cn)
        self.vocab_size_en = len(self.vocab_en)
        self.vocab_size_cn = len(self.vocab_cn)
        
        self.trans_dict, self.trans_matrix_en, self.trans_matrix_cn = self.parse_dictionary()
        
        self.Map_en2cn = self.get_Map(self.trans_matrix_en, self.bow_matrix_en)
        self.Map_cn2en = self.get_Map(self.trans_matrix_cn, self.bow_matrix_cn)


    def __getitem__(self, idx):
        batch_en = self.bow_matrix_en[idx]
        batch_cn = self.bow_matrix_cn[idx]
        return torch.tensor(batch_en, dtype=torch.float32), torch.tensor(batch_cn, dtype=torch.float32)

    def __len__(self):
        return self.size_en
        
    def read_text(self, path):
        texts = []
        with open(path, 'r') as f:
            for line in f: texts.append(line.strip())
        return texts

    def read_data(self, data_dir, lang):
        texts = self.read_text(os.path.join(data_dir, '{}_texts_{}.txt'.format(self.partition,lang)))
        vocab = self.read_text(os.path.join(data_dir, 'vocab_{}'.format(lang)))
        word2id = dict(zip(vocab, range(len(vocab))))
        id2word = dict(zip(range(len(vocab)), vocab))

        bow_matrix = scipy.sparse.load_npz(os.path.join(data_dir, '{}_bow_matrix_{}.npz'.format(self.partition,lang))).toarray()
        return texts, bow_matrix, vocab, word2id, id2word
    
    def parse_dictionary(self):
        trans_dict = defaultdict(set)
        trans_matrix_en = np.zeros((self.vocab_size_en, self.vocab_size_cn), dtype='int32')
        trans_matrix_cn = np.zeros((self.vocab_size_cn, self.vocab_size_en), dtype='int32')
        
        with open('./ch_en_dict.dat') as f:
            for line in f:
                terms = (line.strip()).split()
                if len(terms) == 2:
                    cn_term = terms[0]
                    en_term = terms[1]
                    if cn_term in self.word2id_cn and en_term in self.word2id_en:
                        trans_dict[cn_term].add(en_term)
                        trans_dict[en_term].add(cn_term)
                        cn_term_id = self.word2id_cn[cn_term]
                        en_term_id = self.word2id_en[en_term]

                        trans_matrix_en[en_term_id][cn_term_id] = 1
                        trans_matrix_cn[cn_term_id][en_term_id] = 1

        return trans_dict, trans_matrix_en, trans_matrix_cn
    
    def get_Map(self, trans_matrix, bow_matrix):
        Map = (trans_matrix * bow_matrix.sum(0)[:, np.newaxis]).astype('float32')
        Map = Map + 1
        Map_sum = np.sum(Map, axis=1)
        t_index = Map_sum > 0
        Map[t_index, :] = Map[t_index, :] / Map_sum[t_index, np.newaxis]
        return torch.tensor(Map)

In [4]:
trainDocSet = TextData('./data/Amazon_Review','train')
testDocSet = TextData('./data/Amazon_Review','test')

In [5]:
nn.functional.batch_norm

<function torch.nn.functional.batch_norm(input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05)>

In [98]:
class NMTM(nn.Module):
    def __init__(self, config, Map_en2cn, Map_cn2en):
        super(NMTM, self).__init__()
        self.config = config
        self.Map_en2cn = Map_en2cn
        self.Map_cn2en = Map_cn2en
        
        # encoder
        self.phi_cn = nn.Parameter(torch.randn(self.config['topic_num'], self.config['vocab_size_cn']))
        self.phi_en = nn.Parameter(torch.randn(self.config['topic_num'], self.config['vocab_size_en']))
        
        self.W_cn = nn.Parameter(torch.randn(self.config['vocab_size_cn'], self.config['e1']))
        self.W_en = nn.Parameter(torch.randn(self.config['vocab_size_en'], self.config['e1']))
        
        self.B_cn = nn.Parameter(torch.randn(self.config['e1']))
        self.B_en = nn.Parameter(torch.randn(self.config['e1']))
        
        self.act_fun = nn.Softplus()
        self.softmax = nn.Softmax()
        self.dropout = nn.Dropout(p=0.0)
        self.batch_norm_encode_en = nn.BatchNorm1d(self.config['topic_num'])
        self.batch_norm_encode_cn = nn.BatchNorm1d(self.config['topic_num'])
                
        self.W2 = nn.Parameter(torch.randn(self.config['e1'], self.config['e2']))
        self.B2 = nn.Parameter(torch.randn(self.config['e2']))
        
        self.W_m = nn.Parameter(torch.randn(self.config['e2'], self.config['topic_num']))
        self.B_m = nn.Parameter(torch.randn(self.config['topic_num']))
        
        self.W_s = nn.Parameter(torch.randn(self.config['e2'], self.config['topic_num']))
        self.B_s = nn.Parameter(torch.randn(self.config['topic_num']))
        
        self.init_params()
        # decoder
        beta_cn = (self.config['lam'] * torch.matmul(self.phi_en, self.Map_en2cn) + (1-self.config['lam']) * self.phi_cn).detach()
        self.beta_cn = nn.Parameter(beta_cn)
        beta_en = (self.config['lam'] * torch.matmul(self.phi_cn, self.Map_cn2en) + (1-self.config['lam']) * self.phi_en).detach()
        self.beta_en = nn.Parameter(beta_en)
        

        self.batch_norm_decode_en = nn.BatchNorm1d(self.config['vocab_size_en'])
        self.batch_norm_decode_cn = nn.BatchNorm1d(self.config['vocab_size_cn'])
        
        # loss
        self.a = 1 * torch.ones((1, int(self.config['topic_num'])))
        self.mu_priori = nn.Parameter((torch.log(self.a).T - torch.mean(torch.log(self.a),1).T).T, requires_grad=False)
        sigma_priori = (((1.0/self.a)*(1-(2.0/self.config['topic_num']))).T + 
                            (1.0/(self.config['topic_num']*self.config['topic_num']))*torch.sum(1.0/self.a, 1)).T
        self.sigma_priori = nn.Parameter(sigma_priori, requires_grad=False)
        
    def init_params(self):
        nn.init.xavier_uniform_(self.phi_cn)
        nn.init.xavier_uniform_(self.phi_en)
        
        nn.init.xavier_uniform_(self.W_cn)
        nn.init.xavier_uniform_(self.W_en)
        nn.init.zeros_(self.B_cn)
        nn.init.zeros_(self.B_en)
        
        nn.init.xavier_uniform_(self.W2)
        nn.init.xavier_uniform_(self.W_m)     
        nn.init.xavier_uniform_(self.W_s)
        
        nn.init.zeros_(self.B2)
        nn.init.zeros_(self.B_m)     
        nn.init.zeros_(self.B_s)   
        

    def encode(self, x, lang):
        if lang == 'en': 
            h = self.act_fun(torch.matmul(x, self.W_en) + self.B_en)
            batch_norm = self.batch_norm_encode_en
        else: 
            h = self.act_fun(torch.matmul(x, self.W_cn) + self.B_cn)
            batch_norm = self.batch_norm_encode_cn
        
        h = self.act_fun(torch.matmul(h, self.W2) + self.B2)
        h = self.dropout(h)
        
        mean = batch_norm(torch.matmul(h, self.W_m) + self.B_m)
        log_sigma_sq = batch_norm(torch.matmul(h, self.W_s) + self.B_s)
        val = torch.sqrt(torch.exp(log_sigma_sq))
        eps = torch.zeros_like(val).normal_()
        z = mean + torch.mul(val, eps)
        z = self.softmax(z)
        z = self.dropout(z)
        
        return z, mean, log_sigma_sq
    
    def decode(self, z, beta, lang):
        if lang == 'en': 
            batch_norm = self.batch_norm_decode_en
        else: 
            batch_norm = self.batch_norm_decode_cn
        
        x_recon = self.softmax(batch_norm(torch.matmul(z, beta)))
        return x_recon
    
    def get_loss(self, x, x_recon, z_mean, z_log_sigma_sq):
        sigma = torch.exp(z_log_sigma_sq)
        latent_loss = 0.5 * (torch.sum(torch.div(sigma, self.sigma_priori),1) + \
                        torch.sum(torch.mul(torch.div((self.mu_priori - z_mean), self.sigma_priori), (self.mu_priori - z_mean)), 1) 
                             - self.config['topic_num'] + torch.sum(torch.log(self.sigma_priori), 1) 
                             - torch.sum(z_log_sigma_sq, 1))
        recon_loss = torch.sum(-x * torch.log(x_recon), axis=1)
        loss = latent_loss + recon_loss
        return loss.mean()
    
    def forward(self, x_cn, x_en):
        # encode
        z_cn, z_mean_cn, z_log_sigma_sq_cn = self.encode(x_cn, 'cn')
        z_en, z_mean_en, z_log_sigma_sq_en = self.encode(x_en, 'en')
        
        # decode
        x_recon_cn = self.decode(z_cn, self.beta_cn, 'cn')
        x_recon_en = self.decode(z_en, self.beta_en, 'en')
        
        return z_cn, z_mean_cn, z_log_sigma_sq_cn, z_en, z_mean_en, z_log_sigma_sq_en, x_recon_cn, x_recon_en

In [116]:
config = dict()
config['topic_num'] = 20
config['batch_size'] = 128
config['epoch'] = 1
config['e1'] = 100
config['e2'] = 100
config['vocab_size_en'] = trainDocSet[0][0].size(0)
config['vocab_size_cn'] = trainDocSet[0][1].size(0)
config['lam'] = 0.8
config['learning_rate'] = 0.001
config['output_dir'] = './output'

model = NMTM(config, trainDocSet.Map_en2cn, trainDocSet.Map_cn2en)

In [117]:
def train(config, model, dataset):
    lr = config['learning_rate']
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    train_loader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)
    for epoch in trange(config['epoch']):
        for idx, batch in tqdm(enumerate(train_loader)):
            optimizer.zero_grad()
            batch_data_en, batch_data_cn = batch
            z_cn, z_mean_cn, z_log_sigma_sq_cn, z_en, z_mean_en, z_log_sigma_sq_en, x_recon_cn, x_recon_en = model(batch_data_en, batch_data_cn)

            # get_loss
            loss_cn = model.get_loss(batch_data_cn, x_recon_cn, z_mean_cn, z_log_sigma_sq_cn)
            loss_en = model.get_loss(batch_data_en, x_recon_en, z_mean_en, z_log_sigma_sq_en)
            loss = loss_cn + loss_en
            loss.backward()
            optimizer.step()

        print('Epoch {} \t Loss: {}'.format(epoch, loss))
    
    return model

In [118]:
model = train(config, model, trainDocSet)

  0%|                                                                                                                                                                               | 0/1 [00:00<?, ?it/s]

1it [00:00,  1.59it/s][A
2it [00:01,  1.32it/s][A
3it [00:01,  1.54it/s][A
4it [00:02,  1.38it/s][A
5it [00:03,  1.39it/s][A
6it [00:04,  1.28it/s][A
7it [00:05,  1.35it/s][A
8it [00:05,  1.30it/s][A
9it [00:06,  1.33it/s][A
10it [00:07,  1.35it/s][A
11it [00:08,  1.30it/s][A
12it [00:08,  1.33it/s][A
13it [00:09,  1.33it/s][A
14it [00:10,  1.32it/s][A
15it [00:11,  1.35it/s][A
16it [00:12,  1.28it/s][A
17it [00:12,  1.31it/s][A
18it [00:13,  1.32it/s][A
19it [00:14,  1.34it/s][A
20it [00:14,  1.45it/s][A
21it [00:15,  1.32it/s][A
22it [00:16,  1.42it/s][A
23it [00:17,  1.33it/s][A
24it [00:17,  1.34it/s][A
25it [00:18,  1.28it/s][A
26it [00:19,  1.34it/s][A
27it [00:20,  1.26it/s][A
28it [00:20,  1.31it/s][A
29it [00:21,  1.36it/s][A
30it [00:22,  1.34it/s

Epoch 0 	 Loss: 638.4900512695312





In [131]:
def print_top_words(config, beta, id2word, lang, n_top_words=15):
    top_words = []
    for i in range(len(beta)):
        top_words.append(" ".join([id2word[j] for j in beta[i].argsort().tolist()[:-n_top_words-1:-1]]))
    
    with open(os.path.join(config['output_dir'], 'top_words_T{}_K{}_{}'.format(n_top_words, config['topic_num'], lang)), 'w') as f:
        for line in top_words:
            f.write(line + '\n')
            print(line)

def export_beta(config, model, data):
    beta_en, beta_cn = model.beta_en, model.beta_cn
    print_top_words(config, beta_en, data.id2word_en, lang='en')
    print_top_words(config, beta_cn, data.id2word_cn, lang='cn')

In [None]:
export_beta(config, model, trainDocSet)

like works books want ever bought one detail really fit even made two written printer
recommend expect worth good feel price old fit found excellent acid characters change conclusion important
day go parts back review floor holy house value instead like quality behind amazon times
within found todays provide similar another software last guide looked shelf specifically captain edge singer
year know book really order got low read new people truly im friend detail even
movie well said book feel film made think started buy funny keep first much went
cant youre old want plot quality waste reviews music book operating prize actual someone world
sound bad old every without making set money comparing ever used go nice type work
think story though flaw set little read back less characters band enough however law thrill
maybe could worked see product songs sound great player short difficult directions enjoyed next together
cd get must song youll great box way may release right best john advice 