In [12]:
import os
import json
import time

import numpy as np
import pandas as pd
import torch

### KnowledgeGraph framework

In [2]:
from abc import ABC, abstractmethod

class KnowledgeGraph(ABC):
    def __init__(self, name=None, verbose = 0):
        self._name = name
        self._verbose = verbose
        
        self._entities = [] # list(string)
        self._relations = [] # list(string)
        # np.array([(head_entity, relation, tail_entity)])
        self._triples = np.zeros(shape=(0,3))
        
        self._built = False
        
    ####### PUBLIC #######
    @property
    def name(self):
        return self._name
    
    @property
    def entities(self):
        return self._entities
    
    @property
    def relations(self):
        return self._relations
    
    @property
    def triples(self):
        return self._triples
    
    def sample(self, k=1, negative=False):
        if negative:
            return self._sample_negative_loose(k)
        else:
            return self._sample_positive(k)
        
    def sample_english(self, k=1, negative=False):
        samples = self.sample(k, negative)
        
        english_samples = []
        for sample in samples:
            head_idx, relation_idx, tail_idx = sample
            head_id, relation_id, tail_id = self._entities[head_idx], self._relations[relation_idx], self._entities[tail_idx]
            head, relation, tail = self._id2entity(head_id), self._id2relation(relation_id), self._id2entity(tail_id)
            english_samples.append(relation.replace("{HEAD}", head).replace("{TAIL}", tail))
            
        return english_samples
            
        
    ####### PRIVATE #######
    
    @abstractmethod
    def _id2entity(self, eid):
        """
        A function that maps an entity id (eid) stored in the
        self._entities structure to an english identifier
        and/or description.
        """
        
    @abstractmethod
    def _id2relation(self, rid):
        """
        A function that maps an relation id (rid) stored in
        the self._relations structure to an english identifier
        and/or description.
        """
    
    @abstractmethod
    def _build_graph(self):
        """
        A function that builds the graph by reading in the data in
        its current format and populating self._entities, self._relations,
        self._triples, and at the end should set self._built to True.
        """
        pass
    
    @property
    def _is_built(self):
        return self._built
    
    @property
    def _num_entities(self):
        return len(self._entities)
    
    @property
    def _num_relations(self):
        return len(self._relations)
    
    @property
    def _num_triples(self):
        return self._triples.shape[0]
    
    def _validate_graph(self):
        # Make sure properties are filled out
        assert self._built, "The graph is not built. Please build " \
        "or check that your build_graph method sets self._build " \
        "to True after completion"
        
        # Make sure shape of self._triples is [N, 3]
        assert self._triples.shape[1] == 3, "The _triples property" \
        "must have a shape of 3 in the second dimension. " \
        
        # Make sure all head, tail entities and relations are valid
        head_entities = self._triples[:,0]
        assert head_entities.max() <= len(self._entities), "There" \
        "exists an entity in the head entities of the _triples " \
        "property that exceeds the number of available entities." \
        
        tail_entities = self._triples[:,2]
        assert tail_entities.max() <= len(self._entities), "There " \
        "exists an entity in the tail entities of the _triples " \
        "property that exceeds the number of available entities." \
        
        relations = self._triples[:,1]
        assert relations.max() <= len(self._relations), "There " \
        "exists an relations in the _triples " \
        "property that exceeds the number of available relations." \
        
        for eid in self._entities:
            assert self._id2entity(eid), f"One of the entities ({eid}) " \
            "has no mapping."
            
        for rid in self._relations:
            assert self._id2relation(rid), f"One of the relations ({rid}) " \
            "has no mapping."
            
        assert self.sample(10).shape == (10, 3), "Sampling yields the " \
            "wrong shape"
        
        assert self.sample(10, negative=True).shape == (10, 3), "Sampling " \
            "yields the wrong shape"
        
        if self._verbose >= 1:
            print("Graph was successfully validated!")
        
    def _sample_positive(self, k):
        triple_indices = np.random.choice(self._num_triples, k)
        positive_samples = self._triples[triple_indices]
        
        return positive_samples
    
    def _sample_negative_loose(self, k):
        # TODO(frg100): Make a strict version that makes sure not to
        # add existing triples
        head_entities = np.expand_dims(np.random.choice(self._num_entities, k), 0)
        relations = np.expand_dims(np.random.choice(self._num_relations, k), 0)
        tail_entities = np.expand_dims(np.random.choice(self._num_entities, k), 0)
        
        negative_samples = np.concatenate([head_entities, relations, tail_entities], axis=0).T
        
        return negative_samples
    
    def _load_json_mapping(self, json_path):
        # Load the map
        with open(json_path) as json_file:
            return json.load(json_file)
    
    

