## Zero shot learning on roBERTa 

### Jamie Hunter - 40204692
#### For use in CSC4006 - Zero shot learning for Named Entity Recognition

In [18]:
import torch
import transformers
from transformers import RobertaTokenizer, RobertaForMaskedLM, RobertaModel
import torch.nn.functional as F
from operator import itemgetter

tokenizer = RobertaTokenizer.from_pretrained('roberta-large')
model = RobertaModel.from_pretrained('roberta-base')


def tokenize(txt):
    segments = txt.split('@')
    tokenized_segments = [tokenizer.tokenize(segment) for segment in segments]
    positions = [len(s) for s in tokenized_segments]
    positions = [positions[i]+positions[i-1] if i>0 else positions[i] for i in range(len(positions)) ]
    
    final_tokens = []
    for tokens in tokenized_segments:
        final_tokens.extend(tokens)
    return final_tokens, positions

def merge_embeddings(token_embeddings, start, end, operation = torch.mean):
    #print('token_embedding_to merge',token_embeddings[start:end].shape, token_embeddings[start:end][:,:5])
    merged =  operation(token_embeddings[start:end], dim = 0)
    #print('after merging',merged.shape, merged[:5])
    return merged

# merged = merge_embeddings(token_embeddings[start:end], input[1][0],input[1][1])
# print(merged.shape, merged[:5])
# return merged

def get_embeddings(txt, example = False):
    input = tokenize(txt)
    #print(input[0])
    input_ids = torch.tensor(tokenizer.convert_tokens_to_ids(input[0])).unsqueeze(0)
    op = model(input_ids)
    if example:
        return input[0] , op[0].data.squeeze(), input[1][0], input[1][1]
    else:
        return input[0], op[0].data.squeeze()
    

def find_entities(example, test_text, k):
    """ Finds k-similar entites from the test_text, depending on the highlighted entity from the example
    
    Parameters
    ----------
    example: str
        A single sentence to compare embeddings, with a highlighted entity in @@'s
    test_text : str
        A single sentence to compare embeddings
    k: int
        The number of similar entities to produce
    """
    example_tokens, embeddings, start, end = get_embeddings(example, True)
    entity_embedding = merge_embeddings(embeddings, start, end)

    test_tokens , test_embeddings= get_embeddings(test_text)
    similarity = F.cosine_similarity(test_embeddings, entity_embedding , dim = -1)
    #print(similarity)
    
    #base case
    max_similarity_index = torch.argmax(similarity)
    
    #get top 3 most similar instead (1 or 2 if less than 3 available)
    if len(similarity.size()) != 0:
        if len(similarity) > k:
            result, max_similarity_index = torch.topk(similarity,k)
        else: 
            result, max_similarity_index = torch.topk(similarity,len(similarity))
            
    
    #print()
    #print('Example = ',example)
    #print('Test sentence = ', test_text)
    if len(similarity.size()) != 0:
        #print('Most similar entities  = to {} is {}'.format( example_tokens[start:end], itemgetter(*max_similarity_index)(test_tokens)))
        return itemgetter(*max_similarity_index)(test_tokens)
    else: 
        #print('Most similar entity  = to {} is {}'.format( example_tokens[start:end], test_tokens[max_similarity_index]))
        return test_tokens[max_similarity_index]
    
    
     

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.dense.bias', 'lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Methods used for evaluation, retrieval of entities

In [47]:
def evaluate_script(examples, filename, label_to_evaluate, k, merge):
    """ Runs script that reads input file, finds a sentence example and its labels, and runs evaluate_model
    to complete full evaluation of the corpus.
    
    Parameters
    ----------
    examples : list
        list of input example sentences
    filename : str
        The filename and location of the corpus
    label_to_evaluate: str
        The label in which we would like to evaluate
    k: int
        The number of similar entities to produce
    merge: str
        The merging merthod we would like to use
    """
    
    #initialise text and value for retrieving label
    total_entities = 0
    found_entities = 0
    entities_selected = 0
    text = []
    labels = []
    sentence = ''
    label_loc = 4
    #read file
    with open(filename) as file:
        next(file)
        for line in file:
                #if blank line, process method and reset sentence
                if line.isspace() and text != []:
                    te, fe, es = evaluate_model(examples, text, labels, label_to_evaluate, k, merge)
                    total_entities += te
                    found_entities += fe
                    entities_selected += es
                    test_text = ''
                    text = []
                    labels = []
                #if not blank line, add to line, find label and assign it
                if not line.isspace():
                    word = line.split(' ')[0]
                    label = line.split(' ')[label_loc-1]
                    text.append(word)
                    labels.append(label)
    return total_entities, found_entities, entities_selected

                
