In [13]:
import pandas as pd

def construct_context_gloss_pairs(input, target_start_id, target_end_id, lemma):
    """
    construct context gloss pairs like sent_cls_ws
    :param input: str, a sentence
    :param target_start_id: int
    :param target_end_id: int
    :param lemma: lemma of the target word
    :return: candidate lists
    """
    sent = input.split(" ")
    assert 0 <= target_start_id and target_start_id < target_end_id  and target_end_id <= len(sent)
    target = " ".join(sent[target_start_id:target_end_id])
    if len(sent) > target_end_id:
        sent = sent[:target_start_id] + ['"'] + sent[target_start_id:target_end_id] + ['"'] + sent[target_end_id:]
    else:
        sent = sent[:target_start_id] + ['"'] + sent[target_start_id:target_end_id] + ['"']

    sent = " ".join(sent)
    lemma = lemma


    sense_data = pd.read_csv("/home/gerard/ownCloud/varis_tesi/GlossBERT_datasets/wordnet/index.sense.gloss",sep="\t",header=None, error_bad_lines=False).values
    d = dict()
    for i in range(len(sense_data)):
        s = sense_data[i][0]
        pos = s.find("%")
        try:
            d[s[:pos + 2]].append((sense_data[i][0],sense_data[i][-1]))
        except:
            d[s[:pos + 2]]=[(sense_data[i][0], sense_data[i][-1])]

    # print(len(d))
    # print(len(d["happy%3"]))
    # print(d["happy%3"])

    candidate = []
    for category in ["%1", "%2", "%3", "%4", "%5"]:
        query = lemma + category
        try:
            sents = d[query]
            for sense_key, gloss in sents:
                candidate.append((sent, f"{target} : {gloss}", target, lemma, sense_key, gloss))
        except:
            pass
    assert len(candidate) != 0, f'there is no candidate sense of "{lemma}" in WordNet, please check'
    print(f'there are {len(candidate)} candidate senses of "{lemma}"')


    return candidate

In [30]:
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
    """Truncates a sequence pair in place to the maximum length."""

    # This is a simple heuristic which will always truncate the longer sequence
    # one token at a time. This makes more sense than truncating an equal percent
    # of tokens from each, since if one sequence is very short then each token
    # that's truncated likely contains more information than a longer sequence.
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop()

In [35]:
class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids

In [25]:
from transformers import BertTokenizer
from transformers import BertForSequenceClassification
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


label_list = ["0", "1"]
num_labels = len(label_list)
tokenizer = BertTokenizer.from_pretrained("google/bert_uncased_L-12_H-768_A-12", do_lower_case=True)
model = BertForSequenceClassification.from_pretrained("google/bert_uncased_L-12_H-768_A-12", num_labels=num_labels)
model.to(device)
print('Done')


Some weights of the model checkpoint at google/bert_uncased_L-12_H-768_A-12 were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification

Done


In [27]:
input = "U.N. group drafts plan to reduce emissions"
target_start_id = 3
target_end_id = 4
lemma = "plan"

candidate = construct_context_gloss_pairs(input, target_start_id, target_end_id, lemma)

there are 7 candidate senses of "plan"


In [36]:
max_seq_length = 512

candidate_results = []
features = []
for item in candidate:
    text_a = item[0] # sentence
    text_b = item[1] # gloss
    candidate_results.append((item[-2], item[-1])) # (sense_key, gloss)

    tokens_a = tokenizer.tokenize(text_a)
    tokens_b = tokenizer.tokenize(text_b)
    _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
    tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
    segment_ids = [0] * len(tokens)
    tokens += tokens_b + ["[SEP]"]
    segment_ids += [1] * (len(tokens_b) + 1)

    input_ids = tokenizer.convert_tokens_to_ids(tokens)

    # The mask has 1 for real tokens and 0 for padding tokens. Only real
    # tokens are attended to.
    input_mask = [1] * len(input_ids)

    # Zero-pad up to the sequence length.
    padding = [0] * (max_seq_length - len(input_ids))
    input_ids += padding
    input_mask += padding
    segment_ids += padding

    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length

    features.append(
        InputFeatures(input_ids=input_ids,
                      input_mask=input_mask,
                      segment_ids=segment_ids))

In [None]:
features[0].input_ids