In [None]:
!pip install faiss-gpu
!pip install sentence_transformers
!wget http://74.82.28.99:9000/cskg/data.tgz
!tar -zxvf data.tgz

: 

In [1]:

!pwd

/notebooks/cskg


If you want to re-index embeddings. run build_index_db()

In [2]:
from functools import partial
import pickle
import re
from secrets import randbelow
from typing import Callable, List, Tuple
import csv

import faiss
import numpy as np
from tqdm import tqdm
from sentence_transformers import SentenceTransformer

class Vocab:
    def __init__(self, words) -> None:
        self.idx_to_word = words
        self.word_to_idx = {word: idx for idx, word in enumerate(words)}

In [4]:
class CSKG_EMBEDDINGS():
    
    def __init__(self) -> None:
        print("init")
        #self.embedding_file = 'cskg_embeddings.csv'
        self.model = SentenceTransformer('all-mpnet-base-v2')
    
    def read_embedding_file(self, file) -> Tuple[Vocab, np.ndarray]:
        with open(file, 'r') as f:
            vector_dim = len(next(f).split(",\"")[1].split(','))
            file_len = len(f.readlines())
            shape = tuple([file_len + 1,vector_dim])
            print(shape)
            embeddings = np.zeros(shape, dtype=np.float32)
        with open(self.embedding_file, 'r') as d:
            words = []
            for i, line in tqdm(enumerate(d), total=shape[0]):
                embedding = line.split("text_embedding,")[1].split("\"")[1].split(',')
                word = line.split("text_embedding,")[0].split(",")[0]
                embedding = np.array([float(x) for x in embedding])
                words.append(word)
                embeddings[i] = embedding
        self.words = words
        vocab = Vocab(words)
        
        return (vocab, embeddings)


    def build_index_db(self, metric: str, embeddings: np.ndarray, vocab):

        if metric == 'cosine':
            index = faiss.IndexFlatIP(embeddings.shape[-1])
        elif metric == 'l2':
            index = faiss.IndexFlatL2(embeddings.shape[-1])
        else:
            raise ValueError(f'Bad metric: {metric}')
        index.add(embeddings)
        faiss.write_index(index,"data/vector.index")  # save the index to disk
        self.vocab = vocab
        self.index = index

        with open("data/vocab.pic", 'wb') as f:
                pickle.dump(vocab, f, protocol=4)
        csv.field_size_limit(512000)
        rel = []
        with open("data/cskg_sentences1.tsv") as file:
            tsv_file = csv.reader(file, delimiter="\t")
            for line in tsv_file:
                rel.append(line[3])
        with open("data/relations.pic", 'wb') as f:
            pickle.dump(rel, f, protocol=4)
    
    def load_index_db(self):
        with open("data/vocab.pic", 'rb') as f:
            self.vocab = pickle.load(f)
        with open("data/relations.pic", 'rb') as f:
            self.relations = pickle.load(f)
        self.index = faiss.read_index("data/vector.index")
        # print("VOCAB LOADED, size: ", len(self.vocab.))
        # print("RELATIONS LOADED, size: ", len(self.relations))
        print("INDEX Loaded")
    
    def parse_relations(self, rel):
        isa_relations = []
        description_relations = []
        property_values_relations = []
        for  relations in rel.split('+'):
            if 'isa' in relations:
                isa_relations = relations.split('->')[0].split('isa(')
            if 'description(' in relations:
                description_relations = relations.split('->')[0].split('description(')
            if 'property_values(' in relations:
                property_values_relations = relations.split('->')[0].split('property_values(')
            #print(relations)
            if (len(isa_relations) > 0):
                for isa in isa_relations[1].split(','):
                    print("----------->ISA ",isa)
            if (len(description_relations) > 0):
                for desc in description_relations[1].split(','):
                    print("------------>DR ",desc)
            if (len(property_values_relations) > 0):
                for pv in property_values_relations[1].split(','):
                    print("----------->PVR ",pv)
        #input()
    
    def query_for_events(self, query, topk):
        query = "It a at:" + query.replace(" ", " at:")
        print(query)
        query=np.array([self.model.encode(query)])
        #faiss.normalize_L2(query)
        scores, candidate_ids = self.index.search(query, topk)
        scores = scores.flatten()
        candidate_ids = candidate_ids.flatten()
        top_k_indices = np.argsort(scores)[:topk]
        scores = scores[top_k_indices]
        candidate_ids = candidate_ids[top_k_indices]
        
        for candidate_id, score in zip(np.nditer(candidate_ids), np.nditer(scores)):
            candidate = self.vocab.idx_to_word[candidate_id]
            relation = self.relations[candidate_id]
            # print(candidate, relation)
            # print(candidate[0:5])
            # if qtype == 'at' and candidate[0:3] == 'at:':
            #     print("Candidate: ",candidate, score)
            #     relation = relation.replace('\\\'','')
            #     cskg_emb.parse_relations(relation)
            #if qtype == 'cn' and candidate[0:5] == '/c/en':
            print("Candidate: ",candidate, score)
            relation = relation.replace('\\\'','')
            self.parse_relations(relation)

    def query_for_concepts(self, query, topk):
        query = "It is a /c/en/" + query.replace(" ", " /c/en/")
        print(query)
        query=np.array([self.model.encode(query)])
        #faiss.normalize_L2(query)
        scores, candidate_ids = self.index.search(query, topk)
        scores = scores.flatten()
        candidate_ids = candidate_ids.flatten()
        top_k_indices = np.argsort(scores)[:topk]
        scores = scores[top_k_indices]
        candidate_ids = candidate_ids[top_k_indices]
        
        for candidate_id, score in zip(np.nditer(candidate_ids), np.nditer(scores)):
            candidate = self.vocab.idx_to_word[candidate_id]
            relation = self.relations[candidate_id]
            # print(candidate, relation)
            # print(candidate[0:5])
            # if qtype == 'at' and candidate[0:3] == 'at:':
            #     print("Candidate: ",candidate, score)
            #     relation = relation.replace('\\\'','')
            #     cskg_emb.parse_relations(relation)
            #if qtype == 'cn' and candidate[0:5] == '/c/en':
            print("Candidate: ",candidate, score)
            relation = relation.replace('\\\'','')
            self.parse_relations(relation)

