# Learning Box Embeddings (PyTorch version) with Example

This tutorial outlines the different functionalities available within the Box Embeddings package (PyTorch version)

### A. Initialize a box tensor and check its parameters

#### Standard Box Tensor
To represent a Tensor as a Box, we use the class `BoxTensor`. The necessary parameter is `data` (a tensor).
Let's create a toy example

In [1]:
import torch
from box_embeddings.parameterizations.box_tensor import BoxTensor

x_min = [-2.0]*50
x_max = [0.0]*50
data_x = torch.tensor([[[1,2],[-1,5]], [[0,2],[-2,3]], [[-3,3],[-2,4]]])
box_1 = BoxTensor(data_x)
box_1

AllenNLP not available. Registrable won't work.


BoxTensor(tensor([[[ 1,  2],
         [-1,  5]],

        [[ 0,  2],
         [-2,  3]],

        [[-3,  3],
         [-2,  4]]]))

We can use several methods from the `BoxTensor` class to look at the parameters of our box, such as

In [2]:
from box_embeddings.parameterizations.box_tensor import BoxTensor
# Lower left coordinate
print(box_1.z)
# Top right coordinate
print(box_1.Z)
# Center coordinate
print(box_1.centre)

tensor([[ 1,  2],
        [ 0,  2],
        [-3,  3]])
tensor([[-1,  5],
        [-2,  3],
        [-2,  4]])
tensor([[ 0.0000,  3.5000],
        [-1.0000,  2.5000],
        [-2.5000,  3.5000]])


Let's broadcast our box to a new shape. Broadcasting is needed for different arithmetic operations. The function we
will use is `broadcast()`, which comes from the `BoxTensor` class, and the required parameter is `target_shape=()`,
which specify the new shape for the box. This is very similar to `numpy.broadcast_to()`

In [3]:
from box_embeddings.parameterizations.box_tensor import BoxTensor
data = torch.tensor([[[1, 2, 3], [3, 4, 6]],
          [[5, 6, 8], [7, 9, 5]]])
box = BoxTensor(data)
print('previous shape:', box.box_shape)
box.broadcast(target_shape=(2, 1, 3))
print('after broadcasting:', box.box_shape)

previous shape: (2, 3)
after broadcasting: (2, 1, 3)


### 2. Box Volume
To calculate the volume of a box, use the all-in-one `Volume` class. The volume is default to be log-scaled, hard volume.
To return regular volume (not logged), set `log_scale=False`. To use soft volume, set `volume_temperature`
to be a non-zero value (default is 0.0). The `HardVolume` and `SoftVolume` methods are registered separately
for convenience.

In [12]:
from box_embeddings.modules.volume import Volume, SoftVolume, HardVolume

# Create data as tensors, and initialize a box
data = torch.tensor([[-2.0]*20, [0.0]*20])
box_1 = BoxTensor(data)

# Logged Hard volume using Volume class
print("Logged Hard volume:", Volume(volume_temperature=0.0)(box_1))

# Logged Hard volume using HardVolume method
print("Logged Hard volume:", HardVolume()(box_1))

# Regular Hard volume using Volume class
print("Regular Hard volume:", Volume(volume_temperature=0.0, log_scale=False)(box_1))
print("-----------")
# Logged Soft volume using Volume class
print("Logged Soft volume:", Volume(volume_temperature=1.0)(box_1))

# Logged Hard volume using SoftVolume method
print("Logged Soft volume:", SoftVolume()(box_1))

Logged Hard volume: tensor(13.8629)
Logged Hard volume: tensor(13.8629)
Regular Hard volume: tensor(1048576.)
-----------
Logged Soft volume: tensor(15.0936)
Logged Soft volume: tensor(15.0936)


### 3. Box Intersection

