# SESSIONS ARE ALL YOU NEED
### Workshop on e-commerce personalization

This notebook showcases with working code the main ideas of our ML-in-retail workshop from June lst, 2021 at MICES (https://mices.co/). Please refer to the README in the repo for a bit of context!

While the code below is (well, should be!) fully functioning, please note we aim for functions which are pedagogically useful, more than terse code per se: it should be fairly easy to take these ideas and refactor the code to achieve more speed, better re-usability etc.

_If you want to use Google Colab, you can uncomment this cell:_

In [None]:
# if you need requirements....
# !pip install -r requirements.txt

# #from google.colab import drive
#drive.mount('/content/drive',force_remount=True)
#%cd drive/MyDrive/path_to_directory_containing_train_folder
#LOCAL_FOLDER = 'train'

## Basic import and some global vars to know where data is!

Here we import the libraries we need and set the working folders - make sure your current python interpreter has all the dependencies installed. If you want to use the same real-world data as I'm using, please download the open dataset you find at: https://github.com/coveooss/SIGIR-ecom-data-challenge.

In [None]:
import os
from random import choice
import time
import ast
import json
import numpy as np
import csv
from collections import Counter,defaultdict
# viz stuff
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt
from IPython.display import Image 
# gensim stuff for prod2vec
import gensim  # gensim > 4
from gensim.similarities.annoy import AnnoyIndexer
# keras stuff for auto-encoder
from keras.layers.core import Dropout
from keras.layers.core import Dense
from keras.layers import Concatenate
from keras.models import Sequential
from keras.layers import Input
from keras.optimizers import SGD, Adam
from keras.models import Model
from keras.callbacks import EarlyStopping
from keras.utils import plot_model
from sklearn.model_selection import train_test_split
from keras import utils
import hashlib
from copy import deepcopy

In [None]:
%matplotlib inline

In [None]:
LOCAL_FOLDER = '/Users/jacopotagliabue/Documents/data_dump/train'  # where is the dataset stored?
N_ROWS = 5000000  # how many rows we want to take (to avoid waiting too much for tutorial purposes)?

## Step 1: build a prod2vec space

For more information on prod2vec and its use, you can also check our blog post: https://blog.coveo.com/clothes-in-space-real-time-personalization-in-less-than-100-lines-of-code/ or latest NLP paper: https://arxiv.org/abs/2104.02061

In [None]:
def read_sessions_from_training_file(training_file: str, K: int = None):
    """
    Read the training file containing product interactions, up to K rows.
    
    :return: a list of lists, each list being a session (sequence of product IDs)
    """
    user_sessions = []
    current_session_id = None
    current_session = []
    with open(training_file) as csvfile:
        reader = csv.DictReader(csvfile)
        for idx, row in enumerate(reader):
            # if a max number of items is specified, just return at the K with what you have
            if K and idx >= K:
                break
            # just append "detail" events in the order we see them
            # row will contain: session_id_hash, product_action, product_sku_hash
            _session_id_hash = row['session_id_hash']
            # when a new session begins, store the old one and start again
            if current_session_id and current_session and _session_id_hash != current_session_id:
                user_sessions.append(current_session)
                # reset session
                current_session = []
            # check for the right type and append
            if row['product_action'] == 'detail':
                current_session.append(row['product_sku_hash'])
            # update the current session id
            current_session_id = _session_id_hash

    # print how many sessions we have...
    print("# total sessions: {}".format(len(user_sessions)))
    # print first one to check
    print("First session is: {}".format(user_sessions[0]))
    assert user_sessions[0][0] == 'd5157f8bc52965390fa21ad5842a8502bc3eb8b0930f3f8eafbc503f4012f69c'
    assert user_sessions[0][-1] == '63b567f4cef976d1411aecc4240984e46ebe8e08e327f2be786beb7ee83216d0'

    return user_sessions

In [None]:
def train_product_2_vec_model(sessions: list,
                              min_c: int = 3,
                              size: int = 48,
                              window: int = 5,
                              iterations: int = 15,
                              ns_exponent: float = 0.75):
    """
    Train CBOW to get product embeddings. We start with sensible defaults from the literature - please
    check https://arxiv.org/abs/2007.14906 for practical tips on how to optimize prod2vec.

    :param sessions: list of lists, as user sessions are list of interactions
    :param min_c: minimum frequency of an event for it to be calculated for product embeddings
    :param size: output dimension
    :param window: window parameter for gensim word2vec
    :param iterations: number of training iterations
    :param ns_exponent: ns_exponent parameter for gensim word2vec
    :return: trained product embedding model
    """
    model =  gensim.models.Word2Vec(sentences=sessions,
                                    min_count=min_c,
                                    vector_size=size,
                                    window=window,
                                    epochs=iterations,
                                    ns_exponent=ns_exponent)

    print("# products in the space: {}".format(len(model.wv.index_to_key)))

    return model.wv

Get sessions from the training file, and train a prod2vec model with standard hyperparameters

In [None]:
# get sessions
sessions = read_sessions_from_training_file(
    training_file=os.path.join(LOCAL_FOLDER, 'browsing_train.csv'),
    K=N_ROWS)
# get a counter on all items for later use
sku_cnt = Counter([item for s in sessions for item in s])
# print out most common SKUs
sku_cnt.most_common(3)

In [None]:
# leave some sessions aside
idx = int(len(sessions) * 0.8)
train_sessions = sessions[0: idx]
test_sessions = sessions[idx:]
print("Train sessions # {}, test sessions # {}".format(len(train_sessions), len(test_sessions)))
# finally, train the p2vec, leaving all the default hyperparameters
prod2vec_model = train_product_2_vec_model(train_sessions)

Show how to get a prediction with knn

In [None]:
prod2vec_model.similar_by_word(sku_cnt.most_common(1)[0][0], topn=3)

Visualize the prod2vec space, color-coding for categories in the catalog

In [None]:
def plot_scatter_by_category_with_lookup(title, 
                                         skus, 
                                         sku_to_target_cat,
                                         results, 
                                         custom_markers=None):
    groups = {}
    for sku, target_cat in sku_to_target_cat.items():
        if sku not in skus:
            continue

        sku_idx = skus.index(sku)
        x = results[sku_idx][0]
        y = results[sku_idx][1]
        if target_cat in groups:
            groups[target_cat]['x'].append(x)
            groups[target_cat]['y'].append(y)
        else:
            groups[target_cat] = {
                'x': [x], 'y': [y]
                }
    # DEBUG print
    print("Total of # groups: {}".format(len(groups)))
    
    fig, ax = plt.subplots(figsize=(10, 10))
    for group, data in groups.items():
        ax.scatter(data['x'], data['y'], 
                   alpha=0.3, 
                   edgecolors='none', 
                   s=25, 
                   marker='o' if not custom_markers else custom_markers,
                   label=group)

    plt.title(title)
    plt.show()
    
    return

In [None]:
def tsne_analysis(embeddings, perplexity=25, n_iter=1000):
    tsne = TSNE(n_components=2, verbose=1, perplexity=perplexity, n_iter=n_iter)
    return tsne.fit_transform(embeddings)

In [None]:
def get_sku_to_category_map(catalog_file, depth_index=1):
    """
    For each SKU, get category from catalog file (if specified)
    
    :return: dictionary, mapping SKU to a category
    """
    sku_to_cats = dict()
    with open(catalog_file) as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            _sku = row['product_sku_hash']
            category_hash = row['category_hash']
            if not category_hash:
                continue
            # pick only category at a certain depth in the tree
            # e.g. x/xx/xxx, with depth=1, -> xx
            branches = category_hash.split('/')
            target_branch = branches[depth_index] if depth_index < len(branches) else None
            if not target_branch:
                continue
            # if all good, store the mapping
            sku_to_cats[_sku] = target_branch
            
    return sku_to_cats

In [None]:
sku_to_category = get_sku_to_category_map(os.path.join(LOCAL_FOLDER, 'sku_to_content.csv'))
print("Total of # {} categories".format(len(set(sku_to_category.values()))))
print("Total of # {} SKU with a category".format(len(sku_to_category)))
# debug with a sample SKU
print(sku_to_category[sku_cnt.most_common(1)[0][0]])
skus = prod2vec_model.index_to_key
print("Total of # {} skus in the model".format(len(skus)))
embeddings = [prod2vec_model[s] for s in skus]

In [None]:
# print out tsne plot with standard params
tsne_results = tsne_analysis(embeddings)
assert len(tsne_results) == len(skus)
plot_scatter_by_category_with_lookup('Prod2vec', skus, sku_to_category, tsne_results)

In [None]:
# do a version with only top K categories
TOP_K = 5
cnt_categories = Counter(list(sku_to_category.values()))
top_categories = [c[0] for c in cnt_categories.most_common(TOP_K)]

In [None]:
# filter out SKUs outside of top categories
top_skus = []
top_tsne_results = []
for _s, _t in zip(skus, tsne_results):
    if sku_to_category.get(_s, None) not in top_categories:
        continue
    top_skus.append(_s)
    top_tsne_results.append(_t)
# re-plot tsne with filtered SKUs
print("Top SKUs # {}".format(len(top_skus)))
plot_scatter_by_category_with_lookup('Prod2vec (top {})'.format(TOP_K), 
                                     top_skus, sku_to_category, top_tsne_results)

### Bonus: faster inference

Gensim is awesome and support approximate, faster inference! You need to have installed ANNOY first, e.g. "pip install annoy". We re-run here on our prod space the original benchmark for word2vec from gensim!

See: https://radimrehurek.com/gensim/auto_examples/tutorials/run_annoy.html

In [None]:
# Set up the model and vector that we are using in the comparison
annoy_index = AnnoyIndexer(prod2vec_model, 100)
test_sku = sku_cnt.most_common(1)[0][0]
# test all is good
print(prod2vec_model.most_similar([test_sku], topn=2, indexer=annoy_index))
print(prod2vec_model.most_similar([test_sku], topn=2))

In [None]:
def avg_query_time(model, annoy_index=None, queries=5000):
    """Average query time of a most_similar method over random queries."""
    total_time = 0
    for _ in range(queries):
        _v = model[choice(model.index_to_key)]
        start_time = time.process_time()
        model.most_similar([_v], topn=5, indexer=annoy_index)
        total_time += time.process_time() - start_time
        
    return total_time / queries

gensim_time = avg_query_time(prod2vec_model)
annoy_time = avg_query_time(prod2vec_model, annoy_index=annoy_index)
print("Gensim (s/query):\t{0:.5f}".format(gensim_time))
print("Annoy (s/query):\t{0:.5f}".format(annoy_time))
speed_improvement = gensim_time / annoy_time
print ("\nAnnoy is {0:.2f} times faster on average on this particular run".format(speed_improvement))

### Bonus: hyper tuning

For more info on hyper tuning in the context of product embeddings, please see our paper: https://arxiv.org/abs/2007.14906 and our data release: https://github.com/coveooss/fantastic-embeddings-sigir-2020.

We use the sessions we left out to simulate a small optimization loop...

In [None]:
def calculate_HR_on_NEP(model, sessions, k=10, min_length=3):
    _count = 0
    _hits = 0
    for session in sessions:
        # consider only decently-long sessions
        if len(session) < min_length:
            continue
        # update the counter
        _count += 1
        # get the item to predict
        target_item = session[-1]
        # get model prediction using before-last item
        query_item = session[-2]
        # if model cannot make the prediction, it's a failure
        if query_item not in model:
            continue
        predictions = model.similar_by_word(query_item, topn=k)
        # debug
        # print(target_item, query_item, predictions)
        if target_item in [p[0] for p in predictions]:
            _hits += 1
    # debug
    print("Total test cases: {}".format(_count))
    
    return _hits / _count

In [None]:
# we simulate a test with 3 values for epochs in prod2ve
iterations_values = [1, 10]
# for each value we train a model, and use Next Event Prediction (NEP) to get a quality assessment
for i in iterations_values:
    print("\n ======> Hyper value: {}".format(i))
    cnt_model = train_product_2_vec_model(train_sessions, iterations=i)
    # use hold-out to have NEP performance
    _hr = calculate_HR_on_NEP(cnt_model, test_sessions)
    print("HR: {}\n".format(_hr))

## Step 2: improving low-count vectors

For more information about prod2vec in the cold start scenario, please see our paper: https://dl.acm.org/doi/10.1145/3383313.3411477 and video: https://vimeo.com/455641121

In [None]:
def build_mapper(pro2vec_dims=48):
    """
    Build a Keras model for content-based "fake" embeddings.
    
    :return: a Keras model, mapping BERT-like catalog representations to the prod2vec space
    """
    # input
    description_input = Input(shape=(50,))
    image_input = Input(shape=(50,))
    # model
    x = Dense(25, activation="relu")(description_input)
    y = Dense(25, activation="relu")(image_input)
    combined = Concatenate()([x, y])
    combined = Dropout(0.3)(combined)
    combined = Dense(25)(combined)
    output = Dense(pro2vec_dims)(combined)

    return Model(inputs=[description_input, image_input], outputs=output)

In [None]:
# get vectors representing text and images in the catalog
def get_sku_to_embeddings_map(catalog_file):
    """
    For each SKU, get the text and image embeddings, as provided pre-computed by the dataset
    
    :return: dictionary, mapping SKU to a tuple of embeddings
    """
    sku_to_embeddings = dict()
    with open(catalog_file) as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            _sku = row['product_sku_hash']
            _description = row['description_vector']
            _image = row['image_vector']
            # skip when both vectors are not there
            if not _description or not _image:
                continue
            # if all good, store the mapping
            sku_to_embeddings[_sku] = (json.loads(_description), json.loads(_image))
            
    return sku_to_embeddings

In [None]:
sku_to_embeddings = get_sku_to_embeddings_map(os.path.join(LOCAL_FOLDER, 'sku_to_content.csv'))
print("Total of # {} SKUs with embeddings".format(len(sku_to_embeddings)))
# print out an example
_d, _i = sku_to_embeddings['438630a8ba0320de5235ee1bedf3103391d4069646d640602df447e1042a61a3']
print(len(_d), len(_i), _d[:5], _i[:5])

In [None]:
# just make sure we have the SKUs in the model and a counter
skus = prod2vec_model.index_to_key
print("Total of # {} skus in the model".format(len(skus)))
print(sku_cnt.most_common(5))

In [None]:
# above which percentile of frequency we consider SKU popular enough to be our training set?
FREQUENT_PRODUCTS_PTILE = 80

In [None]:
_counts = [c[1] for c in sku_cnt.most_common()]
_counts[:3]

In [None]:
# make sure we have just SKUS in the prod2vec space for which we have embeddings
popular_threshold = np.percentile(_counts, FREQUENT_PRODUCTS_PTILE)
popular_skus = [s for s in skus if s in sku_to_embeddings and sku_cnt.get(s, 0) > popular_threshold]
product_embeddings = [prod2vec_model[s] for s in popular_skus]
description_embeddings = [sku_to_embeddings[s][0] for s in popular_skus]
image_embeddings = [sku_to_embeddings[s][1] for s in popular_skus]
# debug
print(popular_threshold, len(skus), len(popular_skus))
# print(description_embeddings[:1][:3])
# print(image_embeddings[:1][:3])

In [None]:
# train the mapper now
training_data_X = [np.array(description_embeddings), np.array(image_embeddings)]
training_data_y = np.array(product_embeddings)

In [None]:
es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=20, restore_best_weights=True)
# build and display model
rare_net = build_mapper()
plot_model(rare_net, show_shapes=True, show_layer_names=True, to_file='rare_net.png')
Image('rare_net.png')

