In [1]:
import pandas as pd
import numpy as np

In [2]:
if False:
    data = pd.read_csv("data.csv")
    sizes_dict = dict(data.theme.value_counts())
    max_size = 380768
    others_index = data[data.theme == 'other'].index.values
    np.random.seed(999)
    np.random.shuffle(others_index)
    remove_index = others_index[max_size:]
    print(remove_index)
    new_data = data[~data.index.isin(remove_index)]
    print(new_data.theme.value_counts())
    new_data.to_csv('balanced_data.csv', index=False)

In [3]:
VAL_RATIO = 0.1

def prepare_csv(path="balanced_data.csv", seed=999):
    data = pd.read_csv(path)
    
    idx = data.index.values
    np.random.seed(seed)
    np.random.shuffle(idx)
    
    val_size = int(len(idx) * VAL_RATIO)
    
    data.iloc[idx[:val_size]].to_csv(
        "cache/dataset_val.csv", index=False)
    data.iloc[idx[val_size:2*val_size]].to_csv(
        "cache/dataset_test.csv", index=False)
    data.iloc[idx[2*val_size:]].to_csv(
        "cache/dataset_train.csv", index=False)

In [4]:
import re
import spacy
NLP = spacy.load('en')
MAX_CHARS = 20000

In [5]:
def tokenizer(comment):
    comment = re.sub(
        r"[\*\"“”\n\\…\+\-\/\=\(\)‘•:\[\]\|’\!;]", " ", 
        str(comment))
    comment = re.sub(r"[ ]+", " ", comment)
    comment = re.sub(r"\!+", "!", comment)
    comment = re.sub(r"\,+", ",", comment)
    comment = re.sub(r"\?+", "?", comment)
    if (len(comment) > MAX_CHARS):
        comment = comment[:MAX_CHARS]
    return [
        x.text for x in NLP.tokenizer(comment) if x.text != " "]

In [6]:
import logging
import torch
from torchtext import data
LOGGER = logging.getLogger("reviews_dataset")

In [7]:
#prepare_csv()

In [10]:
def get_dataset(fix_length=100, lower=False, vectors=None, 
                prepare_data=False):
    
    if vectors is not None:
        # pretrain vectors only supports all lower cases
        lower = True
    
    LOGGER.debug("Preparing CSV files...")
    
    if prepare_data:
        prepare_csv()
    
    review_text = data.Field(
        sequential = True,
        #fix_length=fix_length,
        tokenize = tokenizer,
        pad_first = False,
        dtype = torch.int64,
        lower = lower,
        init_token = '<sos>',
        eos_token = '<eos>',
        
    )
    
    theme = data.Field(
                use_vocab=True, 
                sequential=False, 
                dtype=torch.int64)
    
    meta_id = data.Field(
                use_vocab=True, sequential=False, 
                dtype=torch.int64)
    
    fields=[
            ('meta_id', meta_id),
            ('review_text', review_text),
            ('theme', theme)]
    
    LOGGER.debug("Reading train csv file...")
    train, val = data.TabularDataset.splits(
        path='cache/', format='csv', skip_header=True,
        train='dataset_train.csv', validation='dataset_val.csv',
        fields = fields
        )
    
    LOGGER.debug("Reading test csv file...")
    test = data.TabularDataset(
        path='cache/dataset_test.csv', format='csv', 
        skip_header=True, fields=fields)
    
    LOGGER.debug("Building vocabulary...")
    
    review_text.build_vocab(
        train, val, test,
        max_size=30000,
        min_freq=5,
        vectors=vectors
    )

    meta_id.build_vocab(
        train, val, test,
        max_size=float('inf'),
        min_freq=0,
    )
    
    theme.build_vocab(
        train, val, test,
        max_size=10,
        min_freq=0,
    )
    
    
    LOGGER.debug("Done preparing the datasets")
    return train, val, test

In [11]:
def get_iterator(dataset, batch_size, train=True, shuffle=True, repeat=False):
    
    dataset_iter = data.Iterator(
        dataset, batch_size=batch_size, device='cuda',
        train=train, shuffle=shuffle, repeat=repeat,
        
        sort_key = lambda x: len(x.review_text),
        sort_within_batch=False,
        sort=True
    )
    return dataset_iter

In [12]:
# ff = val_dataset.fields['review_text']
# len(val_dataset)

In [11]:
train_dataset, val_dataset, test_dataset = get_dataset(prepare_data=True)

In [36]:
for examples in get_iterator(
            test_dataset, 32, train=False,
            shuffle=False, repeat=False,
        ):
    x = examples.review_text # (fix_length, batch_size) Tensor
    y = examples.theme
    print(x.shape, y.shape)