In [3]:
class FB15k237(KnowledgeGraph):
    def __init__(self, base_path=None, splits=['train', 'test', 'valid'], verbose = 0):
        super().__init__(name='FB15k-237', verbose = verbose)
        
        self._base_path = base_path
        self._splits = splits
        
        self._entity_mapping = None
        
        start = time.time()
        self._build_graph(verbose)
        end = time.time()
        if verbose >= 1:
            print(f"Building the graph took {round(end-start)} seconds")    
        
            
    def _id2entity(self, eid):
        if self._entity_mapping is None:
            assert False, "Entity mapping must be populated"
            
        if eid not in self._entity_mapping:
            #print(f"Entity with id ({eid}) is not mapped...")
            return None
            
        return self._entity_mapping[eid]['label']
    
    def _id2relation(self, rid):
        if self._relation_mapping is None:
            assert False, "Relation mapping must be populated"
            
        if rid not in self._relation_mapping:
            #print(f"Relation with id ({rid}) is not mapped...")
            return None
            
        return self._relation_mapping[rid]

    def _build_graph(self, verbose):
        # Load the mappings
        id2entity_path = os.path.join(self._base_path, "entity2wikidata.json")
        self._entity_mapping = self._load_json_mapping(id2entity_path)
        id2relation_path = os.path.join(self._base_path, "relation_mapping.json")
        self._relation_mapping = self._load_json_mapping(id2relation_path)
        
        # Initialize data structures for bookkeeping
        entities = set()
        relations = set()
        triples = set()

        num_data_points = sum(sum(1 for line in open(os.path.join(self._base_path, f"{split}.txt"))) for split in self._splits)
        
        # Load data
        for split in self._splits:
            path = os.path.join(self._base_path, f"{split}.txt")
            if verbose >= 1:
                print(f"Loading file {split}.txt")
                
            # Process into entities, relations, and triples
            with open(path, 'r') as f:
                for line in f:
                    # Check progress
                    last_percent_done = round((100*(self._num_triples-1))/num_data_points)
                    percent_done = round((100*self._num_triples)/num_data_points)
                    if verbose >= 2 and percent_done % 5 == 0 and last_percent_done % 5 != 0:
                        print(f"Data loading progress: [{percent_done}%]")
                    
                    # Initialize data
                    head, relation, tail = line.split()
                    head_id, relation_id, tail_id = None, None, None
                    
                    # If either of the entities has no natural language translation,
                    if not self._id2entity(head) or not self._id2entity(tail):
                        # Don't process it
                        continue
                    
                    if verbose >= 3 and percent_done % 5 == 0 and last_percent_done % 5 != 0:
                        print(f"{self._id2entity(head)} {relation} {self._id2entity(tail)}")
                    
                    # Process head
                    if head not in entities:
                        entities.add(head)
                        head_id = len(self._entities)
                        self._entities.append(head)
                    else:
                        head_id = self._entities.index(head)
                     
                    # Process tail
                    if tail not in entities:
                        entities.add(tail)
                        tail_id = len(self._entities)
                        self._entities.append(tail)
                    else:
                        tail_id = self._entities.index(tail)
                        
                    # Process relation
                    if relation not in relations:
                        relations.add(relation)
                        relation_id = len(self._relations)
                        self._relations.append(relation)
                    else:
                        relation_id = self._relations.index(relation)

                    # Create and add triple
                    triple = np.array([[head_id, relation_id, tail_id]], dtype=np.int32)  
                    if self._num_triples == 0:
                        self._triples = triple
                    else:
                        self._triples = np.append(self._triples, triple, axis=0)
                        
        # Build and validate
        self._built = True
        self._validate_graph()

### Modeling Framework

In [47]:
import transformers
from transformers import BertModel, BertTokenizer
from happytransformer import HappyTextToText, TTSettings

class LargeLanguageModel(ABC):
    def __init__(self, name=None, verbose = 0):
        self._name = name
        self._verbose = verbose
        
        self._grammar_model = HappyTextToText("T5", "vennify/t5-base-grammar-correction")
        self._grammar_model_args = TTSettings(num_beams=5, min_length=1)
        
        self._built = False
        
    ####### PUBLIC #######
    @property
    def name(self):
        return self._name
    
    def correct_grammar(self, sentence):
        return self._grammar_model.generate_text(f"grammar: {sentence}", args=self._grammar_model_args).text
    
    ####### PRIVATE #######
    
    @abstractmethod
    def batch_perplexity(self, eid):
        """
        A function that calculates a batch perplexity for a set of
        samples.
        """
        
    @property
    def _is_built(self):
        return self._built

