In [90]:
import os
import json
import time

import numpy as np
import pandas as pd

### KnowledgeGraph framework

In [91]:
from abc import ABC, abstractmethod

class KnowledgeGraph(ABC):
    def __init__(self, name=None, verbose = 0):
        self._name = name
        self._verbose = verbose
        
        self._entities = [] # list(string)
        self._relations = [] # list(string)
        # np.array([(head_entity, relation, tail_entity)])
        self._triples = np.zeros(shape=(0,3))
        
        self._built = False
        
    ####### PUBLIC #######
    @property
    def name(self):
        return self._name
    
    @property
    def entities(self):
        return self._entities
    
    @property
    def relations(self):
        return self._relations
    
    @property
    def triples(self):
        return self._triples
    
    def sample(self, k=1, negative=False):
        if negative:
            return self._sample_negative_loose(k)
        else:
            return self._sample_positive(k)
        
    def sample_english(self, k=1, negative=False):
        samples = self.sample(k, negative)
        
        english_samples = []
        for sample in samples:
            head_idx, relation_idx, tail_idx = sample
            head_id, relation_id, tail_id = self._entities[head_idx], self._relations[relation_idx], self._entities[tail_idx]
            head, relation, tail = self._id2entity(head_id), self._id2relation(relation_id), self._id2entity(tail_id)
            english_samples.append(relation.replace("{HEAD}", head).replace("{TAIL}", tail))
            
        return english_samples
            
        
    ####### PRIVATE #######
    
    @abstractmethod
    def _id2entity(self, eid):
        """
        A function that maps an entity id (eid) stored in the
        self._entities structure to an english identifier
        and/or description.
        """
        
    @abstractmethod
    def _id2relation(self, rid):
        """
        A function that maps an relation id (rid) stored in
        the self._relations structure to an english identifier
        and/or description.
        """
    
    @abstractmethod
    def _build_graph(self):
        """
        A function that builds the graph by reading in the data in
        its current format and populating self._entities, self._relations,
        self._triples, and at the end should set self._built to True.
        """
        pass
    
    @property
    def _is_built(self):
        return self._built
    
    @property
    def _num_entities(self):
        return len(self._entities)
    
    @property
    def _num_relations(self):
        return len(self._relations)
    
    @property
    def _num_triples(self):
        return self._triples.shape[0]
    
    def _validate_graph(self):
        # Make sure properties are filled out
        assert self._built, "The graph is not built. Please build " \
        "or check that your build_graph method sets self._build " \
        "to True after completion"
        
        # Make sure shape of self._triples is [N, 3]
        assert self._triples.shape[1] == 3, "The _triples property" \
        "must have a shape of 3 in the second dimension. " \
        
        # Make sure all head, tail entities and relations are valid
        head_entities = self._triples[:,0]
        assert head_entities.max() <= len(self._entities), "There" \
        "exists an entity in the head entities of the _triples " \
        "property that exceeds the number of available entities." \
        
        tail_entities = self._triples[:,2]
        assert tail_entities.max() <= len(self._entities), "There " \
        "exists an entity in the tail entities of the _triples " \
        "property that exceeds the number of available entities." \
        
        relations = self._triples[:,1]
        assert relations.max() <= len(self._relations), "There " \
        "exists an relations in the _triples " \
        "property that exceeds the number of available relations." \
        
        for eid in self._entities:
            assert self._id2entity(eid), f"One of the entities ({eid}) " \
            "has no mapping."
            
        for rid in self._relations:
            assert self._id2relation(rid), f"One of the relations ({rid}) " \
            "has no mapping."
            
        assert self.sample(10).shape == (10, 3), "Sampling yields the " \
            "wrong shape"
        
        assert self.sample(10, negative=True).shape == (10, 3), "Sampling " \
            "yields the wrong shape"
        
        if self._verbose >= 1:
            print("Graph was successfully validated!")
        
    def _sample_positive(self, k):
        triple_indices = np.random.choice(self._num_triples, k)
        positive_samples = self._triples[triple_indices]
        
        return positive_samples
    
    def _sample_negative_loose(self, k):
        # TODO(frg100): Make a strict version that makes sure not to
        # add existing triples
        head_entities = np.expand_dims(np.random.choice(self._num_entities, k), 0)
        relations = np.expand_dims(np.random.choice(self._num_relations, k), 0)
        tail_entities = np.expand_dims(np.random.choice(self._num_entities, k), 0)
        
        negative_samples = np.concatenate([head_entities, relations, tail_entities], axis=0).T
        
        return negative_samples
    
    def _load_json_mapping(self, json_path):
        # Load the map
        with open(json_path) as json_file:
            return json.load(json_file)
    
    