In [None]:
# train!
rare_net.compile(loss='mse', optimizer='rmsprop')
rare_net.fit(training_data_X, 
             training_data_y, 
             batch_size=200, 
             epochs=20000, 
             validation_split=0.2, 
             callbacks=[es])

In [None]:
# rarest_skus = [_[0] for _ in sku_cnt.most_common()[-500:]]
# test_skus = [s for s in rarest_skus if s in sku_to_embeddings]

# get to rare vectors
test_skus = [s for s in skus if s in sku_to_embeddings and sku_cnt.get(s, 0) < popular_threshold/2]
print(len(skus), len(test_skus))
# prepare embeddings for prediction
rare_description_embeddings = [sku_to_embeddings[s][0] for s in test_skus]
rare_image_embeddings = [sku_to_embeddings[s][1] for s in test_skus]

In [None]:
# prepare embeddings for prediction
test_data_X = [np.array(rare_description_embeddings), np.array(rare_image_embeddings)]
predicted_embeddings = rare_net.predict(test_data_X)
# debug
# print(len(predicted_embeddings))
# print(predicted_embeddings[0][:10])

In [None]:
def calculate_HR_on_NEP_rare(model, sessions, rare_skus, k=10, min_length=3):
    _count = 0
    _hits = 0
    _rare_hits = 0
    _rare_count = 0
    for session in sessions:
        # consider only decently-long sessions
        if len(session) < min_length:
            continue
        # update the counter
        _count += 1
        # get the item to predict
        target_item = session[-1]
        # get model prediction using before-last item
        query_item = session[-2]

        # if model cannot make the prediction, it's a failure
        if query_item not in model:
            continue
        
        # increment counter if rare sku
        if query_item in rare_skus:
            _rare_count+=1
        
        predictions = model.similar_by_word(query_item, topn=k)
    
        # debug
        # print(target_item, query_item, predictions)    
        if target_item in [p[0] for p in predictions]:
            _hits += 1
            # track hits if query is rare sku
            if query_item in rare_skus:
                _rare_hits+=1
    # debug
    print("Total test cases: {}".format(_count))
    print("Total rare test cases: {}".format(_rare_count))
    
    return _hits / _count, _rare_hits/_rare_count