torch.Size([0, 32]) torch.Size([32])
torch.Size([1, 32]) torch.Size([32])
torch.Size([1, 32]) torch.Size([32])
torch.Size([1, 32]) torch.Size([32])
torch.Size([1, 32]) torch.Size([32])
torch.Size([1, 32]) torch.Size([32])
torch.Size([1, 32]) torch.Size([32])
torch.Size([1, 32]) torch.Size([32])
torch.Size([1, 32]) torch.Size([32])
torch.Size([1, 32]) torch.Size([32])
torch.Size([1, 32]) torch.Size([32])
torch.Size([1, 32]) torch.Size([32])
torch.Size([1, 32]) torch.Size([32])
torch.Size([1, 32]) torch.Size([32])
torch.Size([1, 32]) torch.Size([32])
torch.Size([2, 32]) torch.Size([32])
torch.Size([2, 32]) torch.Size([32])
torch.Size([2, 32]) torch.Size([32])
torch.Size([2, 32]) torch.Size([32])
torch.Size([2, 32]) torch.Size([32])
torch.Size([2, 32]) torch.Size([32])
torch.Size([2, 32]) torch.Size([32])
torch.Size([2, 32]) torch.Size([32])
torch.Size([2, 32]) torch.Size([32])
torch.Size([2, 32]) torch.Size([32])
torch.Size([2, 32]) torch.Size([32])
torch.Size([2, 32]) torch.Size([32])
t

torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
torch.Size([6, 32]) torch.Size([32])
t

