In [28]:
base_dir="/mnt/efs/shared/meg_shared_scripts/meg-kb"
data_ac="indeeda-meg-ac"
data_pt="indeeda-meg-pt"
yutong_base_dir="/home/ubuntu/users/yutong"

In [29]:
%cd $base_dir/src/concept_learning/

/mnt/efs/shared/meg_shared_scripts/meg-kb/src/concept_learning


In [30]:
from tqdm.notebook import tqdm
import argparse
import re
import numpy as np
from scipy.spatial.distance import cosine
from scipy.stats import pearsonr, entropy, gmean
import random
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers import BertTokenizer, BertModel, BertForMaskedLM
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import json
from collections import defaultdict
import time
import importlib

import logging
from sklearn.cluster import KMeans, AgglomerativeClustering
from sklearn.manifold import TSNE
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
import os
import sys
import math
from annoy import AnnoyIndex
import matplotlib
from matplotlib import pyplot as plt
import networkx as nx

import spacy
from spacy.matcher import Matcher
from spacy.lang.en import English
nlp = English()
nlp.add_pipe(nlp.create_pipe('sentencizer'))
spacy_tokenizer = nlp.tokenizer


from compute_concept_clusters import knn
from compute_keyphrase_embeddings import ensure_tensor_on_device, mean_pooling

from lm_probes import LMProbe, LMProbe_GPT2, LMProbe_Joint, LMProbe_PMI, LMProbe_PMI_greedy
from utils import load_embeddings, load_seed_aligned_concepts, load_seed_aligned_relations, load_benchmark
from utils import get_masked_contexts, bert_untokenize
from utils import learn_patterns

from roberta_ses.interface import Roberta_SES_Entailment

In [31]:
import utils
importlib.reload(utils)
from utils import load_embeddings, load_seed_aligned_concepts, load_seed_aligned_relations, load_benchmark
from utils import get_masked_contexts, bert_untokenize
from utils import learn_patterns

In [32]:
seed_aligned_concepts_path = os.path.join(base_dir, f'data/indeed-benchmark/seed_aligned_concepts.csv')
seed_aligned_relations_path = os.path.join(base_dir, f'data/indeed-benchmark/seed_aligned_relations_nodup.csv')
benchmark_path = os.path.join(base_dir, f'data/indeed-benchmark/benchmark_evidence_clean.csv')


# EE-LM-probe (prompt)

In [33]:
'''
@Nikita: Here are the code blocks for exploring LM prompts for EE.

Some core code:
lm_probe = LMProbe()        // LMProbe: BERT; LMProbe_GPT2: GPT2; etc.
all_entitites = ...         // all the entities
_template = "Dress code like jeans, [MASK] and tattoos." // A string with [MASK] (LMProbe automatically takes care of all entity token lengths, so don't need to duplicate mask tokens)
_res = lm_probe.score_candidates(input_txt=_template, cands=all_entities)
list(enumerate(_res[:50]))  // Show results (rank, cand, score)

'''

'\n@Nikita: Here are the code blocks for exploring LM prompts for EE.\n\nSome core code:\nlm_probe = LMProbe()        // LMProbe: BERT; LMProbe_GPT2: GPT2; etc.\nall_entitites = ...         // all the entities\n_template = "Dress code like jeans, [MASK] and tattoos." // A string with [MASK] (LMProbe automatically takes care of all entity token lengths, so don\'t need to duplicate mask tokens)\n_res = lm_probe.score_candidates(input_txt=_template, cands=all_entities)\nlist(enumerate(_res[:50]))  // Show results (rank, cand, score)\n\n'

In [7]:
lm_probe = LMProbe()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [35]:
d_lm_probe = LMProbe('/home/ubuntu/users/nikita/models/bert_finetuned_lm/indeed_reviews_ques_ans')

In [36]:
seed_concepts_path = os.path.join(base_dir, f'data/indeed-benchmark/seed_aligned_concepts.csv')
seed_concepts_df = load_seed_aligned_concepts(seed_concepts_path)
print(seed_concepts_df.head())
emb_num_path = os.path.join(base_dir, f'data/{data_ac}/intermediate/BERTembednum+seeds.txt')
with open(emb_num_path, 'r') as f:
    all_entities = [l.rsplit(' ', 1)[0] for l in f]
print(all_entities[:20])
all_entities = list(set(all_entities))

seed_instances_dict = dict(zip(
        seed_concepts_df['alignedCategoryName'].tolist(),
        seed_concepts_df['seedInstances'].tolist()
    ))
    
seed_entities_lst = seed_concepts_df['seedInstances'].tolist()
seed_entities = [item for sublist in seed_entities_lst for item in sublist]
print(seed_entities)
concepts = seed_concepts_df['alignedCategoryName'].tolist()

  alignedCategoryName unalignedCategoryName generalizations  \
0             company               company             NaN   
1          dress_code            dress code             NaN   
2        job_position          job position             NaN   
3        pay_schedule            pay period             NaN   
4            benefits              benefits    compensation   

                                       seedInstances  
