In [4]:
from transformers import BertTokenizer, BertModel, BertForMaskedLM
import torch
import math
import numpy

# LM probing

In [5]:
class LMProbe(object):
    def __init__(self, model_name='bert-base-uncased', use_gpu=False):
        self.device = torch.device('cuda' if torch.cuda.is_available() and use_gpu else 'cpu')
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.model = BertForMaskedLM.from_pretrained(model_name)
        self.model.to(self.device)
        self.model.eval()

        self.mask_token = self.tokenizer.mask_token

    def fill_multi_mask(self, input_txt, topk=3):
        if not input_txt.startswith('[CLS]') and not input_txt.endswith('[SEP]'):
            raise Exception('Input string must start with [CLS] and end with [SEP]')
        if not '[MASK]' in input_txt:
            raise Exception('Input string must have at least one mask token')
        tokenized_txt = self.tokenizer.tokenize(input_txt)
        indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_txt)
        tokens_tensor = torch.tensor([indexed_tokens])
        mask_indices = [i for i, x in enumerate(tokenized_txt) if x == "[MASK]"]
        segment_idx = tokens_tensor * 0
        tokens_tensor = tokens_tensor.to(self.device)
        segments_tensors = segment_idx.to(self.device)

        with torch.no_grad():
            outputs = self.model(tokens_tensor, token_type_ids=segments_tensors)
            predictions = outputs[0]

        probs = torch.softmax(predictions, dim=-1)[0]
        sorted_probs, sorted_idx = probs.sort(dim=-1, descending=True)
        sorted_probs = sorted_probs.detach().cpu().numpy()
        sorted_idx = sorted_idx.detach().cpu().numpy()

        masked_cands = []
        for k in range(topk):
            predicted_indices = [sorted_idx[i, k].item() for i in mask_indices]
            predicted_tokens = self.tokenizer.convert_ids_to_tokens(predicted_indices)
            predicted_probs = [sorted_probs[i, k].item() for i in mask_indices]
            seq = []
            for token_id, token, prob, masked_index in zip(predicted_indices, predicted_tokens, predicted_probs,
                                                           mask_indices):
                seq.append({"token": token_id, "token_str": token, "prob": prob, "masked_pos": masked_index})
            masked_cands.append(seq)

        return masked_cands

## Class name from pair of entities

In [8]:
def get_class_name_hearst(entities, lm_probe=None, context=None):
    if lm_probe is None:
        lm_probe = LMProbe()
    mask_token = lm_probe.mask_token
    generation_templates = [
        mask_token + ' such as {} , and {} .',
        'such ' + mask_token + ' as {} , and {} .',
        '{} , {} or other ' + mask_token + ' .',
        '{} , {} and other ' + mask_token + ' .',
        mask_token + ' including {} , and {} .',
        mask_token + ' , especially {}, and {} .',
        mask_token + ' ' + mask_token + ' such as {} , and {} .',
        'such ' + mask_token + ' ' + mask_token  + ' as {} , and {} .',
        '{} , {} or other ' + mask_token + ' ' + mask_token  + ' .',
        '{} , {} and other ' + mask_token + ' ' + mask_token  + ' .',
        mask_token + ' ' + mask_token  + ' including {} , and {} .',
        mask_token + ' ' + mask_token  + ' , especially {}, and {} .',
    ]

    if len(entities) < 2:
        raise Exception("not enough entity instances")

    names_scores = {}
    for template in generation_templates:
        e1 = entities[0]
        e2 = entities[1]
        if context:
            query = '[CLS] ' + template.format(e1, e2) + '[SEP]' + context + '[SEP]'
        else:
            query = '[CLS] ' + template.format(e1, e2) + '[SEP]'
        preds = lm_probe.fill_multi_mask(query)
        for pred in preds:
            name = ' '.join([p['token_str'] for p in pred])
            score = numpy.prod([p['prob'] for p in pred])
            scores = names_scores.get(name, [])
            scores.append(score)
            names_scores[name] = scores
    names_avg_scores = {k: float(sum(v))/ len(v) for k,v in names_scores.items()}
    names_avg_scores = {k: v for k, v in sorted(names_avg_scores.items(), reverse=True, key=lambda item: item[1])}
    return names_avg_scores

