## Creating the Knowledge graph

In [179]:
import numpy as np
import os
import torch
import pickle
from torch_geometric.data import Data
from torch_geometric.utils import to_undirected, degree
from typing import List, Tuple, Optional, Dict

from dataclasses import dataclass

### The Wiki Data

In [180]:
IN_PATH = "../data/raw/WIKI"
OUT_PATH = "../data/processed/WIKI"

In [181]:
class WIKIGraphProcessor:
    def __init__(self, in_dir: str, out_dir: str) -> None:
        self.in_dir = in_dir
        self.out_dir = out_dir

    def _load_quadruples(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]:
        """
        Given a filename from the in_dir, load the quadruples and times.
        Returns Quadruples (head, relation, tail, time) and times.

        Returns:
            quadruples: np.ndarray
            times: np.ndarray

        """
        # Extract
        file_path = os.path.join(self.in_dir, file_name)
        textfile = open(file_path, "r")
        lines = textfile.readlines()

        # Transform
        quadruples = []
        timestamps = set()
        for line in lines:
            quadruple_str = line.replace("\n", "").split("\t")
            quadruple_int = [int(i) for i in quadruple_str]

            quadruples.append(quadruple_int)
            timestamps.add(quadruple_int[3])

        return np.array(quadruples), np.asarray(sorted(list(timestamps)))

    def _get_network_stats(self):
        """
        Get the total number of nodes and edges in the graph.

        Returns:
            num_nodes: int
            num_edges: int
        """
        STAT_FILENAME = "stat.txt"

        textfile = open(os.path.join(self.in_dir, STAT_FILENAME), "r")
        lines = textfile.readlines()

        n_nodes, n_edges, _ = lines[0].replace("\n", "").split("\t")
        return int(n_nodes), int(n_edges)

In [182]:
""" 
# Initialize the processor
processor = WIKIGraphProcessor(IN_PATH, OUT_PATH)

# Load data
quadruples, times = processor._load_quadruples("train.txt")
n_entites, n_relations = processor._get_network_stats() 

"""

' \n# Initialize the processor\nprocessor = WIKIGraphProcessor(IN_PATH, OUT_PATH)\n\n# Load data\nquadruples, times = processor._load_quadruples("train.txt")\nn_entites, n_relations = processor._get_network_stats() \n\n'

In [183]:
@dataclass
class Quadruple:
    source_id: int
    relation_id: int
    object_id: int
    timestamp: int


class DataLoader:
    def __init__(self):
        pass

    def load_quadruples(self, file_name: str) -> Tuple[List[Quadruple], List[int]]:
        quadruples = []
        times = set()

        textfile = open(file_name, "r")
        lines = textfile.readlines()

        for line in lines:
            source_id, relation_id, object_id, timestamp, _ = map(
                int, line.strip().split()
            )
            quadruples.append(Quadruple(source_id, relation_id, object_id, timestamp))
            times.add(timestamp)
        return quadruples, times

    def get_total_number(self, file_name: str) -> Tuple[int, int]:
        textfile = open(file_name, "r")
        num_entities, num_relations, _ = map(int, textfile.readline().strip().split())
        return num_entities, num_relations

In [190]:
class GraphGenerator:
    def __init__(self, num_entities: int, num_relations: int):
        self.num_entities = num_entities
        self.num_relations = num_relations

    def get_data_for_time_step(
        self, quadruples: List[Quadruple], time_step: int
    ) -> List[Quadruple]:
        """
        Given a list of quadruples and a time step, return the quadruples at that time step.
        """
        return [q for q in quadruples if q.timestamp == time_step]

    def generate_graph(self, data: List[Quadruple]) -> Data:
        """
        Generate Graph from the quadruples.
        """
        edge_index = []
        edge_attr = []
        for quadruple in data:
            edge_index.append([quadruple.source_id, quadruple.object_id])
            edge_attr.append(quadruple.relation_id)
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr, dtype=torch.float)
        data = Data(
            edge_index=edge_index, edge_attr=edge_attr, num_nodes=self.num_entities
        )
        return data
        return to_undirected(data)  # Ensure the graph is undirected


@dataclass
class HistoricalRecord:
    relation_id: int
    other_entity_id: int
    timestamp: int


class EntityHistoryCache:
    def __init__(self, history_length: int) -> None:
        self.history_length = history_length

        self.cache: Dict[int, List[HistoricalRecord]] = {}  # entity_id -> history[]
        self.timestamps: Dict[int, List[int]] = {}  # entity_id -> timestamp[]

        self.current_timestamp = 0

    def add_to_history(
        self, entity_id: int, relation_id: int, other_entity_id: int, timestamp: int
    ) -> None:
        if entity_id not in self.cache:
            self.cache[entity_id] = []
            self.timestamps[entity_id] = None

        self.cache[entity_id].append(
            HistoricalRecord(relation_id, other_entity_id, timestamp)
        )
        self.timestamps[entity_id] = timestamp

        # Remove old history entries
        if len(self.cache[entity_id]) > self.history_length:
            self.cache[entity_id].pop(0)

    def increment_timestamp(self) -> None:
        self.current_timestamp += 1

        # Update timestamps
        for entity_id in self.cache.keys():
            if self.timestamps[entity_id] is None:
                continue

            if self.timestamps[entity_id] >= self.current_timestamp:
                while 

    def get_history(self, entity_id: int) -> Tuple[List[HistoricalRecord], int]:
        return self.cache.get(entity_id, []), self.timestamps.get(entity_id, 0)


