In [1]:
import brunoflow as bf
from brunoflow.ad.utils import check_node_equals_tensor
from jax import numpy as jnp, random

In [None]:
# hyperparameters
random_key_val = 42
num_embeddings = 5
embedding_dim = 3
padding_idx = 1
random_key = random.PRNGKey(random_key_val)


In [None]:
# create and inspect a torch embedding
import torch
emb = torch.nn.Embedding(num_embeddings=5, embedding_dim=3, padding_idx=1)
emb.weight

Parameter containing:
tensor([[ 0.7230,  2.1811,  0.9185],
        [ 0.0000,  0.0000,  0.0000],
        [-0.4263, -0.4424, -0.3087],
        [-0.8633,  0.3375,  0.3774],
        [-0.5210,  0.2912, -0.6680]], requires_grad=True)

In [None]:
# Create and inspect a bf Embedding
emb_bf = bf.net.Embedding(num_embeddings=5, embedding_dim=3, padding_idx=1)
bf.matmul(emb_bf([2]), jnp.ones(shape=(3, 4))).backprop()
print(emb_bf.weight.grad)

AssertionError: 

In [None]:
# Inspect grad for a non-pad token - note that the grad is the same for bf and torch!
out = torch.matmul(emb(torch.tensor([2])), torch.ones(size=(3, 4), requires_grad=True))
out.backward(gradient=torch.ones_like(out))
assert(jnp.array_equal(emb_bf.weight.grad, emb.weight.grad))
emb.weight.grad


  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


tensor([[0., 0., 0.],
        [0., 0., 0.],
        [4., 4., 4.],
        [0., 0., 0.],
        [0., 0., 0.]])

In [None]:
# Inspect grad for a pad token for torch - note that the grad DID NOT change from before in the 1st (pad token index) row
out = torch.matmul(emb(torch.tensor([1])), torch.ones(size=(3, 4), requires_grad=True))
out.backward(gradient=torch.ones_like(out))
emb.weight.grad

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [4., 4., 4.],
        [0., 0., 0.],
        [0., 0., 0.]])

In [None]:
# Inspect grad for a pad token for bf, note that the grad DID NOT change from before in the pad token index. Same behavior as torch
bf.matmul(emb_bf([1]), jnp.ones(shape=(3, 4))).backprop()
print(emb_bf.weight.grad)
assert(jnp.array_equal(emb_bf.weight.grad, emb.weight.grad.numpy()))

[[0. 0. 0.]
 [0. 0. 0.]
 [4. 4. 4.]
 [0. 0. 0.]
 [0. 0. 0.]]


In [None]:
# Check that bf and torch embedding contain the same modules 
print("bf modules:", [i for i in emb_bf.modules()])
print("torch modules:", [i for i in emb.modules()])

bf modules: [Embedding(5, 3, padding_idx=1)]
torch modules: [Embedding(5, 3, padding_idx=1)]


In [None]:
# Check that we can access parameters of both bf and torch embeddings
print([i for i in emb_bf.parameters()])
print([(i, i.grad) for i in emb.parameters()])

[node(name: emb weights (5, 3), val: [[-0.716899   -0.20865498 -2.5713923 ]
 [ 0.          0.          0.        ]
 [-0.8396519   0.3010434   0.1421263 ]
 [-1.7631724  -1.6755073   0.31390068]
 [ 0.5912831   0.5325395  -0.9133108 ]], grad: [[0. 0. 0.]
 [0. 0. 0.]
 [4. 4. 4.]
 [0. 0. 0.]
 [0. 0. 0.]])]
[(Parameter containing:
tensor([[ 0.2225,  0.1144, -0.3670],
        [ 0.0000,  0.0000,  0.0000],
        [-0.8473, -0.0814, -0.0372],
        [ 0.0992,  0.0027,  0.1115],
        [-0.7348, -0.7790, -0.2488]], requires_grad=True), tensor([[0., 0., 0.],
        [0., 0., 0.],
        [4., 4., 4.],
        [0., 0., 0.],
        [0., 0., 0.]]))]


In [None]:
# Check that buffers are the same between bf and torch embeddings
print([i for i in emb_bf.buffers()])
print([i for i in emb.buffers()])

[]
[]


### Saving a torch embedding state dict and reloading it as bf

In [None]:
save_torch_path = "emb.pt"
torch.save(emb.state_dict(), save_torch_path)

In [None]:
emb_loaded = torch.nn.Embedding(num_embeddings=5, embedding_dim=3, padding_idx=1)
emb_loaded.load_state_dict(torch.load(save_torch_path))


<All keys matched successfully>

In [None]:
# check that bf embedding when loading from a torch state dict ends up with same values
emb_bf_loaded = bf.net.Embedding(num_embeddings=5, embedding_dim=3, padding_idx=1)
emb_bf_loaded.load_state_dict(torch.load(save_torch_path))
print(emb_bf_loaded.weight)
print(emb_loaded.weight)
assert(check_node_equals_tensor(emb_bf_loaded.weight, emb_loaded.weight))

node(name: emb weights (5, 3), val: [[ 0.22245939  0.11442263 -0.3670383 ]
 [ 0.          0.          0.        ]
 [-0.8473086  -0.08135143 -0.03718618]
 [ 0.09916912  0.00267494  0.11154094]
 [-0.7347803  -0.77901185 -0.24880375]], grad: [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]])
Parameter containing:
tensor([[ 0.2225,  0.1144, -0.3670],
        [ 0.0000,  0.0000,  0.0000],
        [-0.8473, -0.0814, -0.0372],
        [ 0.0992,  0.0027,  0.1115],
        [-0.7348, -0.7790, -0.2488]], requires_grad=True)


In [None]:
# Check forward and backward passes to make sure the outputs match
inds = [1,2,3]
embs_bf_out = emb_bf_loaded(jnp.array(inds))
embs_out = emb_loaded(torch.tensor(inds))
print("bf output of emb:", embs_bf_out)
print("torch output of emb:", embs_out)
assert(check_node_equals_tensor(embs_bf_out, embs_out))

embs_bf_out.backprop()
embs_out.backward(gradient=torch.ones_like(embs_out))
# print(embs_bf_out.grad, embs_out.grad)
assert(jnp.array_equal(emb_bf_loaded.weight.grad, emb_loaded.weight.grad.numpy()))
print(emb_bf_loaded.weight.grad, emb_loaded.weight.grad)

bf output of emb: node(name: get_embedding, val: [[ 0.          0.          0.        ]
 [-0.8473086  -0.08135143 -0.03718618]
 [ 0.09916912  0.00267494  0.11154094]], grad: [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]])
torch output of emb: tensor([[ 0.0000,  0.0000,  0.0000],
        [-0.8473, -0.0814, -0.0372],
        [ 0.0992,  0.0027,  0.1115]], grad_fn=<EmbeddingBackward0>)
[[0. 0. 0.]
 [0. 0. 0.]
 [1. 1. 1.]
 [1. 1. 1.]
 [0. 0. 0.]] tensor([[0., 0., 0.],
        [0., 0., 0.],
        [1., 1., 1.],
        [1., 1., 1.],
        [0., 0., 0.]])


In [None]:
emb_bf_loaded = bf.net.Embedding(num_embeddings=5, embedding_dim=3, padding_idx=1)
emb_bf_loaded.load_state_dict(torch.load(save_torch_path))
out = bf.get_embedding(emb_bf_loaded.weight, [0, 1, 2, 3, 4], padding_idx=1)
out.backprop()
print(emb_bf_loaded.weight.grad)

[[1. 1. 1.]
 [0. 0. 0.]
 [1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]