In [54]:
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

class GPT2(LargeLanguageModel):
    def __init__(self, model_id="gpt2", verbose = 0):
        super().__init__(name='GPT2', verbose = verbose)
        
        self._model_id = model_id
        self._model = GPT2LMHeadModel.from_pretrained(self._model_id)
        self._tokenizer = GPT2TokenizerFast.from_pretrained(self._model_id)
        
        self._verbose = verbose
        
        self._built = True
        
    def batch_perplexity(self, samples):
        if self._verbose >= 1:
            print(f"[{self._model_id}] Calculating perplexity for {len(samples)} samples")
        perplexities = []
        for sample in samples:
            sample = self.correct_grammar(sample)
            encoding = self._tokenizer(sample, return_tensors="pt")
            num_tokens = encoding.input_ids.shape[1]

            nlls = []
            for end_loc in range(1, num_tokens):
                input_ids = encoding.input_ids[:, 0:end_loc]
                target_ids = input_ids.clone()

                with torch.no_grad():
                    outputs = self._model(input_ids, labels=target_ids)
                    neg_log_likelihood = outputs[0] * end_loc

            nlls.append(neg_log_likelihood)

            perplexity = torch.exp(torch.stack(nlls).sum() / end_loc)
            if self._verbose >= 2:
                print(f"[{self._model_id}] Sample <{sample}> has perplexity [{perplexity}]")
            perplexities.append(perplexity)

        if self._verbose >= 1:
            print(f"[{self._model_id}] Final average perplexity: {sum(perplexities)/len(perplexities)}")

        return perplexities 

In [49]:
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

class BERT(LargeLanguageModel):
    def __init__(self, verbose = 0):
        super().__init__(name='BERT', verbose = verbose)
        
        self._name = name
        
        self._model_id = 'bert-base-cased'
        self._tokenizer = BertTokenizer.from_pretrained(self._model_id)
        self._model = BertModel.from_pretrained(self._model_id)
        
        self._verbose = verbose
        
        self._built = True
        
    def batch_perplexity(self, samples):
        if self._verbose >= 1:
            print(f"[{self._name}] Calculating perplexity for {len(samples)} samples")
        perplexities = []
        for sample in samples:
            sample = self.correct_grammar(sample)
            encoding = self._tokenizer(sample, return_tensors="pt")
            num_tokens = encoding.input_ids.shape[1]

            nlls = []
            for end_loc in range(1, num_tokens):
                input_ids = encoding.input_ids[:, 0:end_loc]
                target_ids = input_ids.clone()

                with torch.no_grad():
                    outputs = self._model(input_ids, labels=target_ids)
                    neg_log_likelihood = outputs[0] * end_loc

            nlls.append(neg_log_likelihood)

            perplexity = torch.exp(torch.stack(nlls).sum() / end_loc)
            if self._verbose >= 2:
                print(f"[{self._name}] Sample <{sample}> has perplexity [{perplexity}]")
            perplexities.append(perplexity)

        if self._verbose >= 1:
            print(f"[{self._name}] Final average perplexity: {sum(perplexities)/len(perplexities)}")

        return perplexities 

In [50]:
#graph = FB15k237(base_path='./data/FB15k-237', splits=['train', 'valid','test'], verbose=2)

In [55]:
model_gpt2 = GPT2(model_id="gpt2", verbose=2)
model_gpt2_large = GPT2(model_id="gpt2-large", verbose=2)
#model_bert = BERT(verbose=2)

05/24/2022 12:55:01 - INFO - happytransformer.happy_transformer -   Using model: cpu
05/24/2022 12:55:17 - INFO - happytransformer.happy_transformer -   Using model: cpu


In [57]:
samples = graph.sample_english(100)
gpt2_pos_perplexity = model_gpt2.batch_perplexity(samples)
gpt2_large_pos_perplexity = model_gpt2_large.batch_perplexity(samples)

[gpt2] Calculating perplexity for 100 samples
[gpt2] Sample <The GNI per capita in PPP dollars of Mauritania is measured in United States dollars.> has perplexity [104.25950622558594]
[gpt2] Sample <The Academy Award for Best Animated Feature was given at the 85th Academy Awards.> has perplexity [16.669166564941406]
[gpt2] Sample <The Kids Are All Right was nominated for the Golden Globe Award for Best Screenplay.> has perplexity [28.549013137817383]
[gpt2] Sample <Rory Cochrane is Scottish American.> has perplexity [636.8055419921875]
[gpt2] Sample <The movie The Birds was released in West Germany.> has perplexity [136.253662109375]
[gpt2] Sample <There is a goalkeeper on the roster of the S.C. Bastia.> has perplexity [45.30856704711914]
[gpt2] Sample <Elizabeth McGovern has been nominated for an award along with Maggie Smith.> has perplexity [58.49312210083008]
[gpt2] Sample <The genre of the TV show "Girls" is black comedy.> has perplexity [47.039276123046875]
[gpt2] Sample <The pro