In [92]:
class FB15k237(KnowledgeGraph):
    def __init__(self, base_path=None, splits=['train', 'test', 'valid'], verbose = 0):
        super().__init__(name='FB15k-237', verbose = verbose)
        
        self._base_path = base_path
        self._splits = splits
        
        self._entity_mapping = None
        
        start = time.time()
        self._build_graph(verbose)
        end = time.time()
        if verbose >= 1:
            print(f"Building the graph took {round(end-start)} seconds")    
        
            
    def _id2entity(self, eid):
        if self._entity_mapping is None:
            assert False, "Entity mapping must be populated"
            
        if eid not in self._entity_mapping:
            #print(f"Entity with id ({eid}) is not mapped...")
            return None
            
        return self._entity_mapping[eid]['label']
    
    def _id2relation(self, rid):
        if self._relation_mapping is None:
            assert False, "Relation mapping must be populated"
            
        if rid not in self._relation_mapping:
            #print(f"Relation with id ({rid}) is not mapped...")
            return None
            
        return self._relation_mapping[rid]

    def _build_graph(self, verbose):
        # Load the mappings
        id2entity_path = os.path.join(self._base_path, "entity2wikidata.json")
        self._entity_mapping = self._load_json_mapping(id2entity_path)
        id2relation_path = os.path.join(self._base_path, "relation_mapping.json")
        self._relation_mapping = self._load_json_mapping(id2relation_path)
        
        # Initialize data structures for bookkeeping
        entities = set()
        relations = set()
        triples = set()

        num_data_points = sum(sum(1 for line in open(os.path.join(self._base_path, f"{split}.txt"))) for split in self._splits)
        
        # Load data
        for split in self._splits:
            path = os.path.join(self._base_path, f"{split}.txt")
            if verbose >= 1:
                print(f"Loading file {split}.txt")
                
            # Process into entities, relations, and triples
            with open(path, 'r') as f:
                for line in f:
                    # Check progress
                    last_percent_done = round((100*(self._num_triples-1))/num_data_points)
                    percent_done = round((100*self._num_triples)/num_data_points)
                    if verbose >= 2 and percent_done % 5 == 0 and last_percent_done % 5 != 0:
                        print(f"Data loading progress: [{percent_done}%]")
                    
                    # Initialize data
                    head, relation, tail = line.split()
                    head_id, relation_id, tail_id = None, None, None
                    
                    # If either of the entities has no natural language translation,
                    if not self._id2entity(head) or not self._id2entity(tail):
                        # Don't process it
                        continue
                    
                    if verbose >= 3 and percent_done % 5 == 0 and last_percent_done % 5 != 0:
                        print(f"{self._id2entity(head)} {relation} {self._id2entity(tail)}")
                    
                    # Process head
                    if head not in entities:
                        entities.add(head)
                        head_id = len(self._entities)
                        self._entities.append(head)
                    else:
                        head_id = self._entities.index(head)
                     
                    # Process tail
                    if tail not in entities:
                        entities.add(tail)
                        tail_id = len(self._entities)
                        self._entities.append(tail)
                    else:
                        tail_id = self._entities.index(tail)
                        
                    # Process relation
                    if relation not in relations:
                        relations.add(relation)
                        relation_id = len(self._relations)
                        self._relations.append(relation)
                    else:
                        relation_id = self._relations.index(relation)

                    # Create and add triple
                    triple = np.array([[head_id, relation_id, tail_id]], dtype=np.int32)  
                    if self._num_triples == 0:
                        self._triples = triple
                    else:
                        self._triples = np.append(self._triples, triple, axis=0)
                        
        # Build and validate
        self._built = True
        self._validate_graph()

