# Testing FewRel few-shot relation extraction

* goal is to make it like the demo at http://opennre.thunlp.ai/#/fewshot_re - in the demo it seems to work pretty well
* based on https://github.com/thunlp/FewRel
* use torch 1.3.1
* copy val_wiki.json as test_wiki.json in the data folder to make it work
* python requirements same as opennre
* the checkpoint files are very big, so they aren't in the repository. One checkpoint is at https://drive.google.com/file/d/1yiz3q3xNz-llsY55g5OdodxiH1RThYuz/view?usp=sharing
* can't train on prof song's gpus, not enough vram. the hpc computers have enough vram, but i couldn't get pytorch to work on them, it uses a 10 year old version of linux.... 
* however testing using this code does actually work on the gpu (but you would need to use cuda pytorch (refer to pytorch.org) and whereever you are using model do "model = model.cuda()", and also "tensor=tensor.cuda()"). refer to test_script.py
* on my laptop cpu this code takes about 2 seconds per query, on the gpu it's a lot faster, runs in under 1 second. the speed seems fine, but more testing is required to really find out. 
* i noticed that after a lot of training the model seems to overfit, so it doesn't work properly on the testing data. therefore we should probably make our own benchmark in order to compare different models to each other
* the maximum length is in terms of the number of bert tokens, not in terms of the number of characters in the sentence. increasing it seems to work fine


In [None]:
checkpoint_path = "checkpoint/pair-bert-train_wiki-val_wiki-5-1.pth.tar"
bert_pretrained_checkpoint = 'bert-base-uncased'
max_length = 128

In [None]:
from fewshot_re_kit.data_loader import FewRelDatasetPair, get_loader_pair
from fewshot_re_kit.framework import FewShotREFramework
from fewshot_re_kit.sentence_encoder import BERTPAIRSentenceEncoder
from models.pair import Pair
import os
import torch

from spacy.tokenizer import Tokenizer
from spacy.lang.en import English
import spacy
import neuralcoref


In [None]:
sentence_encoder = BERTPAIRSentenceEncoder(
                    bert_pretrained_checkpoint,
                    max_length)

In [None]:
# meow_loader = get_loader_pair('val_wiki', sentence_encoder,
#                 N=5, K=1, Q=1, na_rate=0, batch_size=1, encoder_name='bert')

val_data_loader = iter(FewRelDatasetPair('val_wiki', sentence_encoder, N=5, K=1, Q=1, na_rate=0, root='./data', encoder_name='bert'))


In [None]:
model = Pair(sentence_encoder, hidden_size=768)
if torch.cuda.is_available():
    model = model.cuda()

In [None]:
type(next(val_data_loader)[0]['word'])

In [None]:
def __load_model__(ckpt):
    '''
    ckpt: Path of the checkpoint
    return: Checkpoint dict
    '''
    if os.path.isfile(ckpt):
        checkpoint = torch.load(ckpt)
        print("Successfully loaded checkpoint '%s'" % ckpt)
        return checkpoint
    else:
        raise Exception("No checkpoint found at '%s'" % ckpt)

        
def item(x):
    '''
    PyTorch before and after 0.4
    '''
    torch_version = torch.__version__.split('.')
    if int(torch_version[0]) == 0 and int(torch_version[1]) < 4:
        return x[0]
    else:
        return x.item()
    
def bert_tokenize(tokens, head_indices, tail_indices):
    word = sentence_encoder.tokenize(tokens,
            head_indices,
            tail_indices)
    return word

In [None]:
nlp = spacy.load("en_core_web_sm")
# Create a Tokenizer with the default settings for English
# including punctuation rules and exceptions
tokenizer = nlp.Defaults.create_tokenizer(nlp)
list(map(str, tokenizer("""hello meow. meow is donald trump's friend""")))

In [None]:
# loading from the model checkpoint state

model.eval()
state_dict = __load_model__(checkpoint_path)['state_dict']
own_state = model.state_dict()
for name, param in state_dict.items():
    if name not in own_state:
        continue
    own_state[name].copy_(param)

In [None]:
#evaluating on the wikidata dataset, which is what they have already implemented.

N = 5
K = 1
Q = 1
na_rate = 0
with torch.no_grad():
    for it in range(10):
        batch, label = next(val_data_loader)
        label = torch.tensor(label)
        batch['word'] = torch.stack(batch['word'])
        batch['seg'] = torch.stack(batch['seg'])
        batch['mask'] = torch.stack(batch['mask'])
        logits, pred = model(batch, N, K, Q * N + Q * na_rate)
        print(pred, label)
        right = model.accuracy(pred, label)
        print(item(right.data))

