In [100]:
import nltk
import pickle
import wikipedia
import pandas as pd
import numpy as np
from sklearn import preprocessing
from statistics import mean
from semantic_text_similarity.models import ClinicalBertSimilarity

import import_ipynb
import preprocess_utils
import wikt_def_parse

import gensim.downloader
###THIS IS SLOWING DOWN INITIALISATION
# glove_vectors = gensim.downloader.load('glove-wiki-gigaword-300') #75s loadtime
glove_vectors = gensim.downloader.load('glove-wiki-gigaword-50')#18s loadtime
from nltk.stem.porter import PorterStemmer
porter_stemmer = PorterStemmer()
from nltk.stem import WordNetLemmatizer 
lemmatizer = WordNetLemmatizer()

In [101]:
clinical_model = ClinicalBertSimilarity(device='cuda', batch_size=10)

In [102]:
def title_domain_predict(title, domain_arr, domain_exclusions):

    title_tokens = nltk.word_tokenize(title.lower())

    
    FINAL = []
    for dt in domain_arr:
        
        condition = False
        for excl in domain_exclusions:
            if excl in dt:
                condition = True
                break
                
                
        if condition == True:
            FINAL.append(0)
        else: 
            temp3 = []
            for dtt in dt:
                temp2 = []
                for dttt in dtt.split():
                    temp1 = []
                    for tt in title_tokens:
                        if dttt in ['general','uncountable','countable']:
                            break
                        else:
                            try:
                                similarity = glove_vectors.similarity(dttt,tt)
                            except:
                                similarity = 0
                        temp1.append(similarity)
                    if len(temp1)>0:
                        temp2.append(np.mean(temp1))
                    else:
                        temp2.append(0.0)
                temp3.append(np.mean(temp2))
            FINAL.append(np.mean(temp3))
            
    return FINAL

In [103]:
def def_semantic_predict(curtext, clean_defs, domain_exclusions):
    final = []
    for i in clean_defs:
        if nltk.word_tokenize(i)[0] in domain_exclusions:
            final.append(0)
        else:
            final += list(clinical_model.predict([(i,curtext)]))
    return final

In [104]:
def ex_semantic_predict(curtext, clean_exs):

    final = []
    for ex in clean_exs:
        if ex == 'null':
            final.append(0)
        else:
            final += list(clinical_model.predict([(curtext, ex)]))
            
    return final

In [105]:
def select_best_index(predict1_probs, predict2_probs):
    
    plus1dev = np.mean(predict2_probs) + np.std(predict2_probs)
    
    argsort = np.argsort(predict1_probs)[::-1][:2]
    
    
    if len(predict2_probs)>1 and (predict2_probs[argsort[0]] == 0 or predict2_probs[argsort[1]]== 0):
        return argsort[0]
    
    for i in argsort:
        if predict2_probs[i] >= plus1dev:
            return i
    return argsort[0]

#     return np.argsort(predict2_probs)[::-1][0]

In [106]:
def select_best_index2(predict1_probs, predict2_probs, predict3_probs):
    
    predict2_norm = list(preprocessing.normalize([np.array(predict2_probs)])[0])
    predict3_norm = list(preprocessing.normalize([np.array(predict3_probs)])[0])
    composite = [predict2_norm[i] + 1.3*predict3_norm[i] for i in range(len(predict2_norm))]
    print(predict1_probs)
    print(composite)
    
    cutoff = 2
    argsort_main = np.argsort(predict1_probs)[::-1][:cutoff]
    argsort_add = np.argsort(composite)[::-1][:cutoff]
    
    if composite[argsort_main[0]] == 0 or composite[argsort_main[1]]== 0:
        print('didnt use composite')
        return argsort_main[0]
    
    for i, element in enumerate(argsort_main):
        if argsort_main[i] == argsort_add[i]:
            print('used composite')
            return argsort_main[i]
    return argsort_main[0]
    

