Équation générale :
$$h_i^{(l+1)} = \sigma\left( \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_i^r} \frac{1}{c_{i,r}} W_r^{(l)}h_j^{(l)} + W_0^{(l)} h_i^{(l)} \right)$$
avec $h_i^{(l)} \in \mathbb{R}^{d^{(l)}}$ the hidden state of node $v_i$

*basis decomposition*:
$$W_r^{(l)} = \sum_{b=1}^B a_{rb}^{(l)} V_b^{(l)}$$

and $V_b^{(l)} \in \mathbb{R}^{d^{(l+1)} \times d^{(l)}}$ 

notons $N=|\mathcal{E}|$, on a :

au layer $(l)$, on a une matrice $H^{(l)}$ de dimension $N \times d^{(l)}$

In [41]:
triples = [
    (0, 0, 1),
    (1, 0, 2),
    (2, 0, 0),
    (0, 1, 1),
    (1, 1, 0),
    (0, 2, 2),
    (1, 2, 2)
]

In [45]:
import torch.nn as nn
import torch


class RGCNLayer(nn.Module):
    def __init__(self, T, B, dim_in, dim_out, init="random"):
        """
        T: adjacency tensor (Nr * Ne * Ne)
        B: number of basis functions
        dim_in: dimension of input feature vectors
        dim_out: dimension of output feature vectors
        """
        super().__init__()
        Nr, Ne, _ = T.shape()
        self.V = self.init(B, dim_in, dim_out, how=init)
        self.A = self.init(Nr, B, how=init)
        self.T = T
        
    def init(self, *size, how="random", fill_value=1.):
        if how == "random":
            return torch.rand(*size)
        elif how == "constant":
            return torch.full(size, fill_value)
        else:
            raise ValueError(f"Unsupported initialization method '{how}'")
            
    def forward(self, H):
        # Input: N * d_in
        W = torch.einsum("rb,bio->rio", self.A, self.V) # -> "R * d_in * d_out"
        HW = torch.einsum("ni,rio->rno", H, W)
        H = torch.einsum("rmn,rno->mo", self.T.to_dense(), HW)
        return H
    
class RGCN(nn.Module):
    def __init__(self, T, n_classes):
        self.conv1 = RGCNLayer(T, 10, 16, 32)
        self.conv2 = RGCNLayer(T, 10, 32, n_classes)
        self.softmax = nn.Softmax()
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.softmax(x)
        return x

Ne = 3
Nr = 3

T = get_adjacency(triples, Nr, Ne)

B = 1
din = 1
dout = 1

H = torch.Tensor([[0], [1], [-1]])

layer = RGCNLayer(T, Nr, B, din, dout, init="constant")

print(H)
for _ in range(5):
    H = layer.forward(H)
    print(H)

tensor([[ 0.],
        [ 1.],
        [-1.]])
tensor([[0.0000],
        [0.0000],
        [1.5000]])
tensor([[1.5000],
        [0.0000],
        [0.0000]])
tensor([[0.0000],
        [3.0000],
        [0.7500]])
tensor([[3.7500],
        [0.0000],
        [4.5000]])
tensor([[4.5000],
        [7.5000],
        [1.8750]])


In [30]:
A = torch.Tensor([[0, 10, 1], [0, -10, -1]])
print(A.shape, V.shape)

torch.Size([2, 3]) torch.Size([3, 2, 4])


In [31]:
torch.einsum("rb,bio->rio", A, V)

tensor([[[ 10.1924,  10.6328,  10.6583,  10.8340],
         [ 10.9752,  10.5400,  10.9157,  10.2023]],

        [[-10.1924, -10.6328, -10.6583, -10.8340],
         [-10.9752, -10.5400, -10.9157, -10.2023]]])

Tenseur d'adjacence :
- dimensions $R \times N \times N$
- une slice $k$ = la matrice d'adjacence pour la relation $k$. *i.e* $T_{ij}^{(k)} = 1$ ssi $(e_i, r_k, e_j)$ est dans le graphe
- pour une slice $(r)$ donnée, la somme sur une ligne doit sommer à 1

In [39]:
# tenseur d'adjacence
from collections import defaultdict

def get_adjacency(triples, n_rels, n_ents):
    d = defaultdict(lambda:defaultdict(set))
    for h, r, t in triples:
        d[r][t].add(h)
    i = []
    v = []
    for r in d:
        for t in d[r]:
            n = 1 / len(d[r][t])
            for h in d[r][t]:
                i.append([r, t, h])
                v.append(n)
    i = torch.LongTensor(i)
    v = torch.FloatTensor(v)
    return torch.sparse.FloatTensor(i.t(), v, torch.Size([n_rels, n_ents, n_ents]))

# nb d'entités
Ne = 3
Nr = 3
triples = [
    (0, 0, 1),
    (1, 0, 2),
    (2, 0, 0),
    (0, 1, 1),
    (1, 1, 0),
    (0, 2, 2),
    (1, 2, 2)
]

T = get_adjacency(triples, 3, 3)
T.to_dense()

tensor([[[0.0000, 0.0000, 1.0000],
         [1.0000, 0.0000, 0.0000],
         [0.0000, 1.0000, 0.0000]],

        [[0.0000, 1.0000, 0.0000],
         [1.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.5000, 0.5000, 0.0000]]])

In [49]:
import os

dirname = r"C:\Users\felix\Downloads\FB15K-237.2\Release"
def readf(fname):
    with open(os.path.join(dirname, fname), "r") as f:
        triples = [line.split() for line in f]
    return triples

