### Embeddings

In [1]:
import numpy as np
import torch

In [36]:
weights = torch.tensor([[0,  1,  2,  3],
                        [10, 11, 12, 13]]).float()

m = torch.nn.Parameter(weights)

m

Parameter containing:
tensor([[ 0.,  1.,  2.,  3.],
        [10., 11., 12., 13.]], requires_grad=True)

In [116]:
weights = torch.tensor([[0,  1,  2],
                        [10, 11, 12],
                        [20, 21, 22],
                        [30, 31, 32]])

fields = torch.tensor([[1, 0, 1, 1],
                       [0, 1, 0, 1]])

full_coupling_matrix = torch.matmul(weights, weights.T)

full_coupling_matrix

tensor([[   5,   35,   65,   95],
        [  35,  365,  695, 1025],
        [  65,  695, 1325, 1955],
        [  95, 1025, 1955, 2885]])

In [117]:
masked_weights = fields.unsqueeze(2) * weights

masked_weights

tensor([[[ 0,  1,  2],
         [ 0,  0,  0],
         [20, 21, 22],
         [30, 31, 32]],

        [[ 0,  0,  0],
         [10, 11, 12],
         [ 0,  0,  0],
         [30, 31, 32]]])

In [118]:
coupling_matrix = masked_weights @ masked_weights.mT

coupling_matrix

tensor([[[   5,    0,   65,   95],
         [   0,    0,    0,    0],
         [  65,    0, 1325, 1955],
         [  95,    0, 1955, 2885]],

        [[   0,    0,    0,    0],
         [   0,  365,    0, 1025],
         [   0,    0,    0,    0],
         [   0, 1025,    0, 2885]]])

In [119]:
masked_coupling_matrix = coupling_matrix.triu(diagonal=1)

masked_coupling_matrix

tensor([[[   0,    0,   65,   95],
         [   0,    0,    0,    0],
         [   0,    0,    0, 1955],
         [   0,    0,    0,    0]],

        [[   0,    0,    0,    0],
         [   0,    0,    0, 1025],
         [   0,    0,    0,    0],
         [   0,    0,    0,    0]]])

In [120]:
masked_coupling_matrix.sum(dim=(1, 2))

tensor([2115, 1025])

In [121]:
# Optimization

In [122]:
masked_weights

tensor([[[ 0,  1,  2],
         [ 0,  0,  0],
         [20, 21, 22],
         [30, 31, 32]],

        [[ 0,  0,  0],
         [10, 11, 12],
         [ 0,  0,  0],
         [30, 31, 32]]])

In [126]:
square_of_sum = masked_weights.sum(dim=1) ** 2

square_of_sum

tensor([[2500, 2809, 3136],
        [1600, 1764, 1936]])

In [127]:
sum_of_squares = (masked_weights ** 2).sum(dim=1)

sum_of_squares

tensor([[1300, 1403, 1512],
        [1000, 1082, 1168]])

In [130]:
prediction = 0.5 * (square_of_sum - sum_of_squares).sum(dim=1)

prediction

tensor([2115., 1025.])

In [None]:
# Linear

In [148]:
linear = torch.nn.Linear(4, 1)

linear

Linear(in_features=4, out_features=1, bias=True)

In [149]:
linear(fields.float())

tensor([[0.2202],
        [0.0043]], grad_fn=<AddmmBackward0>)

In [4]:
c = torch.outer(a, a)

c

tensor([[1, 0, 1, 1],
        [0, 0, 0, 0],
        [1, 0, 1, 1],
        [1, 0, 1, 1]])

In [5]:
c_triu = torch.triu(c, diagonal=1)

c_triu

tensor([[0, 0, 1, 1],
        [0, 0, 0, 0],
        [0, 0, 0, 1],
        [0, 0, 0, 0]])

In [6]:
d = b * c

d

tensor([[   5,    0,   65,   95],
        [   0,    0,    0,    0],
        [  65,    0, 1325, 1955],
        [  95,    0, 1955, 2885]])

In [7]:
d.sum() / 2

tensor(4222.5000)

