In [1]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
import brunoflow as bf
from brunoflow.ad.utils import check_node_equals_tensor, check_node_allclose_tensor
from jax import numpy as jnp, random

env: XLA_PYTHON_CLIENT_PREALLOCATE=false


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# hyperparameters
random_key_val = 42
num_embeddings = 5
embedding_dim = 3
padding_idx = 1
random_key = random.PRNGKey(random_key_val)


In [3]:
# create and inspect a torch embedding
import torch
emb = torch.nn.Embedding(num_embeddings=5, embedding_dim=3, padding_idx=1)
emb.weight

Parameter containing:
tensor([[-0.2964,  0.1638, -1.2453],
        [ 0.0000,  0.0000,  0.0000],
        [-0.3869, -1.0141,  0.5715],
        [-1.1094, -0.8435,  0.0700],
        [-0.8782, -1.0346,  0.3617]], requires_grad=True)

In [4]:
# Create and inspect a bf Embedding
emb_bf = bf.net.Embedding(num_embeddings=5, embedding_dim=3, padding_idx=1)
bf.matmul(emb_bf(jnp.array([2])), jnp.ones(shape=(3, 4))).backprop()
print(emb_bf.weight.grad)

[[0. 0. 0.]
 [0. 0. 0.]
 [4. 4. 4.]
 [0. 0. 0.]
 [0. 0. 0.]]


In [5]:
# Inspect grad for a non-pad token - note that the grad is the same for bf and torch!
out = torch.matmul(emb(torch.tensor([2])), torch.ones(size=(3, 4), requires_grad=True))
out.backward(gradient=torch.ones_like(out))
assert(jnp.array_equal(emb_bf.weight.grad, emb.weight.grad))
emb.weight.grad


tensor([[0., 0., 0.],
        [0., 0., 0.],
        [4., 4., 4.],
        [0., 0., 0.],
        [0., 0., 0.]])

In [6]:
# Inspect grad for a pad token for torch - note that the grad DID NOT change from before in the 1st (pad token index) row
out = torch.matmul(emb(torch.tensor([1])), torch.ones(size=(3, 4), requires_grad=True))
out.backward(gradient=torch.ones_like(out))
emb.weight.grad

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [4., 4., 4.],
        [0., 0., 0.],
        [0., 0., 0.]])

In [7]:
# Inspect grad for a pad token for bf, note that the grad DID NOT change from before in the pad token index. Same behavior as torch
bf.matmul(emb_bf(jnp.array([1])), jnp.ones(shape=(3, 4))).backprop()
print(emb_bf.weight.grad)
assert(jnp.array_equal(emb_bf.weight.grad, emb.weight.grad.numpy()))

[[0. 0. 0.]
 [0. 0. 0.]
 [4. 4. 4.]
 [0. 0. 0.]
 [0. 0. 0.]]


In [8]:
# Check that bf and torch embedding contain the same modules 
print("bf modules:", [i for i in emb_bf.modules()])
print("torch modules:", [i for i in emb.modules()])

bf modules: [Embedding(5, 3, padding_idx=1)]
torch modules: [Embedding(5, 3, padding_idx=1)]


In [9]:
# Check that we can access parameters of both bf and torch embeddings
print([i for i in emb_bf.parameters()])
print([(i, i.grad) for i in emb.parameters()])

[node(name: emb weights (5, 3), val: [[-0.716899   -0.20865498 -2.5713923 ]
 [ 0.          0.          0.        ]
 [-0.8396519   0.3010434   0.1421263 ]
 [-1.7631724  -1.6755074   0.31390068]
 [ 0.5912831   0.5325395  -0.9133108 ]], grad: [[0. 0. 0.]
 [0. 0. 0.]
 [4. 4. 4.]
 [0. 0. 0.]
 [0. 0. 0.]])]
[(Parameter containing:
tensor([[-0.2964,  0.1638, -1.2453],
        [ 0.0000,  0.0000,  0.0000],
        [-0.3869, -1.0141,  0.5715],
        [-1.1094, -0.8435,  0.0700],
        [-0.8782, -1.0346,  0.3617]], requires_grad=True), tensor([[0., 0., 0.],
        [0., 0., 0.],
        [4., 4., 4.],
        [0., 0., 0.],
        [0., 0., 0.]]))]


In [10]:
# Check that buffers are the same between bf and torch embeddings
print([i for i in emb_bf.buffers()])
print([i for i in emb.buffers()])

[]
[]


### Saving a torch embedding state dict and reloading it as bf

In [11]:
save_torch_path = "emb.pt"
torch.save(emb.state_dict(), save_torch_path)

In [12]:
emb_loaded = torch.nn.Embedding(num_embeddings=5, embedding_dim=3, padding_idx=1)
emb_loaded.load_state_dict(torch.load(save_torch_path))


<All keys matched successfully>

In [13]:
# check that bf embedding when loading from a torch state dict ends up with same values
emb_bf_loaded = bf.net.Embedding(num_embeddings=5, embedding_dim=3, padding_idx=1)
emb_bf_loaded.load_state_dict(torch.load(save_torch_path))
print(emb_bf_loaded.weight)
print(emb_loaded.weight)
assert(check_node_equals_tensor(emb_bf_loaded.weight, emb_loaded.weight))