def retrieve_person_entities(text, labels):
    """ Returns list of the person entities from the given text and labels
    
    Parameters
    ----------
    text : list
        The list of text tokens
    labels : list
        The list of labels
    """

    person_entities = []
    b_pos = 0
    for i, label in enumerate(labels):
        if label.strip() == "B-PER" or label.strip() == "I-PER":
            if label.strip() == "I-PER": 
                person_entities[b_pos] = person_entities[b_pos] + " " + text[i]
            else: 
                person_entities.append(text[i])
                b_pos = len(person_entities) - 1
    return person_entities

def retrieve_location_entities(text, labels):
    """ Returns list of the location entities from the given text and labels
    
    Parameters
    ----------
    text : list
        The list of text tokens
    labels : list
        The list of labels
    """
    location_entities = []
    b_pos = 0
    for i, label in enumerate(labels):
        if label.strip() == "B-LOC" or label.strip() == "I-LOC":
            if label.strip() == "I-LOC": 
                location_entities[b_pos] = location_entities[b_pos] + " " + text[i]
            else: 
                location_entities.append(text[i])
                b_pos = len(location_entities) - 1
    return location_entities

def retrieve_organisation_entities(text, labels):
    """ Returns list of the organisation entities from the given text and labels
    
    Parameters
    ----------
    text : list
        The list of text tokens
    labels : list
        The list of labels
    """
    organisation_entities = []
    b_pos = 0
    for i, label in enumerate(labels):
        if label.strip() == "B-ORG" or label.strip() == "I-ORG":
            if label.strip() == "I-ORG": 
                organisation_entities[b_pos] = organisation_entities[b_pos] + " " + text[i]
            else: 
                organisation_entities.append(text[i])
                b_pos = len(organisation_entities) - 1
    return organisation_entities

def retrieve_miscellaneous_entities(text, labels):
    """ Returns list of the miscellaneous entities from the given text and labels
    
    Parameters
    ----------
    text : list
        The list of text tokens
    labels : list
        The list of labels
    """
    miscellaneous_entities = []
    b_pos = 0
    for i, label in enumerate(labels):
        if label.strip() == "B-MISC" or label.strip() == "I-MISC":
            if label.strip() == "I-MISC": 
                miscellaneous_entities[b_pos] = miscellaneous_entities[b_pos] + " " + text[i]
            else: 
                miscellaneous_entities.append(text[i])
                b_pos = len(miscellaneous_entities) - 1
    return miscellaneous_entities

def get_average_number_entities(filename):
    """ Returns average number of entities based on number of examples and total entities in the corpus
    
    Parameters
    ----------
    filename : str
        The filename and location of the corpus
    """
    #initialise values to count all labels
    example_count = 0
    label_count = 0
    label_loc = 4
    
    #read file
    with open(filename) as file:
        next(file)
        for line in file:
            #if empty line, new example, add to example_count
            if line.isspace():
                example_count += 1
            #if not blank line, check for label, if not 'O' then we have entity
            if not line.isspace() and line.split(' ')[label_loc-1].strip() != 'O':
                label_count += 1
    return label_count / example_count

def merge_lists(entity_lists, merge):
    similar_entities = []
    entities_selected = 0
    
    # union of Value Lists
    if merge == "union":
        for entity in entity_lists:
            similar_entities = list(set.union(set(similar_entities), entity))
        entities_selected += len(similar_entities)
    # intersection of value lists
    elif merge == "intersection":
        similar_entities = list(set.intersection(*map(set, entity_lists)))
        entities_selected += len(similar_entities)
    
    return similar_entities, entities_selected