0       [walmart, amazon, subway, microsoft, target]  
1  [business casual, uniform, hair color, tattoos...  
2  [delivery driver, store manager, cashier, pack...  
3               [weekly, biweekly, friday, saturday]  
4  [health insurance, flexible schedule, 401k, pa...  
['multiple times', 'upper', 'management', 'wal mart', 'company', 'overnight stocker', 'walmart', 'prejudice', 'tuition assistance', 'mgr', 'department', 'manager', 'knowledge', 'leadership', 'job security', 'current position', 'positive attitude', 'business', '* *', 'talk']
['walmart', 'ama

In [37]:
gold_ee = pd.read_csv('../../../tmp/gold_ee.csv')
gold_ee = gold_ee[~gold_ee['neighbor'].isin(seed_entities)]
gold_ee = gold_ee.drop_duplicates(subset=['concept', 'neighbor'])

from tabulate import tabulate

def pratk(df, score='lm_score'):
    df = pd.merge(df, gold_ee, on=['concept', 'neighbor'])
    df = df.fillna(1)
    grped = df.groupby('concept')
    scores = {}
    for name, grp in grped:
        grp = grp.reset_index()
        grp = grp.sort_values(by=score, ascending=False)
        ks = [1,2,5,10,25,50,75,100,200]
        ks = [k for k in ks if k < len(grp)]
        all_correct = len(grp[grp['Majority'] > 0])
        patk = {}
        ratk = {}
        mink = {}
        for k in ks:
            sub = grp.head(k)
            correct = sub[sub['Majority'] > 0]
            precision = sub['Majority'].sum() / k
            recall = sub['Majority'].sum() / all_correct
            patk[k] = precision
            ratk[k] = recall
            mink[k] = correct.iloc[-1][score] if len(correct) > 0 else -1
        scores[name] = {"precision": patk, "recall": ratk, "min_score": mink}
    for cc, c_scores in scores.items():
        print(cc)
        print(pd.DataFrame(c_scores).to_string())
    return scores

## Avg. Scoring

In [38]:
def probe_and_eval(lm_probe, concept, probe_prompts, concept_phrase=None, evaluate=True):
    seeds = seed_instances_dict[concept]
    if concept_phrase is None:
        cc_phrase = ' '.join(concept.split('_'))
    else:
        cc_phrase = concept_phrase
    cand_scores_per_template = []
    cand_scores = []
    for template in probe_prompts:
        _input_txt = template.format(cc_phrase, ', '.join(seeds[:-1]), '[MASK]', seeds[-1])
        _cand_scores = lm_probe.score_candidates(_input_txt, all_entities)
        _cand_scores.sort(key=lambda d : d["cand"])
        cand_scores_per_template.append(_cand_scores)

    for _cand_score_lst in zip(*cand_scores_per_template):
                    # _cand_score_lst: List[Dict["cand", "score"]], for the same "cand" and different template 
            _cand = _cand_score_lst[0]["cand"]
            _score = sum([d["score"] for d in _cand_score_lst]) / len(_cand_score_lst)
                    # _score = np.log(_score)
            cand_scores.append({"cand": _cand, "score": _score})

    extraction_results = []
    for d in cand_scores:
        e = d["cand"]
        if e in seeds:
            continue
        lm_score = d["score"]
        extraction_results.append({'concept': concept,
                                           'neighbor': e,
                                           'lm_score': lm_score
                                          })

        extraction_results.sort(key=lambda d : d['lm_score'], reverse=True)
        results_df = pd.DataFrame(extraction_results)
    if evaluate:
        pratk(results_df)
    return results_df

In [14]:
probe_prompts = [
        "{0}, such as {1}, {2} and {3}.",
        "{0}, including {1}, {2} and {3}.",
        "{1}, {2}, {3} and other {0}.",
    ]
probe_prompts2 = [
        "{0}, such as {1}, {2} and {3}.",
        "{0}, including {1}, {2} and {3}.",
        "{1}, {2}, {3} and other {0}.",
        "{2} has many employees.",
        "I worked at {2}",
        "hired at {2}."
    ]

print('BERT-base')
probe_and_eval(lm_probe, 'company', probe_prompts)
print('with relation prompts')
probe_and_eval(lm_probe,'company', probe_prompts2)

print()
print('Domain-adapted BERT')
probe_and_eval(d_lm_probe,'company', probe_prompts)
print('with relation prompts')
probe_and_eval(d_lm_probe,'company', probe_prompts2)

BERT-base
company
     precision    recall  min_score
1     1.000000  0.008772   0.068087
2     1.000000  0.017544   0.046398
5     1.000000  0.043860   0.020752
10    1.000000  0.087719   0.010284
25    1.000000  0.219298   0.002511
50    0.840000  0.368421   0.000986
75    0.706667  0.464912   0.000690
100   0.600000  0.526316   0.000380
with relation prompts
company
     precision    recall  min_score
1     1.000000  0.008772   0.034074
2     1.000000  0.017544   0.023606
5     1.000000  0.043860   0.012628
10    1.000000  0.087719   0.005199
25    0.960000  0.210526   0.001541
50    0.780000  0.342105   0.000706
75    0.706667  0.464912   0.000488
100   0.600000  0.526316   0.000320

Domain-adapted BERT
company
     precision    recall  min_score
1         1.00  0.008772   0.063044
2         1.00  0.017544   0.034719
5         1.00  0.043860   0.014492
10        1.00  0.087719   0.009953
25        1.00  0.219298   0.005252
50        0.88  0.385965   0.001830
75        0.76  0.50000

Unnamed: 0,concept,neighbor,lm_score
0,company,starbucks,3.476803e-02
1,company,mcdonalds,1.984275e-02
2,company,nike,1.751212e-02
3,company,amazons,1.739639e-02
4,company,mcdonald ' s,1.101652e-02
...,...,...,...
8032,company,gdansk,3.600702e-08
8033,company,screenplay,3.215033e-08
8034,company,mascara,2.928107e-08
8035,company,flute,2.463806e-08


In [17]:
_results_df = probe_and_eval(d_lm_probe, 'company', probe_prompts)

company
     precision    recall  min_score
1         1.00  0.008772   0.063044
2         1.00  0.017544   0.034719
5         1.00  0.043860   0.014492
10        1.00  0.087719   0.009953
25        1.00  0.219298   0.005252
50        0.88  0.385965   0.001830
75        0.76  0.500000   0.000813
100       0.72  0.631579   0.000405


In [22]:
''' YS: "fasts" and "etcs" are not company, why P@10 = 1.00 ?? '''

_results_df.head(10)

Unnamed: 0,concept,neighbor,lm_score
0,company,starbucks,0.063044
1,company,nike,0.034719
2,company,amazons,0.025174
3,company,apple,0.01896
4,company,etcs,0.01809
5,company,macs,0.017288
6,company,fasts,0.01469
7,company,mcdonalds,0.014492
8,company,wal mart,0.012479
9,company,wal mart,0.012479


In [102]:
probe_prompts = [
        "{0}, such as {1}, {2} and {3}.",
        "{0}, including {1}, {2} and {3}.",
        "{1}, {2}, {3} and other {0}.",
    ]
probe_prompts2 = [
        "{0}, such as {1}, {2} and {3}.",
        "{0}, including {1}, {2} and {3}.",
        "{1}, {2}, {3} and other {0}.",
        "The company did not do {2}.",
        "They check your {2}.",
        "{2} is not a card."
    ]

print('BERT-base')
probe_and_eval(lm_probe, 'background_screening', probe_prompts)
print('with relation prompts')
probe_and_eval(lm_probe, 'background_screening', probe_prompts2)

print()
print('Domain-adapted BERT')
probe_and_eval(d_lm_probe,'background_screening', probe_prompts)
print('with relation prompts')
probe_and_eval(d_lm_probe,'background_screening', probe_prompts2)

BERT-base
background_screening
     precision    recall  min_score
1     1.000000  0.012346   0.068725
2     1.000000  0.024691   0.061470
5     0.800000  0.049383   0.027790
10    0.800000  0.098765   0.010895
25    0.720000  0.222222   0.006133
50    0.680000  0.419753   0.004163
75    0.546667  0.506173   0.002976
100   0.580000  0.716049   0.002050
with relation prompts
background_screening
     precision    recall  min_score
1         1.00  0.012346   0.034897
2         1.00  0.024691   0.030772
5         0.80  0.049383   0.013992
10        0.90  0.111111   0.006143
25        0.76  0.234568   0.003983
50        0.68  0.419753   0.002365
75        0.60  0.555556   0.001736
100       0.58  0.716049   0.001147

Domain-adapted BERT
background_screening
     precision    recall  min_score
1     0.000000  0.000000  -1.000000
2     0.000000  0.000000  -1.000000
5     0.600000  0.037037   0.085545
10    0.700000  0.086420   0.024580
25    0.720000  0.222222   0.009247
50    0.600000  0.37

In [103]:
probe_prompts = [
        "{0}, such as {1}, {2} and {3}.",
        "{0}, including {1}, {2} and {3}.",
        "{1}, {2}, {3} and other {0}.",
    ]
probe_prompts2 = [
        "{0}, such as {1}, {2} and {3}.",
        "{0}, including {1}, {2} and {3}.",
        "{1}, {2}, {3} and other {0}.",
        "The job requires {2}.",
        "You need to have {2} to apply.",
    ]

print('BERT-base')
probe_and_eval(lm_probe, 'hire_prerequisite', probe_prompts, 'requirements')
print('with relation prompts')
probe_and_eval(lm_probe, 'hire_prerequisite', probe_prompts2, 'requirements')

print()
print('Domain-adapted BERT')
probe_and_eval(d_lm_probe,'hire_prerequisite', probe_prompts, 'requirements')
print('with relation prompts')
probe_and_eval(d_lm_probe,'hire_prerequisite', probe_prompts2, 'requirements')

BERT-base
hire_prerequisite
     precision  recall  min_score
1     1.000000    0.01   0.052757
2     1.000000    0.02   0.013271
5     0.400000    0.02   0.013271
10    0.500000    0.05   0.005917
25    0.560000    0.14   0.004174
50    0.660000    0.33   0.002382
75    0.626667    0.47   0.001623
100   0.610000    0.61   0.001089
200   0.495000    0.99   0.000005
with relation prompts
hire_prerequisite
     precision  recall  min_score
1     1.000000    0.01   0.031668
2     0.500000    0.01   0.031668
5     0.600000    0.03   0.008726
10    0.500000    0.05   0.004849
25    0.600000    0.15   0.003153
50    0.640000    0.32   0.001702
75    0.653333    0.49   0.001265
100   0.620000    0.62   0.000885
200   0.495000    0.99   0.000004

Domain-adapted BERT
hire_prerequisite
     precision  recall  min_score
1     1.000000    0.01   0.019661
2     1.000000    0.02   0.016057
5     0.800000    0.04   0.005686
10    0.800000    0.08   0.003724
25    0.880000    0.22   0.002821
50    0.7

In [104]:
probe_prompts = [
        "{0}, such as {1}, {2} and {3}.",
        "{0}, including {1}, {2} and {3}.",
        "{1}, {2}, {3} and other {0}.",
    ]
probe_prompts2 = [
        "{0}, such as {1}, {2} and {3}.",
        "{0}, including {1}, {2} and {3}.",
        "{1}, {2}, {3} and other {0}.",
        "they will hire people {2}.",
    ]

print('BERT-base')
probe_and_eval(lm_probe, 'person', probe_prompts, 'person')
print('with relation prompts')
probe_and_eval(lm_probe, 'person', probe_prompts2, 'person')

print()
print('Domain-adapted BERT')
probe_and_eval(d_lm_probe,'person', probe_prompts, 'person')
print('with relation prompts')
probe_and_eval(d_lm_probe,'person', probe_prompts2, 'person')

BERT-base
person
     precision    recall  min_score
1     1.000000  0.029412   0.027419
2     1.000000  0.058824   0.015791
5     0.800000  0.117647   0.005052
10    0.900000  0.264706   0.002278
25    0.640000  0.470588   0.001025
50    0.440000  0.647059   0.000608
75    0.346667  0.764706   0.000365
100   0.280000  0.823529   0.000266
with relation prompts
person
     precision    recall  min_score
1     1.000000  0.029412   0.020585
2     1.000000  0.058824   0.011879
5     0.800000  0.117647   0.003799
10    0.900000  0.264706   0.001718
25    0.680000  0.500000   0.000803
50    0.460000  0.676471   0.000445
75    0.346667  0.764706   0.000278
100   0.280000  0.823529   0.000204

Domain-adapted BERT
person
     precision    recall  min_score
1         1.00  0.029412   0.036885
2         1.00  0.058824   0.019011
5         1.00  0.147059   0.012808
10        0.60  0.176471   0.011755
25        0.68  0.500000   0.002846
50        0.48  0.705882   0.001528
75        0.36  0.794118  

In [106]:
probe_prompts = [
        "{0}, such as {1}, {2} and {3}.",
        "{0}, including {1}, {2} and {3}.",
        "{1}, {2}, {3} and other {0}.",
    ]
probe_prompts2 = [
        "{0}, such as {1}, {2} and {3}.",
        "{0}, including {1}, {2} and {3}.",
        "{1}, {2}, {3} and other {0}.",
        "they allow you to wear {2}.",
        "{2} are allowed."
    ]

print('BERT-base')
probe_and_eval(lm_probe, 'dress_code', probe_prompts, 'dress_code')
print('with relation prompts')
probe_and_eval(lm_probe, 'dress_code', probe_prompts2, 'dress_code')

print()
print('Domain-adapted BERT')
probe_and_eval(d_lm_probe,'dress_code', probe_prompts, 'dress_code')
print('with relation prompts')
probe_and_eval(d_lm_probe,'dress_code', probe_prompts2, 'dress_code')

BERT-base
dress_code
     precision    recall  min_score
1         1.00  0.003922   0.044785
2         1.00  0.007843   0.030868
5         1.00  0.019608   0.017537
10        1.00  0.039216   0.008212
25        0.96  0.094118   0.002613
50        0.92  0.180392   0.001255
75        0.84  0.247059   0.000754
100       0.86  0.337255   0.000488
200       0.89  0.698039   0.000045
with relation prompts
dress_code
     precision    recall  min_score
1     1.000000  0.003922   0.026927
2     1.000000  0.007843   0.018830
5     1.000000  0.019608   0.010764
10    1.000000  0.039216   0.006065
25    0.960000  0.094118   0.002116
50    0.980000  0.192157   0.001280
75    0.933333  0.274510   0.000853
100   0.880000  0.345098   0.000605
200   0.885000  0.694118   0.000108

Domain-adapted BERT
dress_code
     precision    recall  min_score
1         0.00  0.000000  -1.000000
2         0.50  0.003922   0.056419
5         0.80  0.015686   0.033686
10        0.90  0.035294   0.015546
25        0.88

In [43]:
corpus_path = os.path.join(base_dir, f'data/{data_ac}/intermediate/sent_segmentation.txt')
with open(corpus_path, 'r') as f:
    corpus = f.readlines()
    corpus = [l.strip() for l in corpus]
print(len(corpus))
corpus[:20]

901796


["Hard , unless you 're favorited .",
 'I asked <phrase>multiple times</phrase> to be trained on other things through out the store and refused .',
 'nearly impossible , unless you have an in with <phrase>upper</phrase> level <phrase>management</phrase> your better off finding a better job anywhere else',
 'Very easy to get promoted as long as you work hard show your motivated and are reliable .',
 'It can be really easy to get promoted !',
 'The <phrase>management</phrase> team , promotes associates to <phrase>management</phrase> and above within the store , before they look for someone outside <phrase>wal mart</phrase>',
 'It is difficult to be promoted .',
 'A lot of factors go into a store employee promotion .',
 "It 's pretty easy to get the job done as long as you are on time , respectful , kind , trustworthy and get the job done when it 's suppose to .",
 "you 'll be promoted .",
 'Easy if you are willing to work around more hours and little pay raise .',
 'not hard at all',
 'V

In [44]:
import collections
def get_context(entities, corpus):
    
    matched_context = []
    contexts = []
    for entity in entities: 
        token = '<phrase>{}</phrase>'.format(entity)
        for sent in corpus:
            if token in sent:
                matched_context.append(sent)
        for sent in matched_context:
            splits = sent.split(token)
            if len(splits) < 2:
                continue
            left_context_words = splits[0].split(' ')
            left_context_words = left_context_words[-min(len(left_context_words), 4):]

            right_context_words = splits[1].split(' ')
            right_context_words = right_context_words[:min(len(right_context_words), 4)]
            context = ' '.join(left_context_words) + '<tgt>' + ' '.join(right_context_words)
            context_words = context.split(' ')
            if len(context_words) < 4:
                continue
            contexts.append(context)
    print(collections.Counter(contexts).most_common(20))

In [81]:
get_context(['walmart', 'amazon', 'subway', 'microsoft', 'target'], corpus)

[('to work at <tgt> .', 63), ('<tgt> does not drug', 56), ('to work at <tgt>', 50), ('to work for <tgt>', 22), ('No <tgt> does not drug', 19), ('while working at <tgt> .', 17), ('<tgt> is a good', 15), ('<tgt> is a great', 15), ('No , <tgt> does not drug', 15), ('<tgt> does not hire', 14), ('to work for <tgt> .', 13), ('<tgt> is a <phrase>great', 13), ('<tgt> does not do', 12), ('No <tgt> does not hire', 11), ('<tgt> will work with', 10), ('<tgt> is always hiring', 10), ('The <tgt> I worked at', 10), ('<tgt> pays bi weekly', 10), ('from home for <tgt>', 10), ('<tgt> is a <phrase>drug', 10)]


In [108]:
probe_prompts = [
        "{0}, such as {1}, {2} and {3}.",
        "{0}, including {1}, {2} and {3}.",
        "{1}, {2}, {3} and other {0}.",
    ]

probe_prompts2 = [
        "{0}, such as {1}, {2} and {3}.",
        "{0}, including {1}, {2} and {3}.",
        "{1}, {2}, {3} and other {0}.",
        "to work at {2}.",
        "{2} does not drug.",
        "while working at {2}."
    ]

print('BERT-base')
probe_and_eval(lm_probe, 'company', probe_prompts, 'company')
print('with relation prompts')
probe_and_eval(lm_probe, 'company', probe_prompts2, 'company')

print()
print('Domain-adapted BERT')
probe_and_eval(d_lm_probe,'company', probe_prompts, 'company')
print('with relation prompts')
probe_and_eval(d_lm_probe,'company', probe_prompts2, 'company')

BERT-base
company
     precision    recall  min_score
1     1.000000  0.008772   0.068087
2     1.000000  0.017544   0.046398
5     1.000000  0.043860   0.020752
10    1.000000  0.087719   0.010284
25    1.000000  0.219298   0.002511
50    0.840000  0.368421   0.000986
75    0.706667  0.464912   0.000690
100   0.600000  0.526316   0.000380
with relation prompts
company
     precision    recall  min_score
1     1.000000  0.008772   0.034133
2     1.000000  0.017544   0.023354
5     1.000000  0.043860   0.012258
10    1.000000  0.087719   0.005157
25    0.960000  0.210526   0.001551
50    0.780000  0.342105   0.000703
75    0.666667  0.438596   0.000495
100   0.590000  0.517544   0.000362

Domain-adapted BERT
company
     precision    recall  min_score
1         1.00  0.008772   0.063044
2         1.00  0.017544   0.034719
5         1.00  0.043860   0.014492
10        1.00  0.087719   0.009953
25        1.00  0.219298   0.005252
50        0.88  0.385965   0.001830
75        0.76  0.50000

In [83]:
get_context(['drug test', 'criminal background check', 'employment verification', 'driving record', 'credit report', 'criminal record'], corpus)

[('to take a <tgt> .', 167), ('to take a <tgt>', 114), ('Yes they do <tgt>', 85), ('Yes they do <tgt> .', 44), ('to do a <tgt> .', 40), ('Yes they <tgt> .', 38), ('<phrase>background check</phrase> and <tgt> .', 34), ('There was no <tgt>', 32), ('they do nt <tgt>', 31), ('<phrase>background check</phrase> and <tgt>', 27), ('to do a <tgt>', 22), ('There is no <tgt>', 22), ('They do nt <tgt>', 22), ('to pass a <tgt> .', 21), ('They do <tgt> .', 20), ('never failed a <tgt> .', 20), ('There was no <tgt> .', 19), ('not have a <tgt> .', 19), ("do n't do <tgt>", 19), ('What kind of <tgt> do they do', 16)]


In [109]:
probe_prompts = [
        "{0}, such as {1}, {2} and {3}.",
        "{0}, including {1}, {2} and {3}.",
        "{1}, {2}, {3} and other {0}.",
    ]
probe_prompts2 = [
        "{0}, such as {1}, {2} and {3}.",
        "{0}, including {1}, {2} and {3}.",
        "{1}, {2}, {3} and other {0}.",
        "to take a {2}.",
        "Yes they do {2}.",
        "to pass a {2}.", 
        "they run a {2} before hiring."
    ]

print('BERT-base')
probe_and_eval(lm_probe, 'background_screening', probe_prompts, 'background_screening')
print('with relation prompts')
probe_and_eval(lm_probe, 'background_screening', probe_prompts2, 'background_screening')

print()
print('Domain-adapted BERT')
probe_and_eval(d_lm_probe,'background_screening', probe_prompts, 'background_screening')
print('with relation prompts')
probe_and_eval(d_lm_probe,'background_screening', probe_prompts2, 'background_screening')

BERT-base
background_screening
     precision    recall  min_score
1     1.000000  0.012346   0.092631
2     1.000000  0.024691   0.046511
5     1.000000  0.061728   0.025006
10    0.800000  0.098765   0.016128
25    0.640000  0.197531   0.007512
50    0.640000  0.395062   0.004474
75    0.586667  0.543210   0.003032
100   0.570000  0.703704   0.002000
with relation prompts
background_screening
     precision    recall  min_score
1     1.000000  0.012346   0.039700
2     1.000000  0.024691   0.019934
5     1.000000  0.061728   0.012520
10    0.800000  0.098765   0.008386
25    0.680000  0.209877   0.003598
50    0.640000  0.395062   0.002311
75    0.586667  0.543210   0.001503
100   0.580000  0.716049   0.000915

Domain-adapted BERT
background_screening
     precision    recall  min_score
1     0.000000  0.000000  -1.000000
2     0.000000  0.000000  -1.000000
5     0.600000  0.037037   0.081086
10    0.700000  0.086420   0.021252
25    0.760000  0.234568   0.008927
50    0.620000  0.38

In [89]:
get_context(["felons", "criminals", "disabled", "drug addicts", "high schoolers", "misdemeanor", "students", "seniors"], corpus)

[('you have a <tgt> .', 8), ('I had a <tgt> ', 8), ('you have a <tgt>', 8), ('depends on the <tgt>', 8), ('not have a <tgt>', 6), ('I had a <tgt>', 6), ('on what the <tgt> was for .', 6), ('depends on the <tgt> .', 4), ('I have a <tgt> on my background', 4), ('people with a <tgt> on their record', 4), ('anyone with a <tgt> .', 4), ('someone with a <tgt> .', 4), ('you with a <tgt> ', 4), ('you have a <tgt> of <phrase>theft</phrase> but', 4), ('I have a <tgt> .', 4), ('you with a <tgt>', 4), ('person with a <tgt> ', 4), ('depending on the <tgt>', 4), ('I have a <tgt> for <phrase>assault</phrase> and', 4), ('hired with a <tgt> .', 4)]


In [110]:
probe_prompts = [
        "{0}, such as {1}, {2} and {3}.",
        "{0}, including {1}, {2} and {3}.",
        "{1}, {2}, {3} and other {0}.",
    ]
probe_prompts2 = [
        "{0}, such as {1}, {2} and {3}.",
        "{0}, including {1}, {2} and {3}.",
        "{1}, {2}, {3} and other {0}.",
        "hire {0}, such as {1}, {2} and {3}.",
    "people with a {2} on their record."
    ]

print('BERT-base')
probe_and_eval(lm_probe, 'person', probe_prompts, 'person')
print('with relation prompts')
probe_and_eval(lm_probe, 'person', probe_prompts2, 'person')

print()
print('Domain-adapted BERT')
probe_and_eval(d_lm_probe,'person', probe_prompts, 'person')
print('with relation prompts')
probe_and_eval(d_lm_probe,'person', probe_prompts2, 'person')

BERT-base
person
     precision    recall  min_score
1     1.000000  0.029412   0.027419
2     1.000000  0.058824   0.015791
5     0.800000  0.117647   0.005052
10    0.900000  0.264706   0.002278
25    0.640000  0.470588   0.001025
50    0.440000  0.647059   0.000608
75    0.346667  0.764706   0.000365
100   0.280000  0.823529   0.000266
with relation prompts
person
     precision    recall  min_score
1     1.000000  0.029412   0.018603
2     1.000000  0.058824   0.011497
5     0.800000  0.117647   0.003313
10    0.800000  0.235294   0.001663
25    0.600000  0.441176   0.000873
50    0.420000  0.617647   0.000513
75    0.333333  0.735294   0.000266
100   0.270000  0.794118   0.000237

Domain-adapted BERT
person
     precision    recall  min_score
1         1.00  0.029412   0.036885
2         1.00  0.058824   0.019011
5         1.00  0.147059   0.012808
10        0.60  0.176471   0.011755
25        0.68  0.500000   0.002846
50        0.48  0.705882   0.001528
75        0.36  0.794118  

## Different scoring schemes

In [114]:
def probe_and_eval_grps(lm_probe, concept, probe_prompts, prompt_groups, concept_phrase=None):
    seeds = seed_instances_dict[concept]
    if concept_phrase is None:
        cc_phrase = ' '.join(concept.split('_'))
    else:
        cc_phrase = concept_phrase
    cand_scores_per_template = []
    cand_scores = []
    
    for template in probe_prompts:
        _input_txt = template.format(cc_phrase, ', '.join(seeds[:-1]), '[MASK]', seeds[-1])
        _cand_scores = lm_probe.score_candidates(_input_txt, all_entities)
        _cand_scores.sort(key=lambda d : d["cand"])
        cand_scores_per_template.append(_cand_scores)

    for _cand_score_lst in zip(*cand_scores_per_template):
                    # _cand_score_lst: List[Dict["cand", "score"]], for the same "cand" and different template 
            _cand = _cand_score_lst[0]["cand"]
            grp_scores = {}
            for i, g in enumerate(prompt_groups):
                _score_i = _cand_score_lst[i]['score']
                grp_score_lst = grp_scores.get(g, [])
                grp_score_lst.append(_score_i)
                grp_scores[g] = grp_score_lst
            _score = 1.0
            for val in grp_scores.values():
                _score = _score * (sum(val) / len(val))
            cand_scores.append({"cand": _cand, "score": _score})

    extraction_results = []
    for d in cand_scores:
        e = d["cand"]
        if e in seeds:
            continue
        lm_score = d["score"]
        extraction_results.append({'concept': concept,
                                           'neighbor': e,
                                           'lm_score': lm_score
                                          })

        extraction_results.sort(key=lambda d : d['lm_score'], reverse=True)
    pratk(pd.DataFrame(extraction_results))

In [115]:
probe_prompts = [
        "{0}, such as {1}, {2} and {3}.",
        "{0}, including {1}, {2} and {3}.",
        "{1}, {2}, {3} and other {0}.",
    ]
prompt_grps = [1,1,1]

probe_prompts2 = [
        "{0}, such as {1}, {2} and {3}.",
        "{0}, including {1}, {2} and {3}.",
        "{1}, {2}, {3} and other {0}.",
        "to work at {2}.",
        "{2} does not drug.",
        "while working at {2}."
    ]
prompt_grps2 = [1,1,1,2,3,2]

print('BERT-base')
probe_and_eval_grps(lm_probe, 'company', probe_prompts, prompt_grps, 'company')
print('with relation prompts')
probe_and_eval_grps(lm_probe, 'company', probe_prompts2, prompt_grps2, 'company')

print()
print('Domain-adapted BERT')
probe_and_eval_grps(d_lm_probe,'company', probe_prompts, prompt_grps, 'company')
print('with relation prompts')
probe_and_eval_grps(d_lm_probe,'company', probe_prompts2, prompt_grps2,'company')

BERT-base
company
     precision    recall  min_score
1     1.000000  0.008772   0.068087
2     1.000000  0.017544   0.046398
5     1.000000  0.043860   0.020752
10    1.000000  0.087719   0.010284
25    1.000000  0.219298   0.002511
50    0.840000  0.368421   0.000986
75    0.706667  0.464912   0.000690
100   0.600000  0.526316   0.000380
with relation prompts
company
     precision    recall     min_score
1         1.00  0.008772  6.400130e-09
2         1.00  0.017544  9.283904e-10
5         0.80  0.035088  7.966789e-10
10        0.60  0.052632  4.044021e-10
25        0.56  0.122807  8.737789e-11
50        0.58  0.254386  1.661150e-11
75        0.60  0.394737  4.061542e-12
100       0.63  0.552632  1.156959e-12

Domain-adapted BERT
company
     precision    recall  min_score
1         1.00  0.008772   0.063044
2         1.00  0.017544   0.034719
5         1.00  0.043860   0.014492
10        1.00  0.087719   0.009953
25        1.00  0.219298   0.005252
50        0.88  0.385965   0.001

In [117]:
probe_prompts = [
        "{0}, such as {1}, {2} and {3}.",
        "{0}, including {1}, {2} and {3}.",
        "{1}, {2}, {3} and other {0}.",
    ]
prompt_grps = [1,1,1]
probe_prompts2 = [
        "{0}, such as {1}, {2} and {3}.",
        "{0}, including {1}, {2} and {3}.",
        "{1}, {2}, {3} and other {0}.",
        "to take a {2}.",
        "Yes they do {2}.",
        "to pass a {2}.", 
        "they run a {2} before hiring."
    ]
prompt_grps2 = [1,1,1,2,3,2,3]

print('BERT-base')
probe_and_eval_grps(lm_probe, 'background_screening', probe_prompts, prompt_grps, 'background_screening')
print('with relation prompts')
probe_and_eval_grps(lm_probe, 'background_screening', probe_prompts2, prompt_grps2, 'background_screening')

print()
print('Domain-adapted BERT')
probe_and_eval_grps(d_lm_probe,'background_screening', probe_prompts, prompt_grps, 'background_screening')
print('with relation prompts')
probe_and_eval_grps(d_lm_probe,'background_screening', probe_prompts2, prompt_grps2, 'background_screening')

BERT-base
background_screening
     precision    recall  min_score
1     1.000000  0.012346   0.092631
2     1.000000  0.024691   0.046511
5     1.000000  0.061728   0.025006
10    0.800000  0.098765   0.016128
25    0.640000  0.197531   0.007512
50    0.640000  0.395062   0.004474
75    0.586667  0.543210   0.003032
100   0.570000  0.703704   0.002000
with relation prompts
background_screening
     precision    recall     min_score
1     1.000000  0.012346  7.751327e-07
2     1.000000  0.024691  3.283780e-08
5     0.800000  0.049383  1.165075e-08
10    0.600000  0.074074  3.171605e-09
25    0.640000  0.197531  5.394475e-10
50    0.580000  0.358025  6.611104e-11
75    0.586667  0.543210  8.228445e-12
100   0.520000  0.641975  1.107889e-12

Domain-adapted BERT
background_screening
     precision    recall  min_score
1     0.000000  0.000000  -1.000000
2     0.000000  0.000000  -1.000000
5     0.600000  0.037037   0.081086
10    0.700000  0.086420   0.021252
25    0.760000  0.234568   0.

In [118]:
probe_prompts = [
        "{0}, such as {1}, {2} and {3}.",
        "{0}, including {1}, {2} and {3}.",
        "{1}, {2}, {3} and other {0}.",
    ]
prompt_grps = [1,1,1]
probe_prompts2 = [
        "{0}, such as {1}, {2} and {3}.",
        "{0}, including {1}, {2} and {3}.",
        "{1}, {2}, {3} and other {0}.",
        "hire {0}, such as {1}, {2} and {3}.",
    "people with a {2} on their record."
    ]
prompt_grps2 = [1,1,1,2,2]

print('BERT-base')
probe_and_eval_grps(lm_probe, 'person', probe_prompts, prompt_grps, 'person')
print('with relation prompts')
probe_and_eval_grps(lm_probe, 'person', probe_prompts2, prompt_grps2, 'person')

print()
print('Domain-adapted BERT')
probe_and_eval_grps(d_lm_probe,'person', probe_prompts, prompt_grps, 'person')
print('with relation prompts')
probe_and_eval_grps(d_lm_probe,'person', probe_prompts2, prompt_grps2, 'person')

BERT-base
person
     precision    recall  min_score
1     1.000000  0.029412   0.027419
2     1.000000  0.058824   0.015791
5     0.800000  0.117647   0.005052
10    0.900000  0.264706   0.002278
25    0.640000  0.470588   0.001025
50    0.440000  0.647059   0.000608
75    0.346667  0.764706   0.000365
100   0.280000  0.823529   0.000266
with relation prompts
person
     precision    recall     min_score
1         1.00  0.029412  1.474955e-04
2         1.00  0.058824  7.982930e-05
5         0.80  0.117647  4.348215e-06
10        0.70  0.205882  2.206297e-06
25        0.56  0.411765  4.635252e-07
50        0.38  0.558824  1.242914e-07
75        0.32  0.705882  4.555480e-08
100       0.28  0.823529  1.098820e-08

Domain-adapted BERT
person
     precision    recall  min_score
1         1.00  0.029412   0.036885
2         1.00  0.058824   0.019011
5         1.00  0.147059   0.012808
10        0.60  0.176471   0.011755
25        0.68  0.500000   0.002846
50        0.48  0.705882   0.001528

# Relation templates

In [120]:
from roberta_ses.interface import Roberta_SES_Entailment
yutong_base_dir="/home/ubuntu/users/yutong"
roberta_ses_dir = os.path.join(yutong_base_dir, "repos", "Roberta_SES")
# 0 = contra, 1 = neutral, 2 = entail
entailment_model = Roberta_SES_Entailment(roberta_path='/home/ubuntu/users/yutong/models/roberta-large',
        ckpt_path=os.path.join(roberta_ses_dir, 'checkpoints/epoch=2-valid_loss=-0.2620-valid_acc_end=0.9223.ckpt'),
        max_length=512,
        device_name='cpu')

In [126]:
from sentence_transformers import SentenceTransformer, util
import numpy as np
d_entailment_model = SentenceTransformer('/home/ubuntu/users/nikita/models/bert_finetuned_lm_sts/indeed_all/')

In [132]:
def predict_rel(sent, rel):
    return entailment_model.predict(sent, rel)[1].tolist()

def d_predict_rel(sent, rel):
    embedding1 = d_entailment_model.encode(sent, convert_to_tensor=True)
    embedding2 = d_entailment_model.encode(rel, convert_to_tensor=True)
    cosine_scores = util.pytorch_cos_sim(embedding1, embedding2)
    return cosine_scores.item()

In [134]:
s1 = 'walmart : we have to wear uniform'
s2 = 'walmart requires uniform'
print(predict_rel(s1,s2))
print(d_predict_rel(s1,s2))

[0.0018634856678545475, 0.008392865769565105, 0.9897436499595642]
0.9119620323181152


In [135]:
s1 = "kroger stores : our kroger did n't have a pharmacy when i was there , but they paid for the permit test to sell alcohol ."
s2 = "permit is needed for pharmacy ."
print(predict_rel(s1,s2))
print(d_predict_rel(s1,s2))

[0.003196667181327939, 0.039147090166807175, 0.9576562643051147]
0.4612185060977936


In [137]:
s1 = "cvs health : general orientation is usually about six hours ; after , pharmacy techs will be brought to the pharmacy for a separate orientation ."
s2 = "There is orientation for pharmacy ."
print(predict_rel(s1,s2))
print(d_predict_rel(s1,s2))

[0.002546712290495634, 0.016983652487397194, 0.9804696440696716]
0.6139006614685059


In [138]:
s1 = "walmart : 401k , health insurance , full time status as soon as there was a space open in your area ."
s2 = "walmart offers health insurance."
print(predict_rel(s1,s2))
print(d_predict_rel(s1,s2))

[0.0007973050232976675, 0.001054643071256578, 0.9981480836868286]
0.5119340419769287


In [139]:
s1 = "walmart : 30 - 40hrs weekly , flexible schedule ."
s2 = "walmart provides flexible schedule."
print(predict_rel(s1,s2))
print(d_predict_rel(s1,s2))

[0.0008081113919615746, 0.009821068495512009, 0.9893708229064941]
0.5245642066001892


In [140]:
s1 = "subway : it was a flexible schedule to go to school either morning or evenings ."
s2 = "subway provides flexible schedule."
print(predict_rel(s1,s2))
print(d_predict_rel(s1,s2))

[0.0009289175504818559, 0.023315617814660072, 0.975755512714386]
0.6597525477409363


In [141]:
s1 = "cvs health : after completing their training you get a raise switching you from pharmacy associate to pharmacy tech ."
s2 = "You can work training as pharmacy ."
print(predict_rel(s1,s2))
print(d_predict_rel(s1,s2))

[0.005240839906036854, 0.06974846869707108, 0.9250107407569885]
0.733309805393219


In [46]:
probe_prompts = [
        "{0}, such as {1}, {2} and {3}.",
        "{0}, including {1}, {2} and {3}.",
        "{1}, {2}, {3} and other {0}.",
    ]
concepts_to_expand = concepts
expanded_instances = {}
for c in concepts_to_expand:
    print(c)
    results = probe_and_eval(d_lm_probe,c, probe_prompts, c, evaluate=False)
    expanded_instances[c] = results

company
dress_code
job_position
pay_schedule
benefits
compensation
payment_option
background_screening
person
hire_prerequisite
shifts
schedule
employee_type
onboarding_steps


In [47]:
expanded_instances['company']

Unnamed: 0,concept,neighbor,lm_score
0,company,starbucks,6.304385e-02
1,company,nike,3.471876e-02
2,company,amazons,2.517403e-02
3,company,apple,1.895988e-02
4,company,etcs,1.808982e-02
...,...,...,...
8032,company,romans,7.788047e-09
8033,company,divorce,7.192135e-09
8034,company,jurisdiction,6.065130e-09
8035,company,gdansk,3.015721e-09


In [48]:
import re
s = "<phrase>nearly impossible</phrase> , unless you have an in with <phrase>upper</phrase> level <phrase>management</phrase> your better off finding a better job anywhere else"
phrases = re.findall(r'<phrase>(.*?)</phrase>', s)
pairs = []
if len(phrases) > 1:
    for i in range(0, len(phrases)):
        for j in range(i, len(phrases)):
            if i != j:
                pairs.append((phrases[i], phrases[j]))
                
print(pairs)

[('nearly impossible', 'upper'), ('nearly impossible', 'management'), ('upper', 'management')]


In [49]:
def get_all_mentions(corpus):
    mapping = {}
    for sent in corpus:
        phrases = re.findall(r'<phrase>(.*?)</phrase>', sent)
        pairs = []
        if len(phrases) > 1:
            for i in range(0, len(phrases)):
                for j in range(i, len(phrases)):
                    if i != j:
                        w1 = phrases[i]
                        w2 = phrases[j]
                        if w1 == w2:
                            continue
                        pair_key = str(set([w1, w2]))
                        sents = mapping.get(pair_key, [])
                        sents.append(sent)
                        mapping[pair_key] = sents
    return mapping

In [50]:
import operator
def find_context(concept1, concept2, mapping, instances):
    concept1_instances = instances[concept1]
    concept2_instances = instances[concept2]
    patterns = {}
    for c1 in concept1_instances:
        for c2 in concept2_instances:
            pair_key = str(set([c1,c2]))
#             print(pair_key)
            sents = mapping.get(pair_key, [])
            for sent in sents:
                sent = sent.replace("<phrase>", "")
                sent = sent.replace("</phrase>", "")
                i1 = sent.index(c1)
                i2 = sent.index(c2)
                if i1 == -1 or i2 == -1: 
                    continue
                start = min(i1, i2)
                end = max(i1, i2)
                offset = len(c2) if end == i2 else len(c1)
                pat = sent[start: end+offset]
#                 print(sent)
#                 print(pair_key)
#                 print(pat)
                pat = pat.replace(c1, '<src>')
                pat = pat.replace(c2, '<tgt>')
                
                left_context_words = sent[:start].split(' ')
                left_context_words = left_context_words[-min(len(left_context_words), 2):]

                right_context_words = sent[end+offset:].split(' ')
                right_context_words = right_context_words[:min(len(right_context_words), 2)]
            
                context_pat = ' '.join(left_context_words) + pat + ' '.join(right_context_words)
#                 print(context_pat)
                pat_ct = patterns.get(context_pat, 0)
                pat_ct += 1
                patterns[context_pat] = pat_ct
    return sorted(patterns.items(), key=operator.itemgetter(1),reverse=True)

In [59]:
sent_mapping = get_all_mentions(corpus)

reliable_expanded_instances = {}
all_expanded_instances = {}
for key, val in expanded_instances.items():
    sub_df = val.iloc[:min(20, len(val))]
    reliable_expanded_instances[key] = sub_df['neighbor'].tolist() + seed_instances_dict.get(key, [])
    all_expanded_instances[key] = val.iloc[:min(100, len(val))]['neighbor'].tolist() + seed_instances_dict.get(key, [])

In [232]:
find_context('company', 'job_position', sent_mapping, reliable_expanded_instances)

[("candidates <tgt> ca n't be trusted to make that decision ) but the panels have completely devolved into a murder trial that are the most painful and bureaucratic thing done at <src> amazon",
  14),
 ('a <tgt> or leader of the warehouse i am sure Fulltime employement at <src> amazon',
  8),
 ('A <tgt> at <src> amazon', 8),
 ('successful <tgt> at <src> amazon', 8),
 ('a <tgt> at <src> amazon', 8),
 ('the <tgt> who hired me started <src> as', 6),
 ('at <src> I was interviewed by both a <tgt> and', 6),
 ('<src> has many job openings and options from door greeter , to a cart pusher , receiving associate to a supervisor , a department <tgt> to',
  6),
 ('The <src> <src> warehouse does offer flexible hours but <src> flex is <tgt> position',
  6),
 (', <src> offers a great program to relocate to another <src> , so long as they are hiring , transfers are put above standard applications to ensure that a <tgt> gets',
  5),
 ('<src> <tgt>', 4),
 ('all <src> <src> <tgt> employees', 4),
 ('or <sr

In [234]:
find_context('job_position', 'background_screening', sent_mapping, reliable_expanded_instances)

[('hire <tgt> tested , but on a monthly basis , Corporate would randomly choose anywhere from five to twenty associate names that a <src> in',
  6),
 ('Usually <tgt> testing is random ; However , if a <src> suspects', 6),
 (', <tgt> tests are always possible , and a <src> can', 6),
 (', <tgt> are always possible , and a <src> can', 6),
 ("hire <tgt> tested , but on a monthly basis , Corporate would randomly choose anywhere from five to twenty associate names that a manager in charge had to immediately call into the <src> 's",
  6),
 ('in <src> call center we are not involve with the human resources group , which is a totally different department , and to my knowledge they do a complete <tgt> for',
  4),
 ('a <src> position they did not <tgt> test', 4),
 ('to <src> , and never once seen anyone get <tgt> drug', 4),
 ('a <src> and I can tell you that <tgt> drug', 4),
 ('ex <src> and asked to come in and complete a physical and <tgt> .', 3),
 ('the <tgt> ( they have a toll free number for 

In [235]:
find_context('job_position', 'hire_prerequisite', sent_mapping, reliable_expanded_instances)

[('good <src> and seeing how the store is free from <tgt> safety', 4),
 ('great <src> but often has to use her vacation hours to hit <tgt> labor', 4),
 ('shift <src> , they start as a shift <src> with no <tgt> to', 4),
 ('your <tgt> to your <src> and', 4),
 ('on <src> , <tgt> ,', 2),
 ('on <src> and food handling and <tgt> tips', 2),
 (', <src> and <tgt> for', 2),
 ('of <src> especially in food preparation or sales , I completed an ekg tech program , I had <tgt> in',
  2),
 ('More <tgt> intense then <src>', 2),
 ('<src> skills , computer skills , <tgt> ,', 2),
 ("'s <src> , and they will help you reset your passwords and <tgt> answers",
  2),
 ('off <tgt> , have 4 year <src> experience', 2),
 ('a <tgt> to become an <src>', 2),
 ('an <src> application was as little as half a year to a year of <tgt> ,', 2),
 ('<tgt> <src> .', 2),
 ('a <tgt> vest while spotting another associate that would be on a <src> picker',
  2),
 ("did <tgt> , q'ing ovens , grill , fryers , <src> ,", 2),
 ('of <src>

# Relation prompting

In [23]:
probe_prompts = [
    'Walmart hires {2}.'
]
probe_and_eval(d_lm_probe, 'job_position', probe_prompts, evaluate=False)

Unnamed: 0,concept,neighbor,lm_score
0,job_position,seasonal employees,2.050211e-02
1,job_position,younger people,1.952814e-02
2,job_position,seasonal workers,1.401879e-02
3,job_position,seasonal,1.048701e-02
4,job_position,hard workers,9.887002e-03
...,...,...,...
8030,job_position,diameter,3.434599e-10
8031,job_position,vitamin,2.737315e-10
8032,job_position,loch,2.378370e-10
8033,job_position,afterlife,1.941458e-10


In [60]:
_input_txt = 'Walmart hires [MASK].'

d_lm_probe.score_candidates(_input_txt, all_expanded_instances['job_position'])

[{'cand': 'manager', 'score': 4.044538218295201e-05},
 {'cand': 'stock', 'score': 9.696713277662635e-06},
 {'cand': 'bakery', 'score': 2.1099281184433502e-06},
 {'cand': 'sanitation', 'score': 4.320804976032375e-07},
 {'cand': 'chef', 'score': 2.816254607296287e-07},
 {'cand': 'bartender', 'score': 1.1553774470485216e-07},
 {'cand': 'trainer', 'score': 1.0000942296528609e-07},
 {'cand': 'porter', 'score': 8.961000474982941e-08},
 {'cand': 'cooking', 'score': 5.6983544993727185e-08},
 {'cand': 'cash management', 'score': 0.0021109164546455673},
 {'cand': 'customer experience', 'score': 0.0007765050211323792},
 {'cand': 'co worker', 'score': 0.00041832973445144405},
 {'cand': 'store manager', 'score': 0.0003973176240863882},
 {'cand': 'meats', 'score': 0.00038776102046236187},
 {'cand': 'line cooks', 'score': 0.0002671291271867165},
 {'cand': 'sales associate', 'score': 0.00023135770015185136},
 {'cand': 'cashier', 'score': 0.000195295312652414},
 {'cand': 'district manager', 'score': 0.

In [62]:
_input_txt = 'Walmart hires [MASK] job position.'

d_lm_probe.score_candidates(_input_txt, all_expanded_instances['job_position'])

[{'cand': 'manager', 'score': 0.00010009146353695548},
 {'cand': 'stock', 'score': 1.6777270502643653e-05},
 {'cand': 'bakery', 'score': 8.81166954513901e-07},
 {'cand': 'trainer', 'score': 4.860810918216882e-07},
 {'cand': 'bartender', 'score': 3.619610708938129e-07},
 {'cand': 'chef', 'score': 2.582919194082935e-07},
 {'cand': 'sanitation', 'score': 1.3348098093501906e-07},
 {'cand': 'cooking', 'score': 1.2421901374182196e-07},
 {'cand': 'porter', 'score': 6.241836558729115e-08},
 {'cand': 'cash management', 'score': 0.0005291664835159803},
 {'cand': 'general manager', 'score': 0.00019238741103562093},
 {'cand': 'customer experience', 'score': 0.0001418161947395106},
 {'cand': 'store manager', 'score': 0.00010484555664704782},
 {'cand': 'cash office', 'score': 7.140986948596141e-05},
 {'cand': 'assistant manager', 'score': 7.043573861223843e-05},
 {'cand': 'customer service', 'score': 6.599919503560219e-05},
 {'cand': 'cashier', 'score': 5.771988724237656e-05},
 {'cand': 'sales speci

In [63]:
_input_txt = 'Starbucks hires [MASK] job position.'

d_lm_probe.score_candidates(_input_txt, all_expanded_instances['job_position'])

[{'cand': 'manager', 'score': 0.000270383694441989},
 {'cand': 'stock', 'score': 1.8286536942468946e-05},
 {'cand': 'bartender', 'score': 5.0582825679157395e-06},
 {'cand': 'chef', 'score': 2.519073404982919e-06},
 {'cand': 'trainer', 'score': 1.8160425270252755e-06},
 {'cand': 'bakery', 'score': 1.7330881973975929e-06},
 {'cand': 'cooking', 'score': 5.276066872283989e-07},
 {'cand': 'porter', 'score': 3.679531346278963e-07},
 {'cand': 'sanitation', 'score': 3.5608999837677397e-07},
 {'cand': 'cash management', 'score': 0.0006751821374371135},
 {'cand': 'customer experience', 'score': 0.00022238398866484966},
 {'cand': 'general manager', 'score': 0.00020506839012965063},
 {'cand': 'store manager', 'score': 0.00011284662924871513},
 {'cand': 'customer host', 'score': 0.00010285030110222894},
 {'cand': 'assistant manager', 'score': 0.00010128185616733264},
 {'cand': 'cash office', 'score': 8.379232337931334e-05},
 {'cand': 'customer service', 'score': 7.168965644495033e-05},
 {'cand': 's

In [64]:
_input_txt = 'USPS hires [MASK] job position.'

d_lm_probe.score_candidates(_input_txt, all_expanded_instances['job_position'])

[{'cand': 'manager', 'score': 8.687973604537548e-05},
 {'cand': 'stock', 'score': 8.656343197799291e-06},
 {'cand': 'bakery', 'score': 9.467102017879368e-07},
 {'cand': 'trainer', 'score': 7.775967105772003e-07},
 {'cand': 'bartender', 'score': 5.26179348980804e-07},
 {'cand': 'porter', 'score': 3.791283518239649e-07},
 {'cand': 'chef', 'score': 2.95824918339349e-07},
 {'cand': 'cooking', 'score': 2.5898702915583256e-07},
 {'cand': 'sanitation', 'score': 1.7764372728379393e-07},
 {'cand': 'cash management', 'score': 0.0003529656935413205},
 {'cand': 'customer experience', 'score': 0.00018121734094852303},
 {'cand': 'cash office', 'score': 0.00016659765980818147},
 {'cand': 'general manager', 'score': 0.00012361444023113518},
 {'cand': 'customer service', 'score': 0.0001232646929914764},
 {'cand': 'courtesy clerk', 'score': 0.00012151121222314429},
 {'cand': 'delivery driver', 'score': 7.858986010762586e-05},
 {'cand': 'package handler', 'score': 7.055162493876149e-05},
 {'cand': 'cashi

In [67]:
_input_txt = 'Walmart pays the employees [MASK].'

d_lm_probe.score_candidates(_input_txt, all_expanded_instances['pay_schedule'])

[{'cand': 'weekly', 'score': 0.02925265394151211},
 {'cand': 'friday', 'score': 2.15337386180181e-05},
 {'cand': 'sunday', 'score': 8.75437490321929e-06},
 {'cand': 'saturday', 'score': 8.235551831603518e-06},
 {'cand': 'weekend', 'score': 1.2102042319384044e-06},
 {'cand': 'sun', 'score': 3.3790175280046263e-07},
 {'cand': 'bi weekly', 'score': 0.0035116657133234183},
 {'cand': 'bi weekly', 'score': 0.0035116657133234183},
 {'cand': 'sos', 'score': 0.0031904917231204977},
 {'cand': 'full time', 'score': 0.0029638933175443216},
 {'cand': 'premiums', 'score': 0.0019574848201432435},
 {'cand': 'half hour', 'score': 0.001049060938753325},
 {'cand': 'a 10', 'score': 0.0009497050029383174},
 {'cand': 'paid weekly', 'score': 0.0007831518314394456},
 {'cand': 'weekly basis', 'score': 0.0007585972736185932},
 {'cand': 'bi monthly', 'score': 0.00062675970127579},
 {'cand': 'a 4', 'score': 0.0006056728361380368},
 {'cand': 'ato', 'score': 0.0005708528863031589},
 {'cand': 'fasts', 'score': 0.000

In [68]:
_input_txt = 'USPS pays the employees [MASK].'

d_lm_probe.score_candidates(_input_txt, all_expanded_instances['pay_schedule'])

[{'cand': 'weekly', 'score': 0.0973866358399391},
 {'cand': 'friday', 'score': 0.00010338638821849601},
 {'cand': 'sunday', 'score': 3.703141919686457e-05},
 {'cand': 'saturday', 'score': 3.693808685056866e-05},
 {'cand': 'sun', 'score': 2.5192748580593616e-06},
 {'cand': 'weekend', 'score': 1.7645753587203215e-06},
 {'cand': 'bi weekly', 'score': 0.0073341648851916776},
 {'cand': 'bi weekly', 'score': 0.0073341648851916776},
 {'cand': 'full time', 'score': 0.004306877357809024},
 {'cand': 'sos', 'score': 0.002426777719371289},
 {'cand': 'premiums', 'score': 0.002290355502246083},
 {'cand': 'weekly basis', 'score': 0.001676511440710488},
 {'cand': 'bi monthly', 'score': 0.001524808952460235},
 {'cand': 'a 10', 'score': 0.0013304599655636386},
 {'cand': 'paid weekly', 'score': 0.0012026205786234358},
 {'cand': 'half hour', 'score': 0.0011395390838692884},
 {'cand': 'a 4', 'score': 0.0009298401443274885},
 {'cand': 'ato', 'score': 0.0008118346603158046},
 {'cand': 'fasts', 'score': 0.000

In [74]:
_input_txt = 'USPS requires [MASK] to work.'

d_lm_probe.score_candidates(_input_txt, all_expanded_instances['hire_prerequisite'])

[{'cand': 'security', 'score': 0.0002408516011200847},
 {'cand': 'labor', 'score': 0.00020930019672960035},
 {'cand': 'degree', 'score': 0.0001575442292960362},
 {'cand': 'training', 'score': 7.807995280018072e-05},
 {'cand': 'production', 'score': 6.87642386765219e-05},
 {'cand': 'warehouse', 'score': 6.018155181664042e-05},
 {'cand': 'permit', 'score': 5.310527558322068e-05},
 {'cand': 'health', 'score': 4.483920929487797e-05},
 {'cand': 'orientation', 'score': 3.529769310262053e-05},
 {'cand': 'safety', 'score': 3.38168247253634e-05},
 {'cand': 'insurance', 'score': 1.5435454770340584e-05},
 {'cand': 'welding', 'score': 1.4564506273018191e-05},
 {'cand': 'construction', 'score': 1.2023656381643386e-05},
 {'cand': 'wheelchair', 'score': 1.1931384506169712e-05},
 {'cand': 'disability', 'score': 1.1664723388093996e-05},
 {'cand': 'property', 'score': 8.178904863598296e-06},
 {'cand': 'physical fitness', 'score': 0.003122204680007453},
 {'cand': 'fasts', 'score': 0.0030584829219061685},

In [75]:
_input_txt = 'Walmart requires [MASK] to work.'

d_lm_probe.score_candidates(_input_txt, all_expanded_instances['hire_prerequisite'])

[{'cand': 'labor', 'score': 0.0005171514349058269},
 {'cand': 'security', 'score': 0.00018792266200762253},
 {'cand': 'production', 'score': 8.835100015858191e-05},
 {'cand': 'safety', 'score': 4.8382073146058274e-05},
 {'cand': 'health', 'score': 3.531390029820612e-05},
 {'cand': 'warehouse', 'score': 2.2596563212573518e-05},
 {'cand': 'training', 'score': 1.927762787090615e-05},
 {'cand': 'degree', 'score': 1.567751860420684e-05},
 {'cand': 'welding', 'score': 9.684446013125129e-06},
 {'cand': 'permit', 'score': 7.751629709673583e-06},
 {'cand': 'construction', 'score': 6.792131443944532e-06},
 {'cand': 'disability', 'score': 3.5656487398227922e-06},
 {'cand': 'orientation', 'score': 3.4482220598874834e-06},
 {'cand': 'insurance', 'score': 3.0071223591221484e-06},
 {'cand': 'property', 'score': 2.0617114842025347e-06},
 {'cand': 'wheelchair', 'score': 6.68701716222131e-07},
 {'cand': 'fasts', 'score': 0.0041359518363891165},
 {'cand': 'fasting', 'score': 0.0026000696824879177},
 {'ca

In [78]:
_input_txt = 'Walmart checks employees background for [MASK].'

d_lm_probe.score_candidates(_input_txt, all_expanded_instances['background_screening'])

[{'cand': 'employment', 'score': 0.518429696559906},
 {'cand': 'theft', 'score': 0.10181621462106703},
 {'cand': 'insurance', 'score': 0.0037357669789344055},
 {'cand': 'drug', 'score': 0.0027832468040287495},
 {'cand': 'unemployment', 'score': 0.0016500866040587425},
 {'cand': 'felony', 'score': 0.0012245919788256288},
 {'cand': 'medical', 'score': 0.0010381153551861646},
 {'cand': 'training', 'score': 0.0006999548058956858},
 {'cand': 'state', 'score': 0.00041898697963915776},
 {'cand': 'disability', 'score': 0.0002645253262016922},
 {'cand': 'orientation', 'score': 0.0002158574934583159},
 {'cand': 'education', 'score': 0.00011932851339224727},
 {'cand': 'fingerprints', 'score': 3.827164982794783e-05},
 {'cand': 'citizenship', 'score': 2.5426430511288365e-05},
 {'cand': 'email', 'score': 1.5739538866910156e-05},
 {'cand': 'passport', 'score': 9.27551639051671e-07},
 {'cand': 'criminal charges', 'score': 0.009037463312988351},
 {'cand': 'criminal activity', 'score': 0.007918861018804

In [79]:
_input_txt = 'Starbucks checks employees background for [MASK].'

d_lm_probe.score_candidates(_input_txt, all_expanded_instances['background_screening'])

[{'cand': 'employment', 'score': 0.543633759021759},
 {'cand': 'theft', 'score': 0.08711812645196915},
 {'cand': 'insurance', 'score': 0.003723518224433064},
 {'cand': 'drug', 'score': 0.003293673740699887},
 {'cand': 'felony', 'score': 0.0013719348935410378},
 {'cand': 'unemployment', 'score': 0.0010913558071479206},
 {'cand': 'training', 'score': 0.0009604228544048966},
 {'cand': 'medical', 'score': 0.0007497334736399349},
 {'cand': 'state', 'score': 0.0002884544664993884},
 {'cand': 'orientation', 'score': 0.00022699564578942938},
 {'cand': 'disability', 'score': 0.0001586443395353855},
 {'cand': 'citizenship', 'score': 9.581060294294723e-05},
 {'cand': 'education', 'score': 8.20969653432258e-05},
 {'cand': 'fingerprints', 'score': 4.7227793402271386e-05},
 {'cand': 'email', 'score': 1.4699212442792493e-05},
 {'cand': 'passport', 'score': 1.6588834341746412e-06},
 {'cand': 'drug testing', 'score': 0.009025568192898717},
 {'cand': 'criminal charges', 'score': 0.008132452841362377},
 

In [81]:
_input_txt = 'Starbucks hires people with [MASK] record.'

d_lm_probe.score_candidates(_input_txt, all_expanded_instances['person'])

[{'cand': 'criminals', 'score': 0.00227281847037375},
 {'cand': 'military', 'score': 0.00013414060231298206},
 {'cand': 'school', 'score': 0.00010547983401920653},
 {'cand': 'college', 'score': 6.402123835869138e-05},
 {'cand': 'family', 'score': 3.531275433488189e-05},
 {'cand': 'disabled', 'score': 1.5425030142068866e-05},
 {'cand': 'pregnant', 'score': 1.9904146029148255e-06},
 {'cand': 'lawyers', 'score': 1.277599267268669e-06},
 {'cand': 'student', 'score': 1.1671477295749364e-06},
 {'cand': 'homeless', 'score': 8.123524253278444e-07},
 {'cand': 'teacher', 'score': 7.053063768580618e-07},
 {'cand': 'students', 'score': 5.290506237543011e-07},
 {'cand': 'lgbt', 'score': 1.4081409460686703e-07},
 {'cand': 'seniors', 'score': 3.65334358320979e-08},
 {'cand': 'high school', 'score': 0.0015656204702118167},
 {'cand': 'goodies', 'score': 0.00030054536652657324},
 {'cand': 'high times', 'score': 0.00012596221357192707},
 {'cand': 'hard workers', 'score': 0.00011253336550451461},
 {'cand'

In [82]:
_input_txt = 'Walmart hires people with [MASK] record.'

d_lm_probe.score_candidates(_input_txt, all_expanded_instances['person'])

[{'cand': 'criminals', 'score': 0.0019037112360820176},
 {'cand': 'military', 'score': 0.00012767614680342384},
 {'cand': 'school', 'score': 5.7169123465428156e-05},
 {'cand': 'family', 'score': 5.712851998396221e-05},
 {'cand': 'college', 'score': 1.2024042007396939e-05},
 {'cand': 'disabled', 'score': 1.1802534572780126e-05},
 {'cand': 'pregnant', 'score': 2.1373982690420247e-06},
 {'cand': 'homeless', 'score': 1.0495664355403276e-06},
 {'cand': 'teacher', 'score': 3.358348124038456e-07},
 {'cand': 'lawyers', 'score': 3.328109983158355e-07},
 {'cand': 'student', 'score': 3.267584816057932e-07},
 {'cand': 'students', 'score': 1.1148632950153105e-07},
 {'cand': 'lgbt', 'score': 6.609695191173171e-08},
 {'cand': 'seniors', 'score': 1.320630893530959e-08},
 {'cand': 'high school', 'score': 0.0008360226454624296},
 {'cand': 'goodies', 'score': 0.00047158681403271246},
 {'cand': 'hard workers', 'score': 7.533298857002207e-05},
 {'cand': 'high schools', 'score': 6.864543739391981e-05},
 {'c

In [90]:
_input_txt = 'Walmart gives job to [MASK] employees.'

d_lm_probe.score_candidates(_input_txt, all_expanded_instances['employee_type'])

[{'cand': 'management', 'score': 0.0002023133711190894},
 {'cand': 'union', 'score': 7.245303277159116e-05},
 {'cand': 'seasonal', 'score': 7.033193105598909e-05},
 {'cand': 'training', 'score': 5.726950257667345e-05},
 {'cand': 'school', 'score': 5.5627162510063526e-05},
 {'cand': 'warehouse', 'score': 5.4823947721160996e-05},
 {'cand': 'contract', 'score': 4.069416900165376e-05},
 {'cand': 'manager', 'score': 3.4320320992264924e-05},
 {'cand': 'retail', 'score': 2.8462314730859385e-05},
 {'cand': 'benefits', 'score': 2.7924636015086428e-05},
 {'cand': 'long', 'score': 2.2153495592647236e-05},
 {'cand': 'stock', 'score': 2.1317267965059714e-05},
 {'cand': 'student', 'score': 2.057677193079145e-05},
 {'cand': 'reserve', 'score': 1.8942944734590118e-05},
 {'cand': 'production', 'score': 1.871063432190566e-05},
 {'cand': 'short', 'score': 1.1610242836468377e-05},
 {'cand': 'variety', 'score': 6.656121968262591e-06},
 {'cand': 'weekend', 'score': 5.218872502155135e-06},
 {'cand': 'season'

In [91]:
_input_txt = 'Starbucks gives job to [MASK] employees.'

d_lm_probe.score_candidates(_input_txt, all_expanded_instances['employee_type'])

[{'cand': 'training', 'score': 0.00029888097196817376},
 {'cand': 'seasonal', 'score': 0.0002584997564554213},
 {'cand': 'student', 'score': 0.00021199595357757056},
 {'cand': 'management', 'score': 0.00019744843302760284},
 {'cand': 'school', 'score': 0.00017765986558515568},
 {'cand': 'retail', 'score': 0.0001127875148085878},
 {'cand': 'warehouse', 'score': 9.721894457470623e-05},
 {'cand': 'union', 'score': 9.24404230318032e-05},
 {'cand': 'manager', 'score': 9.162505739368493e-05},
 {'cand': 'benefits', 'score': 8.565753523726018e-05},
 {'cand': 'production', 'score': 7.808458030922333e-05},
 {'cand': 'contract', 'score': 5.975623935228209e-05},
 {'cand': 'long', 'score': 5.1885985158151016e-05},
 {'cand': 'stock', 'score': 4.135508424951694e-05},
 {'cand': 'reserve', 'score': 3.0646820960100754e-05},
 {'cand': 'variety', 'score': 1.8180080587626445e-05},
 {'cand': 'short', 'score': 1.7994307199842293e-05},
 {'cand': 'weekend', 'score': 1.1185233233845786e-05},
 {'cand': 'peak', '

In [93]:
_input_txt = 'Walmart allows to wear [MASK] .'

d_lm_probe.score_candidates(_input_txt, all_expanded_instances['dress_code'])

[{'cand': 'shoes', 'score': 0.016255917027592656},
 {'cand': 'uniform', 'score': 0.010081397369503978},
 {'cand': 'tattoos', 'score': 0.009606749750673771},
 {'cand': 'hair', 'score': 0.0026066340506076817},
 {'cand': 'logos', 'score': 0.00042903301073238243},
 {'cand': 'piercings', 'score': 0.02605469818890319},
 {'cand': 'sos', 'score': 0.022821717736486224},
 {'cand': 'dress clothes', 'score': 0.02125828807171149},
 {'cand': 'dress shoes', 'score': 0.02110718698794935},
 {'cand': 'black pants', 'score': 0.014208561862990231},
 {'cand': 'dress pants', 'score': 0.010357910923406387},
 {'cand': 'colored shoes', 'score': 0.010226824353455207},
 {'cand': 'dress shirts', 'score': 0.009068602088512826},
 {'cand': 'black shirt', 'score': 0.006964915215190608},
 {'cand': 'victorias', 'score': 0.005558355962268203},
 {'cand': 'colored jeans', 'score': 0.005541391798949053},
 {'cand': 'dress shirt', 'score': 0.005077359136249674},
 {'cand': 'colored pants', 'score': 0.005018600334705397},
 {'c

In [94]:
_input_txt = 'Starbucks allows to wear [MASK] .'

d_lm_probe.score_candidates(_input_txt, all_expanded_instances['dress_code'])

[{'cand': 'tattoos', 'score': 0.018296016380190853},
 {'cand': 'uniform', 'score': 0.014138111844658846},
 {'cand': 'shoes', 'score': 0.00922471284866333},
 {'cand': 'hair', 'score': 0.002794329077005386},
 {'cand': 'logos', 'score': 0.0007846390944905582},
 {'cand': 'dress clothes', 'score': 0.046074741354830286},
 {'cand': 'dress shoes', 'score': 0.04171952682728462},
 {'cand': 'black pants', 'score': 0.0288653639329152},
 {'cand': 'dress shirts', 'score': 0.026836631645269126},
 {'cand': 'dress pants', 'score': 0.024412514108389866},
 {'cand': 'colored jeans', 'score': 0.019477692560183935},
 {'cand': 'black shirt', 'score': 0.018434484449661667},
 {'cand': 'colored shoes', 'score': 0.018232965852295266},
 {'cand': 'piercings', 'score': 0.015925774142625825},
 {'cand': 'dress shirt', 'score': 0.015590730563943701},
 {'cand': 'sos', 'score': 0.014600164004428665},
 {'cand': 'colored pants', 'score': 0.010669165495324958},
 {'cand': 'colored shirt', 'score': 0.006813722005083044},
 {'

In [96]:
_input_txt = 'Starbucks offers benefits like [MASK] .'

d_lm_probe.score_candidates(_input_txt, all_expanded_instances['benefits'])

[{'cand': 'health', 'score': 0.0313359834253788},
 {'cand': 'medical', 'score': 0.025395087897777554},
 {'cand': 'life', 'score': 0.021267509087920192},
 {'cand': 'pension', 'score': 0.00937850214540958},
 {'cand': 'disability', 'score': 0.000223886236199178},
 {'cand': '401k', 'score': 0.20908846424248387},
 {'cand': 'health insurance', 'score': 0.12282314051605882},
 {'cand': 'health care', 'score': 0.04827742550735275},
 {'cand': 'life insurance', 'score': 0.04149468677202955},
 {'cand': 'medical insurance', 'score': 0.039080276010964775},
 {'cand': 'health dental', 'score': 0.038807587922003464},
 {'cand': '401 k', 'score': 0.033417083661547275},
 {'cand': 'dental insurance', 'score': 0.020579941598597086},
 {'cand': 'vision insurance', 'score': 0.020333953842700453},
 {'cand': 'health food', 'score': 0.01889850944788388},
 {'cand': 'health benefits', 'score': 0.016947272865852434},
 {'cand': 'medical dental', 'score': 0.012347927605005221},
 {'cand': 'car insurance', 'score': 0.00

In [97]:
_input_txt = 'Walmart offers benefits like [MASK] .'

d_lm_probe.score_candidates(_input_txt, all_expanded_instances['benefits'])

[{'cand': 'medical', 'score': 0.07124094665050508},
 {'cand': 'life', 'score': 0.022006522864103314},
 {'cand': 'health', 'score': 0.018842017278075218},
 {'cand': 'pension', 'score': 0.01282632350921631},
 {'cand': 'disability', 'score': 0.00033496995456516726},
 {'cand': '401k', 'score': 0.18922815469094245},
 {'cand': 'health insurance', 'score': 0.11916881404899644},
 {'cand': 'medical insurance', 'score': 0.06573124942535141},
 {'cand': 'life insurance', 'score': 0.04036680382759231},
 {'cand': 'health dental', 'score': 0.03923406168168776},
 {'cand': 'health care', 'score': 0.03758870511992955},
 {'cand': 'dental insurance', 'score': 0.03301385843858297},
 {'cand': 'vision insurance', 'score': 0.0271060373545835},
 {'cand': '401 k', 'score': 0.021924331320821187},
 {'cand': 'medical dental', 'score': 0.021640761594793764},
 {'cand': 'medical card', 'score': 0.01195141174198571},
 {'cand': 'health benefits', 'score': 0.01106779980966574},
 {'cand': 'health food', 'score': 0.010903

In [99]:
_input_txt = 'Walmart requires employees to work [MASK] shifts.'

d_lm_probe.score_candidates(_input_txt, all_expanded_instances['shifts'])

[{'cand': 'weekend', 'score': 0.001449428149498999},
 {'cand': 'sunday', 'score': 0.00012698728824034325},
 {'cand': 'saturday', 'score': 7.614116475451738e-05},
 {'cand': 'extra hours', 'score': 0.0033304013082739283},
 {'cand': 'rush hour', 'score': 0.0030810946352085346},
 {'cand': '7 days', 'score': 0.0024497405308976906},
 {'cand': 'rotating shift', 'score': 0.0016808370224549625},
 {'cand': 'late night', 'score': 0.0015779183352873046},
 {'cand': 'swing shift', 'score': 0.0015245780472929404},
 {'cand': 'double shift', 'score': 0.0014214638393428042},
 {'cand': '3rd shift', 'score': 0.0013405070511752957},
 {'cand': 'night shift', 'score': 0.0010209134879879446},
 {'cand': 'rotating schedule', 'score': 0.0009828430609586825},
 {'cand': 'mid shift', 'score': 0.0008786266218655466},
 {'cand': 'a 4', 'score': 0.0008706409089473324},
 {'cand': 'a 7', 'score': 0.0007150173759617601},
 {'cand': 'third shift', 'score': 0.0006999384468982926},
 {'cand': 'early morning', 'score': 0.000601

In [100]:
_input_txt = 'Starbucks requires employees to work [MASK] shifts.'

d_lm_probe.score_candidates(_input_txt, all_expanded_instances['shifts'])

[{'cand': 'weekend', 'score': 0.0006425043684430419},
 {'cand': 'sunday', 'score': 8.818963397061445e-05},
 {'cand': 'saturday', 'score': 3.5667224437929676e-05},
 {'cand': 'rush hour', 'score': 0.004741164207799525},
 {'cand': 'extra hours', 'score': 0.003282868440783734},
 {'cand': 'rotating shift', 'score': 0.002790120100157616},
 {'cand': 'rotating schedule', 'score': 0.0021292528433847504},
 {'cand': '7 days', 'score': 0.0017376892857962265},
 {'cand': 'double shift', 'score': 0.0016502372424610737},
 {'cand': 'late night', 'score': 0.0016251583183810477},
 {'cand': 'swing shift', 'score': 0.0016129829775410596},
 {'cand': 'mid shift', 'score': 0.0014765735137054487},
 {'cand': '3rd shift', 'score': 0.0012173756663064915},
 {'cand': 'night shift', 'score': 0.0011287952507734386},
 {'cand': 'third shift', 'score': 0.0008912379588662466},
 {'cand': 'rotating shifts', 'score': 0.0008755594993033714},
 {'cand': 'early morning', 'score': 0.0007842748072152108},
 {'cand': 'entire shift'

In [107]:
sent_mapping[(str(set(['walmart', 'saturday'])))]
#predict_rel(Walmart requires employees to work)

KeyError: "{'saturday', 'walmart'}"

In [103]:
sent_mapping

{"{'upper', 'management'}": ['nearly impossible , unless you have an in with <phrase>upper</phrase> level <phrase>management</phrase> your better off finding a better job anywhere else',
  "then there is n't any reason at all you could n't make it to an <phrase>upper</phrase> level of <phrase>management</phrase> <phrase>management</phrase> within 3 years",
  "then there is n't any reason at all you could n't make it to an <phrase>upper</phrase> level of <phrase>management</phrase> <phrase>management</phrase> within 3 years",
  'part time job for the part time <phrase>general warehouse</phrase> employees as <phrase>upper</phrase> <phrase>management</phrase> <phrase>management</phrase> is constantly getting promoted to higher and higher levels .',
  'part time job for the part time <phrase>general warehouse</phrase> employees as <phrase>upper</phrase> <phrase>management</phrase> <phrase>management</phrase> is constantly getting promoted to higher and higher levels .',
  'I never saw <phr

# Class name ranking

In [148]:
_input_txt = 'Walmart has many [MASK] such as pharmacy.'
concept_names = [c.replace('_', ' ') for c in gold_ee['concept'].drop_duplicates().tolist()]
lm_probe.score_candidates(_input_txt, ['benefits', 'role', 'person', 'schedule', 'dress', 'department'])

[{'cand': 'department', 'score': 7.928592822281645e-05},
 {'cand': 'benefits', 'score': 3.428554191486911e-05},
 {'cand': 'role', 'score': 9.05354454516782e-07},
 {'cand': 'schedule', 'score': 2.514392747343664e-07},
 {'cand': 'dress', 'score': 1.520515127140242e-07},
 {'cand': 'person', 'score': 4.4079396133156525e-08}]

In [139]:
_input_txt = 'Walmart has many [MASK] such as store manager.'
lm_probe.score_candidates(_input_txt, ['benefits', 'role', 'person', 'schedule', 'dress', 'department'])

[{'cand': 'role', 'score': 0.000921929895412177},
 {'cand': 'benefits', 'score': 0.00017365952953696254},
 {'cand': 'department', 'score': 1.0002711860579432e-05},
 {'cand': 'person', 'score': 7.516615028180241e-07},
 {'cand': 'schedule', 'score': 1.894262311452623e-07},
 {'cand': 'dress', 'score': 5.3246697717668285e-08}]

In [138]:
_input_txt = 'Walmart has many [MASK] such as team member.'
lm_probe.score_candidates(_input_txt, ['benefits', 'role', 'person', 'schedule', 'dress', 'department'])

[{'cand': 'benefits', 'score': 0.007540588732808828},
 {'cand': 'role', 'score': 0.0008790147257968784},
 {'cand': 'department', 'score': 6.689117526548216e-06},
 {'cand': 'person', 'score': 1.7962389620151957e-06},
 {'cand': 'schedule', 'score': 8.820003358778199e-07},
 {'cand': 'dress', 'score': 4.074924220276442e-07}]

In [134]:
_input_txt = 'You can work in different [MASK] such as team member at Walmart.'
lm_probe.score_candidates(_input_txt, ['benefits', 'role', 'person', 'schedule', 'dress', 'department'])

[{'cand': 'role', 'score': 0.0011707723606377844},
 {'cand': 'department', 'score': 6.305181887000794e-05},
 {'cand': 'benefits', 'score': 1.287839768338018e-05},
 {'cand': 'person', 'score': 7.303944130399032e-06},
 {'cand': 'schedule', 'score': 3.885445039486509e-06},
 {'cand': 'dress', 'score': 3.5500202102411983e-06}]

In [133]:
_input_txt = 'You can work in different [MASK] such as store manager at Walmart.'
lm_probe.score_candidates(_input_txt, ['benefits', 'role', 'person', 'schedule', 'dress', 'department'])

[{'cand': 'role', 'score': 0.0003541113110259175},
 {'cand': 'department', 'score': 2.7457364922156564e-05},
 {'cand': 'person', 'score': 2.5015324354171736e-06},
 {'cand': 'benefits', 'score': 2.30706609727349e-06},
 {'cand': 'schedule', 'score': 1.7692718756734404e-06},
 {'cand': 'dress', 'score': 9.699878091851135e-07}]

In [132]:
_input_txt = 'You can work in different [MASK] such as pharmacy at Walmart.'
lm_probe.score_candidates(_input_txt, ['benefits', 'role', 'person', 'schedule', 'dress', 'department'])

[{'cand': 'department', 'score': 0.0002710407134145497},
 {'cand': 'role', 'score': 8.917191007640207e-05},
 {'cand': 'benefits', 'score': 2.308433431608137e-05},
 {'cand': 'dress', 'score': 1.0018445209425412e-05},
 {'cand': 'person', 'score': 7.824675776646476e-06},
 {'cand': 'schedule', 'score': 3.6996345897932782e-06}]

In [131]:
_input_txt = 'Walmart has many [MASK] openings such as pharmacy.'
lm_probe.score_candidates(_input_txt, ['benefits', 'role', 'person', 'schedule', 'dress', 'department'])

[{'cand': 'department', 'score': 0.004289589356631039},
 {'cand': 'dress', 'score': 2.1259726054267968e-05},
 {'cand': 'schedule', 'score': 1.7378777556587014e-05},
 {'cand': 'benefits', 'score': 6.290527380770068e-06},
 {'cand': 'role', 'score': 6.095604476286101e-06},
 {'cand': 'person', 'score': 3.8145537928357957e-06}]

In [130]:
_input_txt = 'Walmart has many [MASK] openings such as store manager.'
lm_probe.score_candidates(_input_txt, ['benefits', 'role', 'person', 'schedule', 'dress', 'department'])

[{'cand': 'role', 'score': 0.0026508672162890426},
 {'cand': 'department', 'score': 0.0003090718237217515},
 {'cand': 'person', 'score': 3.758846651180649e-05},
 {'cand': 'benefits', 'score': 1.8178745449404236e-05},
 {'cand': 'schedule', 'score': 1.0198898962698877e-05},
 {'cand': 'dress', 'score': 2.3267637061508127e-06}]

In [135]:
_input_txt = 'Walmart has many [MASK] openings such as team member.'
lm_probe.score_candidates(_input_txt, ['benefits', 'role', 'person', 'schedule', 'dress', 'department'])

[{'cand': 'role', 'score': 0.0008623945177532729},
 {'cand': 'department', 'score': 0.000592551950830966},
 {'cand': 'schedule', 'score': 0.0001301330048590898},
 {'cand': 'person', 'score': 5.676584623870439e-05},
 {'cand': 'benefits', 'score': 5.27862866874784e-05},
 {'cand': 'dress', 'score': 3.5004035453312126e-05}]

In [11]:
_input_txt = 'Walmart does [MASK] such as employment verification on its employees before hiring.'
lm_probe.score_candidates(_input_txt, ['benefits', 'role', 'person', 'schedule', 'dress', 'department', 'checks'])

[{'cand': 'checks', 'score': 0.004352547693997621},
 {'cand': 'benefits', 'score': 0.0026564912404865035},
 {'cand': 'schedule', 'score': 6.7971504904562624e-06},
 {'cand': 'role', 'score': 6.10022880209726e-06},
 {'cand': 'department', 'score': 3.482102556517932e-06},
 {'cand': 'dress', 'score': 2.0682260526427872e-07},
 {'cand': 'person', 'score': 1.731605259180923e-08}]

In [15]:
_input_txt = 'Walmart does [MASK] such as credit card on its employees before hiring.'
lm_probe.score_candidates(_input_txt, ['benefits', 'role', 'person', 'schedule', 'dress', 'department', 'checks'])

[{'cand': 'benefits', 'score': 0.0901758223772049},
 {'cand': 'checks', 'score': 0.01961134932935238},
 {'cand': 'department', 'score': 4.328676368459127e-06},
 {'cand': 'schedule', 'score': 1.107572074943164e-06},
 {'cand': 'role', 'score': 8.943067655309284e-07},
 {'cand': 'dress', 'score': 6.424879188671187e-07},
 {'cand': 'person', 'score': 2.8385457540025554e-07}]

In [19]:
_input_txt = 'Walmart does [MASK] such as drug testing on its employees before hiring.'
lm_probe.score_candidates(_input_txt, ['benefits', 'role', 'person', 'schedule', 'dress', 'department', 'checks'])

[{'cand': 'checks', 'score': 0.0059774932451546175},
 {'cand': 'benefits', 'score': 0.0003622052900027485},
 {'cand': 'schedule', 'score': 6.428399501601232e-07},
 {'cand': 'department', 'score': 5.216043632572106e-07},
 {'cand': 'role', 'score': 4.566606435219003e-07},
 {'cand': 'dress', 'score': 3.947892324163143e-08},
 {'cand': 'person', 'score': 6.2939999878608325e-09}]

In [23]:
_input_txt = 'Walmart requires employees to typically have [MASK] such as bachelors degree.'
lm_probe.score_candidates(_input_txt, ['benefits', 'role', 'person', 'schedule', 'dress', 'department', 'checks', 'qualification'])

[{'cand': 'qualification', 'score': 0.012046223506331444},
 {'cand': 'benefits', 'score': 0.0001001386917778291},
 {'cand': 'checks', 'score': 1.7845165984908822e-06},
 {'cand': 'role', 'score': 3.0199649359019526e-07},
 {'cand': 'department', 'score': 2.973295067931759e-07},
 {'cand': 'person', 'score': 1.1794387688723869e-07},
 {'cand': 'dress', 'score': 5.511042999728493e-08},
 {'cand': 'schedule', 'score': 4.372743234171138e-08}]

In [26]:
_input_txt = 'Walmart requires employees to typically have [MASK] such as age requirement.'
lm_probe.score_candidates(_input_txt, ['benefits', 'role', 'person', 'schedule', 'dress', 'department', 'checks', 'qualification'])

[{'cand': 'benefits', 'score': 0.0004413749848026781},
 {'cand': 'qualification', 'score': 0.0003532252740114928},
 {'cand': 'checks', 'score': 8.017945947358386e-05},
 {'cand': 'schedule', 'score': 1.3428601960185905e-05},
 {'cand': 'role', 'score': 2.2132369394967104e-06},
 {'cand': 'department', 'score': 7.283127843038528e-07},
 {'cand': 'dress', 'score': 6.437674642256751e-07},
 {'cand': 'person', 'score': 4.5268464532455215e-08}]

In [25]:
_input_txt = 'Walmart requires employees to typically have [MASK] such as permit.'
lm_probe.score_candidates(_input_txt, ['benefits', 'role', 'person', 'schedule', 'dress', 'department', 'checks', 'qualification'])

[{'cand': 'qualification', 'score': 0.007409446407109499},
 {'cand': 'benefits', 'score': 0.00538267381489277},
 {'cand': 'checks', 'score': 0.00029910303419455875},
 {'cand': 'role', 'score': 7.559025561931783e-06},
 {'cand': 'department', 'score': 5.055914698459676e-06},
 {'cand': 'schedule', 'score': 3.987564468843633e-06},
 {'cand': 'person', 'score': 1.8269201973453164e-06},
 {'cand': 'dress', 'score': 9.677619345893613e-07}]

In [27]:
_input_txt = 'Walmart requires employees to typically have [MASK] such as background check.'
lm_probe.score_candidates(_input_txt, ['benefits', 'role', 'person', 'schedule', 'dress', 'department', 'checks', 'qualification'])

[{'cand': 'checks', 'score': 0.0472446121275425},
 {'cand': 'benefits', 'score': 0.0014866552082821731},
 {'cand': 'qualification', 'score': 0.0006987695815041662},
 {'cand': 'role', 'score': 8.052193879848341e-05},
 {'cand': 'department', 'score': 2.2020989490556537e-05},
 {'cand': 'schedule', 'score': 6.8006934270670155e-06},
 {'cand': 'person', 'score': 9.263325750907822e-07},
 {'cand': 'dress', 'score': 3.0549418283953845e-07}]

In [29]:
_input_txt = 'Starbucks requires employees to typically have [MASK] such as food handlers card.'
lm_probe.score_candidates(_input_txt, ['benefits', 'role', 'person', 'schedule', 'dress', 'department', 'checks', 'qualification'])

[{'cand': 'benefits', 'score': 0.06787548214197159},
 {'cand': 'qualification', 'score': 0.006470214109867809},
 {'cand': 'checks', 'score': 0.0015185780357569458},
 {'cand': 'role', 'score': 2.171240521420258e-05},
 {'cand': 'dress', 'score': 2.4999967536132334e-06},
 {'cand': 'department', 'score': 2.2235876713239105e-06},
 {'cand': 'person', 'score': 1.7369054603477703e-06},
 {'cand': 'schedule', 'score': 6.211256504684574e-07}]

In [30]:
_input_txt = 'Walmart requires employees to typically have [MASK] such as age limit.'
lm_probe.score_candidates(_input_txt, ['benefits', 'role', 'person', 'schedule', 'dress', 'department', 'checks', 'qualification'])

[{'cand': 'benefits', 'score': 0.0009525496861897413},
 {'cand': 'qualification', 'score': 0.000348220724845305},
 {'cand': 'checks', 'score': 0.00012878529378212986},
 {'cand': 'schedule', 'score': 2.2948663172428493e-05},
 {'cand': 'role', 'score': 4.899276063952128e-06},
 {'cand': 'department', 'score': 2.1561963876592935e-06},
 {'cand': 'dress', 'score': 8.632335948277616e-07},
 {'cand': 'person', 'score': 2.2145582079247095e-07}]

In [31]:
_input_txt = '[MASK] such as age limit is required to work at Walmart.'
lm_probe.score_candidates(_input_txt, ['benefits', 'role', 'person', 'schedule', 'dress', 'department', 'checks', 'qualification'])

[{'cand': 'qualification', 'score': 0.013415457680821414},
 {'cand': 'benefits', 'score': 0.0018910061335191137},
 {'cand': 'schedule', 'score': 0.0004384078201837838},
 {'cand': 'checks', 'score': 0.00010429541725898159},
 {'cand': 'role', 'score': 5.66959970456082e-05},
 {'cand': 'department', 'score': 2.958809454867153e-05},
 {'cand': 'dress', 'score': 2.8637126888497712e-05},
 {'cand': 'person', 'score': 1.7369668057654056e-05}]

In [32]:
_input_txt = '[MASK] such as food handlers card is required to work at Walmart.'
lm_probe.score_candidates(_input_txt, ['benefits', 'role', 'person', 'schedule', 'dress', 'department', 'checks', 'qualification'])

[{'cand': 'qualification', 'score': 0.03316840529441834},
 {'cand': 'benefits', 'score': 0.008951153606176376},
 {'cand': 'checks', 'score': 0.00024977844441309566},
 {'cand': 'person', 'score': 0.00016404486086685202},
 {'cand': 'role', 'score': 9.962370677385478e-05},
 {'cand': 'schedule', 'score': 1.387066004099325e-05},
 {'cand': 'department', 'score': 1.2791662811650883e-05},
 {'cand': 'dress', 'score': 4.800461738341253e-06}]

In [33]:
_input_txt = '[MASK] such as background check is required to work at Walmart.'
lm_probe.score_candidates(_input_txt, ['benefits', 'role', 'person', 'schedule', 'dress', 'department', 'checks', 'qualification'])

[{'cand': 'qualification', 'score': 0.021259602159261703},
 {'cand': 'checks', 'score': 0.010780897922813896},
 {'cand': 'benefits', 'score': 0.000996787217445672},
 {'cand': 'role', 'score': 0.00011230406380491331},
 {'cand': 'department', 'score': 3.806739550782371e-05},
 {'cand': 'schedule', 'score': 2.99815419566585e-05},
 {'cand': 'person', 'score': 8.421255188295623e-06},
 {'cand': 'dress', 'score': 3.3125306799775015e-06}]

In [None]:
_input_txt = 'The schedule is flexible. You can work [MASK] such as background check is required to work at Walmart.'
lm_probe.score_candidates(_input_txt, ['benefits', 'role', 'person', 'schedule', 'dress', 'department', 'checks', 'qualification'])