In [31]:
from collections import namedtuple
import torch
import numpy as np
from datasets import load_dataset

In [2]:
web = load_dataset('enriched_web_nlg', 'en')['train']

Reusing dataset enriched_web_nlg (/home/dashiell/.cache/huggingface/datasets/enriched_web_nlg/en/0.0.0/71bb7723be90037b52b022d82fd8928fb6f7544fd331aec278aaffb6814be4b5)


In [3]:
len(web)

6940

In [4]:
web[0]

{'category': 'Airport',
 'eid': 'Id1',
 'lex': {'comment': ['good', 'good'],
  'lexicalization': ['AGENT-1 VP[aspect=simple,tense=present,voice=active,person=3rd,number=singular] be DT[form=defined] the airport of PATIENT-1 .',
   'AGENT-1 VP[aspect=simple,tense=present,voice=active,person=3rd,number=null] serve DT[form=defined] the city of PATIENT-1 .'],
  'lid': ['Id1', 'Id2'],
  'sorted_triple_sets': [['Aarhus_Airport | cityServed | "Aarhus, Denmark"'],
   ['Aarhus_Airport | cityServed | "Aarhus, Denmark"']],
  'template': ['AGENT-1 is the airport of PATIENT-1 .',
   'AGENT-1 serves the city of PATIENT-1 .'],
  'text': ['The Aarhus is the airport of Aarhus, Denmark.',
   'Aarhus Airport serves the city of Aarhus, Denmark.']},
 'modified_triple_sets': {'mtriple_set': [['Aarhus_Airport | cityServed | "Aarhus, Denmark"']]},
 'original_triple_sets': {'otriple_set': [['Aarhus_Airport | cityServed | "Aarhus, Denmark"@en']]},
 'shape': '',
 'shape_type': '',
 'size': 1}

In [5]:
web[5467]

{'category': 'University',
 'eid': 'Id71',
 'lex': {'comment': ['good', 'good', 'good'],
  'lexicalization': ['AGENT-1 in PATIENT-3 VP[aspect=simple,tense=present,voice=passive,person=3rd,number=singular] affiliate with BRIDGE-1 in PATIENT-1 . AGENT-1 director VP[aspect=simple,tense=present,voice=active,person=3rd,number=singular] be PATIENT-2 .',
   'AGENT-1 VP[aspect=simple,tense=present,voice=active,person=3rd,number=singular] be located in PATIENT-3 . AGENT-1 VP[aspect=simple,tense=present,voice=passive,person=3rd,number=singular] affiliate with BRIDGE-1 in PATIENT-1 and AGENT-1 director VP[aspect=simple,tense=present,voice=active,person=3rd,number=singular] be PATIENT-2 .',
   'AGENT-1 VP[aspect=simple,tense=past,voice=active,person=null,number=null] base in DT[form=defined] the state of PATIENT-3 VP[aspect=simple,tense=present,voice=active,person=3rd,number=null] have DT[form=defined] the director PATIENT-2 . AGENT-1 VP[aspect=simple,tense=present,voice=passive,person=3rd,number=

In [6]:
Triple = namedtuple('Triple', ['h', 'r', 't'])

def str_to_triple(triple: str):
    return Triple(*triple.split(' | '))
    

In [7]:
print(str_to_triple(web[5467]['modified_triple_sets']['mtriple_set'][0][0]))

Triple(h='Visvesvaraya_Technological_University', r='city', t='Belgaum')


In [8]:
for relation in web[5467]['modified_triple_sets']['mtriple_set'][0]:
    print(str_to_triple(relation))

Triple(h='Visvesvaraya_Technological_University', r='city', t='Belgaum')
Triple(h='Acharya_Institute_of_Technology', r='director', t='"Dr. G. P. Prabhukumar"')
Triple(h='Acharya_Institute_of_Technology', r='state', t='Karnataka')
Triple(h='Acharya_Institute_of_Technology', r='affiliation', t='Visvesvaraya_Technological_University')


In [9]:
# i.e. there is only one element in 'mtriple_set' 
for i, record in enumerate(web):
    if len(record['modified_triple_sets']['mtriple_set']) != 1:
        print(i)

In [10]:
from torch_geometric.data import Data


Entity = namedtuple('Entity', ['name', 'index'])
Relationship = namedtuple('Entity', ['name', 'index'])


class TripleCache:
    
    def __init__(self):
        self.e_cache = {}
        self.r_cache = {}
        self.e_index = 0
        self.r_index = 0
        
    def get_entity(self, entity):
        if entity not in self.e_cache:
            self.e_index += 1
            self.e_cache[entity] = self.e_index
        return Entity(entity, self.e_cache[entity])
    
    def get_relationship(self, relationship):
        if relationship not in self.r_cache:
            self.r_index += 1
            self.r_cache[relationship] = self.r_index
        return Relationship(relationship, self.r_cache[relationship])

    
class WebNLGSample:
    
    def __init__(self, record):
        self.texts = record['lex']['text']
        self.triples, self.entities, self.relations = self.process_record(
            record['modified_triple_sets']['mtriple_set'][0]
        )
        
    @staticmethod
    def process_record(sample):
        entities = set()
        relations = set()
        triples = []
        for record in sample:
            triple = str_to_triple(record)
            entities.add(triple.h)
            entities.add(triple.t)
            relations.add(triple.r)
            triples.append(triple)
        return triples, entities, relations
    
    def make_graph(self, cache):
        ents = {
            cache.get_entity(e): i for i, e in enumerate(self.entities)
        }
        edge_index = torch.zeros(2, 2 * len(self.triples), dtype=torch.long) # 2 times for both directions
        for i, triple in enumerate(self.triples):
            i *= 2
            h = cache.get_entity(triple.h)
            t = cache.get_entity(triple.t)
            edge_index[0][i] = ents[h]
            edge_index[1][i+1] = ents[h]
            edge_index[0][i+1] = ents[t]
            edge_index[1][i] = ents[t]
        x = torch.zeros(len(self.entities), 1)
        for ent, index in ents.items():
            x[index][0] = ent.index
        return Data(x=x, edge_index=edge_index)    
        

In [46]:
sample = WebNLGSample(web[457])

In [47]:
sample.triples

[Triple(h='Adisham_Hall', r='country', t='Sri_Lanka')]

In [48]:
all_records = [WebNLGSample(s) for s in web]

In [49]:
cache = TripleCache()
g1 = all_records[2].make_graph(cache)

In [52]:
graphs = [r.make_graph(cache) for r in all_records]

In [16]:
x = torch.as_tensor([[[1, 2, 3], [3, 2, 1], [10, 20, 30]], [[2, 3, 4], [4, 3, 2], [20, 30, 40]]])
y = torch.as_tensor([[[4, 5, 6], [3, 2, 1], [6, 5, 4]], [[5, 6, 7], [7, 6, 5], [8, 9, 10]]])


In [17]:
x * y 

tensor([[[  4,  10,  18],
         [  9,   4,   1],
         [ 60, 100, 120]],

        [[ 10,  18,  28],
         [ 28,  18,  10],
         [160, 270, 400]]])

In [30]:
torch.sum(x * y, dim=0)

tensor([[ 14,  28,  46],
        [ 37,  22,  11],
        [220, 370, 520]])

tensor([[ 3,  1, 30],
        [ 4,  2, 40]])