# Text CNN Autoencoder

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

### Do the basic imports

In [None]:
import os
import sys

import torch
import torch.nn as nn
import numpy as np
import json
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader

In [None]:
import os
import sys

home_dir = os.path.expanduser('~')
sys.path.append("{}/dev/github/ml-toolkit".format(home_dir))

## Load Data

In [None]:
from pytorch.utils.data.text.vectorizer import Vectorizer

vectorizer = Vectorizer(default_indexes={0: '<pad>'})
vectorizer.load_dictionary('/home/vdw/data/datasets/hotel-reviews-txt/dictionary', word_col=0)

print(vectorizer.vocab_size)

In [None]:
train_seq_list = []
with open('{}/data/datasets/hotel-reviews-txt/train_permute.txt'.format(home_dir)) as infile:
    for idx, line in enumerate(infile):
        seq = [ int(i) for i in line.strip().split()]
        train_seq_list.append(seq)

X, indices = vectorizer.prepare_sequences(train_seq_list, auto_padding=True, max_len=100, unknown_idx=1)

print(X[0])

# Free up some memory
train_seq_list = None

### Get max index (parameter for embedding layer)

In [None]:
max_idx = int(np.max([ np.max(seq) for seq in X ]))

### Print an example

In [None]:
print(X[1])
print(vectorizer.vocabulary.get_words(X[1]))

### Sample dataset for testing

In [None]:
num_samples = 1000
X_train = X[:num_samples]

print("Size of training set: {}".format(len(X_train)))

In [None]:
from pytorch.utils.data.text.wordvectorloader import WordVectorLoader

use_pretrained_embeddings = False

if use_pretrained_embeddings is True:
    word_vector_loader = WordVectorLoader(300)
    embed_mat = word_vector_loader.create_embedding_matrix('{}/data/dumps/glove/glove.840B.300d.txt'.format(home_dir), vectorizer.vocabulary.word_to_index, max_idx, init='random', verbatim=True)
    print(embed_mat.shape)

### Create training data iterator

In [None]:
batch_size = 32

In [None]:
from pytorch.utils.data.text.dataset import BucketBatchSampler, BucketDataset

bucket_batch_sampler = BucketBatchSampler(X_train, batch_size)
bucket_dataset = BucketDataset(X_train, None)

X_train_iter = DataLoader(bucket_dataset, batch_size=1, batch_sampler=bucket_batch_sampler, shuffle=False, num_workers=8, drop_last=False)

print(len(X_train_iter))

## Create network model

In [None]:
from pytorch.models.text.autoencoder.textcnnae import Parameters, ConvMode

### Use GPU if available

In [None]:
use_cuda = torch.cuda.is_available()
use_cuda = True
device = torch.device("cuda:0" if use_cuda else "cpu")
print(device)

In [None]:
path = '{}/data/ml-toolkit/pytorch-models/text-cnn-ae/'.format(home_dir)

params = {'conv_mode': ConvMode.D1,
          'max_seq_len': 100,
          'vocab_size': max_idx+1,
          'embed_dim': 300,
          'encoder_lr': 0.0001,
          'decoder_lr': 0.0001,
          'kernel_sizes': [3, 3, -1],
          'strides': [2, 2, 2],
          'num_filters': [300, 600, 500],
          'output_paddings': [1, 0, 0],
          'do_batch_norm': True,
          'dropout_ratio': 0.5,
          'tau': 0.01,
          'clip': 0.25
          }


print(params)
with open(path+'params.json', 'w') as outfile:
    json.dump(params, outfile)

params = Parameters(params)

In [None]:
from pytorch.models.text.autoencoder.textcnnae import TextCnnAE
criterion = nn.NLLLoss()
text_cnn_ae = TextCnnAE(device, params, criterion)

print(text_cnn_ae.encoder)
print(text_cnn_ae.decoder)
print(text_cnn_ae.params.kernel_sizes)

### Set pretrained word embeddings if needed

In [None]:
#print(sentence_vae.embedding.weight.data[0])
if use_pretrained_embeddings is True:
    print("Initialize embdding layer with pretrained embeddings")
    text_cnn_ae.embedding.weight.data.copy_(torch.from_numpy(embed_mat))
    text_cnn_ae.embedding.weight.data = F.normalize(text_cnn_ae.embedding.weight.data, p=2, dim=1)
    text_cnn_ae.embedding.weight.requires_grad = False
else:
    text_cnn_ae.embedding.weight.requires_grad = True


## Train model

In [None]:
losses = []

In [None]:
num_epochs = 100
safe_after_epoch = False

encoder_file_name = '{}/data/ml-toolkit/pytorch-models/text-cnn-ae/textcnnae-encoder.model'.format(home_dir)
decoder_file_name = '{}/data/ml-toolkit/pytorch-models/text-cnn-ae/textcnnae-decoder.model'.format(home_dir)

text_cnn_ae.train()

text_cnn_ae.set_learning_rates(0.001, 0.001)
for epoch in range(num_epochs):
    epoch_loss = text_cnn_ae.train_epoch(epoch, X_train_iter, verbatim=True)
    print(epoch_loss)
    losses.append(epoch_loss)
    if safe_after_epoch:
        text_cnn_ae.save_models(encoder_file_name, decoder_file_name)
    text_cnn_ae.update_learning_rates(0.99, 0.99)        
        
text_cnn_ae.eval()

In [None]:
max_loss = np.max(losses)
losses_normalized = losses / max_loss

plt.plot(losses_normalized, label='loss')
plt.legend(loc='upper right')
plt.ylabel('CNN-AE (e_dim={})'.format(params.embed_dim))

plt.show()

## Evaluate model

In [None]:
def check_sequence(sequence, model, vectorizer):
    original_sequence = vectorizer.sequence_to_text(sequence)
    #print(original_sequence)
    X = torch.tensor([sequence], dtype=torch.long).to(model.device)
    #print(X)
    decoded_indices = model.generate(X)
    decoded_sequence = vectorizer.sequence_to_text(decoded_indices)
    return ' '.join(original_sequence), ' '.join(decoded_sequence)
    
#print(X[0])
print(check_sequence(X[0] ,text_cnn_ae, vectorizer))

### Check a sample of the training data

In [None]:
for idx, s in enumerate(X):
    original, decoded = check_sequence(s, text_cnn_ae, vectorizer)
    print("================================================")
    print()
    print(original)
    print(">>>")
    print(decoded)
    print()
    if idx > 200:
        break

### Check test data

In [None]:
test_seq_list = []
with open('{}/data/datasets/hotel-reviews-txt/test.txt'.format(home_dir)) as infile:
    for idx, line in enumerate(infile):
        seq = [ int(i) for i in line.strip().split()]
        test_seq_list.append(seq)

X_test, _ = vectorizer.prepare_sequences(test_seq_list, auto_padding=True, max_len=100, unknown_idx=1)

print(X_test[0])

test_seq_list = None

In [None]:
for idx, s in enumerate(X_test):
    original, decoded = check_sequence(s, text_cnn_ae, vectorizer)
    print("================================================\n")
    print(' '.join([t for t in original.split() if t != '<pad>' ]))
    print(">>>")
    print(' '.join([t for t in decoded.split() if t != '<pad>' ]))
    print("\n")
    if idx > 200:
        break