# Load Train Test Dataset

In [50]:
import pandas as pd


def load_df_surnames():
    df_train = pd.read_pickle('data/pickles/train.pickle',compression='bz2')
    df_test = pd.read_pickle('data/pickles/test.pickle',compression='bz2')
    
    return df_train, df_test

def load_df_categories():
    return pd.read_pickle('data/pickles/df_categories.pickle',compression='bz2')

In [51]:
df_train, df_test = load_df_surnames()
df_categories = load_df_categories()

# Load RNN Model

In [52]:
from io import open
import glob
import os
import unicodedata
import string
'''
def findFiles(path): return glob.glob(path)


all_letters = string.ascii_letters + " .,;'"
n_letters = len(all_letters)

# Turn a Unicode string to plain ASCII, thanks to http://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
        and c in all_letters
    )

# Build the category_lines dictionary, a list of names per language
category_lines = {}
all_categories = []

# Read a file and split into lines
def readLines(filename):
    lines = open(filename, encoding='utf-8').read().strip().split('\n')
    return [unicodeToAscii(line) for line in lines]

for filename in findFiles('data/names/*.txt'):
    category = os.path.splitext(os.path.basename(filename))[0]
    all_categories.append(category)
    lines = readLines(filename)
    category_lines[category] = lines

n_categories = len(all_categories)
'''

In [53]:
import unicodedata
import string
import torch

'''
all_letters = string.ascii_letters + " .,;'"
n_letters = len(all_letters)
'''
import torch.nn as nn

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size

        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        output = self.softmax(output)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)

rnn = torch.load('data/model/rnn.pickle')

In [54]:
import torch
import json
import modeling.surname_common as sc

'''
# Find letter index from all_letters, e.g. "a" = 0
def letterToIndex(letter):
    return all_letters.find(letter)

# Just for demonstration, turn a letter into a <1 x n_letters> Tensor
def letterToTensor(letter):
    tensor = torch.zeros(1, n_letters)
    tensor[0][letterToIndex(letter)] = 1
    return tensor

# Turn a line into a <line_length x 1 x n_letters>,
# or an array of one-hot letter vectors
def lineToTensor(line):
    tensor = torch.zeros(len(line), 1, n_letters)
    for li, letter in enumerate(line):
        tensor[li][0][letterToIndex(letter)] = 1
    return tensor
'''

def load_model():
    return torch.load('model/rnn.pickle')

def evaluate(line_tensor):
    hidden = rnn.initHidden()

    for i in range(line_tensor.size()[0]):
        output, hidden = rnn(line_tensor[i], hidden)

    return output

def predict(input_line, n_predictions=3):
#     print('\n> %s' % input_line)
    with torch.no_grad():
        output = evaluate(sc.surname_to_tensor(input_line))

        # Get top N categories
        topv, topi = output.topk(n_predictions, 1, True)
        predictions = []

        for i in range(n_predictions):
            value = topv[0][i].item()
            category_index = topi[0][i].item()
#             print('(%.2f) %s' % (value, all_categories[category_index]))
            predictions.append([value, all_categories[category_index]])
        return json.dumps(predictions)

# Dashbaord program

In [55]:
%%html
<style>
table.dataframe {
    width: 100%
}
iframe.wiki {
    width: 100%;
    height: 400px;
}
</style>

In [56]:
from IPython.display import display, clear_output, HTML, Javascript
import ipywidgets as widgets


def get_sample_df_surnames(category, dataset, n):
    if dataset is 'test':
        df = df_test
    else:
        df = df_train
    return df[df['category'].str.match(category)].sample(n)

def on_pick_category(category, dataset, n):
    df = get_sample_df_surnames(category, dataset, n)
    html = HTML(df.to_html(escape=False))
    return display(html)

def on_describe_dataset(dataset):
    if dataset is 'test':
        df = df_test.describe()
    elif dataset is 'train':
        df = df_train.describe()
    else:
        df = df_categories
    html = HTML(df.to_html(escape=False))
    return display(html)

## browse dataset

In [57]:
picker_w = widgets.interact(on_pick_category, category=list(df_categories['category']), dataset=['train','test'], n=5)

interactive(children=(Dropdown(description='category', options=('Arabic', 'Chinese', 'Czech', 'Dutch', 'Englis…

In [58]:
picker_w2 = widgets.interact(on_describe_dataset, dataset=['train','test', 'categies'])

interactive(children=(Dropdown(description='dataset', options=('train', 'test', 'categies'), value='train'), O…

# Predict Classification

## 1. search surname from train or test dataset

In [59]:
def on_search_surname(surname, dataset):
    if dataset is 'test':
        df_search_result = df_test[df_test['surname'].str.contains(surname, case=False)]
    else:
        df_search_result = df_train[df_train['surname'].str.contains(surname, case=False)]
    html = HTML(df_search_result.to_html(escape=False))
    return display(html)
    
    
search_w = widgets.interact(on_search_surname, surname='james', dataset=['train','test'])

interactive(children=(Text(value='james', description='surname'), Dropdown(description='dataset', options=('tr…

## 2. Predict

In [60]:
output_text = widgets.Textarea()
output_text

Textarea(value='')

In [61]:
input_text = widgets.Text()

def on_classify_surname(sender):
    output_text.value = '\n'.join([input_text.value, '>> Infered class is ', predict(input_text.value)])

    
input_text.on_submit(on_classify_surname)
input_text

Text(value='')