In [1]:
import torch
from torch import nn
from copy import deepcopy
import numpy as np

## Set up dimensions
All different on purpose, so we can track which one goes where!

Here, the `intput_dim` is meant to be the dimension of the input from a model which we want to adapt to use the same embedding space that we've learned in our embedding layer.

In [3]:
input_dim = 3
embedding_size = 5
vocab_size = 7
batch_size = 2

## Construct a simple rotation matrix
Permutes the first 3 elements. 
For a column vector $e = (a, b, c, d, e)$, 
$$Qe = (c, a, b, d, e)$$

In [5]:
Q = torch.eye(embedding_size)
Q[0, 2] = 1
Q[1, 1] = 0
Q[1, 0] = 1
Q[2, 2] = 0
Q[2, 1] = 1
Q

tensor([[1., 0., 1., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.]])

## Construct an embedding layer and a linear layer which share (part of) an embedding matrix.

We construct a torch Embedding layer of with vocab size $v$ and embedding size $e$, whose embedding matrix $E$ will have dimension $(v, e)$.

We also take the first `input_dim` elements (rows) of the embedding matrix and use that $(i, e)$ sized sub-matrix to construct a linear layer with input dim $i$ and output dim $e$, the same as the embedding size.

In [7]:
def get_lin_from_embed(embedding_layer, input_dim, bias=True):
    lin_from_embedding = torch.nn.Linear(input_dim, embedding_layer.weight.shape[-1], bias=bias)
    embedding_matrix = embedding_layer.weight.data
    # Note: The shape of the weight matrix must be [embedding_size, input_dim], so we must transpose in the next line:
    W_ = embedding_matrix.T[:, 0:3]
    if bias:
        b_ = lin_from_embedding.bias.data
        # Transpose to broadcast bias over the columns of W, transpose back to retain the correct shape.
        lin_from_embedding.weight.data = (W_.T - b_).T
    else:
        lin_from_embedding.weight.data = W_
    return lin_from_embedding

In [8]:
embedding_layer = torch.nn.Embedding(vocab_size, embedding_size)
lin_from_embedding = get_lin_from_embed(embedding_layer, input_dim, bias=False)

To check that we have got things right so far, let's inspect the first 2 elements of the embedding matrix, along with the output of the linear layer applied to unit vectors in the x and y directions.

In [10]:
emb_idx = torch.tensor([0, 1])
embedding_layer(emb_idx)

tensor([[ 0.5562,  0.3084,  0.0092, -1.4830, -0.4561],
        [ 1.8453, -0.7494, -0.4230, -1.3185,  0.1440]],
       grad_fn=<EmbeddingBackward0>)

In [11]:
lin_input = torch.zeros([batch_size, input_dim])
lin_input[0, 0] = 1
lin_input[1, 1] = 1
lin_input  # Each input is a row vector.

tensor([[1., 0., 0.],
        [0., 1., 0.]])

In [12]:
lin_from_embedding(lin_input)

tensor([[ 0.5562,  0.3084,  0.0092, -1.4830, -0.4561],
        [ 1.8453, -0.7494, -0.4230, -1.3185,  0.1440]], grad_fn=<MmBackward0>)

## Rotating

In [14]:
def rotate_embedding(embedding_layer, Q):
    # Deepcopy to avoid mutating the original layer.
    rotated_embedding_layer = deepcopy(embedding_layer)
    W = rotated_embedding_layer.weight.data
    rotated_embedding_layer.weight.data = torch.matmul(W, Q)
    return rotated_embedding_layer

In [15]:
def rotate_linear(linear_layer: nn.Linear, Q: torch.Tensor):
    rotated_linear_layer = deepcopy(linear_layer)
    W_ = rotated_linear_layer.weight.data
    rotated_linear_layer.weight.data = torch.matmul(Q.T, W_)
    if linear_layer.bias is not None:
        b_ = rotated_linear_layer.bias.data
        # Either of the next 2 lines work identially:
        # rotated_linear_layer.bias.data = torch.matmul(b_, Q)
        rotated_linear_layer.bias.data = torch.matmul(Q.T, b_)
    return rotated_linear_layer
        

