In [1]:
import tensorflow as tf

import keras
from keras.losses import cosine_similarity

from models import TextFeatureExtractorLayer, GloveEmbeddingLayer
from utils import GoogleRestaurantsReviewDataset

## Load the Train and Test Dataset

In [2]:
max_seq_length = 500
dataset = GoogleRestaurantsReviewDataset(max_seq_length=max_seq_length)
text_vectorize = dataset.text_vectorize

print('Load training data')
train_X_user, train_X_bus, train_y = dataset.load_train_or_test_dataset(train=True)
print(f'Total {train_y.shape} training data\n')

print('Load test data')
test_X_user, test_X_bus, test_y = dataset.load_train_or_test_dataset(train=False)
print(f'Total {test_y.shape} test data\n')

Build Training data


78422it [01:49, 719.07it/s]


Build Testing data


19606it [00:14, 1336.75it/s]


Load training data
Total (78422,) training data

Load test data
Total (19606,) test data



## Build the Model

In [3]:
print('Build model')
embedding = GloveEmbeddingLayer(num_tokens=len(text_vectorize),
                                vocabulary_dict=text_vectorize.vocabulary)

user_inputs = keras.Input(shape=(None,), dtype="int64")
x = embedding(user_inputs)
user_outputs = TextFeatureExtractorLayer(
    input_dim=(dataset.max_seq_length, embedding.embed_dim), output_dim=64)(x)

bus_inputs = keras.Input(shape=(None,), dtype="int64")
y = embedding(bus_inputs)
bus_outputs = TextFeatureExtractorLayer(
    input_dim=(dataset.max_seq_length, embedding.embed_dim), output_dim=64)(y)

outputs = -cosine_similarity(user_outputs, bus_outputs, axis=1)

model = keras.Model([user_inputs, bus_inputs], outputs)
model.compile(optimizer='rmsprop', loss='mse')
model.summary()

Build model
Metal device set to: Apple M1 Max


2023-04-23 17:42:45.565149: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-04-23 17:42:45.565714: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, None)]       0           []                               
                                                                                                  
 input_3 (InputLayer)           [(None, None)]       0           []                               
                                                                                                  
 glove_embedding_layer (GloveEm  (None, None, 300)   10695900    ['input_1[0][0]',                
 beddingLayer)                                                    'input_3[0][0]']                
                                                                                                  
 text_feature_extractor_layer (  (None, 64)          2890012     ['glove_embedding_layer[0][0]

### Load Trained Params

In [4]:
checkpoint_path = "./training_4/cp.ckpt"
model.load_weights(checkpoint_path)

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x318ebdf70>

## Predict

In [14]:
num_samples = 100
predicted = model.predict([test_X_user[:num_samples], test_X_bus[:num_samples]])
a = []
for i in range(num_samples):
    user_id, bus_id, rating = dataset.test_tuples[i]
    print('user_id: %s rest_id: %s similarity score: %.4f, actual rating: %.4f' % (user_id, bus_id, predicted[i], float(rating)))
    a.append([user_id, bus_id, predicted[i], float(rating)])

user_id: 114871201851697215045 rest_id: 60564a793019cb0a47838caf similarity score: 0.5313, actual rating: 4.0000
user_id: 116853678991849278441 rest_id: 6041e7357dfa7f1871839538 similarity score: 0.7326, actual rating: 5.0000
user_id: 113815392849087130851 rest_id: 60486676b1a0aaee3eef9e79 similarity score: 0.4859, actual rating: 3.0000
user_id: 103560483922364315683 rest_id: 60509d929c93e55e75b720f4 similarity score: 0.7321, actual rating: 5.0000
user_id: 116677519600987213388 rest_id: 604ccf013ada919c27677cce similarity score: 0.0814, actual rating: 3.0000
user_id: 110793767867665675464 rest_id: 6043a625b81264dfa846c8a4 similarity score: 0.7781, actual rating: 4.0000
user_id: 106439951675336292661 rest_id: 6050dc0888c7af3f893e6ea9 similarity score: 0.7783, actual rating: 5.0000
user_id: 109973197470710088921 rest_id: 60444e468be5d4454df9f0c9 similarity score: 0.8087, actual rating: 5.0000
user_id: 115609769405789742559 rest_id: 6041efd68be5d4454df985d8 similarity score: 0.6505, actua

In [15]:
import pandas as pd

df = pd.DataFrame(a, columns=['user id', 'restaurant id', 'similarity score', 'actual rating'])

In [17]:
df

Unnamed: 0,user id,restaurant id,similarity score,actual rating
0,114871201851697215045,60564a793019cb0a47838caf,0.531320,4.0
1,116853678991849278441,6041e7357dfa7f1871839538,0.732573,5.0
2,113815392849087130851,60486676b1a0aaee3eef9e79,0.485863,3.0
3,103560483922364315683,60509d929c93e55e75b720f4,0.732054,5.0
4,116677519600987213388,604ccf013ada919c27677cce,0.081354,3.0
...,...,...,...,...
95,112338591386446588083,604c76655a9e6adec8bf8ad2,0.592151,5.0
96,111363279650276794837,6050b1765b4ccec8d5cae7c9,0.666596,5.0
97,101334099430664308606,60443169ad733fba1bcfed1f,0.399905,4.0
98,107523156830941868214,60509469d8c08f462b93de5c,0.619996,5.0


# Evaluation

## Evaluating on Training Data

In [6]:
from tqdm import tqdm
import numpy as np

In [7]:
model.predict([train_X_user[0:1], train_X_bus[0:1]])[0]



0.35495102

In [8]:
predict = model.predict([np.vstack([[train_X_user[3]]*10000]), train_X_bus[:10000]])



In [9]:
np.sort(predict)

array([0.7693964, 0.7693964, 0.7693964, ..., 0.7693964, 0.7693964,
       0.7693964], dtype=float32)

In [1]:
%load_ext tensorboard

In [4]:
%tensorboard --logdir logs/

Reusing TensorBoard on port 6007 (pid 77382), started 0:09:17 ago. (Use '!kill 77382' to kill it.)