# 2. Node Embeddings with TransE 

## 2.1 Warmup: Why the Comparative Loss?
We could have a graph with 4 entities and 2 relationships:
```
(Higher Brothers) ---(from)--->   (China)  
                                     |  
                              (to the north of)    
                                     |  
                                     v  
   (Rich Brian)   ---(from)---> (Indonesia)  
```
There is a trival solution than could minimize the loss function to 0, that is:  
$\mathbf e = [0, 1]^T$ for all entities and $\boldsymbol \ell = [0, 0]^T$ for all relationships. This is a completely useless embedding. 

Note that is does not mean that the algorithm will always end up with this embedding. 

## 2.2 The Purpose of the Margin
With the same embeddings in 2.1, we again minimize the loss function to 0. 

## 2.3 Why are Entity Embeddings Normalized?
The paper explains: 
> This constraint is important for our model, as it is for previous embedding-based methods, because it prevents the training process to trivially minimize $\mathcal L$ by artificially increasing entity embeddings norms. 

Removing the norm constraint, the margin may sometimes be useless. Assuming that the loss function could be minimized to zero with a small non-zero margin $\gamma_1$ (which means that we get a perfect embedding). Then, even if we increase the margin (from $\gamma_1$ to $\gamma_2$), the model could cheat the loss function by just increase the norms (while keeping the direction) by $\frac{\gamma_2}{\gamma_1}$. The model minimizes the loss to zero, but does not improve the embeddings. 

Note that this is not always the case. If we cannot get a perfect embedding, the model will not always increase the norm otherwise the loss will increase finally. 

## 2.4 Where TransE fails
Still the same graph. If we have:  
$\mathbf u_0 + \boldsymbol \ell_0 = \mathbf u_1, \mathbf u_2 + \boldsymbol \ell_1 = \mathbf u_0, 
\mathbf u_3 + \boldsymbol \ell_1 = \mathbf u_1$. 

Then we got
$\mathbf u_2 + \boldsymbol \ell_0 = \mathbf u_3$
However, $(u_2, \ell_0, v_3) \notin S$

## Sample Implementation of TransE

In [1]:
import numpy as np
import matplotlib.pyplot as plt

### Derive Gradients

Simple
$$
\begin{aligned}
\mathcal{L}_{\text{simple}}=\sum_{(h, \ell, t) \in S} d(\mathbf{h}+\boldsymbol{\ell}, \mathbf{t})
\end{aligned}
$$
For each triplet $(h, \ell, t)$: 
$$
\begin{aligned}
\mathbf{d} &= d(\mathbf{h}+\boldsymbol{\ell}, \mathbf{t}) \\
d &= \sqrt{\mathbf{d}^T \mathbf{d}} \\
\end{aligned}
$$
Therefore: 
$$
\begin{aligned}
\frac{\partial d}{\partial \mathbf{d}} &= \frac{2 \mathbf{d}}{2 \sqrt{\mathbf{d}^T \mathbf{d}}} = \frac{\mathbf{d}}{d} \\
\frac{\partial d}{\partial \mathbf{h}} &= \frac{\partial d}{\partial \boldsymbol{\ell}} = \frac{\mathbf{d}}{d} \\
\frac{\partial d}{\partial \mathbf{t}} &= -\frac{\mathbf{d}}{d}
\end{aligned}
$$
Then for a batch, we just need to sum up the gradient for each sample. 

In [2]:
# simple gradient only consider the existed triplets
def gradient_simple(h, l, t):
    d = h + l - t
    dh = d / np.linalg.norm(d)
    dl = dh
    dt = - dh
    return dh, dl, dt, np.linalg.norm(d)

# comparative gradient, with or without margin
def gradient_comparative(h, l, t, hp, tp, margin=0.0):
    dh1, dl1, dt1, d1 = gradient_simple(h, l, t)
    dh2, dl2, dt2, d2 = gradient_simple(hp, l, tp)
    if margin + d1 - d2 <= 0:  # correct
        return np.zeros_like(dh1), np.zeros_like(dl1), np.zeros_like(dt1), 0
    else:  # wrong
        return dh1 - dh2, dl1 - dl2, dt1 - dt2, margin + d1 - d2

In [3]:
def random_sample(h_i, t_i, n_e):
    p_i = np.random.choice(n_e)
    if np.random.choice(2):  # resample t
        return h_i, p_i
    else:  # resample h
        return p_i, t_i