def evaluate_model(example_texts, test_text, test_labels, label_to_evaluate, k, merge):  
    """ Runs the model retrieving k-entities, chooses a merging method and returns 
     results for use in metric evaluation.
    
    Parameters
    ----------
    example_texts : list
        The list of example sentences to compare embeddings
    test_text : str
        A single sentence to compare embeddings
    test_labels: list
        A list of labels for the respective test_text
    label_to_evaluate: str
        The label in which we would like to evaluate
    k: int
        The number of similar entities to produce
    merge: str
        The merging merthod we would like to use
    """
        
    total_entities = 0
    found_entities = 0
    entities_selected = 0
    
    person_entities = []
    organisation_entities = []
    location_entities = []
    miscellaneous_entities = []
    
    #finding similar entities, intersection or union
    similar_entities = []
    entity_lists = []
    entities = []
    
    if label_to_evaluate == 'person':
        person_entities = retrieve_person_entities(test_text, test_labels)
    elif label_to_evaluate == 'organisation':
        organisation_entities = retrieve_organisation_entities(test_text, test_labels)
    elif label_to_evaluate == 'location':
        location_entities = retrieve_location_entities(test_text, test_labels)
    elif label_to_evaluate == 'miscellaneous':
        miscellaneous_entities = retrieve_miscellaneous_entities(test_text, test_labels)        

    text_labels = {
        "person": person_entities,
        "organisation": organisation_entities,
        "location": location_entities,
        "miscellaneous": miscellaneous_entities
    }
    
    full_sentence = ''
    for word in test_text:
        full_sentence = full_sentence + ' ' + word
    
    #works for k=1 entities
    if k == 1:
        for example_text in example_texts:
            entities.append(find_entities(example_text, full_sentence, k))
            entity_lists.append(list(entities))
            entities = []
    else: 
        #stored list of tuples for k>1
        for example_text in example_texts:
            entity_lists.append(list(find_entities(example_text, full_sentence, k)))
    
    joined_entities, num_selected = merge_lists(entity_lists, merge)
    
    similar_entities = joined_entities
    entities_selected += num_selected
    
    if similar_entities != None:
        for i, entity in enumerate(similar_entities):
            if entity[0] == 'Ġ':
                similar_entities[i] = entity[1:]
    
    for entity in text_labels[label_to_evaluate]:
        total_entities += 1
        entity_split = entity.split()
        for entity in entity_split:
            if entity in similar_entities:
                found_entities += 1
                break
                
    return total_entities, found_entities, entities_selected

## Main to run evaluation of model on all values of k and different merge types

In [6]:
#text examples with highlighted entity for model to find similar word embeddings to against unseen corpus

person_examples = ["State media quoted China's top negotiator with Taipei, @Tang Shubei@, as telling a visiting group from Taiwan on Wednesday that it was time for the rivals to hold political talks.",
                  "The president of the Nasdaq, @Alfred Berkeley@, was to hold a news conference Wednesday afternoon to elaborate on the new rules' effects on the market, the second largest in the world."]


location_examples = ["@Germany@'s representative to the European Union's veterinary committee Werner Zwingmann said on Wednesday consumers should buy sheepmeat from countries other than Britain until the scientific advice was clearer.",
                    "@Canada@'s largest grain handling firm said Wednesday it expects to forge a partnership with hog farmers by 1997 with a view to expanding the company's scope into pork production."]

organisation_examples = ["Speaking only hours after Chinese state media said the time was right to engage in political talks with Taiwan, @Foreign Ministry@ spokesman Shen Guofang told Reuters: \"The necessary atmosphere for the opening of the talks has been disrupted by the Taiwan authorities\".",
                         "The Italian cabinet on Wednesday granted a reprieve for media mogul Silvio Berlusconi's @Mediaset@ television empire with a decree extending the current legal framework for television stations until Janurary 31, 1997."]