In [None]:
N = 5
K = 2
Q = 1
na_rate = 0
#names have to be upper case, otherwise they are not detected by NER
example_relation_data = [
    {'name':'love',
    'examples':[
        {'sentence':'Meow loves Mo', 'head':'Meow', 'tail':'Mo'},
        {'sentence':'Tom is in love with Jull', 'head':'Tom', 'tail':'Jull'}
    ]},
    {'name':'hate',
    'examples':[
        {'sentence':'Trump hates the Mooch', 'head':'Trump', 'tail':'Mooch'},
        {'sentence':'Ivanka and Jared dislike each other intensely', 'head':'Ivanka', 'tail':'Jared'}
    ]},
    {'name':'spouse',
    'examples':[
        {'sentence':'Trump is married to Ivanka', 'head':'Trump', 'tail':'Ivanka'},
        {'sentence':"Bill went out with his wife Jill on saturday", 'head':'Bill', 'tail':'Jill'}
    ]},
        {'name':'insult',
    'examples':[
        {'sentence':'The President said that Michael Cohen is a rat', 'head':'The President', 'tail':'Michael'},
        {'sentence':'Meow and Tom threw jabs at each other', 'head':'Meow', 'tail':'Tom'}
    ]},
        {'name':'capital',
    'examples':[
        {'sentence':'Austin is the capital of Texas', 'head':'Austin', 'tail':'Texas'},
        {'sentence':"the capital of China is located in Beijing", 'head':'Beijing', 'tail':"China"}
    ]}
    
]

queries = [{
    'sentence':'Cohen and Fluffy are very loving to each other','head':'Cohen','tail':'Fluffy'
},
{
    'sentence':"""The US's capital is Washington""",'head':'Washington','tail':'US'
}]

## using already specified entities in the query sentence

In [None]:
nlp = spacy.load("en_core_web_sm")

def spacy_tokenize(sentence):
    return list(map(str, nlp(sentence)))

max_length = 128   #max length in terms of the number of characters
for q in queries:
    fusion_set = {'word': [], 'mask': [], 'seg': []}
#     tokens = q['sentence'].split(" ")  #TODO: generalize, make it tokenize like in the example wikidata, would probably need to use some nlp library to do it
    tokens = spacy_tokenize(q['sentence'])
    tokenized_head = spacy_tokenize(q['head'])
    tokenized_tail = spacy_tokenize(q['tail'])
    head_indices = list(range(tokens.index(tokenized_head[0]), tokens.index(tokenized_head[0])+len(tokenized_head)))   #TODO: make it work with multi-word entities
    tail_indices = list(range(tokens.index(tokenized_tail[0]), tokens.index(tokenized_tail[0])+len(tokenized_tail)))
    bert_query_tokens = bert_tokenize(tokens, head_indices, tail_indices)
    for relation in example_relation_data:
        for ex in relation['examples']:
#             tokens = ex['sentence'].split(" ")  #TODO: generalize
            tokens = spacy_tokenize(ex['sentence'])
            tokenized_head = spacy_tokenize(ex['head'])
            tokenized_tail = spacy_tokenize(ex['tail'])
            head_indices = list(range(tokens.index(tokenized_head[0]), tokens.index(tokenized_head[0])+len(tokenized_head)))
            tail_indices = list(range(tokens.index(tokenized_tail[0]), tokens.index(tokenized_tail[0])+len(tokenized_tail)))
            bert_relation_example_tokens = bert_tokenize(tokens, head_indices, tail_indices)
            
            SEP = sentence_encoder.tokenizer.convert_tokens_to_ids(['[SEP]'])
            CLS = sentence_encoder.tokenizer.convert_tokens_to_ids(['[CLS]'])
            word_tensor = torch.zeros((max_length)).long()
            
            new_word = CLS + bert_relation_example_tokens + SEP + bert_query_tokens + SEP
            for i in range(min(max_length, len(new_word))):
                word_tensor[i] = new_word[i]
            mask_tensor = torch.zeros((max_length)).long()
            mask_tensor[:min(max_length, len(new_word))] = 1
            seg_tensor = torch.ones((max_length)).long()
            seg_tensor[:min(max_length, len(bert_relation_example_tokens) + 1)] = 0
            fusion_set['word'].append(word_tensor)
            fusion_set['mask'].append(mask_tensor)
            fusion_set['seg'].append(seg_tensor)
    
    fusion_set['word'] = torch.stack(fusion_set['word'])
    fusion_set['seg'] = torch.stack(fusion_set['seg'])
    fusion_set['mask'] = torch.stack(fusion_set['mask'])
    logits, pred = model(fusion_set, N, K, 1)
    print(pred, logits)
    
            
            

