In [1]:
!pip install gtn

Collecting gtn
  Downloading gtn-0.0.0.tar.gz (45 kB)
[K     |████████████████████████████████| 45 kB 3.6 MB/s eta 0:00:01
[?25hBuilding wheels for collected packages: gtn
  Building wheel for gtn (setup.py) ... [?25ldone
[?25h  Created wheel for gtn: filename=gtn-0.0.0-cp38-cp38-macosx_11_0_x86_64.whl size=517107 sha256=264868212547fc568be5ed19ab970117329faa0f60e29f877ec1fbda46f0f23a
  Stored in directory: /Users/r2q2/Library/Caches/pip/wheels/e1/8e/fa/f19e40c5750bc992a5214c96123a1c19a92082fe6d45605da2
Successfully built gtn
Installing collected packages: gtn
Successfully installed gtn-0.0.0


In [None]:
class GTNLossFunction(torch.autograd.Function):
    """
    A minimal example of adding a custom loss function built with GTN graphs to
    PyTorch.

    The example is a sequence criterion which computes a loss between a
    frame-level input and a token-level target. The tokens in the target can
    align to one or more frames in the input.
    """
    @staticmethod
    def forward(ctx, inputs, targets):
        B, T, C = inputs.shape
        losses = [None] * B
        emissions_graphs = [None] * B

        # Move data to the host as GTN operations run on the CPU:
        device = inputs.device
        inputs = inputs.cpu()
        targets = targets.cpu()

        # Compute the loss for the b-th example:
        def forward_single(b):
            emissions = gtn.linear_graph(T, C, inputs.requires_grad)
            # *NB* A reference to the `data` should be held explicitly when
            # using `data_ptr()` otherwise the memory may be claimed before the
            # weights are set. For example, the following is undefined and will
            # likely cause serious issues:
            #   `emissions.set_weights(inputs[b].contiguous().data_ptr())`
            data = inputs[b].contiguous()
            emissions.set_weights(data.data_ptr())

            target = GTNLossFunction.make_target_graph(targets[b])

            # Score the target:
            target_score = gtn.forward_score(gtn.intersect(target, emissions))

            # Normalization term:
            norm = gtn.forward_score(emissions)

            # Compute the loss:
            loss = gtn.subtract(norm, target_score)

            # We need the save the `loss` graph to call `gtn.backward` and we
            # need the `emissions` graph to access the gradients:
            losses[b] = loss
            emissions_graphs[b] = emissions

        # Compute the loss in parallel over the batch:
        gtn.parallel_for(forward_single, range(B))

        # Save some graphs and other data for backward:
        ctx.auxiliary_data = (losses, emissions_graphs, inputs.shape)

        # Put losses back in a torch tensor and move them  back to the device:
        return torch.tensor([l.item() for l in losses]).to(device)