In [None]:
# make copy of original prod2vec model
prod2vec_rare_model = deepcopy(prod2vec_model)
# update model with new vectors
prod2vec_rare_model.add_vectors(test_skus, predicted_embeddings, replace=True)
prod2vec_rare_model.fill_norms(force=True)
# check
assert np.array_equal(predicted_embeddings[0], prod2vec_rare_model[test_skus[0]])

# test new model
calculate_HR_on_NEP_rare(prod2vec_rare_model, test_sessions, test_skus)

In [None]:
# test original model
calculate_HR_on_NEP_rare(prod2vec_model, test_sessions, test_skus)

## Step 3: query scoping

For more information about query scoping, please see our paper: https://www.aclweb.org/anthology/2020.ecnlp-1.2/ and repository: https://github.com/jacopotagliabue/session-path

In [None]:
# get vectors representing text and images in the catalog
def get_query_to_category_dataset(search_file, cat_2_id, sku_to_category):
    """
    For each query, get a label representing the category in items clicked after the query.
    It uses as input a mapping "sku_to_category" to join the search file with catalog meta-data!
    
    :return: two lists, matching query vectors to a label
    """
    query_X = list()
    query_Y = list()
    with open(search_file) as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            _click_products = row['clicked_skus_hash']
            if not _click_products: # or _click_product not in sku_to_category:
                continue
            # clean the string and extract SKUs from array
            cleaned_skus = ast.literal_eval(_click_products)
            for s in cleaned_skus: 
                if s in sku_to_category:
                    query_X.append(json.loads(row['query_vector']))
                    target_category_as_int = cat_2_id[sku_to_category[s]]
                    query_Y.append(utils.to_categorical(target_category_as_int, num_classes=len(cat_2_id)))
            
    return query_X, query_Y