miscellaneous_examples = ["The pair last worked together when Scotland won the @Five Nations@ grand slam in 1990.", 
                          "\"If the (Tour's tournament) committee decides to change the rule I would not be against it, \"said Ballesteros, Olazabal's compatriot and @Ryder Cup@ captain."]

#path to unseen corpus
path = '../data/valid.txt'

#get average number of entities in corpus
average_entities = round(get_average_number_entities(path))

#set for k-number similarities
num_similar = [1,3,5,10,average_entities]
merges = ["intersection", "union"]

#for all merge types, run evaluation
for merge in merges:
    #for all values of k, run evaluation
    for k in num_similar:
    
        #evaluation for recall and precision
        entities_selected = 0
        entities_relevant = 0
        true_positives = 0
        num_entities_chosen = 0

        #evaLuation for person entities
        total_person_entities = 0
        found_person_entities = 0

        #evaluation for organisation entities
        total_organisation_entities = 0
        found_organisation_entities = 0

        #evaluation for location entities
        total_location_entities = 0
        found_location_entities = 0  
        
        #evaluation for miscellaneous entities
        total_miscellaneous_entities = 0
        found_miscellaneous_entities = 0
        
        #get returned results from entities found, selected and total - add to total results
        total_entities, found_entities, num_entities_chosen = evaluate_script(person_examples, path, 'person', k, merge)
        total_person_entities += total_entities
        found_person_entities += found_entities
        entities_selected += num_entities_chosen

        total_entities, found_entities, num_entities_chosen = evaluate_script(location_examples, path, 'location', k, merge)
        total_location_entities += total_entities
        found_location_entities += found_entities
        entities_selected += num_entities_chosen

        total_entities, found_entities, num_entities_chosen = evaluate_script(organisation_examples, path, 'organisation', k, merge)
        total_organisation_entities += total_entities
        found_organisation_entities += found_entities
        entities_selected += num_entities_chosen
        
        total_entities, found_entities, num_entities_chosen = evaluate_script(miscellaneous_examples, path, 'miscellaneous', k, merge)
        total_miscellaneous_entities += total_entities
        found_miscellaneous_entities += found_entities
        entities_selected += num_entities_chosen

        print("------------EVALUATION, K=" + str(k) + ", " + str(merge) + "------------")

        if found_person_entities == 0 and total_person_entities != 0:
            print("Person Accuracy: 0%")
        elif found_person_entities == 0 and total_person_entities == 0:
            print("No Person in Examples.")
        else: print("Person accuracy: " + str(found_person_entities/total_person_entities*100) + "%")

        if found_organisation_entities == 0 and total_organisation_entities != 0:
            print("Organisation Accuracy: 0%")
        elif found_organisation_entities == 0 and total_organisation_entities == 0:
            print("No Organisation in Examples.")
        else: print("Organisation accuracy: " + str(found_organisation_entities/total_organisation_entities*100) + "%")  

        if  found_location_entities == 0 and total_location_entities != 0:
            print("Location Accuracy: 0%")
        elif found_location_entities == 0 and total_location_entities == 0:
            print("No Location in Examples.")
        else: print("Location accuracy: " + str(found_location_entities/total_location_entities*100) + "%")
            
        if  found_miscellaneous_entities == 0 and total_miscellaneous_entities != 0:
            print("Miscellaneous Accuracy: 0%")
        elif found_miscellaneous_entities == 0 and total_miscellaneous_entities == 0:
            print("No Miscellaneous in Examples.")
        else: print("Miscellaneous accuracy: " + str(found_miscellaneous_entities/total_miscellaneous_entities*100) + "%")

        entities_relevant = total_location_entities + total_organisation_entities + total_person_entities + total_miscellaneous_entities
        true_positives = found_location_entities + found_organisation_entities + found_person_entities + found_miscellaneous_entities
        
        
        #accuracy and recall in this case are the same as we do not check for true negatives
        accuracy = true_positives / entities_relevant
        recall = true_positives / entities_relevant
        precision = true_positives / entities_selected
        f1 = 2 * (precision * recall) / (precision + recall)
        
        #metric evaluation
        print("Accuracy: " + str(accuracy))
        print("Recall: " + str(recall))
        print("Precision: " + str(precision))
        print("F1 score: " + str(f1))

        print("----------------------------------------------------")