In [107]:
def driver(title, curtext, kw, pos, depth):
    
    pos_association = {
    'CC':['conjunction'],
    'CD':['numeral'],
    'DT':['determiner'],
    'EX':[],
    'FW':[],
    'IN':['preposition','conjunction'],
    'JJ':['adjective'],
    'JJR':['adjective'],
    'JJS':['adjective'],
    'LS':[],
    'MD':['verb'],
    'NN':['noun','proper noun'],
    'NNS':['noun'], #PLURAL
    'NNP':['noun', 'proper noun'],
    'NNPS':['noun', 'proper noun'], #plural
    'PDT':['determiner'],
    'POS':[],
    'PRP':['pronoun'],
    'PRP$':['pronoun'],
    'RB':['adverb'],
    'RBR':['adverb'],
    'RBS':['adverb'],
    'RP':['preposition'], #unsure
    'TO':[], #unsure
    'UH':['interjection'],
    'VB':['verb'],
    'VBG':['verb'],
    'VBD':['verb'],
    'VBN':['verb'],
    'VBP':['verb'],
    'VBZ':['verb'],
    'WDT':['determiner'],
    'WP':['pronoun'],
    'WRB':['adverb'],
    #to deal with keyphrases
    'noun':['noun', 'proper noun'],
    'verb':['verb']
    }
    
    
    try:
        wikt_object = wikt_def_parse.define(kw)[pos]
    except IndexError:
        try:
            wikt_object = wikt_def_parse.define(kw.lower())[pos]
        except (IndexError,KeyError):
            try:
                wiki = '(wiki) ' + nltk.sent_tokenize(wikipedia.page(kw, auto_suggest=False).content)[0]
                return wiki
            except:
                return 'invalid-term'
    except KeyError:
        return 'invalid-pos'
    
        
    orig_defs = [obj['def'] for obj in wikt_object]
    orig_exs = [obj['ex'] for obj in wikt_object]
    domain_arr = [preprocess_utils.clean_text(sent[:sent.find(')')],False,False,True,False,True).replace('.','').replace('(','').strip().split(', ')for sent in orig_defs]
    clean_defs = [preprocess_utils.clean_text(definition,False,True,True,True,True) for definition in orig_defs]
    clean_exs = [preprocess_utils.clean_text(example,False,True,True,True,True) for example in orig_exs]
    curtext = preprocess_utils.clean_text(curtext,True,True,True,True,True)
    domain_exclusions = ['dated', 'obsolete', 'rare']
    
    
    predict1_probs = def_semantic_predict(curtext, clean_defs, domain_exclusions)
    
    
#     no_ex_count = 0
#     no_domain_count = 0
#     for i in range(len(domain_arr)):
#         if domain_arr[i][0] == 'general':
#             no_domain_count += 1
#         if clean_exs[i] == 'null':
#             no_ex_count += 1
#     if no_ex_count <= no_domain_count:
#     predict2_probs = ex_semantic_predict(curtext, clean_exs)
#     else:
    predict2_probs = title_domain_predict(title, domain_arr, domain_exclusions)
    
    
    best_index = select_best_index(predict1_probs, predict2_probs)
    prediction = orig_defs[best_index].replace('.','')
    
    
    pred_last_word = nltk.word_tokenize(prediction)[-1]
    if depth==0 and porter_stemmer.stem(pred_last_word) == porter_stemmer.stem(kw):
        depth += 1
        if driver(title, curtext, pred_last_word, pos, depth) == 'invalid-pos':
            pred_last_pos = pos_association[nltk.pos_tag([pred_last_word])[0][1]][0]
            prediction += '. ' + driver(title, curtext, pred_last_word, pred_last_pos, depth)
        else:
            prediction += '. ' + driver(title, curtext, pred_last_word, pos, depth)
            
            

    return prediction.replace('invalid-pos','')

#     return predict1_probs, predict2_probs


In [108]:
# eval_data = pd.read_csv('def_predict_eval_data.csv')
# for i in range(eval_data.shape[0]):
#     if type(eval_data['def1'][i])==str and eval_data['def1'][i][0]!='(':
#         eval_data['def1'][i] = '(general) ' + eval_data['def1'][i]
#     if type(eval_data['def2'][i])==str and eval_data['def2'][i][0]!='(':
#         eval_data['def2'][i] = '(general) ' + eval_data['def2'][i]

In [109]:
# def evaluate_prediction(eval_data):
#     correct_count = 0
#     for i in range(eval_data.shape[0]):
#         curtext = wikipedia.page(eval_data['wiki'][i], auto_suggest=False).content
#         temp_pred = driver(eval_data['wiki'][i], curtext, eval_data['kw'][i], eval_data['pos'][i])
#         if temp_pred == str(eval_data['def1'][i]).strip() or temp_pred == str(eval_data['def2'][i]).strip():
#             print(eval_data['kw'][i],'---',eval_data['wiki'][i],'TRUE','\n',temp_pred, '\n', eval_data['def1'][i], '\n', eval_data['def2'][i],'\n')
#             correct_count += 1
#         else:
#               print(eval_data['kw'][i],'---',eval_data['wiki'][i],'FALSE','\n',temp_pred, '\n', eval_data['def1'][i], '\n', eval_data['def2'][i],'\n')
#     return int(100*correct_count/eval_data.shape[0])

In [110]:
# evaluate_prediction(eval_data)

In [None]:
# sentence sematic predict only = 76% accuracy
# title to domain word embeddings = 38% accuracy
# example semantic predict = 42% accuracy
# sentence semantic + example predict = 78% accuracy
# sentence semantic + domain to title predict = 78% accuracy
# semantic predict + composite = 76% accuracy
# semantic predict + choose method = 80% accuracy