In [8]:
e = torch.nn.Embedding(4, 3)

torch.nn.init.xavier_uniform_(e.weight.data)

e.weight

Parameter containing:
tensor([[-0.2159, -0.1193, -0.5109],
        [-0.8812,  0.3255, -0.1884],
        [ 0.6601, -0.8312,  0.6566],
        [ 0.5007, -0.5679,  0.1766]], requires_grad=True)

In [9]:
e(torch.tensor([[0, 1], [1, 2]]))

tensor([[[-0.2159, -0.1193, -0.5109],
         [-0.8812,  0.3255, -0.1884]],

        [[-0.8812,  0.3255, -0.1884],
         [ 0.6601, -0.8312,  0.6566]]], grad_fn=<EmbeddingBackward0>)

In [10]:
coupling_matrix = e.weight.data

coupling_matrix = torch.matmul(e.weight.data, e.weight.data.T)

coupling_matrix

tensor([[ 0.3218,  0.2477, -0.3788, -0.1306],
        [ 0.2477,  0.9179, -0.9759, -0.6593],
        [-0.3788, -0.9759,  1.5577,  0.9184],
        [-0.1306, -0.6593,  0.9184,  0.6043]])

In [11]:
masked_weights = coupling_matrix * c_triu

masked_weights

tensor([[ 0.0000,  0.0000, -0.3788, -0.1306],
        [ 0.0000,  0.0000, -0.0000, -0.0000],
        [-0.0000, -0.0000,  0.0000,  0.9184],
        [-0.0000, -0.0000,  0.0000,  0.0000]])

In [12]:
masked_weights.sum()

tensor(0.4091)

In [13]:
indices = torch.nonzero(a).squeeze()

indices

tensor([0, 2, 3])

In [14]:
selected_weights = e(indices)

selected_weights

tensor([[-0.2159, -0.1193, -0.5109],
        [ 0.6601, -0.8312,  0.6566],
        [ 0.5007, -0.5679,  0.1766]], grad_fn=<EmbeddingBackward0>)

In [15]:
a_alt = torch.tensor([[1, 0, 1, 1],
                      [1, 0, 0, 0]])

indices_alt = torch.nonzero(a_alt)

indices_alt

tensor([[0, 0],
        [0, 2],
        [0, 3],
        [1, 0]])

In [16]:
selected_weights = e(indices_alt)

selected_weights

tensor([[[-0.2159, -0.1193, -0.5109],
         [-0.2159, -0.1193, -0.5109]],

        [[-0.2159, -0.1193, -0.5109],
         [ 0.6601, -0.8312,  0.6566]],

        [[-0.2159, -0.1193, -0.5109],
         [ 0.5007, -0.5679,  0.1766]],

        [[-0.8812,  0.3255, -0.1884],
         [-0.2159, -0.1193, -0.5109]]], grad_fn=<EmbeddingBackward0>)

In [17]:
coupling_matrix_alt = torch.matmul(selected_weights, selected_weights.mT)

coupling_matrix_alt

tensor([[[ 0.3218,  0.3218],
         [ 0.3218,  0.3218]],

        [[ 0.3218, -0.3788],
         [-0.3788,  1.5577]],

        [[ 0.3218, -0.1306],
         [-0.1306,  0.6043]],

        [[ 0.9179,  0.2477],
         [ 0.2477,  0.3218]]], grad_fn=<UnsafeViewBackward0>)

In [18]:
masked_weights_alt = torch.triu(coupling_matrix_alt, diagonal=1)

masked_weights_alt

tensor([[[ 0.0000,  0.3218],
         [ 0.0000,  0.0000]],

        [[ 0.0000, -0.3788],
         [ 0.0000,  0.0000]],

        [[ 0.0000, -0.1306],
         [ 0.0000,  0.0000]],

        [[ 0.0000,  0.2477],
         [ 0.0000,  0.0000]]], grad_fn=<TriuBackward0>)

In [19]:
masked_weights_alt.sum(dim=(1, 2))

tensor([ 0.3218, -0.3788, -0.1306,  0.2477], grad_fn=<SumBackward1>)