In [None]:
nlp = spacy.load("en_core_web_sm")
# Create a Tokenizer with the default settings for English
# including punctuation rules and exceptions
tokenizer = nlp.Defaults.create_tokenizer(nlp)
neuralcoref.add_to_pipe(nlp)

# ex = "hello Meow. Meow is Donald Trump's friend"
ex = """The US's capital is Washington"""
doc = nlp(ex)
print(doc._.coref_resolved)
doc = nlp(doc._.coref_resolved)

In [None]:
print([(X.text, X.label_) for X in doc.ents])

In [None]:
list(map(str, doc))

## using entities detected in the query sentence

In [None]:
from itertools import combinations
nlp = spacy.load("en_core_web_sm")
neuralcoref.add_to_pipe(nlp)

def spacy_tokenize(sentence):
    doc = nlp(sentence)
    return list(map(str, nlp(doc._.coref_resolved)))

def get_head_tail_pairs(sentence):
    acceptable_entity_types = ['PERSON', 'NORP', 'ORG', 'GPE', 'PRODUCT', 'EVENT', 'LAW', 'LOC', 'FAC']
    doc = nlp(sentence)
    doc = nlp(doc._.coref_resolved)
    entity_info = [(X.text, X.label_) for X in doc.ents]
    entity_info = set(map(lambda x:x[0], filter(lambda x:x[1] in acceptable_entity_types, entity_info)))

    return combinations(entity_info, 2)
    

max_length = 128   #max length in terms of the number of characters - by default it was 128, seems to work with longer lengths also though.
# the actual length of the sentence doesn't matter, only the number of bert tokens which are created. and this is managed automatically. 

for q in queries:
    fusion_set = {'word': [], 'mask': [], 'seg': []}
    tokens = spacy_tokenize(q['sentence'])
    
    
    for head, tail in get_head_tail_pairs(q['sentence']):  #iterating through all possible combinations of 2 named entities
        tokenized_head = spacy_tokenize(head)
        tokenized_tail = spacy_tokenize(tail)
        head_indices = list(range(tokens.index(tokenized_head[0]), tokens.index(tokenized_head[0])+len(tokenized_head)))   
        tail_indices = list(range(tokens.index(tokenized_tail[0]), tokens.index(tokenized_tail[0])+len(tokenized_tail)))
        bert_query_tokens = bert_tokenize(tokens, head_indices, tail_indices)
        for relation in example_relation_data:
            for ex in relation['examples']:
                tokens = spacy_tokenize(ex['sentence'])
                tokenized_head = spacy_tokenize(ex['head'])  #head and tail spelling and punctuation should match the corefered output exactly
                tokenized_tail = spacy_tokenize(ex['tail'])
                head_indices = list(range(tokens.index(tokenized_head[0]), tokens.index(tokenized_head[0])+len(tokenized_head)))
                tail_indices = list(range(tokens.index(tokenized_tail[0]), tokens.index(tokenized_tail[0])+len(tokenized_tail)))
                bert_relation_example_tokens = bert_tokenize(tokens, head_indices, tail_indices)

                SEP = sentence_encoder.tokenizer.convert_tokens_to_ids(['[SEP]'])
                CLS = sentence_encoder.tokenizer.convert_tokens_to_ids(['[CLS]'])
                word_tensor = torch.zeros((max_length)).long()

                new_word = CLS + bert_relation_example_tokens + SEP + bert_query_tokens + SEP
                for i in range(min(max_length, len(new_word))):
                    word_tensor[i] = new_word[i]
                mask_tensor = torch.zeros((max_length)).long()
                mask_tensor[:min(max_length, len(new_word))] = 1
                seg_tensor = torch.ones((max_length)).long()
                seg_tensor[:min(max_length, len(bert_relation_example_tokens) + 1)] = 0
                fusion_set['word'].append(word_tensor)
                fusion_set['mask'].append(mask_tensor)
                fusion_set['seg'].append(seg_tensor)

        fusion_set['word'] = torch.stack(fusion_set['word'])
        fusion_set['seg'] = torch.stack(fusion_set['seg'])
        fusion_set['mask'] = torch.stack(fusion_set['mask'])
        
        if torch.cuda.is_available():
            fusion_set['word'] = fusion_set['word'].cuda()
            fusion_set['seg'] = fusion_set['seg'].cuda()
            fusion_set['mask'] = fusion_set['mask'].cuda()
        
        logits, pred = model(fusion_set, N, K, 1)
        print('Sentence: \"{}\", head: \"{}\", tail: \"{}\", prediction: {}'.format(q['sentence'], head, tail, example_relation_data[pred.item()]['name']))   #TODO: handle na case, which would be out of bounds
    
    
            
            

In [None]:
from fyp_detection_functions import Detector

In [None]:
d = Detector()
d.run_on_sample_data()