## Part of name given an entity

In [7]:
def get_part_of_name(entity, lm_probe=None, context=None, topk=10):
    if lm_probe is None:
        lm_probe = LMProbe()
    mask_token = lm_probe.mask_token
    generation_templates = [
        mask_token + 'is part of {} .',
        mask_token + ' ' +  mask_token + 'is part of {} .',
        '{} has ' + mask_token + ' .',
        '{} has ' + mask_token + ' ' + mask_token + ' .'
    ]
    names_scores = {}
    for template in generation_templates:
        if context:
            query = '[CLS] ' + template.format(entity) + '[SEP]' + context + '[SEP]'
        else:
            query = '[CLS] ' + template.format(entity) + '[SEP]'
        preds = lm_probe.fill_multi_mask(query, topk=topk)
        for pred in preds:
            name = ' '.join([p['token_str'] for p in pred])
            score = numpy.prod([p['prob'] for p in pred])
            scores = names_scores.get(name, [])
            scores.append(score)
            names_scores[name] = scores
    names_avg_scores = {k: float(sum(v)) / len(v) for k, v in names_scores.items()}
    names_avg_scores = {k: v for k, v in sorted(names_avg_scores.items(), reverse=True, key=lambda item: item[1])}
    return names_avg_scores

In [6]:
lm_probe = LMProbe('/home/ubuntu/users/nikita/models/bert_finetuned_lm/indeed_reviews_ques_ans')

Some weights of the model checkpoint at /home/ubuntu/users/nikita/models/bert_finetuned_lm/indeed_reviews_ques_ans were not used when initializing BertForMaskedLM: ['bert.embeddings.position_ids']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [11]:
lm_probe.fill_multi_mask('[CLS] drug test is part of [MASK] . [SEP]', topk=10)

[[{'token': 3105,
   'token_str': 'job',
   'prob': 0.2626919448375702,
   'masked_pos': 6}],
 [{'token': 2832,
   'token_str': 'process',
   'prob': 0.13146057724952698,
   'masked_pos': 6}],
 [{'token': 2731,
   'token_str': 'training',
   'prob': 0.10522313416004181,
   'masked_pos': 6}],
 [{'token': 3343,
   'token_str': 'policy',
   'prob': 0.04612913727760315,
   'masked_pos': 6}],
 [{'token': 2009,
   'token_str': 'it',
   'prob': 0.024157783016562462,
   'masked_pos': 6}],
 [{'token': 7709,
   'token_str': 'procedure',
   'prob': 0.01594134047627449,
   'masked_pos': 6}],
 [{'token': 2147,
   'token_str': 'work',
   'prob': 0.015837358310818672,
   'masked_pos': 6}],
 [{'token': 2673,
   'token_str': 'everything',
   'prob': 0.01579480804502964,
   'masked_pos': 6}],
 [{'token': 5918,
   'token_str': 'requirements',
   'prob': 0.015557022765278816,
   'masked_pos': 6}],
 [{'token': 9095,
   'token_str': 'requirement',
   'prob': 0.015215539373457432,
   'masked_pos': 6}]]

In [38]:
lm_probe.fill_multi_mask('[CLS] drug test is part of [MASK] . [SEP] drug test email background hire [SEP]', topk=10)

[[{'token': 2832,
   'token_str': 'process',
   'prob': 0.202988401055336,
   'masked_pos': 6}],
 [{'token': 3105,
   'token_str': 'job',
   'prob': 0.16196191310882568,
   'masked_pos': 6}],
 [{'token': 2731,
   'token_str': 'training',
   'prob': 0.10627923160791397,
   'masked_pos': 6}],
 [{'token': 14763,
   'token_str': 'hiring',
   'prob': 0.044770751148462296,
   'masked_pos': 6}],
 [{'token': 2009,
   'token_str': 'it',
   'prob': 0.02841365523636341,
   'masked_pos': 6}],
 [{'token': 6107,
   'token_str': 'employment',
   'prob': 0.026154039427638054,
   'masked_pos': 6}],
 [{'token': 10296,
   'token_str': 'orientation',
   'prob': 0.020505793392658234,
   'masked_pos': 6}],
 [{'token': 3343,
   'token_str': 'policy',
   'prob': 0.019412264227867126,
   'masked_pos': 6}],
 [{'token': 11326,
   'token_str': 'screening',
   'prob': 0.014734169468283653,
   'masked_pos': 6}],
 [{'token': 5918,
   'token_str': 'requirements',
   'prob': 0.014058711007237434,
   'masked_pos': 6}]]

