# Installations

In [1]:
# package installer
import importlib.util
import subprocess

def install_if_missing(package_name):
    if importlib.util.find_spec(package_name) is None:
        subprocess.check_call(["pip", "install", package_name])
        print(f"{package_name} installed.")
    else:
        print(f"{package_name} is already installed.")
install_if_missing("tensorboardX")

tensorboardX installed.


# Dataloader

In [2]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import torch

from torch.utils.data import Dataset
import numpy as np
import random
import torch
import time


def list2tuple(l):
    return tuple(list2tuple(x) if type(x) == list else x for x in l)


def tuple2list(t):
    return list(tuple2list(x) if type(x) == tuple else x for x in t)


flatten = lambda l: sum(map(flatten, l), []) if isinstance(l, tuple) else [l]


def parse_time():
    return time.strftime("%Y.%m.%d-%H:%M:%S", time.localtime())


def set_global_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def eval_tuple(arg_return):
    """Evaluate a tuple string into a tuple."""
    if type(arg_return) == tuple:
        return arg_return
    if arg_return[0] not in ["(", "["]:
        arg_return = eval(arg_return)
    else:
        splitted = arg_return[1:-1].split(",")
        List = []
        for item in splitted:
            try:
                item = eval(item)
            except:
                pass
            if item == "":
                continue
            List.append(item)
        arg_return = tuple(List)
    return arg_return


def flatten_query(queries):
    """assign query structure to each sample"""
    all_queries = []
    for query_structure in queries:
        tmp_queries = list(queries[query_structure])
        all_queries.extend([(query, query_structure) for query in tmp_queries])
    return all_queries


class TestDataset(Dataset):
    def __init__(self, queries, nentity, nrelation):
        # queries is a list of (query, query_structure) pairs
        self.len = len(queries)
        self.queries = queries
        self.nentity = nentity
        self.nrelation = nrelation

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        query = self.queries[idx][0]
        query_structure = self.queries[idx][1]
        negative_sample = torch.LongTensor(range(self.nentity))
        return negative_sample, flatten(query), query, query_structure

    @staticmethod
    def collate_fn(data):
        negative_sample = torch.stack([_[0] for _ in data], dim=0)
        query = [_[1] for _ in data]
        query_unflatten = [_[2] for _ in data]
        query_structure = [_[3] for _ in data]
        return negative_sample, query, query_unflatten, query_structure


class TrainDataset(Dataset):
    def __init__(self, queries, nentity, nrelation, negative_sample_size, answer):
        # queries is a list of (query, query_structure) pairs
        self.len = len(queries)
        self.queries = queries
        self.nentity = nentity
        self.nrelation = nrelation
        self.negative_sample_size = negative_sample_size
        self.count = self.count_frequency(queries, answer)
        self.answer = answer

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        query = self.queries[idx][0]
        query_structure = self.queries[idx][1]
        tail = np.random.choice(list(self.answer[query]))
        subsampling_weight = self.count[query]
        subsampling_weight = torch.sqrt(1 / torch.Tensor([subsampling_weight]))
        negative_sample_list = []
        negative_sample_size = 0
        while negative_sample_size < self.negative_sample_size:
            negative_sample = np.random.randint(self.nentity, size=self.negative_sample_size * 2)
            mask = np.in1d(
                negative_sample,
                self.answer[query],
                assume_unique=True,
                invert=True
            )
            negative_sample = negative_sample[mask]
            negative_sample_list.append(negative_sample)
            negative_sample_size += negative_sample.size
        negative_sample = np.concatenate(negative_sample_list)[:self.negative_sample_size]
        negative_sample = torch.from_numpy(negative_sample)
        positive_sample = torch.LongTensor([tail])
        return positive_sample, negative_sample, subsampling_weight, flatten(query), query_structure

    @staticmethod
    def collate_fn(data):
        positive_sample = torch.cat([_[0] for _ in data], dim=0)
        negative_sample = torch.stack([_[1] for _ in data], dim=0)
        subsample_weight = torch.cat([_[2] for _ in data], dim=0)
        query = [_[3] for _ in data]
        query_structure = [_[4] for _ in data]
        return positive_sample, negative_sample, subsample_weight, query, query_structure

    @staticmethod
    def count_frequency(queries, answer, start=4):
        count = {}
        for query, qtype in queries:
            count[query] = start + len(answer[query])
        return count


class SingledirectionalOneShotIterator(object):
    def __init__(self, dataloader):
        self.iterator = self.one_shot_iterator(dataloader)
        self.step = 0

    def __next__(self):
        self.step += 1
        data = next(self.iterator)
        return data

    @staticmethod
    def one_shot_iterator(dataloader):
        while True:
            for data in dataloader:
                yield data

# Models

### **Formulating Logical Operations Using the Laplace Distribution with Reparameterization**  

Since we are using **Laplace-distributed embeddings** with the **Reparameterization Trick**, we need to define how **basic logical operations (conjunction, disjunction, negation, and projection)** work in this probabilistic embedding space.  

Each entity and relation is represented as a **Laplace-distributed random variable**:  

$$E = \mathcal{L}(\mu, b)$$  
where:  
- $\mu$ is the **mean** embedding.  
- $b$ is the **scale (uncertainty)**.  
- We sample as:  
  $$z = -\text{sign}(u) \cdot \log(1 - 2|u|), \quad u \sim \text{Uniform}(-0.5, 0.5)$$  
  $$x = \mu + b \cdot z$$  




## **4. Negation Operation (NOT Operation)**  
To negate an entity embedding $\mathcal{L}(E) = (\mu, b)$, we define:

$$\mathcal{L}(\neg E) = \mathcal{L}(-\mu, b + \lambda)$$  

🔹 **Intuition**:  
- The **mean is negated** to push it in the opposite direction in embedding space.  
- The **uncertainty increases by $\lambda$** (a tunable hyperparameter) to reflect the fact that negation introduces **more ambiguity**.  

📌 **Example Use Case**:  
- **"Not a city"** should **move the embedding away** from the "city" space while increasing uncertainty.  

---

## **5. Difference Operation (Set Difference)**  
If we want to express **"A but not B"**, we can define:

$$\mathcal{L}(E_1 - E_2) = \mathcal{L}(\mu_1, b_1 + b_2)$$  

🔹 **Intuition**:  
- The **mean stays the same**, but the **uncertainty increases**, since we are removing some knowledge.  

📌 **Example Use Case**:  
- **"Cities in Europe but NOT in France"** should **increase uncertainty** about which cities qualify.  

---

## **6. Comparison with Other Embedding Approaches**  
| **Operation** | **Laplace-based Formulation** | **GammaE (Gamma Distribution)** | **BetaE (Beta Distribution)** |
|--------------|----------------------------|--------------------------|--------------------------|
| **Projection** | $\mu_h + \mu_r, b_h + b_r$ | $\alpha_r \cdot \alpha_h, \beta_r \cdot \beta_h$ | No probabilistic interpretation |
| **Conjunction (AND)** | Weighted mean, reduced uncertainty | Min of Gamma parameters | Min of Beta parameters |
| **Disjunction (OR)** | Mean of means, increased uncertainty | Max of Gamma parameters | Max of Beta parameters |
| **Negation (NOT)** | $-\mu, b + \lambda$ | Inversion of shape parameters | No well-defined rule |

📌 **Why Laplace?**  
- More interpretable than **Gamma/Beta**, since it models **absolute differences**.  
- **Smooth reparameterization** avoids numerical issues.  
- **Handles logical queries better** than traditional deterministic embeddings.

---



In [3]:
# imports 
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import random
import pickle
import math
import collections
import itertools
import time
from tqdm import tqdm
import os