triples = readf("test.txt")
triples[:10]

[['/m/08966',
  '/travel/travel_destination/climate./travel/travel_destination_monthly_climate/month',
  '/m/05lf_'],
 ['/m/01hww_',
  '/music/performance_role/regular_performances./music/group_membership/group',
  '/m/01q99h'],
 ['/m/09v3jyg',
  '/film/film/release_date_s./film/film_regional_release_date/film_release_region',
  '/m/0f8l9c'],
 ['/m/02jx1', '/location/location/contains', '/m/013t85'],
 ['/m/02jx1', '/location/location/contains', '/m/0m0bj'],
 ['/m/02bfmn', '/film/actor/film./film/performance/film', '/m/04ghz4m'],
 ['/m/05zrvfd',
  '/award/award_category/nominees./award/award_nomination/nominated_for',
  '/m/04y9mm8'],
 ['/m/060bp',
  '/government/government_office_category/officeholders./government/government_position_held/jurisdiction_of_office',
  '/m/04j53'],
 ['/m/07l450', '/film/film/genre', '/m/082gq'],
 ['/m/07h1h5',
  '/sports/pro_athlete/teams./sports/sports_team_roster/team',
  '/m/029q3k']]

In [64]:
from scipy.sparse import csr_matrix, diags

graph = [triple.split() for triple in """\
a r1 b
b r1 c
c r1 a
a r2 b
b r2 a
a r3 c
b r3 c
c r4 a
c r4 b
c r4 c""".split("\n")]

def adj(triples):
    hs, rs, ts = zip(*triples)
    entities = {e: i for i, e in enumerate(set(hs) | set(ts))}
    relations = {r: i for i, r in enumerate(set(rs))}
    nr, ne = len(relations), len(entities)
    sorted_triples = defaultdict(list)
    for h, r, t in triples:
        sorted_triples[relations[r]].append((entities[t], entities[h]))
    A = []
    for r, coords in sorted_triples.items():
        row_inds, col_inds = zip(*coords)
        data = [1] * len(coords)
        A.append(csr_matrix((data, (row_inds, col_inds)), shape=(ne, ne)))
    return A

A = adj(graph)
for a in A:
    print(a.todense())

[[0 1 0]
 [0 0 1]
 [1 0 0]]
[[0 0 0]
 [0 0 1]
 [0 1 0]]
[[0 1 1]
 [0 0 0]
 [0 0 0]]
[[1 0 0]
 [1 0 0]
 [1 0 0]]


In [95]:
from scipy import sparse

def normalize(a):
    d = np.array(a.sum(1)).squeeze()
    d = np.divide(1, d, where=d!=0)
    d = diags(d, format="csr")
    return d * a

Ah = [normalize(a) for a in A]

sparse.hstack(Ah).todense()

matrix([[0. , 1. , 0. , 0. , 0. , 0. , 0. , 0.5, 0.5, 1. , 0. , 0. ],
        [0. , 0. , 1. , 0. , 0. , 1. , 0. , 0. , 0. , 1. , 0. , 0. ],
        [1. , 0. , 0. , 0. , 1. , 0. , 0. , 0. , 0. , 1. , 0. , 0. ]])

In [77]:
import numpy as np

d = np.array(A[2].sum(1))


In [100]:
Nr = len(A)
Ne = A[0].shape[0]
di = 8
do = 16
H = torch.rand(Ne, di)
W = torch.rand(Nr, di, do)

HxW = torch.matmul(H, W)

In [102]:
A = torch.rand(Nr, Ne, Ne)

AxHxW = torch.matmul(A, HxW)
print(AxHxW.shape)

AxHxW.sum(axis=0).shape

torch.Size([4, 3, 16])


torch.Size([3, 16])

In [119]:
def adj(triples):
    hs, rs, ts = zip(*triples)
    entities = {e: i for i, e in enumerate(set(hs) | set(ts))}
    relations = {r: i for i, r in enumerate(set(rs))}
    nr, ne = len(relations), len(entities)
    i = torch.LongTensor([
        [relations[r] for r in rs],
        [entities[t] for t in ts],
        [entities[h] for h in hs]
    ])
    c = torch.ones(len(triples))
    A = torch.sparse.FloatTensor(i, c, torch.Size([nr, ne, ne]))
    return A

def adj(triples):
    hs, rs, ts = zip(*triples)
    entities = {e: i for i, e in enumerate(set(hs) | set(ts))}
    relations = {r: i for i, r in enumerate(set(rs))}
    nr, ne = len(relations), len(entities)
    sorted_triples = defaultdict(list)
    for h, r, t in triples:
        sorted_triples[relations[r]].append((entities[t], entities[h]))
    A = []
    for r, coords in sorted_triples.items():
        row_inds, col_inds = zip(*coords)
        data = [1] * len(coords)
        A.append(
            torch.sparse.FloatTensor(
                torch.LongTensor([row_inds, col_inds]),
                torch.ones(len(coords)),
                torch.Size([ne, ne])
            )
        )
    return A

# shape: Nr * Ne * d_out
HxW = torch.matmul(H, W)
# shape: Nr * Ne * Ne
A = adj(graph)

# expected shape: Nr * Ne * d_out
AxHxW = torch.stack([torch.sparse.mm(a, hw) for a, hw in zip(A, HxW)])
print(AxHxW.shape)

AxHxW.sum(axis=0).shape

torch.Size([4, 3, 16])


torch.Size([3, 16])

In [120]:
AxHxW.backward()

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn