In [32]:
import pickle
import random

import keras
import pandas as pd
import numpy as np

from modelling import *
from evaluating import *
from utils import TextProcessor

In [46]:
class BatchGenerator(keras.utils.Sequence):
    def __init__(self, pos_size, neg_size, data_path):
        self.pos_size = pos_size
        self.neg_size = neg_size
        
        que = pd.read_csv(data_path + 'questions.csv')
        tag_que = pd.read_csv(data_path + 'tag_questions.csv')
        tags = pd.read_csv(data_path + 'tags.csv')
        pro = pd.read_csv(data_path + 'professionals.csv')
        ans = pd.read_csv(data_path + 'answers.csv')
        
        self.tp = TextProcessor()
        pro['professionals_industry'] = pro['professionals_industry'].apply(self.tp.process)
        pro['professionals_industry'] = pro['professionals_industry'].apply(lambda x: ' '.join(x))
        
        self.pro_ind = {row['professionals_id']: row['professionals_industry'] for i, row in pro.iterrows()}
        
        que_tags = que.merge(tag_que, left_on = 'questions_id', right_on = 'tag_questions_question_id').merge(tags, left_on = 'tag_questions_tag_id', right_on = 'tags_tag_id')
        que_tags = que_tags[['questions_id', 'tags_tag_name']].groupby(by = 'questions_id', as_index = False).aggregate(lambda x: ' '.join(x))
        self.que_tag = {row['questions_id']: row['tags_tag_name'].split() for i, row in que_tags.iterrows()}
        
        ans_que = ans.merge(que, left_on = 'answers_question_id', right_on = 'questions_id')
        ans_que_pro = ans_que.merge(pro, left_on = 'answers_author_id', right_on = 'professionals_id')
        
        self.ques = list(ans_que_pro['questions_id'])
        self.pros = list(ans_que_pro['professionals_id']) 
        
        self.que_pro_set = {(row['questions_id'], row['professionals_id']) for i, row in ans_que_pro.iterrows()}
        self.que_pro_list = list(self.que_pro_set)
        
        with open('tags_embs.pickle', 'rb') as file:
            self.tag_emb = pickle.load(file)
        with open('industries_embs.pickle', 'rb') as file:
            self.ind_emb = pickle.load(file)
        
    def __len__(self):
        return len(self.que_pro) // self.pos_size
    
    def __convert(self, pairs):
        x_que, x_pro = [], []
        for que, pro in pairs:
            tmp = []
            for tag in self.que_tag:
                tmp.append(self.tag_emb[tag])
            x_que.append(np.vstack(tmp).mean(axis = 0))
            x_pro.append(self.ind_emb[self.pro_ind[pro]])
        return np.vstack(x_que), np.vstack(x_que)
            
    def __getitem__(self, index):
        pos_pairs = self.que_pro_list[self.pos_size * index: self.pos_size * (index + 1)]
        neg_pairs = []
        for i in range(self.neg_size):
            while True:
                que = random.choice(self.ques)
                pro = random.choice(self.pros)
                if (que, pro) not in self.que_pro_set:
                    neg_pairs.append((que, pro))
                    break
        print(pos_pairs)
        print(neg_pairs)
        x_pos_que, x_pos_pro = self.__convert(pos_pairs)
        x_neg_que, x_neg_pro = self.__convert(neg_pairs)
        
        return [np.vstack([x_pos_que, x_neg_que]), np.vstack([x_pos_pro, x_neg_pro])], \
                np.vstack([np.ones((len(x_pos_que), )), np.zeros((len(x_neg_que), ))])

In [49]:
bg = BatchGenerator(1, 1, '../../data/')

In [50]:
bg.__getitem__(0)

[('27ef6db256fd41ffab6ad90d1c27ebac', 'c62e0c865dd543098b45dab7781310e2')]
[('80dc85ecc9524ac5a46716a62691e491', '6a388a0aecba4e25a7c06cb389577386')]


KeyError: '0003e7bf48f24b5c985f8fce96e611f3'

In [51]:
bg.tag_emb

{'career': array([-0.37037852,  0.00895506,  0.15439406, -0.00267434, -0.15145998,
        -0.3769104 ,  1.1146779 , -0.31231797,  0.7811482 ,  0.36156967],
       dtype=float32),
 'finance': array([ 0.06638838,  0.5124224 ,  0.6222087 , -0.37416294, -0.19592267,
        -1.4379241 ,  0.61562103,  0.711204  , -0.09105965,  1.3037035 ],
       dtype=float32),
 'art': array([-1.6796117 , -0.05714934,  0.5607866 , -0.06386299, -0.85911447,
        -0.1489143 ,  0.21528502, -0.21861446,  0.44828528,  0.792982  ],
       dtype=float32),
 'college-major': array([-0.3505281 ,  0.2867531 , -0.0532017 , -0.191701  , -0.6257066 ,
        -0.06549475,  0.7878797 ,  0.34178403,  1.0999756 ,  0.52936906],
       dtype=float32),
 'nursing': array([-0.15025817,  1.6489999 , -0.13229975,  0.31736776,  0.16912708,
        -0.7903406 ,  1.1541829 ,  0.1010666 ,  0.73791355, -0.3667346 ],
       dtype=float32),
 'video-games': array([-1.3961504 ,  0.17308146,  0.7201859 , -0.37004372, -1.5393287 ,
      