### Modeling Framework

In [131]:
import transformers
from transformers import BertModel, BertTokenizer

class LargeLanguageModel(ABC):
    def __init__(self, name=None, verbose = 0):
        self._name = name
        self._verbose = verbose
        
        self._built = False
        
    ####### PUBLIC #######
    @property
    def name(self):
        return self._name
    
    ####### PRIVATE #######
    
    @abstractmethod
    def batch_perplexity(self, eid):
        """
        A function that calculates a batch perplexity for a set of
        samples.
        """
        
    @property
    def _is_built(self):
        return self._built

In [132]:
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

class GPT2(LargeLanguageModel):
    def __init__(self, verbose = 0):
        super().__init__(name='GPT2', verbose = verbose)
        
        self._model_id = "gpt2"
        self._model = GPT2LMHeadModel.from_pretrained(model_id)
        self._tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
        
        self._verbose = verbose
        
    def batch_perplexity(self, samples):
        if self._verbose >= 1:
            print(f"Calculating perplexity for {len(samples)} samples")
        perplexities = []
        for sample in samples:
            encoding = self._tokenizer(sample, return_tensors="pt")
            num_tokens = encoding.input_ids.shape[1]

            nlls = []
            for end_loc in range(1, num_tokens):
                input_ids = encoding.input_ids[:, 0:end_loc]
                target_ids = input_ids.clone()

                with torch.no_grad():
                    outputs = self._model(input_ids, labels=target_ids)
                    neg_log_likelihood = outputs[0] * end_loc

            nlls.append(neg_log_likelihood)

            perplexity = torch.exp(torch.stack(nlls).sum() / end_loc)
            if self._verbose >= 2:
                print(f"Sample <{sample}> has perplexity [{perplexity}]")
            perplexities.append(perplexity)

        if self._verbose >= 1:
            print(f"Final average perplexity: {sum(perplexities)/len(perplexities)}")

        return perplexities 

In [133]:
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

class BERT(LargeLanguageModel):
    def __init__(self, verbose = 0):
        super().__init__(name='BERT', verbose = verbose)
        
        self._model_id = 'bert-base-cased'
        self._model = BertTokenizer.from_pretrained(self._model_id)
        self._tokenizer = BertModel.from_pretrained(self._model_id)
        
        self._verbose = verbose
        
    def batch_perplexity(self, samples):
        if self._verbose >= 1:
            print(f"Calculating perplexity for {len(samples)} samples")
        perplexities = []
        for sample in samples:
            encoding = self._tokenizer(sample, return_tensors="pt")
            num_tokens = encoding.input_ids.shape[1]

            nlls = []
            for end_loc in range(1, num_tokens):
                input_ids = encoding.input_ids[:, 0:end_loc]
                target_ids = input_ids.clone()

                with torch.no_grad():
                    outputs = self._model(input_ids, labels=target_ids)
                    neg_log_likelihood = outputs[0] * end_loc

            nlls.append(neg_log_likelihood)

            perplexity = torch.exp(torch.stack(nlls).sum() / end_loc)
            if self._verbose >= 2:
                print(f"Sample <{sample}> has perplexity [{perplexity}]")
            perplexities.append(perplexity)

        if self._verbose >= 1:
            print(f"Final average perplexity: {sum(perplexities)/len(perplexities)}")

        return perplexities 