[gpt2] Sample <Patty Loveless has been nominated for an award along with Vince Gill.> has perplexity [45.83945846557617]
[gpt2] Sample <Julia was nominated for the National Society of Film Critics Award for Best Actress.> has perplexity [23.54741668701172]
[gpt2] Sample <Ellen Burstyn practices the religion of Catholicism.> has perplexity [308.926025390625]
[gpt2] Sample <The profession of Wes Anderson is screenwriter.> has perplexity [1083.2369384765625]
[gpt2] Sample <There is a significant population in Nebraska that practices United Church of Christ.> has perplexity [68.44824981689453]
[gpt2] Sample <The fiddle and electric guitar are often part of the same performance.> has perplexity [58.51577377319336]
[gpt2] Sample <Coming Home was released as a DVD.> has perplexity [82.7439193725586]
[gpt2] Sample <Gustavo Cerati is in a marriage union.> has perplexity [257.96429443359375]
[gpt2] Sample <The movie Bad Education was released in Finland.> has perplexity [376.0457763671875]
[gpt2

[gpt2-large] Sample <Mumbai is located in India.> has perplexity [118.16363525390625]
[gpt2-large] Sample <20th Century Fox distributed the film Barton Fink.> has perplexity [39.54825210571289]
[gpt2-large] Sample <Laurie Metcalf has been nominated for an award along with Marcia Cross.> has perplexity [24.602293014526367]
[gpt2-large] Sample <The drum and Bouzouki are often part of the same performance.> has perplexity [57.37322235107422]
[gpt2-large] Sample <Steven Tyler was born in the United States of America.> has perplexity [14.699027061462402]
[gpt2-large] Sample <Crossroads was nominated for the Golden Raspberry Award for Worst Original Song.> has perplexity [29.125059127807617]
[gpt2-large] Sample <The profession of Eugene Levy is actor.> has perplexity [2305.234130859375]
[gpt2-large] Sample <Foo Fighters has been nominated for an award along with Green Day.> has perplexity [39.339393615722656]
[gpt2-large] Sample <Eddie Murphy won an award for their work, Showtime.> has perpl

In [None]:
negative_samples = graph.sample_english(100, negative=True)
gpt2_neg_perplexity = model_gpt2.batch_perplexity(negative_samples)
gpt2_large_neg_perplexity = model_gpt2_large.batch_perplexity(negative_samples)

[gpt2] Calculating perplexity for 100 samples
[gpt2] Sample <Saskatoon was a peer of Gods and Monsters.> has perplexity [514.7250366210938]
[gpt2] Sample <One of the sports at the Nelson Riddle was Portland State University.> has perplexity [187.6018829345703]
[gpt2] Sample <The release region of the film "Minnesota Golden Gophers men's basketball" is Patricia Arquette.> has perplexity [101.3570556640625]
[gpt2] Sample <Home Alone is a Manmohan Krishna.> has perplexity [997.503173828125]
[gpt2] Sample <Ecuador impersonated Marie Antoinette.> has perplexity [334.19293212890625]
[gpt2] Sample <The administrative parent of A.C. ChievoVerona is lounge music.> has perplexity [661.1280517578125]
[gpt2] Sample <Michael Redgrave wrote the music for the movie "abdominal pain".> has perplexity [124.79692840576172]


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

def compare_models(samples, model_a, perplexities_a, model_b, perplexities_b):
    f, ax = plt.subplots(figsize=(20,10))
    ax.set(yscale="log")

    model_one_ppx = np.array([ppx.item() for ppx in perplexities_a])
    model_two_ppx = np.array([ppx.item() for ppx in perplexities_b])

    sort_order = np.argsort(model_one_ppx).astype(int)

    sns.lineplot(x=samples, y=model_one_ppx[sort_order], label=model_a)
    sns.lineplot(x=samples, y=model_two_ppx[sort_order], label=model_b)

    plt.xticks(rotation=90)

    plt.legend()
    plt.ylabel('Perplexity')
    plt.title("Comparing perplexity on sample sentences between two models")

    plt.show()

compare_models(samples, 'GPT2', gpt2_pos_perplexity, "GPT2-Large", gpt2_large_pos_perplexity)
compare_models(negative_samples, 'GPT2', gpt2_neg_perplexity, "GPT2-Large", gpt2_large_neg_perplexity)

### Random Experiments

In [None]:
weights_name = 'bert-base-cased'
bert_tokenizer = BertTokenizer.from_pretrained(weights_name)
bert_model = BertModel.from_pretrained(weights_name)

In [None]:
bert_tokenizer.batch_encode_plus(
    graph.sample_english(10),
    add_special_tokens=True,
    return_attention_mask=True,
    padding='longest')

In [None]:
def batch_perplexity(samples, verbose=0):
    print(f"Calculating perplexity for {len(samples)} samples")
    perplexities = []
    for sample in samples:
        encoding = tokenizer(sample, return_tensors="pt")
        num_tokens = encoding.input_ids.shape[1]

        nlls = []
        for end_loc in range(1, num_tokens):
            input_ids = encoding.input_ids[:, 0:end_loc]
            target_ids = input_ids.clone()

            with torch.no_grad():
                outputs = model(input_ids, labels=target_ids)
                neg_log_likelihood = outputs[0] * end_loc

        nlls.append(neg_log_likelihood)

        perplexity = torch.exp(torch.stack(nlls).sum() / end_loc)
        if verbose >= 2:
            print(f"Sample <{sample}> has perplexity [{perplexity}]")
        perplexities.append(perplexity)

    if verbose >= 1:
        print(f"Final average perplexity: {sum(perplexities)/len(perplexities)}")
        
    return perplexities

batch_perplexity(graph.sample_english(10, negative=False), verbose=2)

In [None]:
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

model_id = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_id)
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)