In [5]:
cskg_emb = CSKG_EMBEDDINGS()
cskg_emb.load_index_db()

init
INDEX Loaded


Run "Events" Query

In [12]:
print(cskg_emb.query_for_events('boat in the sea', 7))

It a at:boat at:in at:the at:sea
Candidate:  at:personx_drops_anchor 0.7336569
----------->PVR  at:xAttr /c/en/active
----------->PVR  at:xAttr /c/en/clumsy
----------->PVR  at:xAttr /c/en/dominant
----------->PVR  at:xAttr /c/en/noisy
----------->PVR  at:xAttr /c/en/responsible
----------->PVR  at:xAttr /c/en/willful
----------->PVR  at:xEffect at:person_x_brings_anchor_back_up
----------->PVR  at:xEffect at:person_x_rests
----------->PVR  at:xEffect at:they_jump_onto_the_jetty_and_tie_the_boat_up
----------->PVR  at:xEffect at:they_steady_the_boat_to_get_off
----------->PVR  at:xEffect at:they_turn_the_boat_engine_off
----------->PVR  at:xEffect at:to_board_land
----------->PVR  at:xEffect at:to_secure_boat
----------->PVR  at:xIntent at:to_slow_a_boat
----------->PVR  at:xNeed at:go_signal_from_superior
----------->PVR  at:xNeed at:to_be_on_a_boat
----------->PVR  at:xNeed at:to_be_on_a_watercraft
----------->PVR  at:xNeed at:to_make_sure_it_is_safe
----------->PVR  at:xReact /c/en/

In [10]:
print(cskg_emb.query_for_concepts('boat_in_the_sea', 10))

It is a /c/en/boat_in_the_sea
Candidate:  /c/en/pinnace 0.4705486
----------->ISA  \/c/en/ships_boat/n\"
Candidate:  /c/en/dinghy 0.4705486
----------->ISA  \/c/en/ships_boat/n\"
Candidate:  /c/en/longboat 0.4705486
----------->ISA  \/c/en/ships_boat/n\"
Candidate:  /c/en/eight_person_inflatable_boat/n 0.531604
----------->ISA  /c/en/inflatable_boat/n
Candidate:  /c/en/three_person_inflatable_boat/n 0.531604
----------->ISA  /c/en/inflatable_boat/n
Candidate:  /c/en/zodiac_f470 0.5440282
----------->ISA  /c/en/powered_rigid_inflatable_boat
Candidate:  /c/en/sail_boat 0.54575664
----------->ISA  /c/en/boat
----------->ISA  /c/en/boat
----------->PVR  /r/HasFirstSubevent /c/en/get_in_boat
Candidate:  /c/en/german_submarine 0.54818875
----------->ISA  /c/en/u_boat
Candidate:  /c/en/harbor 0.56792694
----------->ISA  /c/en/ore/n
----------->ISA  /c/en/ore/n
----------->PVR  /r/AtLocation /c/en/boat
Candidate:  /c/en/canoeist/n/wn/person 0.5684298
----------->ISA  /c/en/boat/v/wn/navigation

Run "concepts" query 

In [9]:
print(cskg_emb.query_for_concepts('george_washington', 10))

It is a /c/en/george_washington
Candidate:  /c/en/first_us_president 0.37781304
----------->ISA  /c/en/george_washington
Candidate:  /c/en/interstate_highway_system 0.5384255
----------->ISA  \/c/en/dwight_eisenhowers_legacy\"
Candidate:  /c/en/national_academy_of_sciences 0.573247
----------->ISA  /c/en/historic_place/n
----------->ISA  /c/en/historic_place/n
----------->PVR  /r/AtLocation /c/en/washington_d.c
Candidate:  /c/en/washington_monument 0.6633189
----------->ISA  /c/en/historic_place/n
----------->ISA  /c/en/obelisk
----------->ISA  /c/en/historic_place/n
----------->ISA  /c/en/obelisk
----------->PVR  /r/AtLocation /c/en/washington_d.c
Candidate:  /c/en/abe_lincoln 0.6636475
----------->ISA  /c/en/president_of_us
Candidate:  /c/en/argentine_president/n 0.68949616
----------->ISA  /c/en/national_president/n
Candidate:  /c/en/battleground 0.69416463
----------->ISA  /c/en/national_cemetery_in_washington_d.c
Candidate:  /c/en/washington_state_song 0.70237666
----------->ISA  