In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import genops
from pystruct.utils import make_grid_edges, generate_binary_edges
from pystruct.inference.lbp import lbp_plus, compute_energy_plus
from pystruct.inference.batch_lbp import batch_lbp
from lunanlp import batch_pad, batch_mask_by_len, batch_lens, chain

Here we construct a batch of sentences, whose lengths are `[10, 8, 5]`.

For each sentence whose length is $n$:

- the number of unary potentials is $n$

- the number of binary potentials is $n-1$

In [3]:
n_states = 7
sizes = (15, 12, 11, 13)
# sizes = np.random.randint(5, 20, size=(3))
bsz = len(sizes)
max_1_size = max(sizes)
max_2_size = max_1_size - 1

genops.set_backend(genops.TORCH)
genops.set_printoptions(precision=2)

b1_phis = genops.normal(shape=(bsz, max_1_size, n_states))
b1_masks = genops.tensor(batch_mask_by_len(sizes)).bool()

edges = [generate_binary_edges(size, 2) for size in sizes]
b2_edges = genops.tensor(batch_pad(edges, pad_ele=(0, 0)))
b2_phis = genops.normal(shape=(bsz, b2_edges.shape[1], n_states, n_states))
b2_masks = genops.tensor(batch_mask_by_len(batch_lens(edges))).bool()


In [7]:
def sequential_lbp():
    ret = []
    for i in range(bsz):
        ret.append(lbp_plus(
            unary_potentials=b1_phis[i][:b1_masks.sum(-1)[i]],
            binary_potentials=b2_phis[i][:b2_masks.sum(-1)[i]],
            binary_edges=b2_edges[i][:b2_masks.sum(-1)[i]],
        ))
    return ret

In [8]:
_, ret = batch_lbp(
    bat_unary_potentials=b1_phis,  #
    bat_unary_masks=b1_masks,
    bat_binary_potentials=b2_phis,
    bat_binary_masks=b2_masks,
    bat_binary_edges=b2_edges)
print(ret)

print(sequential_lbp())

tensor([[ 4,  0,  1,  1,  5,  3,  0,  6,  3,  2,  0,  5,  6,  0,  6],
        [ 6,  4,  3,  1,  5,  2,  3,  2,  0,  4,  1,  0, -1, -1, -1],
        [ 6,  0,  5,  4,  6,  0,  0,  5,  3,  3,  6, -1, -1, -1, -1],
        [ 3,  2,  1,  0,  2,  0,  2,  0,  5,  4,  2,  3,  3, -1, -1]])
[tensor([4, 0, 1, 1, 5, 3, 0, 6, 3, 2, 0, 5, 6, 0, 6]), tensor([6, 4, 3, 1, 5, 2, 3, 2, 0, 4, 1, 0]), tensor([6, 0, 5, 4, 6, 0, 0, 5, 3, 3, 6]), tensor([3, 2, 1, 0, 2, 0, 2, 0, 5, 4, 2, 3, 3])]


In [25]:
%%timeit -n10
batch_lbp(
    bat_unary_potentials=b1_phis,
    bat_unary_masks=b1_masks,
    bat_binary_potentials=b2_phis,
    bat_binary_masks=b2_masks,
    bat_binary_edges=b2_edges
)

10.1 ms ± 1.65 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [26]:
%%timeit -n10
sequential_lbp()

21.7 ms ± 2.06 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
