In [1]:
import tensorflow as tf

import keras
from keras.losses import cosine_similarity

from models import TextFeatureExtractorLayer, GloveEmbeddingLayer
from utils import GoogleRestaurantsReviewDataset

## Load the Train and Test Dataset

In [2]:
max_seq_length = 500
dataset = GoogleRestaurantsReviewDataset(max_seq_length=max_seq_length)
text_vectorize = dataset.text_vectorize

print('Load training data')
train_X_user, train_X_bus, train_y = dataset.load_train_or_test_dataset(train=True)
print(f'Total {train_y.shape} training data\n')

print('Load test data')
test_X_user, test_X_bus, test_y = dataset.load_train_or_test_dataset(train=False)
print(f'Total {test_y.shape} test data\n')

Build Training data


78422it [01:50, 711.91it/s]


Build Testing data


19606it [00:15, 1302.48it/s]


Load training data
Total (78422,) training data

Load test data
Total (19606,) test data



## Build the Model

In [6]:
print('Build model')
embedding = GloveEmbeddingLayer(num_tokens=len(text_vectorize),
                                vocabulary_dict=text_vectorize.vocabulary)

user_inputs = keras.Input(shape=(None,), dtype="int64")
x = embedding(user_inputs)
user_outputs = TextFeatureExtractorLayer(
    input_dim=(dataset.max_seq_length, embedding.embed_dim), output_dim=64)(x)

bus_inputs = keras.Input(shape=(None,), dtype="int64")
y = embedding(bus_inputs)
bus_outputs = TextFeatureExtractorLayer(
    input_dim=(dataset.max_seq_length, embedding.embed_dim), output_dim=64)(y)

outputs = -cosine_similarity(user_outputs, bus_outputs, axis=1)

model = keras.Model([user_inputs, bus_inputs], outputs)
model.compile(optimizer='rmsprop', loss='mse')
model.summary()

Build model
Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_5 (InputLayer)           [(None, None)]       0           []                               
                                                                                                  
 input_7 (InputLayer)           [(None, None)]       0           []                               
                                                                                                  
 glove_embedding_layer_1 (Glove  (None, None, 300)   10695900    ['input_5[0][0]',                
 EmbeddingLayer)                                                  'input_7[0][0]']                
                                                                                                  
 text_feature_extractor_layer_2  (None, 64)          2890012     ['glove_embeddi

### Load Trained Params

In [7]:
checkpoint_path = "./training_2/cp.ckpt"
model.load_weights(checkpoint_path)

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x3205b0eb0>

## Predict

In [10]:
num_samples = 100
predicted = model.predict([test_X_user[:num_samples], test_X_bus[:num_samples]])
for i in range(num_samples):
    print('predicted: %.4f, actual: %.4f' % (predicted[i], test_y[i]))

predicted: 0.0000, actual: -1.0000
predicted: 0.6140, actual: 1.0000
predicted: 0.0139, actual: -1.0000
predicted: 0.2976, actual: 1.0000
predicted: 0.0000, actual: -1.0000
predicted: 0.3293, actual: -1.0000
predicted: 0.6051, actual: 1.0000
predicted: 0.5342, actual: 1.0000
predicted: 0.2366, actual: -1.0000
predicted: 0.0000, actual: 1.0000
predicted: 0.3357, actual: -1.0000
predicted: 0.0000, actual: -1.0000
predicted: 0.0000, actual: -1.0000
predicted: 0.0000, actual: 1.0000
predicted: 0.7817, actual: 1.0000
predicted: 0.7039, actual: 1.0000
predicted: 0.1982, actual: 1.0000
predicted: 0.0000, actual: 1.0000
predicted: 0.3506, actual: -1.0000
predicted: 0.0000, actual: 1.0000
predicted: 0.6307, actual: -1.0000
predicted: 0.2608, actual: -1.0000
predicted: 0.5730, actual: -1.0000
predicted: 0.0000, actual: -1.0000
predicted: 0.7568, actual: 1.0000
predicted: 0.7010, actual: 1.0000
predicted: 0.8220, actual: -1.0000
predicted: 0.0744, actual: 1.0000
predicted: 0.2582, actual: -1.0000