torch.Size([9, 32]) torch.Size([32])
torch.Size([9, 32]) torch.Size([32])
torch.Size([9, 32]) torch.Size([32])
torch.Size([9, 32]) torch.Size([32])
torch.Size([9, 32]) torch.Size([32])
torch.Size([9, 32]) torch.Size([32])
torch.Size([9, 32]) torch.Size([32])
torch.Size([9, 32]) torch.Size([32])
torch.Size([9, 32]) torch.Size([32])
torch.Size([9, 32]) torch.Size([32])
torch.Size([9, 32]) torch.Size([32])
torch.Size([9, 32]) torch.Size([32])
torch.Size([9, 32]) torch.Size([32])
torch.Size([9, 32]) torch.Size([32])
torch.Size([9, 32]) torch.Size([32])
torch.Size([9, 32]) torch.Size([32])
torch.Size([9, 32]) torch.Size([32])
torch.Size([9, 32]) torch.Size([32])
torch.Size([9, 32]) torch.Size([32])
torch.Size([9, 32]) torch.Size([32])
torch.Size([9, 32]) torch.Size([32])
torch.Size([10, 32]) torch.Size([32])
torch.Size([10, 32]) torch.Size([32])
torch.Size([10, 32]) torch.Size([32])
torch.Size([10, 32]) torch.Size([32])
torch.Size([10, 32]) torch.Size([32])
torch.Size([10, 32]) torch.Size([

torch.Size([12, 32]) torch.Size([32])
torch.Size([12, 32]) torch.Size([32])
torch.Size([12, 32]) torch.Size([32])
torch.Size([12, 32]) torch.Size([32])
torch.Size([12, 32]) torch.Size([32])
torch.Size([12, 32]) torch.Size([32])
torch.Size([12, 32]) torch.Size([32])
torch.Size([12, 32]) torch.Size([32])
torch.Size([12, 32]) torch.Size([32])
torch.Size([12, 32]) torch.Size([32])
torch.Size([12, 32]) torch.Size([32])
torch.Size([12, 32]) torch.Size([32])
torch.Size([12, 32]) torch.Size([32])
torch.Size([12, 32]) torch.Size([32])
torch.Size([12, 32]) torch.Size([32])
torch.Size([12, 32]) torch.Size([32])
torch.Size([12, 32]) torch.Size([32])
torch.Size([12, 32]) torch.Size([32])
torch.Size([12, 32]) torch.Size([32])
torch.Size([13, 32]) torch.Size([32])
torch.Size([13, 32]) torch.Size([32])
torch.Size([13, 32]) torch.Size([32])
torch.Size([13, 32]) torch.Size([32])
torch.Size([13, 32]) torch.Size([32])
torch.Size([13, 32]) torch.Size([32])
torch.Size([13, 32]) torch.Size([32])
torch.Size([

torch.Size([15, 32]) torch.Size([32])
torch.Size([15, 32]) torch.Size([32])
torch.Size([15, 32]) torch.Size([32])
torch.Size([15, 32]) torch.Size([32])
torch.Size([15, 32]) torch.Size([32])
torch.Size([15, 32]) torch.Size([32])
torch.Size([15, 32]) torch.Size([32])
torch.Size([15, 32]) torch.Size([32])
torch.Size([15, 32]) torch.Size([32])
torch.Size([15, 32]) torch.Size([32])
torch.Size([15, 32]) torch.Size([32])
torch.Size([15, 32]) torch.Size([32])
torch.Size([15, 32]) torch.Size([32])
torch.Size([15, 32]) torch.Size([32])
torch.Size([16, 32]) torch.Size([32])
torch.Size([16, 32]) torch.Size([32])
torch.Size([16, 32]) torch.Size([32])
torch.Size([16, 32]) torch.Size([32])
torch.Size([16, 32]) torch.Size([32])
torch.Size([16, 32]) torch.Size([32])
torch.Size([16, 32]) torch.Size([32])
torch.Size([16, 32]) torch.Size([32])
torch.Size([16, 32]) torch.Size([32])
torch.Size([16, 32]) torch.Size([32])
torch.Size([16, 32]) torch.Size([32])
torch.Size([16, 32]) torch.Size([32])
torch.Size([

torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([19, 32]) torch.Size([32])
torch.Size([

torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([22, 32]) torch.Size([32])
torch.Size([

torch.Size([26, 32]) torch.Size([32])
torch.Size([26, 32]) torch.Size([32])
torch.Size([26, 32]) torch.Size([32])
torch.Size([26, 32]) torch.Size([32])
torch.Size([26, 32]) torch.Size([32])
torch.Size([26, 32]) torch.Size([32])
torch.Size([26, 32]) torch.Size([32])
torch.Size([26, 32]) torch.Size([32])
torch.Size([26, 32]) torch.Size([32])
torch.Size([26, 32]) torch.Size([32])
torch.Size([26, 32]) torch.Size([32])
torch.Size([26, 32]) torch.Size([32])
torch.Size([26, 32]) torch.Size([32])
torch.Size([26, 32]) torch.Size([32])
torch.Size([26, 32]) torch.Size([32])
torch.Size([26, 32]) torch.Size([32])
torch.Size([26, 32]) torch.Size([32])
torch.Size([26, 32]) torch.Size([32])
torch.Size([26, 32]) torch.Size([32])
torch.Size([26, 32]) torch.Size([32])
torch.Size([26, 32]) torch.Size([32])
torch.Size([26, 32]) torch.Size([32])
torch.Size([26, 32]) torch.Size([32])
torch.Size([26, 32]) torch.Size([32])
torch.Size([27, 32]) torch.Size([32])
torch.Size([27, 32]) torch.Size([32])
torch.Size([

torch.Size([31, 32]) torch.Size([32])
torch.Size([31, 32]) torch.Size([32])
torch.Size([32, 32]) torch.Size([32])
torch.Size([32, 32]) torch.Size([32])
torch.Size([32, 32]) torch.Size([32])
torch.Size([32, 32]) torch.Size([32])
torch.Size([32, 32]) torch.Size([32])
torch.Size([32, 32]) torch.Size([32])
torch.Size([32, 32]) torch.Size([32])
torch.Size([32, 32]) torch.Size([32])
torch.Size([32, 32]) torch.Size([32])
torch.Size([32, 32]) torch.Size([32])
torch.Size([32, 32]) torch.Size([32])
torch.Size([32, 32]) torch.Size([32])
torch.Size([32, 32]) torch.Size([32])
torch.Size([32, 32]) torch.Size([32])
torch.Size([32, 32]) torch.Size([32])
torch.Size([32, 32]) torch.Size([32])
torch.Size([32, 32]) torch.Size([32])
torch.Size([32, 32]) torch.Size([32])
torch.Size([32, 32]) torch.Size([32])
torch.Size([32, 32]) torch.Size([32])
torch.Size([32, 32]) torch.Size([32])
torch.Size([32, 32]) torch.Size([32])
torch.Size([32, 32]) torch.Size([32])
torch.Size([32, 32]) torch.Size([32])
torch.Size([

torch.Size([40, 32]) torch.Size([32])
torch.Size([40, 32]) torch.Size([32])
torch.Size([40, 32]) torch.Size([32])
torch.Size([40, 32]) torch.Size([32])
torch.Size([40, 32]) torch.Size([32])
torch.Size([40, 32]) torch.Size([32])
torch.Size([40, 32]) torch.Size([32])
torch.Size([40, 32]) torch.Size([32])
torch.Size([40, 32]) torch.Size([32])
torch.Size([40, 32]) torch.Size([32])
torch.Size([40, 32]) torch.Size([32])
torch.Size([40, 32]) torch.Size([32])
torch.Size([40, 32]) torch.Size([32])
torch.Size([40, 32]) torch.Size([32])
torch.Size([41, 32]) torch.Size([32])
torch.Size([41, 32]) torch.Size([32])
torch.Size([41, 32]) torch.Size([32])
torch.Size([41, 32]) torch.Size([32])
torch.Size([41, 32]) torch.Size([32])
torch.Size([41, 32]) torch.Size([32])
torch.Size([41, 32]) torch.Size([32])
torch.Size([41, 32]) torch.Size([32])
torch.Size([41, 32]) torch.Size([32])
torch.Size([41, 32]) torch.Size([32])
torch.Size([41, 32]) torch.Size([32])
torch.Size([41, 32]) torch.Size([32])
torch.Size([