In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''

import numpy as np
import tensorflow as tf
import pandas as pd

import sys
from model import cnn_model
from utils import batch_gen
from evaluation_metrics import get_trec_eval_metrics
from tensorflow.keras.backend import set_session

In [None]:
LAGRANGE_MULTIPLIER = 0.01
BATCH_SIZE = 64
EPOCHS = 1
PATIENCE = 2
MODEL_DIR = ''
LOG_DIR = ''
INPUT_DATA_PATH = ''
MODEL_NAME = 'click_logs'
TREC_EVAL_PATH = "/trec_eval"
OUTPUT_FREQUENCY = 10000 #number of batches to be processed before printing dev metrics

In [None]:
train_data = pd.read_csv(os.path.join(INPUT_DATA_PATH, 'click_rel_train.csv')
dev_data = pd.read_csv(os.path.join(INPUT_DATA_PATH, 'click_rel_dev.csv')
test_data = pd.read_csv(os.path.join(INPUT_DATA_PATH, 'click_rel_test.csv')

embeddings = np.load(os.path.join(INPUT_DATA_PATH, 'embedding.npy')

q_train = np.load(os.path.join(INPUT_DATA_PATH, 'queries_train.npy'))
a_train = np.load(os.path.join(INPUT_DATA_PATH, 'products_train.npy'))

qids_train = train_data['qid']
loss_train = train_data['loss']
action_train = train_data['action']
addn_feat_train = np.array(train_data['addn_feat'].apply(lambda x: [float(elem) for elem in x.split(', ')]).tolist())
probs_train = train_data['control_policy_prob']

print('''q_train.shape, a_train.shape, qids_train.shape, addn_feat_train.shape: ''')
print(q_train.shape, q_train.shape, qids_train.shape, addn_feat_train.shape)

q_dev = np.load(os.path.join(INPUT_DATA_PATH, 'queries_dev.npy'))
a_dev = np.load(os.path.join(INPUT_DATA_PATH, 'products_dev.npy'))
y_dev = dev_data['click_rel']
qids_dev = dev_data['qids']
addn_feat_dev = np.array(dev_data['addn_feat'].apply(lambda x: [float(elem) for elem in x.split(', ')]).tolist())

q_test = np.load(os.path.join(INPUT_DATA_PATH, 'queries_test.npy'))
a_test = np.load(os.path.join(INPUT_DATA_PATH, 'products_test.npy'))
y_test = test_data['click_rel']
qids_test = test_data['qids']
addn_feat_test = np.array(test_data['addn_feat'].apply(lambda x: [float(elem) for elem in x.split(', ')]).tolist())

In [None]:
addit_feat_len = 1
if addn_feat_train.ndim > 1:
    addit_feat_len = addn_feat_train.shape[1]

embed_dim = embeddings.shape
max_ques_len = q_train.shape[1]
max_ans_len = a_train.shape[1]

with tf.Graph().as_default():
    with tf.Session() as sess:
        set_session(sess)

        # Get model
        cnn_model_instance = cnn_model(max_ques_len, max_ans_len,  embeddings, addit_feat_len=addit_feat_len)

        # Compute weights for the model
        weights_train = (loss_train - LAGRANGE_MULTIPLIER)/probs_train 

        y_pred_dev = cnn_model_instance.predict([q_dev, a_dev, addn_feat_dev, np.ones(shape = len(q_dev))])
        map_score_overall, mrr, p_5, p_10, ndcg_5, ndcg_10 = get_trec_eval_metrics(qids_dev, y_pred_dev, y_dev, TREC_EVAL_PATH)
        print('Initial results on {} set are: MAP: {}, MRR:{}, P@5: {}, P@10: {}, NDCG@5: {}, NDCG@10: {}'
              .format('DEV', map_score_overall, mrr, p_5, p_10, ndcg_5, ndcg_10))
        y_pred_test = cnn_model_instance.predict([q_test, a_test, addn_feat_test, np.ones(shape = len(q_test))])
        map_score, mrr, p_5, p_10, ndcg_5, ndcg_10 = get_trec_eval_metrics(qids_test, y_pred_test, y_test, TREC_EVAL_PATH)
        print('Initial results on {} set are: MAP: {}, MRR:{}, P@5: {}, P@10: {}, NDCG@5: {}, NDCG@10: {}'
              .format('TEST', map_score, mrr, p_5, p_10, ndcg_5, ndcg_10))

        patience = 0
        best_model_weights = cnn_model_instance.get_weights()

        loss_mult_batches = 0 

        for epoch in range(EPOCHS):
            print('Epoch: ', epoch)

            batch_no = 0
            num_clicks = 0
            for b_q_train, b_a_train, b_addn_feat_train, b_weights, b_action, b_qid in zip(
                batch_gen(q_train, BATCH_SIZE), batch_gen(a_train, BATCH_SIZE), 
                batch_gen(addn_feat_train, BATCH_SIZE), batch_gen(weights_train, BATCH_SIZE),
                batch_gen(action_train, BATCH_SIZE), batch_gen(qids_train, BATCH_SIZE)):

                loss_batch = cnn_model_instance.train_on_batch([b_q_train, b_a_train, b_addn_feat_train, b_weights], b_action)
                loss_mult_batches += loss_batch

                if batch_no%OUTPUT_FREQUENCY == 0:

                    print('{} batches were already processed'.format(batch_no))     
                    print('Average Loss of Last {} batches is: {}'.format(OUTPUT_FREQUENCY, (loss_mult_batches/100)+0.1))
                    loss_mult_batches = 0

                    # [NOTE] makes training significantly longer
                    y_pred_dev = cnn_model_instance.predict([q_dev, a_dev, addn_feat_dev, np.ones(shape = len(q_dev))])
                    map_score, mrr, p_5, p_10, ndcg_5, ndcg_10 = get_trec_eval_metrics(qids_dev, y_pred_dev, y_dev, TREC_EVAL_PATH)
                    print('Results on {} set after {} batches are: MAP: {}, MRR:{}, NDCG@5: {}, NDCG@10: {}'
                          .format('DEV', batch_no, map_score, mrr, ndcg_5, ndcg_10))

                batch_no += 1
                
            y_pred_dev = cnn_model_instance.predict([q_dev, a_dev, addn_feat_dev, np.ones(shape = len(q_dev))])
            map_score_current, mrr, p_5, p_10, ndcg_5, ndcg_10 = get_trec_eval_metrics(qids_dev, y_pred_dev, y_dev, TREC_EVAL_PATH)
            print('Results on {} set for epoch {} are: MAP: {}, MRR:{}, NDCG@5: {}, NDCG@10: {}'
                  .format('DEV', epoch, map_score_current, mrr, ndcg_5, ndcg_10))    

            y_pred_train = cnn_model_instance.predict([q_train, a_train, addn_feat_train, np.ones(shape = len(q_train))])                
            y_pred_train[action_train == 0] = 1 - y_pred_train[action_train == 0]
            y_pred_train = np.reshape(y_pred_train, (y_pred_train.shape[0], 1))
            probs_train = np.reshape(probs_train, (y_pred_train.shape[0], 1))
            np.divide(y_pred_train, probs_train, out = y_pred_train)
            S_train = np.mean(y_pred_train)
            print('S after epoch {} is {}'.format(epoch, S_train))

            if map_score_current > map_score_overall or epoch == 0:
                map_score_overall = map_score_current
                best_model_weights = cnn_model_instance.get_weights()
            elif patience < PATIENCE:
                patience += 1
            else: break

        cnn_model_instance.set_weights(best_model_weights)
        cnn_model_instance.save(os.path.join(MODEL_DIR, MODEL_NAME+'.h5'))

        y_pred_dev = cnn_model_instance.predict([q_dev, a_dev, addn_feat_dev, np.ones(shape = len(q_dev))])
        map_score, mrr, p_5, p_10, ndcg_5, ndcg_10 = get_trec_eval_metrics(qids_dev, y_pred_dev, y_dev, TREC_EVAL_PATH)
        print('BEST results on {} set are: MAP: {}, MRR:{}, P@5: {}, P@10: {}, NDCG@5: {}, NDCG@10: {}'
              .format('DEV', map_score, mrr, p_5, p_10, ndcg_5, ndcg_10))
        y_pred_test = cnn_model_instance.predict([q_test, a_test, addn_feat_test, np.ones(shape = len(q_test))])
        map_score, mrr, p_5, p_10, ndcg_5, ndcg_10 = get_trec_eval_metrics(qids_test, y_pred_test, y_test, TREC_EVAL_PATH)
        print('BEST results on {} set are: MAP: {}, MRR:{}, P@5: {}, P@10: {}, NDCG@5: {}, NDCG@10: {}'
              .format('TEST', map_score, mrr, p_5, p_10, ndcg_5, ndcg_10))