In [None]:
from transformers import AutoModel, AutoTokenizer
from nltk.corpus import wordnet as wn

from random

import torch
import numpy as np

random.seed(42)

DEFINITION_TOKEN = '<DEF>'

def get_candidates(syn):
    definition_set = set()
    option_names = set()
    options = {}

    # Add the target first
    target_name = syn.name().split('.')[0].replace("_", " ")
    target_def = syn.definition()
    options[syn.name()] = target_def
    option_names.add(target_name)
    definition_set.add(target_def)

    for hypernym in syn.hypernyms():
        for option in hypernym.hyponyms():
            option_name = option.name().split('.')[0].replace("_", " ")
            definition = option.definition()
            if (option_name not in option_names) and (definition not in definition_set):
                options[option.name()] = definition
                option_names.add(option_name)
                definition_set.add(definition)   

    definitions = list(options.values())
    option_words = list(options.keys())
    
    return option_words, definitions


    
    
model_name = 'gpt2-xl'
syn_name = 'boondoggle.n.01' 
# syn_name = 'clanger.n.01'
# syn_name = 'block_vote.n.01'
gpu_id = 0

batch_size = 1
padded_context_counts = [0,1,3]

pattern = f'{DEFINITION_TOKEN} is the definition of'


target_word = syn_name.split('.')[0].replace("_", " ")


with open("/mounts/work/kerem/datasets/WordNet/wordnet_words_max_count_100_contexts.txt") as context_file:
    for line in context_file:
        word = line.split('\t')[0]
        if word == ' '.join(nltk.tokenize(target_word)):
            target_word_contexts = line.split('\t')[1:] 
            break

model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")   
model.to(device)

softmax = torch.nn.functional.softmax


syn = wn.synset(syn_name)

option_words, definitions = get_candidates(syn)
answer = option_words.index(syn.name())
option_count = len(option_words)

target_ids = tokenizer.encode( " " + target_word, add_special_tokens=False)

for padded_context_count in padded_context_counts:
    contexts_to_pad = random.sample(target_word_contexts, padded_context_count)

    # Process the input
    start_inds = []
    contexts = []
    def_ind = pattern.split().index(DEFINITION_TOKEN)     
    for option_def in definitions: 
        defin = option_def.capitalize()

        context = ' '.join(contexts_to_pad.extend(pattern.replace(DEFINITION_TOKEN, defin)))
        contexts.append(context)
        start_inds.append(len(self.tokenizer.encode(context))-1)

    inputs = self.tokenizer(contexts, padding=True)           
    
    seq_len = len(inputs['input_ids'][0])
    
    print(f'Prepadding {padded_context_count} contexts before definition')
    print(f'Input sequence length is {seq_len} tokens')
    
    
    input_ids = inputs['input_ids']
    seq_mask = inputs['attention_mask']
    
    split_inds = np.array_split(list(range(option_count)), np.ceil(option_count/batch_size))
    
    prediction_probs = np.ones([option_count])
    for gen_no, target_id in enumerate(target_ids):
        sample_no = 0
        for inds in split_inds:
            input_ids_split = input_ids[inds].to(device)
            seq_mask_split = seq_mask[inds].to(device)

            with torch.no_grad():
                output = self.model(input_ids=input_ids_split, attention_mask=seq_mask_split)[0]

            for sample_no_split in range(len(input_ids_split)):
                target_ind = sum(seq_mask_split[sample_no_split])-1
                prob = softmax(output[sample_no_split, target_ind,:], dim=0)[target_id].cpu().numpy()    
                prediction_probs[sample_no] *= prob
                
                sample_no += 1
                
                if target_ind == (seq_len - 1):
                    input_ids[input_no] = torch.cat((input_ids[input_no, 1:], torch.tensor(target_id).unsqueeze(0).to(self.device)),0) 
                else:
                    input_ids[input_no, target_location + 1] = target_id
                    seq_mask[input_no, target_location + 1] = 1

