In [2]:
import pickle
from collections import namedtuple
from typing import List
import numpy as np
import random
from tqdm import tqdm_notebook as tqdm

Quint = namedtuple('Quint', 's p o qp qe')

In [3]:
# Load data from disk
with open('./parsed_raw_data.pkl', 'rb') as f:
    raw_data = pickle.load(f)
    
entities, predicates = [], []

for quint in raw_data:
    entities += [quint[0], quint[2]]
    if quint[4]:
        entities.append(quint[4])
        
    predicates.append(quint[1])
    if quint[3]: 
        predicates.append(quint[3])
    
entities = list(set(entities))
predicates = list(set(predicates))
        
print(len(entities), len(predicates))

174951 659


In [4]:
def sample_negatives(quint: Quint, probs: List[float]) -> Quint:
    """ probs: [ p(s), p(r), p(o), p(q) ] """
    assert np.sum(probs)==1.0
    assert len(probs) == 4
    # print(probs)
    l = np.random.choice(["s", "p", "o", "q"], 1, p=probs)
    if l[0] == "s":
        return Quint(s=random.choice(entities), p=quint[1], o=quint[2], qp=quint[3], qe=quint[4])
#         while True:
#             new_s = random.choice(entities)
#             q = Quint(s=new_s, p=quint[1], o=quint[2], qp=quint[3], qe=quint[4])
#             if q not in raw_data:
#                 return q
    elif l[0] == "p":
        return Quint(s=quint[0], p=random.choice(predicates), o=quint[2], qp=quint[3], qe=quint[4])
#         while True:
#             new_p = random.choice(predicates)
#             q = Quint(s=quint[0], p=new_p, o=quint[2], qp=quint[3], qe=quint[4])
#             if q not in raw_data:
#                 return q
    elif l[0] == "o":
        return Quint(s=quint[0], p=quint[1], o=random.choice(entities), qp=quint[3], qe=quint[4])
#         while True:
#             new_o = random.choice(entities)
#             q = Quint(s=quint[0], p=quint[1], o=new_o, qp=quint[3], qe=quint[4])
#             if q not in raw_data:
#                 return q
    elif l[0] == "q":
        if quint[3]:
            if np.random.random() > 0.5:
                # sample qp
                return Quint(s=quint[0], p=quint[1], o=quint[2], qp=random.choice(predicates), qe=quint[4])
#                 while True:
#                     qp = random.choice(predicates)
#                     q = Quint(s=quint[0], p=quint[1], o=quint[2], qp=qp, qe=quint[4])
#                     if q not in raw_data:
#                         return q
            else:
                return Quint(s=quint[0], p=quint[1], o=quint[2], qp=quint[3], qe=random.choice(entities))
#                 while True:
#                     qe = random.choice(entities)
#                     q = Quint(s=quint[0], p=quint[1], o=quint[2], qp=quint[3], qe=qe)
#                     if q not in raw_data:
#                         return q
        else:
            return Quint(s=quint[0], p=quint[1], o=quint[2], qp=random.choice(predicates), qe=random.choice(entities))
#             while True:
#                 qp = random.choice(predicates)
#                 qe = random.choice(entities)
#                 q = Quint(s=quint[0], p=quint[1], o=quint[2], qp=qp, qe=qe)
#                 if q not in raw_data:
#                     return q
  

In [5]:
probs = [0.3, 0.0, 0.3, 0.4]
q_neg = sample_negatives(raw_data[0], probs)
print(q_neg)

Quint(s='Q636', p='P27', o='Q145', qp='P483', qe='Q970873')


In [108]:
l = np.random.choice(["s", "p", "o", "q"], 1000, p=probs)
print(l[0])
unique, counts = np.unique(l, return_counts=True)
dict(zip(unique, counts))
#l.count("s"), l.count("p"), l.count("o"), l.count("q")

s


{'o': 311, 'q': 390, 's': 299}

In [6]:
negative_samples = []
for q in tqdm(raw_data):
    negative_samples.append(sample_negatives(q, probs))

HBox(children=(IntProgress(value=0, max=389937), HTML(value='')))




In [None]:
count = 0
for n in tqdm(negative_samples):
    if n in raw_data:
        print(n)
        count += 1

print(f"{count} / {len(raw_data)} are not unique negatives")

HBox(children=(IntProgress(value=0, max=389937), HTML(value='')))

Quint(s='Q35869', p='P2293', o='Q18027836', qp='P459', qe='Q1098876')
Quint(s='Q5608', p='P21', o='Q6581097', qp=None, qe=None)
Quint(s='Q296698', p='P21', o='Q6581097', qp=None, qe=None)
Quint(s='Q62547', p='P1411', o='Q131520', qp='P805', qe='Q740425')
Quint(s='Q714845', p='P106', o='Q486748', qp=None, qe=None)
Quint(s='Q489831', p='P106', o='Q8246794', qp=None, qe=None)
Quint(s='Q459384', p='P31', o='Q5', qp=None, qe=None)
Quint(s='Q3093', p='P150', o='Q283344', qp=None, qe=None)
Quint(s='Q19570', p='P161', o='Q242584', qp=None, qe=None)
Quint(s='Q19570', p='P161', o='Q1198897', qp=None, qe=None)
Quint(s='Q390120', p='P495', o='Q30', qp=None, qe=None)
Quint(s='Q80135', p='P166', o='Q2329480', qp='P1686', qe='Q261140')
Quint(s='Q2831', p='P106', o='Q28389', qp=None, qe=None)
Quint(s='Q310324', p='P31', o='Q5', qp=None, qe=None)
Quint(s='Q122113', p='P462', o='Q22006653', qp=None, qe=None)
Quint(s='Q108510', p='P1196', o='Q3739104', qp=None, qe=None)
Quint(s='Q525', p='P398', o='Q1364