In [35]:
lm_probe.fill_multi_mask('[CLS] pay schedule is [MASK] . [SEP]', topk=10)

[[{'token': 12379,
   'token_str': 'flexible',
   'prob': 0.11059394478797913,
   'masked_pos': 4}],
 [{'token': 2204,
   'token_str': 'good',
   'prob': 0.09503347426652908,
   'masked_pos': 4}],
 [{'token': 2307,
   'token_str': 'great',
   'prob': 0.06696714460849762,
   'masked_pos': 4}],
 [{'token': 1012,
   'token_str': '.',
   'prob': 0.06632491946220398,
   'masked_pos': 4}],
 [{'token': 7929,
   'token_str': 'ok',
   'prob': 0.051000095903873444,
   'masked_pos': 4}],
 [{'token': 4882,
   'token_str': 'weekly',
   'prob': 0.04453660175204277,
   'masked_pos': 4}],
 [{'token': 4189,
   'token_str': 'fair',
   'prob': 0.029096927493810654,
   'masked_pos': 4}],
 [{'token': 2275,
   'token_str': 'set',
   'prob': 0.025036100298166275,
   'masked_pos': 4}],
 [{'token': 11519,
   'token_str': 'decent',
   'prob': 0.020469404757022858,
   'masked_pos': 4}],
 [{'token': 3100,
   'token_str': 'okay',
   'prob': 0.017768269404768944,
   'masked_pos': 4}]]

In [39]:
lm_probe.fill_multi_mask('[CLS] pay schedule is [MASK] . [SEP] pay schedule friday monthly paid  [SEP]', topk=10)

[[{'token': 4882,
   'token_str': 'weekly',
   'prob': 0.177995964884758,
   'masked_pos': 4}],
 [{'token': 9857,
   'token_str': 'tuesday',
   'prob': 0.14633633196353912,
   'masked_pos': 4}],
 [{'token': 9432,
   'token_str': 'thursday',
   'prob': 0.12900091707706451,
   'masked_pos': 4}],
 [{'token': 6928,
   'token_str': 'monday',
   'prob': 0.08406960219144821,
   'masked_pos': 4}],
 [{'token': 5958,
   'token_str': 'friday',
   'prob': 0.07963141053915024,
   'masked_pos': 4}],
 [{'token': 9317,
   'token_str': 'wednesday',
   'prob': 0.054221030324697495,
   'masked_pos': 4}],
 [{'token': 7058,
   'token_str': 'monthly',
   'prob': 0.045881446450948715,
   'masked_pos': 4}],
 [{'token': 4465,
   'token_str': 'sunday',
   'prob': 0.010150549933314323,
   'masked_pos': 4}],
 [{'token': 1012,
   'token_str': '.',
   'prob': 0.006899422034621239,
   'masked_pos': 4}],
 [{'token': 2800,
   'token_str': 'available',
   'prob': 0.006658394355326891,
   'masked_pos': 4}]]

In [12]:
get_class_name_hearst(['morning shift', 'night shift'], lm_probe) 

{'shifts': 0.40606245398521423,
 ',': 0.26602329313755035,
 'shift': 0.2167864441871643,
 'as such': 0.17379335917342864,
 'departments': 0.12587101757526398,
 '.': 0.10468659549951553,
 'such': 0.08616548031568527,
 'shift .': 0.08373006346583474,
 'shift ,': 0.06922138099221087,
 'thing': 0.06822600960731506,
 'different shifts': 0.05819503914694368,
 'days': 0.030948365107178688,
 'duties': 0.030666278675198555,
 'people': 0.019882088527083397,
 'all shift': 0.018312630556953785,
 'shifts shifts': 0.015264405753421417,
 'the shifts': 0.012394977859977452,
 'shifts .': 0.0087607046969298,
 'other shift': 0.005827367802047068,
 'night shift': 0.005223425829630202,
 'a ,': 0.00369603466209302,
 'shifts far': 0.002672503210559174,
 'departments shift': 0.0013202989668067572}