## **Conjunction Operation (Intersection of Two Entities)**  
For two Laplace embeddings $\mathcal{L}(E_1) = (\mu_1, b_1)$ and $\mathcal{L}(E_2) = (\mu_2, b_2)$, their **conjunction (AND operation)** is:

$$\mathcal{L}(E_1 \cap E_2) = \mathcal{L}\left(\frac{b_1 \mu_2 + b_2 \mu_1}{b_1 + b_2}, \frac{b_1 b_2}{b_1 + b_2}\right)$$  

🔹 **Intuition**:  
- The **new mean** is a **weighted average** based on uncertainty (more confident entities have more influence).  
- The **new scale (uncertainty) decreases**, meaning intersection leads to **more confident knowledge**.  

📌 **Example Use Case**:  
- **"Paris is in France"** $\cap$ **"Paris is the capital of a country"**  
- The result will be a **more certain** embedding for **Paris → Country**.


In [4]:
class LaplaceIntersection(nn.Module):
    def __init__(self, dim):
        super(LaplaceIntersection, self).__init__()
        self.dim = dim
        # Linear layers for computing attention over mean (mu)
        self.layer_mu1 = nn.Linear(self.dim * 2, self.dim)
        self.layer_mu2 = nn.Linear(self.dim, self.dim)
        # Linear layers for computing attention over scale (b)
        self.layer_b1 = nn.Linear(self.dim * 2, self.dim)
        self.layer_b2 = nn.Linear(self.dim, self.dim)

        nn.init.xavier_uniform_(self.layer_mu1.weight)
        nn.init.xavier_uniform_(self.layer_mu2.weight)
        nn.init.xavier_uniform_(self.layer_b1.weight)
        nn.init.xavier_uniform_(self.layer_b2.weight)

    def forward(self, mu_embeddings, b_embeddings):
        """
        Implements the intersection (AND) operation on Laplace embeddings.
        
        Inputs:
          - mu_embeddings: Tensor of shape (num_conj, batch_size, dim) representing means.
          - b_embeddings: Tensor of shape (num_conj, batch_size, dim) representing scale parameters.
          
        Outputs:
          - mu_out: Mean of the resulting intersection embedding.
          - b_out: Scale (uncertainty) of the resulting intersection embedding.
        """
        # Concatenate mean and scale embeddings along last dimension
        all_embeddings = torch.cat([mu_embeddings, b_embeddings], dim=-1)

        # Compute attention for mu (mean)
        layer_mu = F.relu(self.layer_mu1(all_embeddings))  
        attention_mu = F.softmax(self.layer_mu2(layer_mu), dim=0)  

        # Compute attention for b (scale)
        layer_b = F.relu(self.layer_b1(all_embeddings))  
        attention_b = F.softmax(self.layer_b2(layer_b), dim=0)  

        # Compute new mean using weighted sum (weighted by uncertainty)
        mu_out = torch.sum(attention_mu * mu_embeddings, dim=0)
        # Compute new scale using harmonic mean-like aggregation
        b_out = torch.sum(attention_b * b_embeddings, dim=0)  

        # Clamping the scale to avoid extreme values
        b_out = torch.clamp(b_out, min=1e-4, max=1.0)

        return mu_out, b_out


## **Disjunction Operation (Union of Two Entities)**  
For **OR operation** (choosing between two possible entities):

$$\mathcal{L}(E_1 \cup E_2) = \mathcal{L}\left(\frac{\mu_1 + \mu_2}{2}, b_1 + b_2\right)$$  

🔹 **Intuition**:  
- The **new mean is the midpoint** between the two possibilities.  
- The **uncertainty increases**, since the union introduces more ambiguity.  

📌 **Example Use Case**:  
- **"The capital of Canada is Ottawa OR Toronto"**  
- The model represents both possibilities with **higher uncertainty**.

