# Style transfer exploration

In [134]:
import os
import sys
from typing import Dict, List, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from transformers import PreTrainedTokenizer

sys.path.append('../deep-latent-sequence-model/src')
from data_utils import DataUtil
from utils import reorder
os.chdir('/home/przemyslaw/text-style-transfer/deep-latent-sequence-model')  # sorry for that

In [22]:
model_dir = '../deep-latent-sequence-model/outputs_yelp/yelp_wd0.0_wb0.0_ws0.0_an3_pool5_klw0.1_lr0.001_t0.01_lm_bt_hard_avglen'
model_path = os.path.join(model_dir, 'model.pt')
model = torch.load(model_path)

In [23]:
model

Seq2Seq(
  (encoder): Encoder(
    (word_emb): Embedding(9653, 128, padding_idx=0)
    (layer): LSTM(128, 512, batch_first=True, dropout=0.3, bidirectional=True)
    (bridge): Linear(in_features=1024, out_features=512, bias=False)
    (dropout): Dropout(p=0.3, inplace=False)
  )
  (decoder): Decoder(
    (attention): MlpAttn(
      (dropout): Dropout(p=0.3, inplace=False)
      (w_trg): Linear(in_features=512, out_features=512, bias=True)
      (w_att): Linear(in_features=512, out_features=1, bias=True)
    )
    (ctx_to_readout): Linear(in_features=1536, out_features=512, bias=False)
    (readout): Linear(in_features=512, out_features=9653, bias=False)
    (word_emb): Embedding(9653, 128, padding_idx=0)
    (attr_emb): Embedding(2, 128, padding_idx=0)
    (layer): LSTMCell(1152, 512)
    (dropout): Dropout(p=0.3, inplace=False)
  )
  (enc_to_k): Linear(in_features=1024, out_features=512, bias=False)
  (noise): NoiseLayer()
  (LM0): LSTM_LM(
    (embed): Embedding(9653, 128, padding_id

In [24]:
class HParams(object):
    def __init__(self, **args):
        self.pad = "<pad>"
        self.unk = "<unk>"
        self.bos = "<s>"
        # self.eos = "<\s>"
        self.eos = "</s>"
        self.pad_id = 0
        self.unk_id = 1
        self.bos_id = 2
        self.eos_id = 3

        self.batcher = "sent"
        self.batch_size = 32
        self.src_vocab_size = None
        self.trg_vocab_size = None

        self.inf = float("inf")

In [26]:
class TranslationHparams(HParams):
    dataset = "Translate dataset"
    def __init__(self):
        self.cuda = True
        self.beam_size = 1
        self.max_len = 300
        self.batch_size = 32
        self.merge_bpe = False
        self.decode = True

In [27]:
hparams = TranslationHparams()

In [28]:
hparams_file_name = os.path.join(model_dir, "hparams.pt")
train_hparams = torch.load(hparams_file_name)
hparams = TranslationHparams()
for k, v in train_hparams.__dict__.items():
    setattr(hparams, k, v)

In [57]:
data = model.data

In [53]:
hyps = list()

In [39]:
x_valid, x_mask, x_count, x_len, x_pos_emb_idxs, y_valid, y_mask, \
    y_count, y_len, y_pos_emb_idxs, y_neg, batch_size, end_of_epoch, index = data.next_test(test_batch_size=64)

In [41]:
hs = model.translate(
            x_valid, x_mask, x_len, y_neg, y_mask, y_len, beam_size=hparams.beam_size, max_len=hparams.max_len, poly_norm_m=0)



In [51]:
hs = reorder(hs, index)

In [58]:
h_best_words = map(lambda wi: data.trg_i2w_list[0][wi], hs)

In [68]:
word_lists = map(lambda h: ' '.join([data.src_i2w[w] for w in h]), hs)
list(word_lists)

['great job all from me .',
 "best service i 've ever experienced .",
 'always great and timely .',
 'love this place , you have good quality and quality staff .',
 "awesome , excellent service , they do n't have any good food .",
 'best service ever .',
 "i love the wings but they are n't worth the beyond great service .",
 "i 'll definitely b going back .",
 'loved the server , tried to charge me for a new ranch !',
 'totally reccomend for any family restaurant in my opinion .',
 "fantastic and that 's the inside of the restaurant is so perfect .",
 'the carrot and antipasto pancakes are tender and juicy and color .',
 'was definitely recommended when i walked in .',
 'place is in top of some local businesses .',
 'drink was served in a fresh plastic cup .',
 'so i ordered a salad dish .',
 'also pretty good quality food .',
 'highly recommended .',
 'i felt good for the owners of this arizona staple .',
 'it definitely rocks .',
 "i love this place , do n't have to have to never .",

## Tokenizer
Let's be honest: DataUtil class sucks. It has a lot of useless stuff, all we need is a vocabulary and a way to encode words to token indices (and decode them back). That should be extremely easy, but it's not. Let's implement something like that.

All we need is a vocabulary.

In [81]:
data.src_i2w[:10]

['<pad>', '<unk>', '<s>', '</s>', 'i', 'was', 'sadly', 'mistaken', '.', 'so']

In [171]:
class Tokenizer:
    def __init__(self, vocab: List[str]):
        self.idx2word = vocab
        self.word2idx = {word: i for i, word in enumerate(vocab)}
        
    def convert_ids_to_tokens(self, token_ids: Union[int, List[int]]):
        if type(token_ids) == list:
            return [self.idx2word[id_] for id_ in token_ids]
        elif type(token_ids) == int:
            return self.idx2word[ids]
        else:
            raise TypeError(f'Type of ids should be either list or int but is {type(ids)}')
            
    def convert_tokens_to_ids(self, tokens: Union[str, List[str]]):
        if type(tokens) == list:
            return [self.word2idx.get(token, 1) for token in tokens]
        elif type(tokens) == str:
            return self.word2idx.get(tokens, 1)
        else:
            raise TypeError(f'Type of ids should be either list or str but is {type(tokens)}')
            
    def convert_tokens_to_string(self, tokens: List[str]):
        return ' '.join(tokens)
    
    def decode(self, token_ids: List[int], skip_special_tokens: bool = False):
        if skip_special_tokens:
            ids = [id_ for id_ in token_ids if id_ >= 4]
        tokens = self.convert_ids_to_tokens(token_ids)
        return self.convert_tokens_to_string(tokens)
    
    def get_vocab(self):
        return self.word2idx
    
    def tokenize(self, text: str):
        """Tokenizes a piece of text. Assumes that dots and commas etc. are taken care of before."""
        tokens = text.lower().split()
        return self.convert_tokens_to_ids(tokens)

In [172]:
tokenizer = Tokenizer(data.src_i2w)

In [175]:
encoding = tokenizer.tokenize('I very much like research job , it is my passion .')
encoding

[4, 103, 435, 154, 6179, 2444, 14, 41, 16, 170, 2337, 8]

In [176]:
tokenizer.decode(encoding)

'i very much like research job , it is my passion .'