In [13]:
lm_probe.fill_multi_mask('[CLS] morning shift is a [MASK] . [SEP]')

[[{'token': 10103,
   'token_str': 'nightmare',
   'prob': 0.22245627641677856,
   'masked_pos': 5}],
 [{'token': 9478,
   'token_str': 'breeze',
   'prob': 0.08262979239225388,
   'masked_pos': 5}],
 [{'token': 6752,
   'token_str': 'mess',
   'prob': 0.049640536308288574,
   'masked_pos': 5}]]

In [14]:
get_class_name_hearst(['tattoos', 'piercings'], lm_probe) 

{'as such': 0.43733613785742875,
 ',': 0.42075057327747345,
 'things': 0.2776534240692854,
 'thing': 0.17522531747817993,
 'such': 0.17029350996017456,
 '.': 0.14481838420033455,
 'piercing .': 0.14278165690318012,
 'colors': 0.11828122287988663,
 'all ,': 0.08104194746758253,
 '##s': 0.07992054584125678,
 'tattoos': 0.057640042155981064,
 'certain ,': 0.041750938464703014,
 'hair': 0.0404670424759388,
 'hair .': 0.016858414259972143,
 ', far': 0.013993208231601051,
 'offensive ##s': 0.012718825611116868,
 'of .': 0.012042091326530668,
 'all ##s': 0.00942369014160055,
 'of hair': 0.008287945854710155,
 'have ##s': 0.004615252299980166,
 'your tattoos': 0.0025069542206117568,
 'of ##s': 0.0016842648931715831,
 'colored ##ities': 0.001684199194723776,
 'things well': 0.0006482934268544037}

In [15]:
get_class_name_hearst(['urine test', 'swab test'], lm_probe) 

{'tests': 0.2956191450357437,
 'test': 0.2912965764602025,
 ',': 0.22519341111183167,
 'testing': 0.22195589169859886,
 'as such': 0.18646264151812808,
 'testing .': 0.1794008589777789,
 '.': 0.14398349821567535,
 'drug test': 0.11761508577432522,
 'things': 0.07416696846485138,
 'drug ,': 0.06666907243964815,
 'random test': 0.026148756885049806,
 'test test': 0.019849586816945153,
 'drug tests': 0.011564690442011555,
 'of tests': 0.011472677921928742,
 'test testing': 0.010680313054750099,
 'test ,': 0.008002044270066788,
 'tests test': 0.005254073502398537,
 'drug testing': 0.00470752540166286,
 'all .': 0.004546203769800888}

In [16]:
get_class_name_hearst(['dress code', 'vacation policy'], lm_probe, context='You get 20 vacation days and unlimited sick leaves .') 

{',': 0.4591578394174576,
 'as far': 0.3131515095225179,
 'thing': 0.2881258353590965,
 'benefits': 0.2380053705225388,
 'things': 0.2274885829538107,
 'policies': 0.09467612951993942,
 '.': 0.0833433698862791,
 'policy': 0.07025929540395737,
 'all ,': 0.060510486888192716,
 'rules': 0.05586520582437515,
 'such': 0.05194425210356712,
 'benefits benefits': 0.038104624996751824,
 'stuff': 0.025048289448022842,
 'the ,': 0.02315254734827299,
 'things such': 0.013054041127456006,
 'of .': 0.009818693481176366,
 'per things': 0.009229950335676862,
 'personal ##s': 0.009148165204145298,
 'company ##s': 0.005164277678788576,
 'important benefits': 0.005003145212758686,
 '. things': 0.004069842443769456,
 'the policies': 0.003685950651606547,
 'benefits .': 0.0036825491692631385,
 'work policy': 0.0028935247978831014,
 'per .': 0.0027636477912691693,
 'of rules': 0.002558088221109589,
 'thing well': 0.0014380955690048713}