class HistoryCache:
    """
    The HistoryCache is responsible for maintaining and managing the historical interactions
    of entities (either source or objects) in the graph, up to a specific length (history length).
    """

    def __init__(self, num_entities, history_length):
        self.num_entities = num_entities
        self.history_length = history_length
        self.histories = [[] for _ in range(num_entities)]
        self.timestamps = [[] for _ in range(num_entities)]
        self.cache = [[] for _ in range(num_entities)]
        self.cache_timestamps = [None for _ in range(num_entities)]

    def update_history(self):
        for entity_id in range(self.num_entities):
            # Remove old history entries
            while len(self.histories[entity_id]) > self.history_length:
                self.histories[entity_id].pop(0)
                self.timestamps[entity_id].pop(0)

            if len(self.cache[entity_id]) > 0:
                if len(self.histories[entity_id]) >= self.history_length:
                    self.histories[entity_id].pop(0)
                    self.timestamps[entity_id].pop(0)

                self.histories[entity_id].append(self.cache[entity_id].copy())
                self.timestamps[entity_id].append(self.cache_timestamps[entity_id])
                self.cache[entity_id] = []
                self.cache_timestamps[entity_id] = None

    def add_to_cache(self, entity_id, relation_id, other_entity_id, timestamp):
        if len(self.cache[entity_id]) == 0:
            self.cache[entity_id] = np.array([[relation_id, other_entity_id]])
        else:
            self.cache[entity_id] = np.concatenate(
                (self.cache[entity_id], [[relation_id, other_entity_id]]), axis=0
            )
        self.cache_timestamps[entity_id] = timestamp

    def get_history(self, entity_id):
        return self.histories[entity_id], self.timestamps[entity_id]

In [185]:
STAT_FILENAME = "stat.txt"
TRAIN_FILENAME = "train.txt"

train_path = os.path.join(IN_PATH, TRAIN_FILENAME)
stat_path = os.path.join(IN_PATH, STAT_FILENAME)

# Load the data
dataloader = DataLoader()
quadruples, times = dataloader.load_quadruples(train_path)
num_entities, num_relations = dataloader.get_total_number(stat_path)

In [186]:
# Process the data
graph_generator = GraphGenerator(num_entities=num_entities, num_relations=num_relations)
history_cache = HistoryCache(num_entities=num_entities, history_length=10)


# Create Temporal snapshots
def _create_temporal_snapshots(
    quadruples: List[Quadruple], times: List[int], graph_generator: GraphGenerator
) -> Dict[int, Data]:
    graph_dict = {}
    for idx, t in enumerate(times):
        if idx % 100 == 0:
            print(f"Processing {idx}/{len(times)}")
        data = graph_generator.get_data_for_time_step(quadruples, t)

        graph = graph_generator.generate_graph(data)
        graph_dict[t] = graph
    return graph_dict


temporal_snapshot_graphs = _create_temporal_snapshots(
    quadruples, times, graph_generator
)

Processing 0/211
Processing 100/211
Processing 200/211


In [212]:
def _create_historical_context(
    quadruples: List[Quadruple], source_cache: HistoryCache, object_cache: HistoryCache
):
    num_quadruples = len(quadruples)

    # Sort quadruples by timestamp
    quadruples = sorted(quadruples, key=lambda x: x.timestamp)

    # Initialize history data containers
    s_history_data = [[] for _ in range(num_quadruples)]
    s_history_data_t = [[] for _ in range(num_quadruples)]
    o_history_data = [[] for _ in range(num_quadruples)]
    o_history_data_t = [[] for _ in range(num_quadruples)]

    latest_t = -1

    for i, quadruple in enumerate(quadruples):
        source_id = quadruple.source_id
        relation_id = quadruple.relation_id
        object_id = quadruple.object_id
        timestep = quadruple.timestamp

        if latest_t != timestep:
            print(f"New timestep: {timestep}, at index {i}")
            source_cache.update_history(timestep)
            object_cache.update_history(timestep)
            latest_t = timestep

        source_cache.add_to_cache(source_id, relation_id, object_id, timestep)
        object_cache.add_to_cache(object_id, relation_id, source_id, timestep)

        s_history_data[i], s_history_data_t[i] = source_cache.get_history(source_id)
        o_history_data[i], o_history_data_t[i] = object_cache.get_history(object_id)

    return s_history_data, s_history_data_t, o_history_data, o_history_data_t


source_cache = HistoryCache(num_entities=num_entities, history_length=10)
object_cache = HistoryCache(num_entities=num_entities, history_length=10)

s_history_data, s_history_data_t, o_history_data, o_history_data_t = (
    _create_historical_context(quadruples, source_cache, object_cache)
)

New timestep: 0, at index 0
New timestep: 1, at index 1513
New timestep: 2, at index 3026
New timestep: 3, at index 4538
New timestep: 4, at index 6055
New timestep: 5, at index 8146
New timestep: 6, at index 10239
New timestep: 7, at index 12333
New timestep: 8, at index 14433
New timestep: 9, at index 16524
New timestep: 10, at index 18620
New timestep: 11, at index 20718
New timestep: 12, at index 22822
New timestep: 13, at index 24914
New timestep: 14, at index 27010
New timestep: 15, at index 29106
New timestep: 16, at index 31211
New timestep: 17, at index 33321
New timestep: 18, at index 35429
New timestep: 19, at index 37561
New timestep: 20, at index 39689
New timestep: 21, at index 41813
New timestep: 22, at index 43947
New timestep: 23, at index 46078
New timestep: 24, at index 48210
New timestep: 25, at index 50342
New timestep: 26, at index 52474
New timestep: 27, at index 54609
New timestep: 28, at index 56747
New timestep: 29, at index 58882
New timestep: 30, at index 61