In [4]:
import gzip

def extract_first_n_lines(source_path, dest_path, n=1000000):
    with gzip.open(source_path, 'rt') as source_file:
        lines = [source_file.readline() for _ in range(n)]
    
    with gzip.open(dest_path, 'wt') as dest_file:
        dest_file.writelines(lines)

# Example usage
source_path = '../data/wikidata5m/wikidata5m_transductive.tar.gz'
dest_path = '../data/wikidata5m/wikidata5m_transductive_1m.tar.gz'
extract_first_n_lines(source_path, dest_path)



In [1]:
import gzip

In [2]:
with gzip.open('../data/wikidata5m/wikidata5m_transductive.tar.gz', 'rt') as f:
    transductive = f.readlines()
    print(transductive[1])
    print(len(transductive))

Q6719921	P31	Q11446

20624576


In [5]:
print(transductive[0])
print(transductive[-1])

wikidata5m_transductive_test.txt                                                                    0000644 0601751 0601751 00000314623 13756763533 016457  0                                                                                                    ustar   兆成                                                                                                                                                                                                                                                 Q7965079	P27	Q16

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          

In [3]:
with gzip.open('../data/wikidata5m/wikidata5m_transductive_100k.tar.gz', 'rt') as f:
    transductive = f.readlines()
    print(transductive[1])
    print(len(transductive))

Q6719921	P31	Q11446

100000


In [3]:
with gzip.open('../data/wikidata5m/wikidata5m_alias.tar.gz', 'rt', encoding='latin-1') as f:
    aliases = f.readlines()
    print(aliases[100])
    print(aliases[-100])
    print(len(aliases))
    
    # save first 1000 aliases to a file for testing
    # with open('../data/wikidata5m/wikidata5m_alias_1000000.txt', 'w') as f:
    #     for alias in aliases[:1000000]:
    #         f.write(alias)

Q5816691	shadiabad, kerman	Shadiabad, Kerman

P2348	time period	era	historic era	epoch	historical period	sports season	theatre season	legislative period	historic period

4814317


In [5]:
print(aliases[-2])

P2821	by-product	byproduct



In [6]:
ents = aliases[1:4813490]
edges = aliases[4813492:-2]

In [7]:
print(len(ents))
print(len(edges))

4813489
823


In [None]:
ent_ids = []
# get first and last entries that start with 'Q'



In [2]:
with gzip.open('../data/wikidata5m/wikidata5m_text.txt.gz', 'rt', encoding='utf-8') as f:
    corpus = f.readlines()
    num_tokens = 0
    for line in corpus:
        num_tokens += len(line.split())
    print(corpus[1])
    print(len(corpus))
    print(num_tokens)

Q1640233	The Monumentum Ancyranum (Latin 'Monument of Ancyra') or Temple of Augustus and Rome in Ancyra is an Augusteum in Ankara (ancient Ancyra), Turkey. The text of the Res Gestae Divi Augusti ("Deeds of the Divine Augustus") is inscribed on its walls, and is the most complete copy of that text. The temple is adjacent to the Hadji Bairam Mosque in the Ulus quarter.

4815483
388737145


In [3]:
for text in corpus[]

'Q12253838\tArantxa Iturbe Maiz, (born 9 June 1964) is a Basque journalist, announcer and writer in the Basque language.\n'

In [11]:
n_docs = None
n_docs = n_docs or len(corpus)
print(n_docs)

4815483


In [6]:
# print number of documents with more than 128 tokens
num_long_docs = 0
num_tokens = 0
for line in corpus:
    tokens = len(line.split())
    if tokens > 128:
        num_long_docs += 1
        num_tokens += tokens
print(num_long_docs)
print(num_tokens)

831897
188403345


In [7]:
import os
import multiprocessing
import joblib
import marisa_trie
import numpy as np
import torch
# import faiss
from transformers import AutoTokenizer, AutoModel