In [4]:
def TransE(Entities, Relationships, Edges, k=2, loss='comparative', alpha=0.01, epochs=100, sample_rate=1.0, margin=1.0):
    n_e = len(Entities)  # number of entities
    n_l = len(Relationships)  # number of relationships
    m = len(Edges)  # number of edges
    b = int(sample_rate * m)  # number of edges sampled in each epoch
    
    # Initialize Embeddings
    # Relationships Embeddings
    L = np.random.uniform(low=-6/np.sqrt(k), high=6/np.sqrt(k), size=(n_l, k))
    L = L / np.linalg.norm(L, axis=1, keepdims=True)
    # Entities Embeddings
    E = np.random.uniform(low=-6/np.sqrt(k), high=6/np.sqrt(k), size=(n_e, k))
    
    # Initialize their gradients
    dL = np.zeros_like(L)
    dE = np.zeros_like(E)
    
    for epoch in range(epochs):
        # Normalize Entity Embeddings
        # E = E / np.linalg.norm(E, axis=1, keepdims=True)
        # Clear gradients
        dL.fill(0.0)
        dE.fill(0.0)
        sum_loss = 0
        # Sample batch of size b
        batch = np.random.choice(m, b, replace=False)
        for index in batch:
            h_i, l_i, t_i = Edges[index]  # positive triplets
            hp_i, tp_i = random_sample(h_i, t_i, n_e)  # corrupted triplets
            if loss == 'simple':
                h, l, t = E[h_i], L[l_i], E[t_i]
                dh, dl, dt, d = gradient_simple(h, l, t)
                # accumulate gradients
                dE[h_i] = dE[h_i] + dh
                dL[l_i] = dL[l_i] + dl
                dE[t_i] = dE[t_i] + dt
                sum_loss = sum_loss + d
            elif loss == 'comparative':
                h, l, t = E[h_i], L[l_i], E[t_i]
                hp, tp = E[hp_i], E[tp_i]
                dh, dl, dt, d = gradient_comparative(h, l, t, hp, tp, margin=2.0)
                # accumulate gradients
                dE[h_i] = dE[h_i] + dh
                dL[l_i] = dL[l_i] + dl
                dE[t_i] = dE[t_i] + dt
                sum_loss = sum_loss + d
            else:
                raise NotImplementedError
        E = E - alpha * dE 
        L = L - alpha * dL
        if epoch % 50 == 0:
            print('Epoch', epoch, ':', sum_loss)
        
    # E = E / np.linalg.norm(E, axis=1, keepdims=True)
    return E, L

In [5]:
# Simple Graph
Entities = {0: 'China', 1: 'Indonesia', 2: 'Higher Brothers', 3: 'Rich Brian'}
Relationships = {0: 'to the north of', 1: 'from'}
Edges = [(0, 0, 1), (2, 1, 0), (3, 1, 1)]
for h_i, l_i, t_i in Edges:
    print(Entities[h_i], Relationships[l_i], Entities[t_i])

China to the north of Indonesia
Higher Brothers from China
Rich Brian from Indonesia


In [6]:
E, L = TransE(Entities, Relationships, Edges, k=2, loss='comparative', alpha=0.005, epochs=1000, sample_rate=1.0, margin=1.0)

Epoch 0 : 10.332187792774103
Epoch 50 : 13.425523543179683
Epoch 100 : 10.620780422533668
Epoch 150 : 6.661666409476698
Epoch 200 : 8.25368831727983
Epoch 250 : 7.6690717218380335
Epoch 300 : 10.380212460687357
Epoch 350 : 6.39088840592918
Epoch 400 : 5.25310525184958
Epoch 450 : 6.095685087100951
Epoch 500 : 7.172016782605932
Epoch 550 : 6.7771743831426985
Epoch 600 : 4.478209156781671
Epoch 650 : 8.714961684796366
Epoch 700 : 4.429065031929193
Epoch 750 : 4.128495337528484
Epoch 800 : 5.3099327007962005
Epoch 850 : 5.870670534028997
Epoch 900 : 5.256516900549345
Epoch 950 : 2.5849615817430704


In [7]:
E

array([[-1.58017788,  0.90863433],
       [-1.11345653,  1.72820179],
       [ 0.64104502, -1.03172   ],
       [-0.80101233,  2.09512848]])

In [8]:
L

array([[ 0.91502093,  0.86074989],
       [-1.58774833,  1.82029599]])