In [16]:
rotated_embedding_layer = rotate_embedding(embedding_layer, Q)
rotated_lin_from_embedding = rotate_linear(lin_from_embedding, Q)

Check again that these two outputs match each other, after rotation.

In [18]:
rotated_embedding_layer(emb_idx)

tensor([[ 0.8646,  0.0092,  0.5562, -1.4830, -0.4561],
        [ 1.0960, -0.4230,  1.8453, -1.3185,  0.1440]],
       grad_fn=<EmbeddingBackward0>)

In [19]:
rotated_lin_from_embedding(lin_input)

tensor([[ 0.8646,  0.0092,  0.5562, -1.4830, -0.4561],
        [ 1.0960, -0.4230,  1.8453, -1.3185,  0.1440]], grad_fn=<MmBackward0>)

# And now, with bias

In [21]:
embedding_layer = torch.nn.Embedding(vocab_size, embedding_size)
lin_from_embedding = get_lin_from_embed(embedding_layer, input_dim, bias=True)

Double check that the bias broadcasting has worked correctly in the linear layer. The next 2 cells should have identical outputs.

In [23]:
embedding_layer(emb_idx)

tensor([[ 1.2411, -0.6772, -0.6461,  0.5904, -1.2008],
        [-0.8815,  0.5253,  1.4727, -0.8439,  1.6534]],
       grad_fn=<EmbeddingBackward0>)

In [24]:
lin_from_embedding(lin_input)

tensor([[ 1.2411, -0.6772, -0.6461,  0.5904, -1.2008],
        [-0.8815,  0.5253,  1.4727, -0.8439,  1.6534]],
       grad_fn=<AddmmBackward0>)

But the weight in the linear layer should be different.

In [26]:
lin_from_embedding.weight.data.T[0:2, :]

tensor([[ 0.8945, -0.1401, -0.6745,  0.0688, -1.0219],
        [-1.2280,  1.0624,  1.4443, -1.3655,  1.8322]])

## Rotating
Rotate as we did before:

In [28]:
rotated_embedding_layer = rotate_embedding(embedding_layer, Q)
rotated_lin_from_embedding = rotate_linear(lin_from_embedding, Q)

If the rotation functions above are correct, the next 2 cells should be identical

In [30]:
rotated_embedding_layer(emb_idx)

tensor([[ 0.5639, -0.6461,  1.2411,  0.5904, -1.2008],
        [-0.3562,  1.4727, -0.8815, -0.8439,  1.6534]],
       grad_fn=<EmbeddingBackward0>)

In [31]:
rotated_lin_from_embedding(lin_input)

tensor([[ 0.5639, -0.6461,  1.2411,  0.5904, -1.2008],
        [-0.3562,  1.4727, -0.8815, -0.8439,  1.6534]],
       grad_fn=<AddmmBackward0>)

## Sanity check
Let's manually rotate the outputs of the original linear layer, and check them against the outputs of the rotated linear layer.

In [33]:
x = torch.randn([10, input_dim])
output = lin_from_embedding(x)
manually_rotated_output = torch.matmul(output, Q).detach().numpy()

rotated_output = rotated_lin_from_embedding(x).detach().numpy()

delta = rotated_output - manually_rotated_output

assert np.allclose(manually_rotated_output, rotated_output)
print(delta)

[[-4.4703484e-08  0.0000000e+00  0.0000000e+00  0.0000000e+00
   0.0000000e+00]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
   0.0000000e+00]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
   0.0000000e+00]
 [-4.7683716e-07  0.0000000e+00  0.0000000e+00  0.0000000e+00
   0.0000000e+00]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
   0.0000000e+00]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
   0.0000000e+00]
 [ 5.9604645e-08  0.0000000e+00  0.0000000e+00  0.0000000e+00
   0.0000000e+00]
 [-5.9604645e-08  0.0000000e+00  0.0000000e+00  0.0000000e+00
   0.0000000e+00]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
   0.0000000e+00]
 [-2.3841858e-07  0.0000000e+00  0.0000000e+00  0.0000000e+00
   0.0000000e+00]]