node(name: emb weights (5, 3), val: [[-0.2964264   0.1638094  -1.2453288 ]
 [ 0.          0.          0.        ]
 [-0.38688654 -1.0140668   0.5715288 ]
 [-1.1094499  -0.84354985  0.0699667 ]
 [-0.8781804  -1.0346296   0.3616861 ]], grad: [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]])
Parameter containing:
tensor([[-0.2964,  0.1638, -1.2453],
        [ 0.0000,  0.0000,  0.0000],
        [-0.3869, -1.0141,  0.5715],
        [-1.1094, -0.8435,  0.0700],
        [-0.8782, -1.0346,  0.3617]], requires_grad=True)


In [15]:
emb_bf.weight[jnp.array([2,2,2])]

node(name: (getitem emb weights (5, 3) arg=[2 2 2]), val: [[-0.8396519  0.3010434  0.1421263]
 [-0.8396519  0.3010434  0.1421263]
 [-0.8396519  0.3010434  0.1421263]], grad: [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]])

### BF and Torch equality tests

In [None]:
### 1D ARRAY
# Check forward and backward passes to make sure the outputs match
inds = [0, 1, 2, 2, 3, 2]
emb_bf_loaded.zero_grad()
emb_loaded.zero_grad()
embs_bf_out = emb_bf_loaded(bf.Node(jnp.array(inds), name="INDICES"))
embs_out = emb_loaded(torch.tensor(inds))

embs_bf_out = bf.reduce_sum(embs_bf_out)
embs_out = torch.sum(embs_out)

print("bf output of emb:", embs_bf_out)
print("torch output of emb:", embs_out)
assert(check_node_allclose_tensor(embs_bf_out, embs_out))

embs_bf_out.backprop(values_to_compute=("grad",), verbose=False)
embs_out.backward(gradient=torch.ones_like(embs_out))
# print(embs_bf_out.grad, embs_out.grad)
print("\nbf grad:", emb_bf_loaded.weight.grad)
print("torch grad:", emb_loaded.weight.grad.numpy())
assert(jnp.array_equal(emb_bf_loaded.weight.grad, emb_loaded.weight.grad.numpy()))
print("bf loaded and original grads:", emb_bf_loaded.weight.grad, emb_loaded.weight.grad)

In [None]:
### 2D ARRAY
# Check forward and backward passes to make sure the outputs match
inds = [[0, 1, 2, 2, 2, 3], [2, 3, 4, 1, 1, 1]]
emb_bf_loaded.zero_grad()
emb_loaded.zero_grad()
embs_bf_out = emb_bf_loaded(bf.Node(jnp.array(inds), name="INDICES"))
embs_out = emb_loaded(torch.tensor(inds))

embs_bf_out = bf.reduce_sum(embs_bf_out)
embs_out = torch.sum(embs_out)

print("bf output of emb:", embs_bf_out)
print("torch output of emb:", embs_out)
assert(check_node_allclose_tensor(embs_bf_out, embs_out))

embs_bf_out.backprop(values_to_compute=("grad",), verbose=False)
embs_out.backward(gradient=torch.ones_like(embs_out))
# print(embs_bf_out.grad, embs_out.grad)
print("\nbf grad:", emb_bf_loaded.weight.grad)
print("torch grad:", emb_loaded.weight.grad.numpy())
assert(jnp.array_equal(emb_bf_loaded.weight.grad, emb_loaded.weight.grad.numpy()))
print("bf loaded and original grads:", emb_bf_loaded.weight.grad, emb_loaded.weight.grad)

In [20]:
### 2D ARRAY where one dim value is 1
# Check forward and backward passes to make sure the outputs match
inds = [[0, 1, 2, 2, 3, 2]]
emb_bf_loaded.zero_grad()
emb_loaded.zero_grad()
embs_bf_out = emb_bf_loaded(bf.Node(jnp.array(inds), name="INDICES"))
embs_out = emb_loaded(torch.tensor(inds))

embs_bf_out = bf.reduce_sum(embs_bf_out)
embs_out = torch.sum(embs_out)

print("bf output of emb:", embs_bf_out)
print("torch output of emb:", embs_out)
assert(check_node_allclose_tensor(embs_bf_out, embs_out))

embs_bf_out.backprop(values_to_compute=("grad",), verbose=False)
embs_out.backward(gradient=torch.ones_like(embs_out))
# print(embs_bf_out.grad, embs_out.grad)
print("\nbf grad:", emb_bf_loaded.weight.grad)
print("torch grad:", emb_loaded.weight.grad.numpy())
assert(jnp.array_equal(emb_bf_loaded.weight.grad, emb_loaded.weight.grad.numpy()))
print("bf loaded and original grads:", emb_bf_loaded.weight.grad, emb_loaded.weight.grad)

bf output of emb: node(name: (sum get_embedding axis=None keepdims=None), val: -5.749251842498779, grad: 0.0)
torch output of emb: tensor(-5.7493, grad_fn=<SumBackward0>)

