In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

from torch_random_fields.utils.misc import chain, make_grid_edges, generate_edges, batch_pad, batch_mask_by_len, batch_lens
from torch_random_fields.utils.loopy_belief_propagation import loopy_belief_propagation
from torch_random_fields.utils.batch_lbp import batch_lbp

Here we construct a batch of sentences, whose lengths are `[10, 8, 5]`.

For each sentence whose length is $n$:

- the number of unary potentials is $n$

- the number of binary potentials is $n-1$

In [3]:
n_states = 7
sizes = (15, 12, 11, 13)
# sizes = np.random.randint(5, 20, size=(3))
bsz = len(sizes)
max_1_size = max(sizes)
max_2_size = max_1_size - 1

torch.set_printoptions(precision=2)

b1_phis = torch.normal(mean=0.0, std=1.0, size=(bsz, max_1_size, n_states))
b1_masks = torch.tensor(batch_mask_by_len(sizes)).bool()

edges = [generate_edges(size, 2) for size in sizes]
b2_edges = torch.tensor(batch_pad(edges, pad_ele=(0, 0)))
b2_phis = torch.normal(mean=0.0, std=1.0, size=(bsz, b2_edges.shape[1], n_states, n_states))
b2_masks = torch.tensor(batch_mask_by_len(batch_lens(edges))).bool()


In [4]:
def non_batch_lbp():
    ret = []
    for i in range(bsz):
        ret.append(loopy_belief_propagation(
            unary_potentials=b1_phis[i][:b1_masks.sum(-1)[i]],
            binary_potentials=b2_phis[i][:b2_masks.sum(-1)[i]],
            binary_edges=b2_edges[i][:b2_masks.sum(-1)[i]],
        ).tolist())
    ret = batch_pad(ret, -1)
    return torch.tensor(ret)

In [5]:
_, ret1 = batch_lbp(
    bat_unary_potentials=b1_phis,  #
    bat_unary_masks=b1_masks,
    bat_binary_potentials=b2_phis,
    bat_binary_masks=b2_masks,
    bat_binary_edges=b2_edges)
print(ret1)

ret2 = non_batch_lbp()
print(ret2)

print((ret1 - ret2).sum())

tensor([[ 3,  3,  4,  4,  6,  1,  5,  2,  4,  0,  4,  5,  0,  2,  0],
        [ 2,  6,  1,  1,  3,  1,  5,  3,  4,  0,  1,  0, -1, -1, -1],
        [ 3,  1,  5,  6,  6,  0,  1,  2,  3,  1,  1, -1, -1, -1, -1],
        [ 5,  0,  1,  4,  3,  2,  1,  3,  6,  2,  3,  6,  2, -1, -1]])
tensor([[ 3,  3,  4,  4,  6,  1,  5,  2,  4,  0,  4,  5,  0,  2,  0],
        [ 2,  6,  1,  1,  3,  1,  5,  3,  4,  0,  1,  0, -1, -1, -1],
        [ 3,  1,  5,  6,  6,  0,  1,  2,  3,  1,  1, -1, -1, -1, -1],
        [ 5,  0,  1,  4,  3,  2,  1,  3,  6,  2,  3,  6,  2, -1, -1]])
tensor(0)


In [6]:
%%timeit -n10
batch_lbp(
    bat_unary_potentials=b1_phis,
    bat_unary_masks=b1_masks,
    bat_binary_potentials=b2_phis,
    bat_binary_masks=b2_masks,
    bat_binary_edges=b2_edges
)

11.5 ms ± 1.28 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [8]:
%%timeit -n10
non_batch_lbp()

30.3 ms ± 2.28 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
