In [1]:
import numpy as np

In [2]:
def get_hit_pos(relation):
    print(relation.shape)
    if len(relation.shape) == 3:
        res = []
        for ib in range(relation.shape[0]):
            one_res = [
                relation[ib, :, ic].nonzero()[0][0]
                for ic in range(relation.shape[2])
            ]
            res.append(one_res)
    else:
        res = [
            relation[:, ic].nonzero()[0][0]
            for ic in range(relation.shape[1])
        ]
    return np.array(res, dtype=np.long)

## Prepare dummy data

In [3]:
Ri = np.array([
    [1, 1, 0, 0, 0, 1, 0],
    [0, 0, 1, 1, 0, 0, 0],
    [0, 0, 0, 0, 1, 0, 1],
    [0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0]], dtype=np.float32
)
print(Ri.shape)
Ri= Ri[None, :]
Ri_b = np.tile(Ri, (2, 1, 1))
print(Ri_b.shape)

(5, 7)
(2, 5, 7)


In [4]:
Ro = np.array([
    [0, 0, 0, 0, 0, 0, 0],
    [1, 0, 0, 0, 0, 0, 0],
    [0, 1, 0, 0, 0, 0, 0],
    [0, 0, 1, 0, 1, 1, 0],
    [0, 0, 0, 1, 0, 0, 1]
], dtype=np.float32)
print(Ro.shape)
Ro = Ro[None, :]
Ro_b = np.tile(Ro, (2, 1, 1))
print(Ro_b.shape)

(5, 7)
(2, 5, 7)


In [5]:
X = np.array([
    [0, 0.1, 0],
    [1, 0.2, 1],
    [1, 0.3, 1],
    [2, 0.2, 2],
    [2, 0.3, 2]
], dtype=np.float32)
X = X[None, :]
X_b = np.tile(X, (2, 1, 1))
print(X_b.shape)

(2, 5, 3)


In [6]:
E = np.array([
    0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7
], dtype=np.float32)
E = E[None, :]
E_b = np.tile(E, (2, 1))
print(E_b.shape)

(2, 7)


In [7]:
n_features = 3
n_edges = 7
n_hits = 5
n_batch = 2

## Check EdgeNetWork

In [8]:
import torch
print(torch.__version__)

1.0.0a0+0b862fa


In [9]:
Ro_t = torch.tensor(Ro_b)
Ri_t = torch.tensor(Ri_b)
X_t  = torch.tensor(X_b)
E_t  = torch.tensor(E_b)

In [10]:
Ro_pos_t = torch.from_numpy(get_hit_pos(Ro_t.numpy()))
Ri_pos_t = torch.from_numpy(get_hit_pos(Ri_t.numpy()))

(2, 5, 7)
(2, 5, 7)


In [11]:
bo1 = torch.bmm(Ro_t.transpose(1, 2,), X_t)
print(bo1.size())
print(bo1)

torch.Size([2, 7, 3])
tensor([[[1.0000, 0.2000, 1.0000],
         [1.0000, 0.3000, 1.0000],
         [2.0000, 0.2000, 2.0000],
         [2.0000, 0.3000, 2.0000],
         [2.0000, 0.2000, 2.0000],
         [2.0000, 0.2000, 2.0000],
         [2.0000, 0.3000, 2.0000]],

        [[1.0000, 0.2000, 1.0000],
         [1.0000, 0.3000, 1.0000],
         [2.0000, 0.2000, 2.0000],
         [2.0000, 0.3000, 2.0000],
         [2.0000, 0.2000, 2.0000],
         [2.0000, 0.2000, 2.0000],
         [2.0000, 0.3000, 2.0000]]])


In [12]:
x1 = X_t.reshape(-1, n_features)
bo2 = x1[Ro_pos_t.long().flatten()].reshape(n_batch, -1, n_features)
print(bo2.size())
print(bo2)

torch.Size([2, 7, 3])
tensor([[[1.0000, 0.2000, 1.0000],
         [1.0000, 0.3000, 1.0000],
         [2.0000, 0.2000, 2.0000],
         [2.0000, 0.3000, 2.0000],
         [2.0000, 0.2000, 2.0000],
         [2.0000, 0.2000, 2.0000],
         [2.0000, 0.3000, 2.0000]],

        [[1.0000, 0.2000, 1.0000],
         [1.0000, 0.3000, 1.0000],
         [2.0000, 0.2000, 2.0000],
         [2.0000, 0.3000, 2.0000],
         [2.0000, 0.2000, 2.0000],
         [2.0000, 0.2000, 2.0000],
         [2.0000, 0.3000, 2.0000]]])


## Now check the NodeNetWork

In [13]:
e_ex = E_t[:, None].repeat(1, n_hits, 1)
print(e_ex.size())
print(e_ex)

torch.Size([2, 5, 7])
tensor([[[0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000],
         [0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000],
         [0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000],
         [0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000],
         [0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000]],

        [[0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000],
         [0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000],
         [0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000],
         [0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000],
         [0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000]]])


In [14]:
Rwi1 = Ri_t * E_t[:, None]
print(Rwi1.size())
print(Rwi1)

torch.Size([2, 5, 7])
tensor([[[0.1000, 0.2000, 0.0000, 0.0000, 0.0000, 0.6000, 0.0000],
         [0.0000, 0.0000, 0.3000, 0.4000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.7000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.1000, 0.2000, 0.0000, 0.0000, 0.0000, 0.6000, 0.0000],
         [0.0000, 0.0000, 0.3000, 0.4000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.7000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])


In [37]:
Ri_pos_t

tensor([[0, 0, 1, 1, 2, 0, 2],
        [0, 0, 1, 1, 2, 0, 2]])

In [15]:
mi1 = torch.bmm(Rwi1, bo1)
print(mi1.size())
print(mi1)

torch.Size([2, 5, 3])
tensor([[[1.5000, 0.2000, 1.5000],
         [1.4000, 0.1800, 1.4000],
         [2.4000, 0.3100, 2.4000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000]],

        [[1.5000, 0.2000, 1.5000],
         [1.4000, 0.1800, 1.4000],
         [2.4000, 0.3100, 2.4000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000]]])


In [16]:
new_E_t = torch.zeros(e_ex.size())
for ib in range(n_batch):
    for ic in range(new_E_t.shape[2]):
        new_E_t[ib, Ri_pos_t[ib, ic] , ic] = e_ex[ib, Ri_pos_t[ib, ic] , ic]

print(new_E_t)

tensor([[[0.1000, 0.2000, 0.0000, 0.0000, 0.0000, 0.6000, 0.0000],
         [0.0000, 0.0000, 0.3000, 0.4000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.7000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.1000, 0.2000, 0.0000, 0.0000, 0.0000, 0.6000, 0.0000],
         [0.0000, 0.0000, 0.3000, 0.4000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.7000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])


In [39]:
new_E_t = torch.zeros(e_ex.size())
for ib in range(n_batch):
    n_c = new_E_t.shape[2]
    new_E_t[[ib]*n_c, Ri_pos_t[ib] , range(n_c)] = e_ex[[ib]*n_c, Ri_pos_t[ib] , range(n_c)]
print(new_E_t)

tensor([[[0.1000, 0.2000, 0.0000, 0.0000, 0.0000, 0.6000, 0.0000],
         [0.0000, 0.0000, 0.3000, 0.4000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.7000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.1000, 0.2000, 0.0000, 0.0000, 0.0000, 0.6000, 0.0000],
         [0.0000, 0.0000, 0.3000, 0.4000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.7000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])