In [None]:
sku_to_category = get_sku_to_category_map(os.path.join(LOCAL_FOLDER, 'sku_to_content.csv'))
print("Total of # {} categories".format(len(set(sku_to_category.values()))))
cats = list(set(sku_to_category.values()))
cat_2_id = {c: idx for idx, c in enumerate(cats)}
print(cat_2_id[cats[0]])
query_X, query_Y = get_query_to_category_dataset(os.path.join(LOCAL_FOLDER, 'search_train.csv'), 
                                                 cat_2_id,
                                                 sku_to_category)
print(len(query_X))
print(query_Y[0])

In [None]:
x_train, x_test, y_train, y_test = train_test_split(np.array(query_X), np.array(query_Y), test_size=0.2)

In [None]:
def build_query_scoping_model(input_d, target_classes):
    print('Shape tensor {}, target classes {}'.format(input_d, target_classes))
    # define model
    model = Sequential()
    model.add(Dense(64, activation='relu', input_dim=input_d))
    model.add(Dropout(0.5))
    model.add(Dense(64, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(target_classes, activation='softmax'))
    
    return model

In [None]:
query_model = build_query_scoping_model(x_train[0].shape[0], y_train[0].shape[0])

In [None]:
# compile model
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
query_model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
# train first
query_model.fit(x_train, y_train, epochs=10, batch_size=32)
# compute and print eval score
score = query_model.evaluate(x_test, y_test, batch_size=32)
score

In [None]:
# get vectors representing text and images in the catalog
def get_query_info(search_file):
    """
    For each query, extract relevant metadata of query and to match with session data

    :return: list of queries with metadata
    """
    queries = list()
    with open(search_file) as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            _click_products = row['clicked_skus_hash']
            if not _click_products: # or _click_product not in sku_to_category:
                continue
            # clean the string and extract SKUs from array
            cleaned_skus = ast.literal_eval(_click_products)
            queries.append({'session_id_hash' : row['session_id_hash'],
                            'server_timestamp_epoch_ms' : int(row['server_timestamp_epoch_ms']),
                            'clicked_skus' : cleaned_skus,
                            'query_vector' : json.loads(row['query_vector'])})
    print("# total queries: {}".format(len(queries)))        
    
    return queries

def get_session_info_for_queries(training_file: str, query_info: list, K: int = None):
    """
    Read the training file containing product interactions for sessions with query, up to K rows.
    
    :return: dict of lists with session_id as key, each list being a session (sequence of product events with metadata) 
    """
    user_sessions = dict()
    current_session_id = None
    current_session = []
    
    query_session_ids = set([ _['session_id_hash'] for _ in query_info])

    with open(training_file) as csvfile:
        reader = csv.DictReader(csvfile)
        for idx, row in enumerate(reader):
            # if a max number of items is specified, just return at the K with what you have
            if K and idx >= K:
                break
            # just append "detail" events in the order we see them
            # row will contain: session_id_hash, product_action, product_sku_hash
            _session_id_hash = row['session_id_hash']
            # when a new session begins, store the old one and start again
            if current_session_id and current_session and _session_id_hash != current_session_id:
                user_sessions[current_session_id] = current_session
                # reset session
                current_session = []
            # check for the right type and append event info
            if row['product_action'] == 'detail' and _session_id_hash in query_session_ids :
                current_session.append({'product_sku_hash': row['product_sku_hash'],
                                        'server_timestamp_epoch_ms' : int(row['server_timestamp_epoch_ms'])})
            # update the current session id
            current_session_id = _session_id_hash

    # print how many sessions we have...
    print("# total sessions: {}".format(len(user_sessions)))


    return dict(user_sessions)

In [None]:
query_info = get_query_info(os.path.join(LOCAL_FOLDER, 'search_train.csv'))
session_info = get_session_info_for_queries(os.path.join(LOCAL_FOLDER, 'browsing_train.csv'), query_info)

In [None]:
def get_contextual_query_to_category_dataset(query_info, session_info, prod2vec_model, cat_2_id, sku_to_category):
    """
    For each query, get a label representing the category in items clicked after the query.
    It uses as input a mapping "sku_to_category" to join the search file with catalog meta-data!
    It also creates a joint embedding for input by concatenating query vector and average session vector up till
    when query was made
    
    :return: two lists, matching query vectors to a label
    """
    query_X = list()
    query_Y = list()
    
    for row in query_info:
        query_timestamp = row['server_timestamp_epoch_ms']
        cleaned_skus = row['clicked_skus']
        session_id_hash = row['session_id_hash']
        if session_id_hash not in session_info or not cleaned_skus: # or _click_product not in sku_to_category:
            continue            
            
        session_skus = session_info[session_id_hash]
        context_skus = [ e['product_sku_hash'] for e in session_skus if query_timestamp > e['server_timestamp_epoch_ms'] 
                                                                        and e['product_sku_hash'] in prod2vec_model]
        if not context_skus:
            continue
        context_vector = np.mean([prod2vec_model[sku] for sku in context_skus], axis=0).tolist()
        for s in cleaned_skus: 
            if s in sku_to_category:
                query_X.append(row['query_vector'] + context_vector)
                target_category_as_int = cat_2_id[sku_to_category[s]]
                query_Y.append(utils.to_categorical(target_category_as_int, num_classes=len(cat_2_id)))
            
    return query_X, query_Y

In [None]:
context_query_X, context_query_Y = get_contextual_query_to_category_dataset(query_info, 
                                                                            session_info, 
                                                                            prod2vec_model, 
                                                                            cat_2_id, 
                                                                            sku_to_category)
print(len(context_query_X))
print(context_query_Y[0])

In [None]:
x_train, x_test, y_train, y_test = train_test_split(np.array(context_query_X), np.array(context_query_Y), test_size=0.2)

In [None]:
contextual_query_model = build_query_scoping_model(x_train[0].shape[0], y_train[0].shape[0])

In [None]:
# compile model
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
contextual_query_model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
# train first
contextual_query_model.fit(x_train, y_train, epochs=10, batch_size=32)
# compute and print eval score
score = contextual_query_model.evaluate(x_test, y_test, batch_size=32)
score