In [5]:
class LaplaceUnion(nn.Module):
    def __init__(self, dim, projection_regularizer, drop):
        super(LaplaceUnion, self).__init__()
        self.dim = dim
        # Layers for mean (mu)
        self.layer_mu1 = nn.Linear(self.dim * 2, self.dim)
        self.layer_mu2 = nn.Linear(self.dim, self.dim // 2)
        self.layer_mu3 = nn.Linear(self.dim // 2, self.dim)
        # Layers for scale (b)
        self.layer_b1 = nn.Linear(self.dim * 2, self.dim)
        self.layer_b2 = nn.Linear(self.dim, self.dim // 2)
        self.layer_b3 = nn.Linear(self.dim // 2, self.dim)

        self.projection_regularizer = projection_regularizer
        self.drop = nn.Dropout(p=drop)

        nn.init.xavier_uniform_(self.layer_mu1.weight)
        nn.init.xavier_uniform_(self.layer_mu2.weight)
        nn.init.xavier_uniform_(self.layer_mu3.weight)
        nn.init.xavier_uniform_(self.layer_b1.weight)
        nn.init.xavier_uniform_(self.layer_b2.weight)
        nn.init.xavier_uniform_(self.layer_b3.weight)

    def forward(self, mu_embeddings, b_embeddings):
        """
        Implements the union (OR) operation on Laplace embeddings.
        Inputs:
          - mu_embeddings: Tensor of shape (num_disj, batch_size, dim) for means.
          - b_embeddings: Tensor of shape (num_disj, batch_size, dim) for scale parameters.
          
        Outputs:
          - mu_out: Mean of the resulting union embedding.
          - b_out: Scale (uncertainty) of the resulting union embedding.
        """
        # Concatenate means and scales along the last dimension
        all_embeddings = torch.cat([mu_embeddings, b_embeddings], dim=-1)
        
        # Compute attention for mu (mean)
        layer_mu = F.relu(self.layer_mu1(all_embeddings))  
        layer_mu = F.relu(self.layer_mu2(layer_mu))
        attention_mu = F.softmax(self.drop(self.layer_mu3(layer_mu)), dim=0)  

        # Compute attention for b (scale)
        layer_b = F.relu(self.layer_b1(all_embeddings))  
        layer_b = F.relu(self.layer_b2(layer_b))
        attention_b = F.softmax(self.drop(self.layer_b3(layer_b)), dim=0)  

        # Compute new mean and scale
        mu_out = torch.sum(attention_mu * mu_embeddings, dim=0)  # Average of means
        b_out = torch.sum(b_embeddings, dim=0)  # Sum of uncertainties

        # Clamping the scale to avoid extreme values
        b_out = torch.clamp(b_out, min=1e-4, max=1.0)

        return mu_out, b_out


## **Projection Operation (Entity + Relation → New Entity)**  
Given a **head entity** $h$ and a **relation** $r$, the new projected embedding $t$ is:

$$\mathcal{L}(h + r) = \mathcal{L}(\mu_h + \mu_r, b_h + b_r)$$  

🔹 **Intuition**:  
- The **mean vectors add** because relations shift entity embeddings.  
- The **uncertainty (scale) also adds**, ensuring that uncertainty accumulates through multiple reasoning steps.  
- This allows multi-hop reasoning while **keeping track of uncertainty**.




In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LaplaceProjection(nn.Module):
    def __init__(self, entity_dim, relation_dim, hidden_dim, projection_regularizer, num_layers):
        super(LaplaceProjection, self).__init__()
        self.entity_dim = entity_dim
        self.relation_dim = relation_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Neural network for processing the mean (mu)
        self.layer_mu1 = nn.Linear(self.entity_dim + self.relation_dim, self.hidden_dim)  
        self.layer_mu0 = nn.Linear(self.hidden_dim, self.entity_dim)  # Final layer
        
        # Neural network for processing the scale (b)
        self.layer_b1 = nn.Linear(self.entity_dim + self.relation_dim, self.hidden_dim)  
        self.layer_b0 = nn.Linear(self.hidden_dim, self.entity_dim)  

        # Additional layers for deeper networks
        for nl in range(2, num_layers + 1):
            setattr(self, f"layer_mu{nl}", nn.Linear(self.hidden_dim, self.hidden_dim))
            setattr(self, f"layer_b{nl}", nn.Linear(self.hidden_dim, self.hidden_dim))

        # Xavier Initialization
        for nl in range(1, num_layers + 1):
            nn.init.xavier_uniform_(getattr(self, f"layer_mu{nl}").weight)
            nn.init.xavier_uniform_(getattr(self, f"layer_b{nl}").weight)

        self.projection_regularizer = projection_regularizer

    def forward(self, mu_embedding, b_embedding, mu_embedding_r, b_embedding_r):
        """
        Implements the projection operation: moving from one entity to another via a relation.
        
        Inputs:
          - mu_embedding: Mean of entity embeddings.
          - b_embedding: Scale (uncertainty) of entity embeddings.
          - mu_embedding_r: Mean of relation embeddings.
          - b_embedding_r: Scale (uncertainty) of relation embeddings.
        
        Outputs:
          - mu_out: Projected mean embedding.
          - b_out: Projected scale (uncertainty).
        """
        # Concatenate entity and relation embeddings
        x_mu = torch.cat([mu_embedding, mu_embedding_r], dim=-1)
        x_b = torch.cat([b_embedding, b_embedding_r], dim=-1)

        # Pass through deep network for mu (mean)
        for nl in range(1, self.num_layers + 1):
            x_mu = F.relu(getattr(self, f"layer_mu{nl}")(x_mu))
        mu_out = self.layer_mu0(x_mu)
        mu_out = self.projection_regularizer(mu_out)

        # Pass through deep network for b (scale)
        for nl in range(1, self.num_layers + 1):
            x_b = F.relu(getattr(self, f"layer_b{nl}")(x_b))
        b_out = self.layer_b0(x_b)
        b_out = self.projection_regularizer(b_out)

        # Enforce positivity constraint on scale (uncertainty)
        b_out = torch.clamp(b_out, min=1e-4, max=1.0)

        return mu_out, b_out


## Regularizer

In [7]:
class Regularizer():
    def __init__(self, base_add, min_val, max_val):
        self.base_add = base_add
        self.min_val = min_val
        self.max_val = max_val

    def __call__(self, entity_embedding):
        return torch.clamp(entity_embedding + self.base_add, self.min_val, self.max_val)

# KGReasoning

Distance between two entity embeddings computed using Wasserstein-1 distance between Laplace distributions.
$$
W_1(\mathcal{L}_1, \mathcal{L}_2) = |\mu_1 - \mu_2| + |b_1 - b_2|
$$


In [None]:
class KGReasoningLapE(nn.Module):
    def __init__(self, nentity, nrelation, hidden_dim, gamma,
                 geo, test_batch_size=1,
                 box_mode=None, use_cuda=False,
                 query_name_dict=None, beta_mode=None, gamma_mode=None, drop=0.):
        super(KGReasoningLapE, self).__init__()

        self.nentity = nentity
        self.nrelation = nrelation
        self.hidden_dim = hidden_dim
        self.epsilon = 2.0
        self.geo = geo
        self.is_u = False
        self.use_cuda = use_cuda
        self.batch_entity_range = (
            torch.arange(nentity).to(torch.float).repeat(test_batch_size, 1).cuda()
            if self.use_cuda else torch.arange(nentity).to(torch.float).repeat(test_batch_size, 1)
        )
        self.query_name_dict = query_name_dict

        self.gamma = nn.Parameter(torch.Tensor([gamma]), requires_grad=False)
        self.embedding_range = nn.Parameter(
            torch.Tensor([(self.gamma.item() + self.epsilon) / hidden_dim]), requires_grad=False
        )

        self.entity_dim = hidden_dim
        self.relation_dim = hidden_dim

        # Each entity embedding is represented as a Laplace distribution:
        # first half is μ (mean) and second half is b (scale/uncertainty)
        self.entity_embedding = nn.Parameter(torch.zeros(nentity, self.entity_dim * 2))
        self.entity_regularizer = Regularizer(1, 0.15, 1e9)
        self.projection_regularizer = Regularizer(1, 0.15, 1e9)

        nn.init.uniform_(
            tensor=self.entity_embedding,
            a=-3 * self.embedding_range.item(),
            b=3 * self.embedding_range.item()
        )

        # Standard relation embeddings remain for additional relation information
        self.relation_embedding = nn.Parameter(torch.zeros(nrelation, self.relation_dim))
        nn.init.uniform_(
            tensor=self.relation_embedding,
            a=-3 * self.embedding_range.item(),
            b=3 * self.embedding_range.item()
        )

        # For relations, we now maintain separate parameters for μ and b
        self.mu_relation = nn.Parameter(torch.zeros(nrelation, self.relation_dim), requires_grad=True)
        nn.init.uniform_(
            tensor=self.mu_relation,
            a=-3 * self.embedding_range.item(),
            b=3 * self.embedding_range.item()
        )
        self.b_relation = nn.Parameter(torch.zeros(nrelation, self.relation_dim), requires_grad=True)
        nn.init.uniform_(
            tensor=self.b_relation,
            a=-3 * self.embedding_range.item(),
            b=3 * self.embedding_range.item()
        )

        self.modulus = nn.Parameter(torch.Tensor([1 * self.embedding_range.item()]), requires_grad=True)

        # gamma_mode returns (hidden_dim, num_layers)
        hidden_dim, num_layers = gamma_mode
        # Use our Laplace-based modules for logical operations:
        self.center_net = LaplaceIntersection(self.entity_dim)
        self.projection_net = LaplaceProjection(self.entity_dim,
                                                 self.relation_dim,
                                                 hidden_dim,
                                                 self.projection_regularizer,
                                                 num_layers)
        self.union_net = LaplaceUnion(self.entity_dim, self.projection_regularizer, drop)
    def sample_laplace(self, mu, b):
        # u ~ Uniform(-0.5, 0.5)
        u = torch.rand_like(mu) - 0.5
        # z ~ Laplace(0, 1)
        z = -torch.sign(u) * torch.log(1 - 2 * torch.abs(u) + 1e-12)
        # reparameterized sample
        return mu + b * z
    def embed_query_lape(self, queries, query_structure, idx):
        '''
        Iteratively embeds a batch of queries with the same structure using Laplace-based embeddings.
        Each entity is represented as a Laplace distribution: (mu, b).
        '''
        # Special case handling (if needed)
        if query_structure == ((('e', ('r',)), ('e', ('r', 'n'))), ('r',)):
            aa = 1  # (dummy code, as in original)
        
        # Determine if the current query structure is purely relational (only 'r' and 'n')
        all_relation_flag = True
        for ele in query_structure[-1]:
            if ele not in ['r', 'n']:
                all_relation_flag = False
                break

        if all_relation_flag:
            # Base case: query structure starts with an entity
            if query_structure[0] == 'e':
                ent_embedding = torch.index_select(self.entity_embedding, dim=0, index=queries[:, idx])
                # Split the entity embedding into mean (mu) and scale (b)
                mu_embedding, b_embedding = torch.chunk(ent_embedding, 2, dim=-1)
                mu_embedding = self.sample_laplace(mu_embedding, b_embedding)
                idx += 1
            else:
                mu_embedding, b_embedding, idx = self.embed_query_lape(queries, query_structure[0], idx)
            
            # Process each relation (or negation) in the current branch
            for i in range(len(query_structure[-1])):
                if query_structure[-1][i] == 'n':
                # Negation: flip the sign of the mean and increase the scale
                # Here we add a constant (0.07) to represent increased uncertainty
                    mu_embedding = -mu_embedding
                    b_embedding = b_embedding + 0.07
                else:
                # For relation traversal, use the LapE relation embeddings:
                # Use self.mu_relation and self.b_relation instead of self.alpha_embedding and self.beta_embedding
                    mu_r_embedding = torch.index_select(self.mu_relation, dim=0, index=queries[:, idx])
                    b_r_embedding = torch.index_select(self.b_relation, dim=0, index=queries[:, idx])
                    # Apply projection operation
                    mu_embedding, b_embedding = self.projection_net(mu_embedding, b_embedding,
                                                                 mu_r_embedding, b_r_embedding)
                idx += 1

        else:
        # If not all relations, then we are dealing with a multi-branch (e.g., union or intersection) query.
            if self.is_u:
                mu_embedding_list = []
                b_embedding_list = []
                for i in range(len(query_structure)):
                    mu_emb, b_emb, idx = self.embed_query_lape(queries, query_structure[i], idx)
                    mu_embedding_list.append(mu_emb)
                    b_embedding_list.append(b_emb)
                mu_embedding, b_embedding = self.union_net(torch.stack(mu_embedding_list),
                                                        torch.stack(b_embedding_list))
            else:
                mu_embedding_list = []
                b_embedding_list = []
                for i in range(len(query_structure)):
                    mu_emb, b_emb, idx = self.embed_query_lape(queries, query_structure[i], idx)
                    mu_embedding_list.append(mu_emb)
                    b_embedding_list.append(b_emb)
                mu_embedding, b_embedding = self.center_net(torch.stack(mu_embedding_list),
                                                         torch.stack(b_embedding_list))

        return mu_embedding, b_embedding, idx

    def cal_logit_lape(self, entity_embedding, query_dist):
        """
        Compute the logit for a query based on Laplace-distributed embeddings.
        
        Args:
          entity_embedding: Tensor of shape (..., 2 * dim), where the first half is μ (mean)
                            and the second half is b (scale).
          query_dist: A tuple (query_mu, query_b) representing the Laplace distribution
                      of the query. distance computed using Wasserstein-1 distance between Laplace distributions.
        
        Returns:
          logit: The computed logit value, where higher values indicate higher similarity.
        """
        # Split the entity embedding into mean (mu) and scale (b)
        mu_embedding, b_embedding = torch.chunk(entity_embedding, 2, dim=-1)
        # Unpack the query distribution (assumed to be in the same format: (mu, b))
        query_mu, query_b = query_dist
        query_mu = self.sample_laplace(query_mu, query_b)
        # Compute the Wasserstein-1 distance between the Laplace distributions:
        # Compute elementwise absolute differences for both μ and b, then sum over the embedding dimensions.
        distance = torch.abs(mu_embedding - query_mu) + torch.abs(b_embedding - query_b)
        distance = torch.sum(distance, dim=-1)
        
        # Compute the logit by subtracting the distance from the margin parameter gamma.
        logit = self.gamma - distance
    
        return logit

    def forward(self, positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict):
        all_idxs, all_mu_embeddings, all_b_embeddings = [], [], []
        all_union_idxs, all_union_mu_embeddings, all_union_b_embeddings = [], [], []
    
        # Loop over each query structure in the batch.
        for query_structure in batch_queries_dict:
            if 'u' in self.query_name_dict[query_structure] and 'DNF' in self.query_name_dict[query_structure]:
                self.is_u = True
                # For union queries, transform the query and then embed it.
                mu_embedding, b_embedding, _ = \
                    self.embed_query_lape(self.transform_union_query(batch_queries_dict[query_structure],
                                                                      query_structure),
                                           self.transform_union_structure(query_structure),
                                           0)
                all_union_idxs.extend(batch_idxs_dict[query_structure])
                all_union_mu_embeddings.append(mu_embedding)
                all_union_b_embeddings.append(b_embedding)
            else:
                self.is_u = False
                mu_embedding, b_embedding, _ = self.embed_query_lape(batch_queries_dict[query_structure],
                                                                     query_structure,
                                                                     0)
                all_idxs.extend(batch_idxs_dict[query_structure])
                all_mu_embeddings.append(mu_embedding)
                all_b_embeddings.append(b_embedding)
    
        # Form the Laplace distributions for non-union queries as (mu, b) tuples.
        if len(all_mu_embeddings) > 0:
            all_mu_embeddings = torch.cat(all_mu_embeddings, dim=0).unsqueeze(1)
            all_b_embeddings = torch.cat(all_b_embeddings, dim=0).unsqueeze(1)
            all_dists = (all_mu_embeddings, all_b_embeddings)
        # For union queries.
        if len(all_union_mu_embeddings) > 0:
            all_union_mu_embeddings = torch.cat(all_union_mu_embeddings, dim=0).unsqueeze(1)
            all_union_b_embeddings = torch.cat(all_union_b_embeddings, dim=0).unsqueeze(1)
            all_union_dists = (all_union_mu_embeddings, all_union_b_embeddings)
            
        if subsampling_weight is not None:
            subsampling_weight = subsampling_weight[all_idxs + all_union_idxs]
    
        # Process positive samples.
        if positive_sample is not None:
            if len(all_mu_embeddings) > 0:
                positive_sample_regular = positive_sample[all_idxs]
                positive_embedding = self.entity_regularizer(
                    torch.index_select(self.entity_embedding, dim=0, index=positive_sample_regular).unsqueeze(1))
                positive_logit = self.cal_logit_lape(positive_embedding, all_dists)
            else:
                positive_logit = torch.Tensor([]).to(self.entity_embedding.device)
    
            if len(all_union_mu_embeddings) > 0:
                positive_sample_union = positive_sample[all_union_idxs]
                positive_embedding = self.entity_regularizer(
                    torch.index_select(self.entity_embedding, dim=0, index=positive_sample_union).unsqueeze(1))
                positive_union_logit = self.cal_logit_lape(positive_embedding, all_union_dists)
            else:
                positive_union_logit = torch.Tensor([]).to(self.entity_embedding.device)
            positive_logit = torch.cat([positive_logit, positive_union_logit], dim=0)
        else:
            positive_logit = None
    
        # Process negative samples.
        if negative_sample is not None:
            if len(all_mu_embeddings) > 0:
                negative_sample_regular = negative_sample[all_idxs]
                batch_size, negative_size = negative_sample_regular.shape
                negative_embedding = self.entity_regularizer(
                    torch.index_select(self.entity_embedding, dim=0, index=negative_sample_regular.view(-1))
                    .view(batch_size, negative_size, -1)
                )
                negative_logit = self.cal_logit_lape(negative_embedding, all_dists)
            else:
                negative_logit = torch.Tensor([]).to(self.entity_embedding.device)
    
            if len(all_union_mu_embeddings) > 0:
                negative_sample_union = negative_sample[all_union_idxs]
                batch_size, negative_size = negative_sample_union.shape
                negative_embedding = self.entity_regularizer(
                    torch.index_select(self.entity_embedding, dim=0, index=negative_sample_union.view(-1))
                    .view(batch_size, negative_size, -1)
                )
                negative_union_logit = self.cal_logit_lape(negative_embedding, all_union_dists)
            else:
                negative_union_logit = torch.Tensor([]).to(self.entity_embedding.device)
            negative_logit = torch.cat([negative_logit, negative_union_logit], dim=0)
        else:
            negative_logit = None
    
        return positive_logit, negative_logit, subsampling_weight, all_idxs + all_union_idxs
    def transform_union_query(self, queries, query_structure):
        # For union queries, the transformation remains the same
        # regardless of whether we use Gamma or Laplace embeddings.
        if self.query_name_dict[query_structure] == '2u-DNF':
            queries = queries[:, :-1]
        elif self.query_name_dict[query_structure] == 'up-DNF':
            queries = torch.cat([queries[:, :4], queries[:, 5:6]], dim=1)
        return queries
    
    def transform_union_structure(self, query_structure):
        # The union structure mapping is identical for LapE.
        if self.query_name_dict[query_structure] == '2u-DNF':
            return (('e', ('r',)), ('e', ('r',)))
        elif self.query_name_dict[query_structure] == 'up-DNF':
            return ((('e', ('r',)), ('e', ('r',))), ('r',))

    @staticmethod
    def train_step(model, optimizer, train_iterator, args, step):
        model.train()
        optimizer.zero_grad()
    
        positive_sample, negative_sample, subsampling_weight, batch_queries, query_structures = next(train_iterator)
        batch_queries_dict = collections.defaultdict(list)
        batch_idxs_dict = collections.defaultdict(list)
        for i, query in enumerate(batch_queries):
            batch_queries_dict[query_structures[i]].append(query)
            batch_idxs_dict[query_structures[i]].append(i)
        for query_structure in batch_queries_dict:
            if args.cuda:
                batch_queries_dict[query_structure] = torch.LongTensor(batch_queries_dict[query_structure]).cuda()
            else:
                batch_queries_dict[query_structure] = torch.LongTensor(batch_queries_dict[query_structure])
        if args.cuda:
            positive_sample = positive_sample.cuda()
            negative_sample = negative_sample.cuda()
            subsampling_weight = subsampling_weight.cuda()
    
        # Call the LapE model's forward function, which returns positive and negative logits computed via Laplace embeddings.
        positive_logit, negative_logit, subsampling_weight, _ = model(
            positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict
        )
    
        # Compute the ranking loss using logsigmoid. Note that a higher logit means a better match.
        negative_score = F.logsigmoid(-negative_logit).mean(dim=1)
        positive_score = F.logsigmoid(positive_logit).squeeze(dim=1)
        positive_sample_loss = - (subsampling_weight * positive_score).sum()
        negative_sample_loss = - (subsampling_weight * negative_score).sum()
        positive_sample_loss /= subsampling_weight.sum()
        negative_sample_loss /= subsampling_weight.sum()
    
        loss = (positive_sample_loss + negative_sample_loss) / 2
        loss.backward()
        optimizer.step()
        log = {
            'positive_sample_loss': positive_sample_loss.item(),
            'negative_sample_loss': negative_sample_loss.item(),
            'loss': loss.item(),
        }
        return log
    
    @staticmethod
    def test_step(model, easy_answers, hard_answers, args, test_dataloader, query_name_dict, save_result=False,
                  save_str="", save_empty=False):
        model.eval()
        step = 0
        total_steps = len(test_dataloader)
        logs = collections.defaultdict(list)
    
        with torch.no_grad():
            for negative_sample, queries, queries_unflatten, query_structures in tqdm(
                    test_dataloader, disable=not args.print_on_screen):
                batch_queries_dict = collections.defaultdict(list)
                batch_idxs_dict = collections.defaultdict(list)
                for i, query in enumerate(queries):
                    batch_queries_dict[query_structures[i]].append(query)
                    batch_idxs_dict[query_structures[i]].append(i)
                for query_structure in batch_queries_dict:
                    if args.cuda:
                        batch_queries_dict[query_structure] = torch.LongTensor(
                            batch_queries_dict[query_structure]).cuda()
                    else:
                        batch_queries_dict[query_structure] = torch.LongTensor(batch_queries_dict[query_structure])
                if args.cuda:
                    negative_sample = negative_sample.cuda()
    
                # Call our LapE model's forward function (which now returns logits computed using Laplace embeddings)
                _, negative_logit, _, idxs = model(None, negative_sample, None, batch_queries_dict, batch_idxs_dict)
                queries_unflatten = [queries_unflatten[i] for i in idxs]
                query_structures = [query_structures[i] for i in idxs]
                argsort = torch.argsort(negative_logit, dim=1, descending=True)
                ranking = argsort.clone().to(torch.float)
                if len(argsort) == args.test_batch_size:  # reuse batch_entity_range if possible
                    ranking = ranking.scatter_(1, argsort, model.batch_entity_range)
                else:
                    if args.cuda:
                        ranking = ranking.scatter_(1, argsort,
                                                    torch.arange(model.nentity).to(torch.float).repeat(argsort.shape[0], 1).cuda())
                    else:
                        ranking = ranking.scatter_(1, argsort,
                                                    torch.arange(model.nentity).to(torch.float).repeat(argsort.shape[0], 1))
                for idx, (i, query, query_structure) in enumerate(zip(argsort[:, 0], queries_unflatten, query_structures)):
                    hard_answer = hard_answers[query]
                    easy_answer = easy_answers[query]
                    num_hard = len(hard_answer)
                    num_easy = len(easy_answer)
                    assert len(hard_answer.intersection(easy_answer)) == 0
                    cur_ranking = ranking[idx, list(easy_answer) + list(hard_answer)]
                    cur_ranking, indices = torch.sort(cur_ranking)
                    masks = indices >= num_easy
                    if args.cuda:
                        answer_list = torch.arange(num_hard + num_easy).to(torch.float).cuda()
                    else:
                        answer_list = torch.arange(num_hard + num_easy).to(torch.float)
                    cur_ranking = cur_ranking - answer_list + 1  # filtered setting
                    cur_ranking = cur_ranking[masks]  # only take indices that belong to the hard answers
                        
                    mrr = torch.mean(1. / cur_ranking).item()
                    h1 = torch.mean((cur_ranking <= 1).to(torch.float)).item()
                    h3 = torch.mean((cur_ranking <= 3).to(torch.float)).item()
                    h10 = torch.mean((cur_ranking <= 10).to(torch.float)).item()
    
                    logs[query_structure].append({
                        'MRR': mrr,
                        'HITS1': h1,
                        'HITS3': h3,
                        'HITS10': h10,
                        'num_hard_answer': num_hard,
                    })
    
                if step % args.test_log_steps == 0:
                    logging.info('Evaluating the model... (%d/%d)' % (step, total_steps))
                step += 1
    
        metrics = collections.defaultdict(lambda: collections.defaultdict(int))
        for query_structure in logs:
            for metric in logs[query_structure][0].keys():
                if metric in ['num_hard_answer']:
                    continue
                metrics[query_structure][metric] = sum([log[metric] for log in logs[query_structure]]) / len(logs[query_structure])
            metrics[query_structure]['num_queries'] = len(logs[query_structure])
    
        return metrics
    
    

# Main runner

In [9]:
# imports
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import _thread
import argparse
import json
import logging
import os
import random

import numpy as np
import torch
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
import time
import pickle
from collections import defaultdict
from tqdm import tqdm

In [10]:
query_name_dict = {('e', ('r',)): '1p',
                   ('e', ('r', 'r')): '2p',
                   ('e', ('r', 'r', 'r')): '3p',
                   (('e', ('r',)), ('e', ('r',))): '2i',
                   (('e', ('r',)), ('e', ('r',)), ('e', ('r',))): '3i',
                   ((('e', ('r',)), ('e', ('r',))), ('r',)): 'ip',
                   (('e', ('r', 'r')), ('e', ('r',))): 'pi',
                   (('e', ('r',)), ('e', ('r', 'n'))): '2in',
                   (('e', ('r',)), ('e', ('r',)), ('e', ('r', 'n'))): '3in',
                   ((('e', ('r',)), ('e', ('r', 'n'))), ('r',)): 'inp',
                   (('e', ('r', 'r')), ('e', ('r', 'n'))): 'pin',
                   (('e', ('r', 'r', 'n')), ('e', ('r',))): 'pni',
                   (('e', ('r',)), ('e', ('r',)), ('u',)): '2u-DNF',
                   ((('e', ('r',)), ('e', ('r',)), ('u',)), ('r',)): 'up-DNF',
                   ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n',)): '2u-DM',
                   ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n', 'r')): 'up-DM'
                   }
name_query_dict = {value: key for key, value in query_name_dict.items()}
all_tasks = list(
    name_query_dict.keys())  # ['1p', '2p', '3p', '2i', '3i', 'ip', 'pi', '2in', '3in', 'inp', 'pin', 'pni', '2u-DNF', '2u-DM', 'up-DNF', 'up-DM']


class Args:
    def __init__(self, **kwargs):
        # Basic settings
        self.cuda = kwargs.get('cuda', True)
        self.do_train = kwargs.get('do_train', True)
        self.do_valid = kwargs.get('do_valid', True)
        self.do_test = kwargs.get('do_test', True)

        # Data path and dataset parameters
        self.data_path = kwargs.get('data_path', '/path/to/data')
        self.negative_sample_size = kwargs.get('negative_sample_size', 128)
        self.nentity = kwargs.get('nentity', 0)  # Will be set later from stats.txt
        self.nrelation = kwargs.get('nrelation', 0)  # Will be set later from stats.txt

        # Embedding and model parameters
        self.hidden_dim = kwargs.get('hidden_dim', 800)
        self.gamma = kwargs.get('gamma', 60.0)  # Margin for ranking loss
        self.geo = kwargs.get('geo', 'gamma')   # For LapE, you can set this to 'laplace' or leave as is for compatibility

        # Batch and learning parameters
        self.batch_size = kwargs.get('batch_size', 512)
        self.test_batch_size = kwargs.get('test_batch_size', 4)
        self.learning_rate = kwargs.get('learning_rate', 0.0001)
        self.cpu_num = kwargs.get('cpu_num', 3)

        # Checkpoint and save settings
        self.save_path = kwargs.get('save_path', './checkpoints')
        self.max_steps = kwargs.get('max_steps', 450001)
        self.warm_up_steps = kwargs.get('warm_up_steps', None)
        self.drop = kwargs.get('drop', 0.1)
        self.save_checkpoint_steps = kwargs.get('save_checkpoint_steps', 50000)
        self.valid_steps = kwargs.get('valid_steps', 30000)
        self.log_steps = kwargs.get('log_steps', 100)
        self.test_log_steps = kwargs.get('test_log_steps', 1000)

        # Query and task settings
        self.tasks = kwargs.get('tasks', '1p.2p.3p.2i.3i.ip.pi.2in.3in.inp.pin.pni.2u.up')
        self.evaluate_union = kwargs.get('evaluate_union', "DNF")
        self.print_on_screen = kwargs.get('print_on_screen', True)
        
        # Seed and mode settings
        self.seed = kwargs.get('seed', 42)
        self.beta_mode = kwargs.get('beta_mode', "(1600,2)")
        self.gamma_mode = kwargs.get('gamma_mode', "(1600,4)")
        self.box_mode = kwargs.get('box_mode', "(none,0.02)")
        self.prefix = kwargs.get('prefix', None)
        self.checkpoint_path = kwargs.get('checkpoint_path', None)
        
def save_model(model, optimizer, save_variable_list, args):
    """
    Save the parameters of the model and the optimizer,
    as well as some other variables such as step and learning_rate.
    (Works for LapE model)
    """
    argparse_dict = vars(args)
    with open(os.path.join(args.save_path, 'config.json'), 'w') as fjson:
        json.dump(argparse_dict, fjson)

    torch.save({
        **save_variable_list,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, os.path.join(args.save_path, 'checkpoint'))


def set_logger(args):
    """
    Configure logging to output to both console and a log file.
    (Works for LapE model)
    """
    if args.do_train:
        log_file = os.path.join(args.save_path, 'train.log')
    else:
        log_file = os.path.join(args.save_path, 'test.log')

    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S',
        filename=log_file,
        filemode='a+'
    )
    if args.print_on_screen:
        console = logging.StreamHandler()
        console.setLevel(logging.INFO)
        formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
        console.setFormatter(formatter)
        logging.getLogger('').addHandler(console)
def log_metrics(mode, step, metrics):
    """
    Print the evaluation logs.
    (This function is generic and works for LapE as well.)
    """
    for metric in metrics:
        logging.info('%s %s at step %d: %f' % (mode, metric, step, metrics[metric]))
def evaluate(model, tp_answers, fn_answers, args, dataloader, query_name_dict, mode, step, writer):
    '''
    Evaluate queries in the dataloader for the LapE model.
    This function computes metrics (MRR, Hits@K, etc.) by calling the model's test_step,
    and then logs and aggregates the results.
    '''
    average_metrics = defaultdict(float)
    all_metrics = defaultdict(float)

    metrics = model.test_step(model, tp_answers, fn_answers, args, dataloader, query_name_dict)
    num_query_structures = 0
    num_queries = 0
    for query_structure in metrics:
        log_metrics(mode + " " + query_name_dict[query_structure], step, metrics[query_structure])
        for metric in metrics[query_structure]:
            writer.add_scalar("_".join([mode, query_name_dict[query_structure], metric]),
                              metrics[query_structure][metric], step)
            all_metrics["_".join([query_name_dict[query_structure], metric])] = metrics[query_structure][metric]
            if metric != 'num_queries':
                average_metrics[metric] += metrics[query_structure][metric]
        num_queries += metrics[query_structure]['num_queries']
        num_query_structures += 1

    for metric in average_metrics:
        average_metrics[metric] /= num_query_structures
        writer.add_scalar("_".join([mode, 'average', metric]), average_metrics[metric], step)
        all_metrics["_".join(["average", metric])] = average_metrics[metric]
    log_metrics('%s average' % mode, step, average_metrics)

    return all_metrics

def load_data(args, tasks):
    '''
    Load queries and remove queries not in tasks.
    This function is independent of the embedding distribution (Gamma, Laplace, etc.)
    and can be used as-is for the LapE model.
    '''
    logging.info("loading data")
    train_queries = pickle.load(open(os.path.join(args.data_path, "train-queries.pkl"), 'rb'))
    train_answers = pickle.load(open(os.path.join(args.data_path, "train-answers.pkl"), 'rb'))
    valid_queries = pickle.load(open(os.path.join(args.data_path, "valid-queries.pkl"), 'rb'))
    valid_hard_answers = pickle.load(open(os.path.join(args.data_path, "valid-hard-answers.pkl"), 'rb'))
    valid_easy_answers = pickle.load(open(os.path.join(args.data_path, "valid-easy-answers.pkl"), 'rb'))
    test_queries = pickle.load(open(os.path.join(args.data_path, "test-queries.pkl"), 'rb'))
    test_hard_answers = pickle.load(open(os.path.join(args.data_path, "test-hard-answers.pkl"), 'rb'))
    test_easy_answers = pickle.load(open(os.path.join(args.data_path, "test-easy-answers.pkl"), 'rb'))

    # remove tasks not in args.tasks
    for name in all_tasks:
        if 'u' in name:
            name, evaluate_union = name.split('-')
        else:
            evaluate_union = args.evaluate_union
        if name not in tasks or evaluate_union != args.evaluate_union:
            query_structure = name_query_dict[name if 'u' not in name else '-'.join([name, evaluate_union])]
            if query_structure in train_queries:
                del train_queries[query_structure]
            if query_structure in valid_queries:
                del valid_queries[query_structure]
            if query_structure in test_queries:
                del test_queries[query_structure]

    return train_queries, train_answers, valid_queries, valid_hard_answers, valid_easy_answers, test_queries, test_hard_answers, test_easy_answers

def main(args):
    set_global_seed(args.seed)
    tasks = args.tasks.split('.')
    for task in tasks:
        if 'n' in task and args.geo in ['box', 'vec']:
            assert False, "Q2B and GQE cannot handle queries with negation"
    # For union evaluation, our LapE model uses the same scheme as GammaE
    if args.evaluate_union == 'DM':
        assert args.geo == 'gamma', "only BetaE supports modeling union using De Morgan's Laws"

    cur_time = parse_time()
    prefix = args.prefix if args.prefix is not None else 'logs'

    print("overwriting args.save_path")
    args.save_path = os.path.join(prefix, os.path.basename(args.data_path), args.tasks, args.geo)
    if args.geo in ['box']:
        tmp_str = "g-{}-mode-{}".format(args.gamma, args.box_mode)
    elif args.geo in ['vec']:
        tmp_str = "g-{}".format(args.gamma)
    elif args.geo == 'beta':
        tmp_str = "g-{}-mode-{}".format(args.gamma, args.beta_mode)
    elif args.geo == 'gamma':
        tmp_str = "g-{}-mode-{}".format(args.gamma, args.gamma_mode)
    # For our LapE model, we typically set args.geo to 'laplace'
    else: tmp_str = "g-{}-mode-{}".format(args.gamma, args.geo)  # Uses args.geo directly

    if args.checkpoint_path is not None:
        args.save_path = args.checkpoint_path
    else:
        args.save_path = os.path.join(args.save_path, tmp_str, cur_time)

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    print("logging to", args.save_path)
    writer = SummaryWriter(args.save_path) if args.do_train else SummaryWriter('./logs-debug/unused-tb')
    set_logger(args)

    with open(os.path.join(args.data_path, 'stats.txt')) as f:
        entrel = f.readlines()
        nentity = int(entrel[0].split(' ')[-1])
        nrelation = int(entrel[1].split(' ')[-1])
    args.nentity = nentity
    args.nrelation = nrelation

    logging.info('-------------------------------' * 3)
    logging.info('Geo: %s' % args.geo)
    logging.info('seed: %d' % args.seed)
    logging.info('Data Path: %s' % args.data_path)
    logging.info('#entity: %d' % nentity)
    logging.info('#relation: %d' % nrelation)
    logging.info('#max steps: %d' % args.max_steps)
    logging.info('Evaluate unions using: %s' % args.evaluate_union)

    train_queries, train_answers, valid_queries, valid_hard_answers, valid_easy_answers, \
        test_queries, test_hard_answers, test_easy_answers = load_data(args, tasks)
    # Merge union queries from valid into train queries
    train_queries[(('e', ('r',)), ('e', ('r',)), ('u',))].update(valid_queries[(('e', ('r',)), ('e', ('r',)), ('u',))])
    train_queries[((('e', ('r',)), ('e', ('r',)), ('u',)), ('r',))].update(
        valid_queries[((('e', ('r',)), ('e', ('r',)), ('u',)), ('r',))])
    # for key in valid_easy_answers.keys():
    #     valid_easy_answers[key] = valid_easy_answers[key].union(valid_hard_answers[key])
    for key, valid_set in valid_hard_answers.items():
        if key in train_answers:
            train_answers[key] = train_answers[key].union(valid_set)
        else:
            train_answers[key] = valid_set
    logging.info("Training info:")
    if args.do_train:
        for query_structure in train_queries:
            logging.info(query_name_dict[query_structure] + ": " + str(len(train_queries[query_structure])))
        train_path_queries = defaultdict(set)
        train_other_queries = defaultdict(set)
        path_list = ['1p', '2p', '3p']
        for query_structure in train_queries:
            if query_name_dict[query_structure] in path_list:
                train_path_queries[query_structure] = train_queries[query_structure]
            else:
                train_other_queries[query_structure] = train_queries[query_structure]
        train_path_queries = flatten_query(train_path_queries)
        train_path_iterator = SingledirectionalOneShotIterator(DataLoader(
            TrainDataset(train_path_queries, nentity, nrelation, args.negative_sample_size, train_answers),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.cpu_num,
            collate_fn=TrainDataset.collate_fn
        ))
        if len(train_other_queries) > 0:
            train_other_queries = flatten_query(train_other_queries)
            train_other_iterator = SingledirectionalOneShotIterator(DataLoader(
                TrainDataset(train_other_queries, nentity, nrelation, args.negative_sample_size, train_answers),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.cpu_num,
                collate_fn=TrainDataset.collate_fn
            ))
        else:
            train_other_iterator = None

    logging.info("Validation info:")
    if args.do_valid:
        for query_structure in valid_queries:
            logging.info(query_name_dict[query_structure] + ": " + str(len(valid_queries[query_structure])))
        valid_queries = flatten_query(valid_queries)
        valid_dataloader = DataLoader(
            TestDataset(valid_queries, args.nentity, args.nrelation),
            batch_size=args.test_batch_size,
            num_workers=args.cpu_num,
            collate_fn=TestDataset.collate_fn
        )

    logging.info("Test info:")
    if args.do_test:
        for query_structure in test_queries:
            logging.info(query_name_dict[query_structure] + ": " + str(len(test_queries[query_structure])))
        test_queries = flatten_query(test_queries)
        test_dataloader = DataLoader(
            TestDataset(test_queries, args.nentity, args.nrelation),
            batch_size=args.test_batch_size,
            num_workers=args.cpu_num,
            collate_fn=TestDataset.collate_fn
        )

    # Initialize our LapE model (KGReasoningLapE) instead of GammaE
    model = KGReasoningLapE(
        nentity=nentity,
        nrelation=nrelation,
        hidden_dim=args.hidden_dim,
        gamma=args.gamma,
        geo=args.geo,
        use_cuda=args.cuda,
        box_mode=eval_tuple(args.box_mode),
        beta_mode=eval_tuple(args.beta_mode),
        gamma_mode=eval_tuple(args.gamma_mode),
        test_batch_size=args.test_batch_size,
        query_name_dict=query_name_dict,
        drop=args.drop
    )

    logging.info('Model Parameter Configuration:')
    num_params = 0
    for name, param in model.named_parameters():
        logging.info('Parameter %s: %s, require_grad = %s' % (name, str(param.size()), str(param.requires_grad)))
        if param.requires_grad:
            num_params += np.prod(param.size())
    logging.info('Parameter Number: %d' % num_params)

    if args.cuda:
        model = model.cuda()

    if args.do_train:
        current_learning_rate = args.learning_rate
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=current_learning_rate
        )
        warm_up_steps = args.max_steps // 2

    if args.checkpoint_path is not None:
        logging.info('Loading checkpoint %s...' % args.checkpoint_path)
        checkpoint = torch.load(os.path.join(args.checkpoint_path, 'checkpoint'))
        init_step = checkpoint['step']
        model.load_state_dict(checkpoint['model_state_dict'])
        if args.do_train:
            current_learning_rate = checkpoint['current_learning_rate']
            warm_up_steps = checkpoint['warm_up_steps']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        logging.info('Randomly Initializing %s Model...' % args.geo)
        init_step = 0

    step = init_step
    # For LapE, we can log the mode as 'laplace'
    if args.geo == 'box':
        logging.info('box mode = %s' % args.box_mode)
    elif args.geo == 'beta':
        logging.info('beta mode = %s' % args.beta_mode)
    elif args.geo == 'gamma':
        logging.info('gamma mode = %s' % args.gamma_mode)
    else:
        logging.info('Using Laplace-based embeddings for KG reasoning.')
    logging.info('tasks = %s' % args.tasks)
    logging.info('init_step = %d' % init_step)
    if args.do_train:
        logging.info('Start Training...')
        logging.info('learning_rate = %f' % current_learning_rate)
    logging.info('batch_size = %d' % args.batch_size)
    logging.info('hidden_dim = %d' % args.hidden_dim)
    logging.info('gamma = %f' % args.gamma)

    if args.do_train:
        training_logs = []
        # Training Loop
        for step in range(init_step, args.max_steps):
            if step == 2 * args.max_steps // 3:
                args.valid_steps *= 4

            log = model.train_step(model, optimizer, train_path_iterator, args, step)
            for metric in log:
                writer.add_scalar('path_' + metric, log[metric], step)
            if train_other_iterator is not None:
                log = model.train_step(model, optimizer, train_other_iterator, args, step)
                for metric in log:
                    writer.add_scalar('other_' + metric, log[metric], step)
                log = model.train_step(model, optimizer, train_path_iterator, args, step)

            training_logs.append(log)

            if step >= warm_up_steps:
                current_learning_rate = current_learning_rate / 5
                logging.info('Change learning_rate to %f at step %d' % (current_learning_rate, step))
                optimizer = torch.optim.Adam(
                    filter(lambda p: p.requires_grad, model.parameters()),
                    lr=current_learning_rate
                )
                warm_up_steps = warm_up_steps * 1.5

            if step % args.save_checkpoint_steps == 0:
                save_variable_list = {
                    'step': step,
                    'current_learning_rate': current_learning_rate,
                    'warm_up_steps': warm_up_steps
                }
                save_model(model, optimizer, save_variable_list, args)

            if step % args.valid_steps == 0 and step > 0:
                if args.do_valid:
                    logging.info('Evaluating on Valid Dataset...')
                    valid_all_metrics = evaluate(model, valid_easy_answers, valid_hard_answers, args, valid_dataloader,
                                                 query_name_dict, 'Valid', step, writer)
                if args.do_test:
                    logging.info('Evaluating on Test Dataset...')
                    test_all_metrics = evaluate(model, test_easy_answers, test_hard_answers, args, test_dataloader,
                                                query_name_dict, 'Test', step, writer)

            if step % args.log_steps == 0:
                metrics = {}
                for metric in training_logs[0].keys():
                    metrics[metric] = sum([log[metric] for log in training_logs]) / len(training_logs)
                log_metrics('Training average', step, metrics)
                training_logs = []

        save_variable_list = {
            'step': step,
            'current_learning_rate': current_learning_rate,
            'warm_up_steps': warm_up_steps
        }
        save_model(model, optimizer, save_variable_list, args)

    try:
        print(step)
    except:
        step = 0

    if args.do_test:
        logging.info('Evaluating on Test Dataset...')
        test_all_metrics = evaluate(model, test_easy_answers, test_hard_answers, args, test_dataloader, query_name_dict,
                                    'Test', step, writer)

    logging.info("Training finished!!")


# Calculation for converting training step and number of epochs
we use training step when we want to control training based on the training time and not on the number of data points seen
To calculate the number of **epochs**, we need:  
1. **Dataset size** (number of training triples in FB15k-237).  
2. **Batch size** (given as `-b 512`).  
3. **Total steps** (`--max_steps 450001`).  

### **Step 1: Get Dataset Size**  
The **FB15k-237** dataset has **272,115 training triples**.

### **Step 2: Compute Steps per Epoch**  
Each epoch means processing **all** training triples once.  
$$
\text{Steps per epoch} = \frac{\text{Training triples}}{\text{Batch size}} = \frac{272115}{512} \approx 532 \text{ steps per epoch}
$$

### **Step 3: Compute Total Epochs**  
Total epochs =  
$$
\frac{\text{max\_steps}}{\text{steps per epoch}} = \frac{450001}{532} \approx 846 \text{ epochs}
$$

### **Final Answer:**  
With `max_steps = 450001`, the model will train for **~846 epochs** on FB15k-237.  



# Training

In [None]:
args = Args(
    cuda=True,  
    do_train=True,  
    do_test=True,  
    data_path="/kaggle/input/kg-data/FB15k-237-betae",  
    negative_sample_size=128,  
    batch_size=512,  
    hidden_dim=800,  
    gamma=60.0,  
    learning_rate=0.0001,  
    max_steps=1000,  
    cpu_num=3,  
    test_batch_size=4,  
    geo='laplace',  # Changed from 'gamma' to 'laplace' for LapE  
    drop=0.1,  
    valid_steps=500,  
    gamma_mode="(1600,4)",  # Using GammaE settings as reference  
    seed=42  
)

main(args)

2025-03-20 15:27:47,470 INFO     ---------------------------------------------------------------------------------------------
2025-03-20 15:27:47,470 INFO     Geo: laplace
2025-03-20 15:27:47,472 INFO     seed: 42
2025-03-20 15:27:47,472 INFO     Data Path: /kaggle/input/kg-data/FB15k-237-betae
2025-03-20 15:27:47,473 INFO     #entity: 14505
2025-03-20 15:27:47,474 INFO     #relation: 474
2025-03-20 15:27:47,474 INFO     #max steps: 1000
2025-03-20 15:27:47,475 INFO     Evaluate unions using: DNF
2025-03-20 15:27:47,476 INFO     loading data


overwriting args.save_path
logging to logs/FB15k-237-betae/1p.2p.3p.2i.3i.ip.pi.2in.3in.inp.pin.pni.2u.up/laplace/g-60.0-mode-laplace/2025.03.20-15:27:47


2025-03-20 15:28:23,604 INFO     Training info:
2025-03-20 15:28:23,605 INFO     1p: 149689
2025-03-20 15:28:23,605 INFO     2p: 149689
2025-03-20 15:28:23,606 INFO     3p: 149689
2025-03-20 15:28:23,607 INFO     2i: 149689
2025-03-20 15:28:23,608 INFO     3i: 149689
2025-03-20 15:28:23,609 INFO     2in: 14968
2025-03-20 15:28:23,610 INFO     3in: 14968
2025-03-20 15:28:23,611 INFO     inp: 14968
2025-03-20 15:28:23,612 INFO     pin: 14968
2025-03-20 15:28:23,612 INFO     pni: 14968
2025-03-20 15:28:23,613 INFO     2u-DNF: 5000
2025-03-20 15:28:23,613 INFO     up-DNF: 5000
2025-03-20 15:28:25,110 INFO     Validation info:
2025-03-20 15:28:25,111 INFO     1p: 20094
2025-03-20 15:28:25,111 INFO     2p: 5000
2025-03-20 15:28:25,112 INFO     3p: 5000
2025-03-20 15:28:25,113 INFO     2i: 5000
2025-03-20 15:28:25,113 INFO     3i: 5000
2025-03-20 15:28:25,114 INFO     ip: 5000
2025-03-20 15:28:25,115 INFO     pi: 5000
2025-03-20 15:28:25,115 INFO     2in: 5000
2025-03-20 15:28:25,116 INFO    