bf grad: [[1. 1. 1.]
 [0. 0. 0.]
 [3. 3. 3.]
 [1. 1. 1.]
 [0. 0. 0.]]
torch grad: [[1. 1. 1.]
 [0. 0. 0.]
 [3. 3. 3.]
 [1. 1. 1.]
 [0. 0. 0.]]
bf loaded and original grads: [[1. 1. 1.]
 [0. 0. 0.]
 [3. 3. 3.]
 [1. 1. 1.]
 [0. 0. 0.]] tensor([[1., 1., 1.],
        [0., 0., 0.],
        [3., 3., 3.],
        [1., 1., 1.],
        [0., 0., 0.]])


In [22]:
### 2D ARRAY where other dim value is 1
# Check forward and backward passes to make sure the outputs match
inds = jnp.array([[0, 1, 2, 2, 3, 2]]).transpose().__array__()
emb_bf_loaded.zero_grad()
emb_loaded.zero_grad()
embs_bf_out = emb_bf_loaded(bf.Node(jnp.array(inds), name="INDICES"))
embs_out = emb_loaded(torch.tensor(inds))

embs_bf_out = bf.reduce_sum(embs_bf_out)
embs_out = torch.sum(embs_out)

print("bf output of emb:", embs_bf_out)
print("torch output of emb:", embs_out)
assert(check_node_allclose_tensor(embs_bf_out, embs_out))

embs_bf_out.backprop(values_to_compute=("grad",), verbose=False)
embs_out.backward(gradient=torch.ones_like(embs_out))
# print(embs_bf_out.grad, embs_out.grad)
print("\nbf grad:", emb_bf_loaded.weight.grad)
print("torch grad:", emb_loaded.weight.grad.numpy())
assert(jnp.array_equal(emb_bf_loaded.weight.grad, emb_loaded.weight.grad.numpy()))
print("bf loaded and original grads:", emb_bf_loaded.weight.grad, emb_loaded.weight.grad)

bf output of emb: node(name: (sum get_embedding axis=None keepdims=None), val: -5.749251842498779, grad: 0.0)
torch output of emb: tensor(-5.7493, grad_fn=<SumBackward0>)

bf grad: [[1. 1. 1.]
 [0. 0. 0.]
 [3. 3. 3.]
 [1. 1. 1.]
 [0. 0. 0.]]
torch grad: [[1. 1. 1.]
 [0. 0. 0.]
 [3. 3. 3.]
 [1. 1. 1.]
 [0. 0. 0.]]
bf loaded and original grads: [[1. 1. 1.]
 [0. 0. 0.]
 [3. 3. 3.]
 [1. 1. 1.]
 [0. 0. 0.]] tensor([[1., 1., 1.],
        [0., 0., 0.],
        [3., 3., 3.],
        [1., 1., 1.],
        [0., 0., 0.]])


In [23]:
### 3D ARRAY
# Check forward and backward passes to make sure the outputs match
inds = [[0, 1, 2, 2, 2, 3], [2, 3, 4, 1, 1, 1], [1, 0, 1, 1, 2, 1]]
emb_bf_loaded.zero_grad()
emb_loaded.zero_grad()
embs_bf_out = emb_bf_loaded(bf.Node(jnp.array(inds), name="INDICES"))
embs_out = emb_loaded(torch.tensor(inds))

embs_bf_out = bf.reduce_sum(embs_bf_out)
embs_out = torch.sum(embs_out)

print("bf output of emb:", embs_bf_out)
print("torch output of emb:", embs_out)
assert(check_node_allclose_tensor(embs_bf_out, embs_out))

embs_bf_out.backprop(values_to_compute=("grad",), verbose=False)
embs_out.backward(gradient=torch.ones_like(embs_out))
# print(embs_bf_out.grad, embs_out.grad)
print("\nbf grad:", emb_bf_loaded.weight.grad)
print("torch grad:", emb_loaded.weight.grad.numpy())
assert(jnp.array_equal(emb_bf_loaded.weight.grad, emb_loaded.weight.grad.numpy()))
print("bf loaded and original grads:", emb_bf_loaded.weight.grad, emb_loaded.weight.grad)

bf output of emb: node(name: (sum get_embedding axis=None keepdims=None), val: -12.22020435333252, grad: 0.0)
torch output of emb: tensor(-12.2202, grad_fn=<SumBackward0>)

bf grad: [[2. 2. 2.]
 [0. 0. 0.]
 [5. 5. 5.]
 [2. 2. 2.]
 [1. 1. 1.]]
torch grad: [[2. 2. 2.]
 [0. 0. 0.]
 [5. 5. 5.]
 [2. 2. 2.]
 [1. 1. 1.]]
bf loaded and original grads: [[2. 2. 2.]
 [0. 0. 0.]
 [5. 5. 5.]
 [2. 2. 2.]
 [1. 1. 1.]] tensor([[2., 2., 2.],
        [0., 0., 0.],
        [5., 5., 5.],
        [2., 2., 2.],
        [1., 1., 1.]])
