In [1]:
%matplotlib inline

%load_ext autoreload
%autoreload 2

In [2]:
from collections import Counter

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from tensorflow import keras as K

from categorization import data
from categorization.featurization import NameAndReviewTextFeaturizer
from categorization.evaluation import experiment, model_experiment, evaluate_model
from categorization.model import DenseTextualModel, RnnTextualModel

## Build the Train and Validation Datasets

In [3]:
businesses = data.load_business_df()
len(businesses)

158525

In [4]:
pen_businesses = businesses[businesses.state == 'PA']

In [5]:
pen_businesses.head()

Unnamed: 0,business_id,business_name,review_count,stars,state,city,categories
20,1RHY4K3BD22FK7Cfftn8Mg,Marathon Diner,35,4.0,PA,Pittsburgh,"[Sandwiches, Salad, Restaurants, Burgers, Comf..."
43,qWWBVE5T_zMEF7UJ4iTfNw,"DJ Yonish, Inc.",3,2.5,PA,Bethel Park,"[Home Services, Heating & Air Conditioning/HVAC]"
51,dQj5DLZjeDK3KFysh1SYOQ,Apteka,242,4.5,PA,Pittsburgh,"[Nightlife, Bars, Polish, Modern European, Res..."
58,v-scZMU6jhnmV955RSzGJw,No. 1 Sushi Sushi,106,4.5,PA,Pittsburgh,"[Japanese, Sushi Bars, Restaurants]"
61,KFbUQ-RR2UOV62Ep7WnXHw,Westwood Bar & Grill,5,3.0,PA,West Mifflin,"[American (Traditional), Restaurants]"


In [6]:
root_categories = data.CategoryTree().root_categories

In [7]:
examples, label_sets = data.load_examples(
    set(pen_businesses.business_id.unique()),
    min_reviews=1,
    accepted_categories=root_categories,
)

len(examples), len(label_sets)

(9411, 9411)

In [8]:
train_examples, validation_examples, train_label_sets, validation_label_sets = \
    train_test_split(examples, label_sets, test_size=.2)

len(train_examples), len(train_label_sets), len(validation_examples), len(validation_label_sets)

(7528, 7528, 1883, 1883)

## Models

In [9]:
labelizer = MultiLabelBinarizer()
labelizer.fit(train_label_sets)


NUM_CLASSES = len({label for labels in label_sets for label in labels})

In [10]:
%%time

vocab_size = 5000
input_length = 10000
exp = experiment(
    NameAndReviewTextFeaturizer(max_vocab_size=vocab_size, max_length=input_length),
    DenseTextualModel(
        vocab_size=vocab_size, input_length=input_length,
        embedding_dimension=50, hidden_dimension=30, num_classes=NUM_CLASSES,
        learning_rate=0.01, epochs=10, batch_size=256),
    MultiLabelBinarizer(),
    train_examples, train_label_sets,
    validation_examples, validation_label_sets
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Train Macro:
precision    0.917999
recall       0.046683
f1           0.063665
dtype: float64

Validation Macro:
precision    0.911365
recall       0.048892
f1           0.066421
dtype: float64
CPU times: user 7min 40s, sys: 4min 54s, total: 12min 35s
Wall time: 3min 12s


In [None]:
featurizer = NameAndReviewTextFeaturizer(max_vocab_size=vocab_size, max_length=input_length)
featurizer.fit(train_examples)
train_features = featurizer.transform(train_examples)
validation_features = featurizer.transform(validation_examples)

In [None]:
dense_experiment = model_experiment(
    featurizer,
    DenseTextualModel(
        vocab_size=vocab_size, input_length=input_length,
        embedding_dimension=50, hidden_dimension=50, num_classes=NUM_CLASSES,
        learning_rate=0.01, epochs=200, batch_size=256
    ),
    labelizer,
    train_features,
    train_label_sets, validation_features, validation_label_sets, 
)

In [None]:
train_labels = labelizer.transform(train_label_sets)
validation_labels = labelizer.transform(validation_label_sets)

In [None]:
rnn_experiment = model_experiment(
    featurizer,
    RnnTextualModel(
        vocab_size=vocab_size, input_length=input_length,
        embedding_dimension=50, rnn_dimension=30, num_classes=NUM_CLASSES,
        learning_rate=0.01, epochs=1, batch_size=256
    ),
    labelizer,
    train_features,
    train_label_sets, validation_features, validation_label_sets, 
)