In [134]:
graph = FB15k237(base_path='./data/FB15k-237', splits=['train', 'valid','test'], verbose=2)
model_gpt2 = GPT2(verbose=2)
model_bert = BERT(verbose=2)

Loading file train.txt
Data loading progress: [5%]
Data loading progress: [10%]


KeyboardInterrupt: 

### Random Experiments

In [97]:
weights_name = 'bert-base-cased'
bert_tokenizer = BertTokenizer.from_pretrained(weights_name)
bert_model = BertModel.from_pretrained(weights_name)

Downloading:   0%|          | 0.00/208k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/416M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [99]:
bert_tokenizer.batch_encode_plus(
    graph.sample_english(10),
    add_special_tokens=True,
    return_attention_mask=True,
    padding='longest')

{'input_ids': [[101, 2290, 112, 188, 1390, 1110, 3618, 2067, 102, 0, 0, 0, 0, 0, 0, 0, 0], [101, 16745, 1110, 1388, 1107, 1103, 1970, 1735, 2614, 102, 0, 0, 0, 0, 0, 0, 0], [101, 11028, 27238, 1108, 3639, 1111, 1103, 2127, 1698, 1111, 1798, 12440, 7617, 2574, 102, 0, 0], [101, 24432, 1110, 170, 21204, 1273, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 1847, 16752, 11088, 1182, 1108, 1255, 1107, 1244, 1311, 1104, 1738, 102, 0, 0, 0, 0], [101, 1109, 3207, 6486, 1104, 2715, 1104, 18959, 5208, 3293, 1110, 19349, 1777, 102, 0, 0, 0], [101, 139, 2858, 5521, 14467, 11656, 1394, 1110, 1388, 1107, 1375, 2201, 102, 0, 0, 0, 0], [101, 1789, 1234, 1104, 18787, 1234, 6585, 2936, 9001, 102, 0, 0, 0, 0, 0, 0, 0], [101, 1109, 2523, 1109, 140, 23156, 1179, 1107, 1103, 8726, 1108, 1308, 1107, 1699, 102, 0, 0], [101, 1109, 8645, 1698, 1111, 1798, 10018, 6429, 2574, 1108, 1549, 1120, 1103, 1949, 8645, 2763, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0

In [124]:
def batch_perplexity(samples, verbose=0):
    print(f"Calculating perplexity for {len(samples)} samples")
    perplexities = []
    for sample in samples:
        encoding = tokenizer(sample, return_tensors="pt")
        num_tokens = encoding.input_ids.shape[1]

        nlls = []
        for end_loc in range(1, num_tokens):
            input_ids = encoding.input_ids[:, 0:end_loc]
            target_ids = input_ids.clone()

            with torch.no_grad():
                outputs = model(input_ids, labels=target_ids)
                neg_log_likelihood = outputs[0] * end_loc

        nlls.append(neg_log_likelihood)

        perplexity = torch.exp(torch.stack(nlls).sum() / end_loc)
        if verbose >= 2:
            print(f"Sample <{sample}> has perplexity [{perplexity}]")
        perplexities.append(perplexity)

    if verbose >= 1:
        print(f"Final average perplexity: {sum(perplexities)/len(perplexities)}")
        
    return perplexities

batch_perplexity(graph.sample_english(10, negative=False), verbose=2)

Calculating perplexity for 10 samples
Sample <Eugene Levy is a graduate of McMaster University> has perplexity [73.97415161132812]
Sample <Joe Morton appears in the film Terminator 2: Judgment Day> has perplexity [61.194366455078125]
Sample <The profession of will.i.am is record producer> has perplexity [262.07037353515625]
Sample <There is a quarterback on the roster of the Dallas Cowboys> has perplexity [38.15388488769531]
Sample <Metro-Goldwyn-Mayer distributed the film The Living Daylights> has perplexity [84.44568634033203]
Sample <Michael McKean won an award along with Harry Shearer> has perplexity [99.71240997314453]
Sample <Norway is located next to Russia> has perplexity [120.02970886230469]
Sample <Jane's Addiction's music is art rock> has perplexity [449.6371154785156]
Sample <University of Maryland, College Park had a draft pick in the 2002 Major League Baseball draft> has perplexity [37.60438919067383]
Sample <2010: The Year We Make Contact is a drama film> has perplexity 

[tensor(73.9742),
 tensor(61.1944),
 tensor(262.0704),
 tensor(38.1539),
 tensor(84.4457),
 tensor(99.7124),
 tensor(120.0297),
 tensor(449.6371),
 tensor(37.6044),
 tensor(207.2831)]

In [104]:
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

model_id = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_id)
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)

encodings = tokenizer("\n\n".join(graph.sample_english(10)), return_tensors="pt")

import torch
from tqdm import tqdm

max_length = model.config.n_positions
stride = 1

nlls = []
for i in tqdm(range(0, encodings.input_ids.size(1), stride)):
    begin_loc = max(i + stride - max_length, 0)
    end_loc = min(i + stride, encodings.input_ids.size(1))
    trg_len = end_loc - i  # may be different from stride on last loop
    input_ids = encodings.input_ids[:, begin_loc:end_loc]
    target_ids = input_ids.clone()
    target_ids[:, :-trg_len] = -100

    with torch.no_grad():
        print(begin_loc, end_loc)
        outputs = model(input_ids, labels=target_ids)
        neg_log_likelihood = outputs[0] * trg_len

    nlls.append(neg_log_likelihood)

ppl = torch.exp(torch.stack(nlls).sum() / end_loc)

Downloading:   0%|          | 0.00/0.99M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]


  0%|                                                   | 0/137 [00:00<?, ?it/s][A
  1%|▋                                          | 2/137 [00:00<00:08, 15.45it/s][A
  4%|█▌                                         | 5/137 [00:00<00:07, 17.70it/s][A
  5%|██▏                                        | 7/137 [00:00<00:08, 15.94it/s][A
  7%|██▊                                        | 9/137 [00:00<00:10, 12.65it/s][A
  8%|███▎                                      | 11/137 [00:00<00:11, 10.61it/s][A
  9%|███▉                                      | 13/137 [00:01<00:13,  9.19it/s][A
 10%|████▎                                     | 14/137 [00:01<00:14,  8.51it/s][A
 11%|████▌                                     | 15/137 [00:01<00:15,  7.96it/s][A
 12%|████▉                                     | 16/137 [00:01<00:15,  7.69it/s][A
 12%|█████▏                                    | 17/137 [00:01<00:16,  7.12it/s][A
 13%|█████▌                                    | 18/137 [00:01<00:18,  6.58

 76%|███████████████████████████████          | 104/137 [00:40<00:23,  1.42it/s][A
 77%|███████████████████████████████▍         | 105/137 [00:41<00:22,  1.40it/s][A
 77%|███████████████████████████████▋         | 106/137 [00:42<00:22,  1.39it/s][A
 78%|████████████████████████████████         | 107/137 [00:42<00:21,  1.38it/s][A
 79%|████████████████████████████████▎        | 108/137 [00:43<00:21,  1.36it/s][A
 80%|████████████████████████████████▌        | 109/137 [00:44<00:20,  1.34it/s][A
 80%|████████████████████████████████▉        | 110/137 [00:45<00:20,  1.32it/s][A
 81%|█████████████████████████████████▏       | 111/137 [00:45<00:19,  1.30it/s][A
 82%|█████████████████████████████████▌       | 112/137 [00:46<00:19,  1.31it/s][A
 82%|█████████████████████████████████▊       | 113/137 [00:47<00:18,  1.30it/s][A
 83%|██████████████████████████████████       | 114/137 [00:48<00:17,  1.29it/s][A
 84%|██████████████████████████████████▍      | 115/137 [00:48<00:17,  1.29i

In [105]:
ppl

tensor(nan)

### Running the code

In [93]:
graph = FB15k237(base_path='./data/FB15k-237', splits=['train', 'valid','test'], verbose=2)

Loading file train.txt
Data loading progress: [5%]
Data loading progress: [10%]
Data loading progress: [15%]
Data loading progress: [20%]
Data loading progress: [25%]
Data loading progress: [30%]
Data loading progress: [35%]
Data loading progress: [40%]
Data loading progress: [45%]
Data loading progress: [50%]
Data loading progress: [55%]
Data loading progress: [60%]
Data loading progress: [65%]
Data loading progress: [70%]
Data loading progress: [75%]
Data loading progress: [80%]
Data loading progress: [85%]
Loading file valid.txt
Data loading progress: [90%]
Loading file test.txt
Data loading progress: [95%]
Graph was successfully validated!
Building the graph took 56 seconds


In [95]:
graph.sample_english(10, negative=False)

['The currency of Frederick County is the United States dollar',
 'Forrest Gump was nominated for the Academy Award for Best Visual Effects award',
 'Phil Ramone has been nominated for an award along with Billy Joel',
 'Rich Sommer won an award along with Mark Moses',
 'Tara Reid was nominated for the Golden Raspberry Award for Worst Supporting Actress award',
 'Dilwale Dulhania Le Jayenge is a drama film',
 'The Gay Divorcee was nominated for the Academy Award for Best Production Design award',
 'J. G. Ballard was influenced by William S. Burroughs',
 'Rick Kline was nominated for an award for Terms of Endearment',
 'The Blues Brothers is a musical film']

### Random Experiments

In [21]:
# base_path='./data/FB15k-237'

# mapping = {}

# # Load file if it exists
# json_path = os.path.join(base_path, "relation_mapping.json")

# if os.path.exists(json_path):
#     json_file_read = open(json_path, 'r')
#     mapping = json.load(json_file_read)
#     json_file_read.close()


# # for rel in graph.relations:
# #     if rel not in mapping:
# #         relations_done = len(mapping.keys())
# #         relations_total = len(graph.relations)
# #         relations_left = relations_total - relations_done
# #         print(f"[{round(100*(relations_done/relations_total), 2)}%] done describing relations ({relations_left} left)")
        
# #         instance_idx = np.where(graph.triples[:, 1] == graph.relations.index(rel))[0][0]
# #         head, relation, tail = graph.triples[instance_idx]
# #         head, tail = graph._mid2entity(graph.entities[head]), graph._mid2entity(graph.entities[tail])
        
# #         format_string = input(f"{head} {rel} {tail}: ")
# #         mapping[rel] = format_string
# #         json_file_write = open(json_path, 'w')
# #         json.dump(mapping, json_file_write)
# #         json_file_write.close()
        
    

In [None]:
# Generate the dataset using the functions


In [144]:
# def see_relation_examples(relation, k = 1):
#     instance_indices = np.where(graph.triples[:, 1] == graph.relations.index(rel))[0][:k]
#     samples = graph.triples[instance_indices]
#     for sample in samples:
#         h, r, t = sample
#         print(graph._id2entity(graph.entities[h]), graph.relations[r], graph._id2entity(graph.entities[t]))
    
# see_relation_examples('/award/award_winner/awards_won./award/award_honor/award_winner', k=5)

Michelle Rodriguez /award/award_winner/awards_won./award/award_honor/award_winner Naveen Andrews
Scott Rudin /award/award_winner/awards_won./award/award_honor/award_winner Alan Bennett
Don Cheadle /award/award_winner/awards_won./award/award_honor/award_winner Larenz Tate
Freddy Rodriguez /award/award_winner/awards_won./award/award_honor/award_winner Justina Machado
Vincent Pastore /award/award_winner/awards_won./award/award_honor/award_winner Michael Imperioli


In [16]:
len(graph.relations)

235