------------EVALUATION, K=1, intersection------------
Person accuracy: 10.640608034744844%
Organisation accuracy: 6.343283582089552%
Location accuracy: 9.744148067501362%
Miscellaneous accuracy: 14.316702819956618%
Accuracy: 0.09964652415418279
Recall: 0.09964652415418279
Precision: 0.06341045415595545
F1 score: 0.07750212738103031
----------------------------------------------------
------------EVALUATION, K=3, intersection------------
Person accuracy: 27.035830618892508%
Organisation accuracy: 19.25373134328358%
Location accuracy: 15.351115949918345%
Miscellaneous accuracy: 33.40563991323211%
Accuracy: 0.22656118498569264
Recall: 0.22656118498569264
Precision: 0.04205855701028029
F1 score: 0.07094665823318574
----------------------------------------------------
------------EVALUATION, K=5, intersection------------
Person accuracy: 34.473398479913136%
Organisation accuracy: 28.059701492537314%
Location accuracy: 23.897659227000545%
Miscellaneous accuracy: 44.36008676789588%
Accuracy: 

## Pre-requisites for testing

In [44]:
import unittest
import pytest

## Definition and execution of tests

In [58]:
class TestRetrieval(unittest.TestCase):
    
    def test_retrieve_organisation_entities(self):
        
        text = ["John", "went", "to", "Starbucks", "in", "Belfast"]
        labels = ["B-PER", "O", "O", "B-ORG", "O", "B-LOC"]
        
        organisation_entities = retrieve_organisation_entities(text, labels)
        print(organisation_entities)
        
        assert organisation_entities == ["Starbucks"]
    
    def test_retrieve_person_entities(self):
        text = ["John", "went", "to", "Starbucks", "in", "Belfast"]
        labels = ["B-PER", "O", "O", "B-ORG", "O", "B-LOC"]
        
        person_entities = retrieve_person_entities(text, labels)
        print(person_entities)
        
        assert person_entities == ["John"]
    
    def test_retrieve_location_entities(self):
        text = ["John", "went", "to", "Starbucks", "in", "Belfast"]
        labels = ["B-PER", "O", "O", "B-ORG", "O", "B-LOC"]
        
        location_entities = retrieve_location_entities(text, labels)
        print(location_entities)
        
        assert location_entities == ["Belfast"]
        
    def test_retrieve_miscellaneous_entities(self):
        text = ["John", "went", "to", "test", "miscellaneous"]
        labels = ["B-PER", "O", "O", "B-ORG", "B-MISC"]
        
        miscellaneous_entities = retrieve_miscellaneous_entities(text, labels)
        print(miscellaneous_entities)
        
        assert miscellaneous_entities == ["miscellaneous"]

    def test_retrieve_grouped_entity(self):
        text = ["Joe", "Bloggs", "went", "to", "Starbucks", "in", "Belfast"]
        labels = ["B-PER", "I-PER", "O", "O", "B-ORG", "O", "B-LOC"]
        
        person_entities = retrieve_person_entities(text, labels)
        print(person_entities)
        
        assert person_entities == ["Joe Bloggs"]

class TestFindEntities(unittest.TestCase):
    
    def test_find_entities(self):
        example_text = "Sentence with @Entity@ tagged"
        test_text = "Another sentence to find entity from"
        entities = find_entities(example_text, test_text, 3)
        
        print(len(entities))
        
        assert len(entities) == 3
        
    def test_max_found(self):
        
        #make sure more entities than possible are not selected
        #should be 6 as only 6 entities to choose instead of 10
        
        example_text = "Sentence with @Entity@ tagged"
        test_text = "Another sentence to find entity from" # 6 tokens
        entities = find_entities(example_text, test_text, 10) # find k=10 entities
        
        assert len(entities) == len(test_text.split())
        