class EntityLinker:
    def __init__(self, alias_file, use_embeddings=False, entity_embeddings=None):
        self.trie = self.build_trie_from_aliases(alias_file)
        self.use_embeddings = use_embeddings
        if use_embeddings:
            # Setup device and move model to GPU if available
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
            self.model = AutoModel.from_pretrained('bert-base-uncased').to(self.device)
            
            if entity_embeddings:
                self.entity_embeddings = entity_embeddings
            else:
                self.entity_embeddings = self.compute_entity_embeddings(alias_file)
                
            # Build Faiss index for efficient similarity search
            self.index = self.build_faiss_index()
            
    def process_aliases_batch(self, batch):
        aliases = []
        for line in batch:
            parts = line.strip().split('\t')
            aliases.extend(parts[1:])
        return aliases
    
    
    def build_trie_from_aliases(self, alias_file, save_dir=None):
        # Check if trie is already present on disk
        if save_dir:
            trie_path = os.path.join(save_dir, "trie.marisa")
            if os.path.exists(trie_path):
                return marisa_trie.Trie().load(trie_path)

        # Read the aliases file and split it into batches for parallel processing
        with gzip.open(alias_file, 'rt', encoding='latin-1') as file:
            lines = file.readlines()
        n_processes = multiprocessing.cpu_count()
        batch_size = len(lines) // n_processes
        batches = [lines[i:i + batch_size] for i in range(0, len(lines), batch_size)]

        # Process batches in parallel
        with multiprocessing.Pool(processes=n_processes) as pool:
            results = pool.map(self.process_aliases_batch, batches)
        
        # Merge results from all batches
        all_aliases = [alias for sublist in results for alias in sublist]

        # Build the marisa-trie from the merged list of aliases
        trie = marisa_trie.Trie(all_aliases)

        # Save the trie to disk
        if save_dir:
            trie.save(trie_path)

        return trie
    
    
    def compute_entity_embeddings(self, alias_file, save_dir=None):
        if save_dir:
            embeddings_path = os.path.join(save_dir, "embeddings.pkl")
            # Check if embeddings are already present on disk
            if os.path.exists(embeddings_path):
                with open(embeddings_path, 'rb') as f:
                    return joblib.load(f)

        entity_embeddings = {}
        with open(alias_file, 'r') as file:
            for line in file:
                parts = line.strip().split('\t')
                entity_id = parts[0]
                all_embeddings = []
                for alias in parts[1:]:
                    sentence = alias  # Since aliases are often short, we treat them as sentences
                    inputs = self.tokenizer(sentence, return_tensors="pt").to(self.device)
                    with torch.no_grad():
                        outputs = self.model(**inputs)
                    alias_embedding = outputs.last_hidden_state[0].mean(dim=0).cpu().numpy()
                    all_embeddings.append(alias_embedding)
                entity_embeddings[entity_id] = np.mean(all_embeddings, axis=0)
                
        if save_dir:
            with open(embeddings_path, 'wb') as f:
                joblib.dump(entity_embeddings, f)
                
        return entity_embeddings
    
    
    def build_faiss_index(self):
        embedding_matrix = np.array(list(self.entity_embeddings.values())).astype('float32')
        dimension = embedding_matrix.shape[1]
        index = faiss.IndexFlatL2(dimension)
        index.add(embedding_matrix)
        return index
    

    def compute_contextual_embedding(self, mention, full_text, start_pos, end_pos):
        # Extract the sentence containing the mention
        sentence_start = full_text.rfind('.', 0, start_pos) + 1  # Start after the last period before the mention
        sentence_end = full_text.find('.', end_pos)  # End at the next period after the mention
        sentence = full_text[sentence_start:sentence_end].strip()

        # Tokenize both the mention and the sentence
        mention_tokens = self.tokenizer.tokenize(mention)
        sentence_tokens = self.tokenizer.tokenize(sentence)

        # Find the start and end position of the mention in the sentence tokens
        mention_start_pos = None
        mention_end_pos = None
        for i in range(len(sentence_tokens) - len(mention_tokens) + 1):
            if sentence_tokens[i:i+len(mention_tokens)] == mention_tokens:
                mention_start_pos = i
                mention_end_pos = i + len(mention_tokens) - 1
                break

        # Convert tokens to inputs and pass through the model
        inputs = self.tokenizer(sentence, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model(**inputs)
        embeddings = outputs.last_hidden_state[0]

        # Extract the embeddings corresponding to the mention and average them
        mention_embedding = torch.mean(embeddings[mention_start_pos:mention_end_pos+1], dim=0)

        return mention_embedding.cpu().numpy()
    

    def find_best_entity_match(self, mention_embedding, k=1):
        distances, indices = self.index.search(mention_embedding.reshape(1, -1), k)
        closest_index = indices[0][0]
        entity_id = list(self.entity_embeddings.keys())[closest_index]
        return entity_id, distances[0][0]
    

    def link_entities(self, document):
        tokens = document.split()  # Basic whitespace-based tokenization
        i = 0
        entities_found = []

        while i < len(tokens):
            longest_match = None
            longest_match_length = 0
            current_string = tokens[i].lower()

            # Check for entity match starting from the current token
            if current_string in self.trie:
                entity_id = self.trie[current_string]
                longest_match = (entity_id, i, i)
                longest_match_length = 1

            # Extend the match to subsequent tokens to find longer matches
            j = i + 1
            while j < len(tokens) and current_string in self.trie:
                current_string += " " + tokens[j].lower()
                if current_string in self.trie:
                    entity_id = self.trie[current_string]
                    longest_match = (entity_id, i, j)
                    longest_match_length = j - i + 1
                j += 1

            # If a match was found, add it to entities_found and skip the matched tokens
            if longest_match:
                entities_found.append(longest_match)
                i += longest_match_length
            else:
                i += 1

        # Convert token positions to character positions
        char_entities_found = []
        char_pos = 0
        token_index = 0
        for entity in entities_found:
            while token_index < entity[1]:
                char_pos += len(tokens[token_index]) + 1  # +1 for space
                token_index += 1
            start_pos = char_pos
            while token_index <= entity[2]:
                char_pos += len(tokens[token_index]) + 1
                token_index += 1
            end_pos = char_pos - 2  # -1 for last space, -1 to get the end of the word
            char_entities_found.append((entity[0], start_pos, end_pos))

        if self.use_embeddings:
            for idx, (entity_id, start, end) in enumerate(char_entities_found):
                mention = document[start:end+1]
                mention_embedding = self.compute_contextual_embedding(mention)
                best_match, confidence = self.find_best_entity_match(mention_embedding)
                # Replace entity ID based on embedding match and add confidence
                char_entities_found[idx] = (best_match, start, end, confidence)

        return char_entities_found


    def process_chunk(self, chunk):
        results = {}
        for document_id, document in chunk:
            entities = self.link_entities(document)
            results[document_id] = entities
        return results


    def link_entities_for_corpus(self, corpus, n_processes=None):
        if n_processes is None:
            n_processes = multiprocessing.cpu_count()

        with multiprocessing.Pool(processes=n_processes) as pool:
            chunks = self.chunk_corpus(corpus, n_processes)
            results_list = pool.map(self.process_chunk, chunks)

        # Combine results from all processes
        combined_results = {}
        for result in results_list:
            combined_results.update(result)

        return combined_results


    def chunk_corpus(self, corpus, n):
        """Divide the corpus into n chunks"""
        avg_len = len(corpus) // n
        chunks = []
        for i in range(0, len(corpus), avg_len):
            chunks.append(corpus[i:i + avg_len])
        return chunks

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
entity_linker = EntityLinker('../data/wikidata5m/wikidata5m_alias.tar.gz')

In [9]:
print(max(corpus[:10000], key=len))




In [None]:
corpus[-1]

In [10]:
entities = entity_linker.link_entities(max(corpus[:10000], key=len))

In [20]:
trump = entity_linker.link_entities("Donald Trump is the president of the United States.")
print(trump)

[(18455175, 0, 11), (1802, 13, 14), (6975261, 20, 28), (1894, 30, 31), (2958240, 37, 42)]


In [11]:
print(len(max(corpus[:10000], key=len).split()))
print(len(entities))
print(entities[:16])

1406
753
[(286143, 18, 21), (22131056, 27, 50), (2083738, 52, 57), (44965, 59, 61), (10949484, 63, 71), (17287804, 88, 103), (1894, 105, 106), (1613703, 119, 124), (2641965, 137, 142), (1894, 144, 145), (8912594, 147, 155), (1801, 178, 179), (7332280, 185, 193), (1894, 195, 196), (10602540, 198, 210), (2008177, 227, 232)]


In [45]:
import os
import tarfile
import multiprocessing
from functools import lru_cache
import igraph as ig

class GraphProcessor:
    def __init__(self, triples_path, save_dir):
        self.graph = self.build_graph_from_triples(triples_path)
        self.save_dir = save_dir

        # Load graph from disk if it exists, otherwise save it
        if save_dir is not None:
            if os.path.exists(os.path.join(self.save_dir, "full_graph.pickle")):
                self.graph = ig.Graph.Read_Pickle(os.path.join(self.save_dir, "full_graph.pickle"))
            else:
                self.graph.write_pickle(os.path.join(self.save_dir, "full_graph.pickle"))

    def build_graph_from_triples(self, tar_gz_path):
        """Build a graph from the provided triples file compressed as tar.gz."""

        triples = []

        # Open and extract the tar.gz file
        with tarfile.open(tar_gz_path, 'r:gz') as archive:
            for member in archive.getmembers():
                # Extract triples from each .txt file in the archive
                if member.name.endswith('.txt'):
                    with archive.extractfile(member) as file:
                        lines = file.readlines()
                        triples.extend([tuple(line.decode('utf-8').strip().split('\t')) for line in lines])

        # Extracting unique entities and relations from triples
        entities = list(set([s for s, _, _ in triples] + [t for _, _, t in triples]))

        # Create an empty directed graph
        g = ig.Graph(directed=True)

        # Add vertices to the graph
        g.add_vertices(entities)

        # Add edges to the graph
        for s, r, t in triples:
            g.add_edge(s, t, name=r)

        return g

    @lru_cache(maxsize=None) # unlimited cache size
    def get_k_hop_neighbors(self, node, k=2):
        """Retrieve k-hop neighbors for a given node."""
        neighbors = set()
        for hop in range(1, k+1):
            neighbors.update(self.graph.neighborhood(node, order=hop))
        return list(neighbors)

    def get_subgraph_for_entities(self, entities, k=2, max_nodes=256):
        """Retrieve a subgraph for a given list of entities."""
        all_neighbors = set()
        for entity in entities:
            all_neighbors.update(self.get_k_hop_neighbors(entity, k))
        subgraph = self.graph.subgraph(list(all_neighbors))
        pruned_subgraph = self.prune_by_degree(subgraph, k=k, max_nodes=max_nodes)
        return pruned_subgraph

    def prune_by_degree(self, subgraph, k=2, max_nodes=256):
        """Prune the subgraph based on node degree."""
        while len(subgraph.vs) > max_nodes:
            # Start with nodes at the outer boundary (i.e., k-hops away)
            to_delete_ids = [v.index for v in subgraph.vs if subgraph.degree(v) <= k]
            if not to_delete_ids:  # If we can't prune any more distant nodes, prune closer ones
                k -= 1
            if k == 0:  # If we can't prune based on degree, prune randomly
                to_delete_ids = [v.index for v in subgraph.vs][:len(subgraph.vs) - max_nodes]
            subgraph.delete_vertices(to_delete_ids)
        return subgraph
    
    def process_document_wrapper(self, args):
        return GraphProcessor.process_document(*args)

    def parallel_process_corpus(self, corpus, num_processes):
        num_processes = multiprocessing.cpu_count()
        with multiprocessing.Pool(num_processes) as pool:
            results = pool.map(self.process_document_wrapper, corpus)
        return results

In [46]:
kg_processor = GraphProcessor('../data/wikidata5m/wikidata5m_transductive.tar.gz', None)

: 

: 

In [32]:
test_entities = [line.strip().split('\t')[0] for line in transductive[1:32]]

In [33]:
subgraph = kg_processor.get_subgraph_for_entities(test_entities, k=2, max_nodes=256)

In [42]:
subg = subgraph.to_dict_list()
subg

([{'name': 'Q1223'},
  {'name': 'Q148'},
  {'name': 'Q5'},
  {'name': 'Q30'},
  {'name': 'Q11446'},
  {'name': 'Q3957'},
  {'name': 'Q183'},
  {'name': 'Q40348'},
  {'name': 'Q211005'}],
 [])

In [44]:
kg_processor.graph.get_adjacency().shape

(7957, 7957)