To calculate the intersection of two boxes (which yields another box defined by a pair of `(z, Z)`,
use the all-in-one `Intersection` module. The intersection from this module is default to be
hard intersection. To use Gumbel intersection, set `intersection_temperature` to be a non-zero value (default is 0.0).
The `HardIntersection` and `GumbelIntersection` methods are also registered separately for convenience.

In [11]:
from box_embeddings.modules.intersection import Intersection, HardIntersection, GumbelIntersection

# Create data as tensors, and initialize two boxes, box_1 and box_2
data_x = torch.tensor([[-2.0]*20, [0.0]*20])
box_1 = BoxTensor(data_x)

y_min = [1/n for n in range(1, 21)]
y_max = [1 - k for k in reversed(y_min)]
data_y = torch.tensor([y_min, y_max], requires_grad=True)
box_2 = BoxTensor(data_y)

# Hard intersection of box_1 and box_2 using the Intersection method
print("Hard intersection:", Intersection(intersection_temperature=0.0)(box_1, box_2))

# Hard intersection of box_1 and box_2 using the HardIntersection method
print("Hard intersection:", HardIntersection()(box_1, box_2))
print("-----------")
# Gumbel intersection of box_1 and box_2 using the Intersection method
print("Gumbel intersection:", Intersection(intersection_temperature=1.0)(box_1, box_2))

# Hard intersection of box_1 and box_2 using the HardIntersection method
print("Gumbel intersection:", GumbelIntersection()(box_1, box_2))

Hard intersection: BoxTensor(z=tensor([1.0000, 0.5000, 0.3333, 0.2500, 0.2000, 0.1667, 0.1429, 0.1250, 0.1111,
        0.1000, 0.0909, 0.0833, 0.0769, 0.0714, 0.0667, 0.0625, 0.0588, 0.0556,
        0.0526, 0.0500], grad_fn=<MaximumBackward>),
Z=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       grad_fn=<MinimumBackward>))
Hard intersection: BoxTensor(z=tensor([1.0000, 0.5000, 0.3333, 0.2500, 0.2000, 0.1667, 0.1429, 0.1250, 0.1111,
        0.1000, 0.0909, 0.0833, 0.0769, 0.0714, 0.0667, 0.0625, 0.0588, 0.0556,
        0.0526, 0.0500], grad_fn=<MaximumBackward>),
Z=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       grad_fn=<MinimumBackward>))
-----------
Gumbel intersection: BoxTensor(z=tensor([1.0486, 0.5789, 0.4259, 0.3502, 0.3051, 0.2751, 0.2538, 0.2378, 0.2254,
        0.2155, 0.2074, 0.2007, 0.1950, 0.1901, 0.1859, 0.1822, 0.1789, 0.1760,
        0.1734, 0.1711], grad_fn=<MulBackward0>),
Z=ten

### 4. Box Containment Training
In the following example, we train a simple box `box_2` to require it to be contained inside another box
`box_1`. The training loop returns the best `box_1` and `box_2`.

In [16]:
import numpy
from box_embeddings.parameterizations.box_tensor import BoxTensor
from box_embeddings.modules.volume.volume import Volume
from box_embeddings.modules.intersection import Intersection

x_z = numpy.array([-2.0 for n in range(1, 21)])
x_Z = numpy.array([0.0 for k in (x_z)])
data_x = torch.tensor([x_z, x_Z], requires_grad=True)
box_1 = BoxTensor(data_x)

y_z = numpy.array([1/n for n in range(1, 21)])
y_Z = numpy.array([1 + k for k in reversed(y_z)])
data_y = torch.tensor([y_z, y_Z], requires_grad=True)
box_2 = BoxTensor(data_y)

# Training loop
learning_rate = 0.1
def train(box_1, box_2, optimizer, epochs=1):
    best_loss = int()
    best_box_1 = None
    best_box_2 = None
    box_vol = Volume(volume_temperature=0.1, intersection_temperature=0.0001)
    box_int = Intersection(intersection_temperature=0.0001)
    for e in range(epochs):
        loss = box_vol(box_2) - box_vol(box_int(box_1, box_2))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if best_loss < loss.item():
            best_loss = loss.item()
            best_box_2 = box_2
            best_box_1 = box_1
        print('Iteration %d, loss = %.4f' % (e, loss.item()))
    return best_box_1, best_box_2

optimizer =  torch.optim.SGD([data_x, data_y], lr=learning_rate)
contained_box1, contained_box2 = train(box_1, box_2, optimizer, epochs=20)

# Print the coordinates of the boxes after training
print(contained_box1)
print(contained_box2)

Iteration 0, loss = 82.9150
Iteration 1, loss = 2.5045
Iteration 2, loss = 1.4673
Iteration 3, loss = 0.8954
Iteration 4, loss = 0.6085
Iteration 5, loss = 0.4437
Iteration 6, loss = 0.3400
Iteration 7, loss = 0.2536
Iteration 8, loss = 0.2115
Iteration 9, loss = 0.1708
Iteration 10, loss = 0.1313
Iteration 11, loss = 0.0927
Iteration 12, loss = 0.0548
Iteration 13, loss = 0.0173
Iteration 14, loss = 0.0000
Iteration 15, loss = 0.0000
Iteration 16, loss = 0.0000
Iteration 17, loss = 0.0000
Iteration 18, loss = 0.0000
Iteration 19, loss = 0.0000
BoxTensor(tensor([[-2.0000, -2.0000, -2.0000, -2.0000, -2.0000, -2.0000, -2.0000, -2.0000,
         -2.0000, -2.0000, -2.0000, -2.0000, -2.0000, -2.0000, -2.0000, -2.0000,
         -2.0000, -2.0000, -2.0000, -2.0000],
        [ 1.0000,  0.9967,  0.9827,  0.9616,  1.0032,  0.9825,  0.9645,  0.9491,
          0.9359,  0.9876,  0.9781,  0.9699,  0.9627,  0.9563,  1.0121,  1.0069,
          1.0021,  1.0561,  1.1607,  1.4783]], dtype=torch.float64,
 