In [32]:
import pickle
import random

import keras
import pandas as pd
import numpy as np

from modelling import *
from evaluating import *
from utils import TextProcessor

In [142]:
class BatchGenerator(keras.utils.Sequence):
    def __init__(self, pos_size, neg_size, data_path):
        self.pos_size = pos_size
        self.neg_size = neg_size
        
        que = pd.read_csv(data_path + 'questions.csv')
        tag_que = pd.read_csv(data_path + 'tag_questions.csv')
        tags = pd.read_csv(data_path + 'tags.csv')
        pro = pd.read_csv(data_path + 'professionals.csv')
        ans = pd.read_csv(data_path + 'answers.csv')
        
        self.tp = TextProcessor()
        pro['professionals_industry'] = pro['professionals_industry'].apply(self.tp.process)
        pro['professionals_industry'] = pro['professionals_industry'].apply(lambda x: ' '.join(x))
        
        self.pro_ind = {row['professionals_id']: row['professionals_industry'] for i, row in pro.iterrows()}
        
        que_tags = que.merge(tag_que, left_on = 'questions_id', right_on = 'tag_questions_question_id').merge(tags, left_on = 'tag_questions_tag_id', right_on = 'tags_tag_id')
        que_tags = que_tags[['questions_id', 'tags_tag_name']].groupby(by = 'questions_id', as_index = False).aggregate(lambda x: ' '.join(x))
        self.que_tag = {row['questions_id']: row['tags_tag_name'].split() for i, row in que_tags.iterrows()}
        
        ans_que = ans.merge(que, left_on = 'answers_question_id', right_on = 'questions_id')
        ans_que_pro = ans_que.merge(pro, left_on = 'answers_author_id', right_on = 'professionals_id')
        
        self.ques = list(set(ans_que_pro['questions_id']))
        self.pros = list(set(ans_que_pro['professionals_id']))
        
        self.que_pro_set = {(row['questions_id'], row['professionals_id']) for i, row in ans_que_pro.iterrows()}
        self.que_pro_list = list(self.que_pro_set)
        
        with open('tags_embs.pickle', 'rb') as file:
            self.tag_emb = pickle.load(file)
        with open('industries_embs.pickle', 'rb') as file:
            self.ind_emb = pickle.load(file)
        
    def __len__(self):
        return len(self.que_pro_list) // self.pos_size
    
    def __convert(self, pairs):
        x_que, x_pro = [], []
        for que, pro in pairs:
            tmp = []
            for tag in self.que_tag.get(que, ['#']):
                tmp.append(self.tag_emb.get(tag, np.zeros(10)))
            x_que.append(np.vstack(tmp).mean(axis = 0))
            x_pro.append(self.ind_emb.get(self.pro_ind[pro], np.zeros(10)))
        return np.vstack(x_que), np.vstack(x_pro)
            
    def __getitem__(self, index):
        pos_pairs = self.que_pro_list[self.pos_size * index: self.pos_size * (index + 1)]
        neg_pairs = []
        for i in range(self.neg_size):
            while True:
                que = random.choice(self.ques)
                pro = random.choice(self.pros)
                if (que, pro) not in self.que_pro_set:
                    neg_pairs.append((que, pro))
                    break
        x_pos_que, x_pos_pro = self.__convert(pos_pairs)
        x_neg_que, x_neg_pro = self.__convert(neg_pairs)
        
        return [np.vstack([x_pos_que, x_neg_que]), np.vstack([x_pos_pro, x_neg_pro])], \
                np.vstack([np.ones((len(x_pos_que), 1)), np.zeros((len(x_neg_que), 1))])
    
    def on_epoch_end(self):
        self.que_pro_list = random.sample(self.que_pro_list, len(self.que_pro_list))

In [143]:
bg = BatchGenerator(64, 64, '../../data/')

In [144]:
len(bg)

773

In [145]:
bg.__getitem__(0)

([array([[-0.09613705,  0.21159824, -0.0252068 , ..., -0.10991801,
           0.47798666,  0.35002603],
         [-0.4919048 ,  0.31844521,  0.00869839, ...,  0.32226032,
          -0.5686155 ,  0.92960918],
         [-0.61165535, -0.00737908, -0.17614603, ..., -0.42834923,
           1.25261104,  0.23862971],
         ...,
         [-0.60735184,  0.11329423, -0.02505915, ..., -0.16446325,
           0.40958168,  0.78099154],
         [-0.68318015,  0.44432926,  0.48859978, ...,  0.38601133,
          -0.17467406,  1.41676748],
         [ 0.        ,  0.        ,  0.        , ...,  0.        ,
           0.        ,  0.        ]]),
  array([[ 0.75101042,  0.59809995, -0.15109728, ..., -1.29697061,
          -0.41889337,  0.15616894],
         [ 0.25563693, -0.05060307,  0.62323833, ..., -1.52895582,
          -0.89223313,  1.18154633],
         [ 1.0357312 , -0.12629689, -1.36398983, ..., -1.06989944,
          -0.59437275,  0.42159066],
         ...,
         [ 0.98593777,  0.3946355 