class TestMerge(unittest.TestCase):
    
    def test_union(self):
        entity_lists = [["John"], ["Nathan"]]
        similar_entities, entities_selected = merge_lists(entity_lists, 'union')
        
        print(similar_entities)
        
        assert similar_entities == ["John", "Nathan"]
        assert entities_selected == 2
    
    def test_intersection(self):
        entity_lists = [["John"], ["John", "Nathan"]]
        similar_entities, entities_selected = merge_lists(entity_lists, 'intersection')
        
        print(similar_entities)
        
        assert similar_entities == ["John"]
        assert entities_selected == 1
        
        
class TestCalculations(unittest.TestCase):
    
    # sample data includes 2 sentences with 9 entities
    
    def test_get_average_number_entities(self):
        test_file = '../data/sample_data.txt'
        avg = get_average_number_entities(test_file)
        # 14 entities, 2 examples
        print(avg)
        
        assert avg == 4.5
        
    def test_total_entities(self):
        person_examples = ["State media quoted China's top negotiator with Taipei, @Tang Shubei@, as telling a visiting group from Taiwan on Wednesday that it was time for the rivals to hold political talks.",
                  "The president of the Nasdaq, @Alfred Berkeley@, was to hold a news conference Wednesday afternoon to elaborate on the new rules' effects on the market, the second largest in the world."]
        
        test_file = '../data/sample_data.txt'
        te, ef, es = evaluate_script(person_examples, test_file, 'person', 1, 'union')
        
        print(te)
        total_entities = te
        
        assert total_entities == 2
        
    def test_entities_found(self):
        person_examples = ["State media quoted China's top negotiator with Taipei, @Tang Shubei@, as telling a visiting group from Taiwan on Wednesday that it was time for the rivals to hold political talks.",
                  "The president of the Nasdaq, @Alfred Berkeley@, was to hold a news conference Wednesday afternoon to elaborate on the new rules' effects on the market, the second largest in the world."]
        
        test_file = '../data/sample_data.txt'
        te, ef, es = evaluate_script(person_examples, test_file, 'person', 1, 'union')
        
        print(ef)
        entities_found = ef
        
        ##entities_found can be confirmed by using the find_entities method
        assert entities_found == 2
        
    def test_entities_selected(self):
        person_examples = ["State media quoted China's top negotiator with Taipei, @Tang Shubei@, as telling a visiting group from Taiwan on Wednesday that it was time for the rivals to hold political talks.",
                  "The president of the Nasdaq, @Alfred Berkeley@, was to hold a news conference Wednesday afternoon to elaborate on the new rules' effects on the market, the second largest in the world."]
        
        test_file = '../data/sample_data.txt'
        te, ef, es = evaluate_script(person_examples, test_file, 'person', 1, 'union')
        
        print(es)
        entities_selected = es
        
        assert entities_selected == 2

unittest.main(argv=[''], verbosity=2, exit=False)

test_entities_found (__main__.TestCalculations) ... 

2


ok
test_entities_selected (__main__.TestCalculations) ... 

2


ok
test_get_average_number_entities (__main__.TestCalculations) ... 

4.5


ok
test_total_entities (__main__.TestCalculations) ... 

2


ok
test_find_entities (__main__.TestFindEntities) ... 

3


ok
test_max_found (__main__.TestFindEntities) ... 

6


ok
test_intersection (__main__.TestMerge) ... 

['John']


ok
test_union (__main__.TestMerge) ... 

['John', 'Nathan']


ok
test_retrieve_grouped_entity (__main__.TestRetrieval) ... 

['Joe Bloggs']


ok
test_retrieve_location_entities (__main__.TestRetrieval) ... 

['Belfast']


ok
test_retrieve_miscellaneous_entities (__main__.TestRetrieval) ... 

['miscellaneous']


ok
test_retrieve_organisation_entities (__main__.TestRetrieval) ... 

['Starbucks']


ok
test_retrieve_person_entities (__main__.TestRetrieval) ... 

['John']


ok

----------------------------------------------------------------------
Ran 13 tests in 2.798s

OK


<unittest.main.TestProgram at 0x1ca74bc2310>