encodings = tokenizer("\n\n".join(graph.sample_english(10)), return_tensors="pt")

import torch
from tqdm import tqdm

max_length = model.config.n_positions
stride = 1

nlls = []
for i in tqdm(range(0, encodings.input_ids.size(1), stride)):
    begin_loc = max(i + stride - max_length, 0)
    end_loc = min(i + stride, encodings.input_ids.size(1))
    trg_len = end_loc - i  # may be different from stride on last loop
    input_ids = encodings.input_ids[:, begin_loc:end_loc]
    target_ids = input_ids.clone()
    target_ids[:, :-trg_len] = -100

    with torch.no_grad():
        print(begin_loc, end_loc)
        outputs = model(input_ids, labels=target_ids)
        neg_log_likelihood = outputs[0] * trg_len

    nlls.append(neg_log_likelihood)

ppl = torch.exp(torch.stack(nlls).sum() / end_loc)

In [None]:
ppl

### Running the code

In [None]:
graph = FB15k237(base_path='./data/FB15k-237', splits=['train', 'valid','test'], verbose=2)

In [None]:
graph.sample_english(10, negative=False)

### Random Experiments

In [None]:
# base_path='./data/FB15k-237'

# mapping = {}

# # Load file if it exists
# json_path = os.path.join(base_path, "relation_mapping.json")

# if os.path.exists(json_path):
#     json_file_read = open(json_path, 'r')
#     mapping = json.load(json_file_read)
#     json_file_read.close()


# # for rel in graph.relations:
# #     if rel not in mapping:
# #         relations_done = len(mapping.keys())
# #         relations_total = len(graph.relations)
# #         relations_left = relations_total - relations_done
# #         print(f"[{round(100*(relations_done/relations_total), 2)}%] done describing relations ({relations_left} left)")
        
# #         instance_idx = np.where(graph.triples[:, 1] == graph.relations.index(rel))[0][0]
# #         head, relation, tail = graph.triples[instance_idx]
# #         head, tail = graph._mid2entity(graph.entities[head]), graph._mid2entity(graph.entities[tail])
        
# #         format_string = input(f"{head} {rel} {tail}: ")
# #         mapping[rel] = format_string
# #         json_file_write = open(json_path, 'w')
# #         json.dump(mapping, json_file_write)
# #         json_file_write.close()
        
    

In [None]:
# Generate the dataset using the functions


In [None]:
# def see_relation_examples(relation, k = 1):
#     instance_indices = np.where(graph.triples[:, 1] == graph.relations.index(rel))[0][:k]
#     samples = graph.triples[instance_indices]
#     for sample in samples:
#         h, r, t = sample
#         print(graph._id2entity(graph.entities[h]), graph.relations[r], graph._id2entity(graph.entities[t]))
    
# see_relation_examples('/award/award_winner/awards_won./award/award_honor/award_winner', k=5)

In [None]:
len(graph.relations)