<a href="https://colab.research.google.com/github/lyutyuh/genbmm/blob/master/genbmminside.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install git+https://github.com/lyutyuh/genbmm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/lyutyuh/genbmm
  Cloning https://github.com/lyutyuh/genbmm to /tmp/pip-req-build-vc88_tu8
  Running command git clone -q https://github.com/lyutyuh/genbmm /tmp/pip-req-build-vc88_tu8
Building wheels for collected packages: genbmm
  Building wheel for genbmm (setup.py) ... [?25l[?25hdone
  Created wheel for genbmm: filename=genbmm-0.1-cp37-cp37m-linux_x86_64.whl size=2126111 sha256=97df70f02a4d93a2afeea366664ac06c07252de4c5cdda6f98a2bd85c5816d38
  Stored in directory: /tmp/pip-ephem-wheel-cache-a1odytd1/wheels/06/1a/23/e1223f7f8c9761cbd1e38c41fddacc6eaef55dc73e89000c44
Successfully built genbmm
Installing collected packages: genbmm
Successfully installed genbmm-0.1


In [2]:
import logging
import math
from typing import Any, Dict, List, Tuple
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
assert torch.cuda.is_available(), "enable CUDA to use genbmm"

from genbmm import logbmminside, logbmminside_rule

In [3]:
def logsumexp(tensor: torch.Tensor, dim: int = -1, keepdim: bool = False) -> torch.Tensor:
    max_score, _ = tensor.max(dim, keepdim=keepdim)
    if keepdim:
        stable_vec = tensor - max_score
    else:
        stable_vec = tensor - max_score.unsqueeze(dim)
        
    return max_score + stable_vec.logsumexp(dim, keepdim=keepdim)

def stripe(x, n, w, offset=(0, 0), horizontal=1):
    x, seq_len = x.contiguous(), x.size(1)
    stride, numel = list(x.stride()), x[0, 0].numel()
    stride[0] = (seq_len + 1) * numel
    stride[1] = (1 if horizontal == 1 else seq_len) * numel
    
    return x.as_strided(
        size=(n, w, *x.shape[2:]), 
        stride=stride,
        storage_offset=(offset[0]*seq_len+offset[1])*numel
    )

In [4]:
# An example from https://github.com/lyutyuh/structured-span-selector
LARGENUMBER = 1e4
class CKY(torch.nn.Module):
    def __init__(
        self,
        max_span_width=30,
    ):
        super().__init__()
        self.max_span_width = max_span_width
        return
    
    def forward(
        self,
        span_mention_score_matrix: torch.FloatTensor, 
        sequence_lengths: torch.IntTensor,
   ) -> Tuple[torch.FloatTensor]:
        
        with torch.autograd.enable_grad():
            # Enable gradients during inference
            return self.io(span_mention_score_matrix, sequence_lengths)
        
    def io(
        self, 
        span_mention_score_matrix: torch.FloatTensor, 
        sequence_lengths: torch.IntTensor,
    ) -> Tuple[torch.FloatTensor]:
        """
            Parameters:
                span_mention_score_matrix: shape (batch_size, sent_len, sent_len, score_dim)
                    Score of each span being a span of interest. There are batch_size number
                    of sentences in this document. And the maximum length of sentence is 
                    sent_len. 
                sequence_lengths: shape (batch_size, )
                    The actual length of each sentence. 
        """
        span_mention_score_matrix.requires_grad_(True)
        
        batch_size, _, _, score_dim = span_mention_score_matrix.size()
        seq_len = sequence_lengths.max()
        # Shape: (batch_size, )
        sequence_lengths = sequence_lengths.view(-1)
        
        # Shape: (seq_len, seq_len, score_dim, batch_size)
        span_mention_score_matrix = span_mention_score_matrix.permute(1, 2, 3, 0)
        
        # There should be another matrix of non-mention span scores, which is full of 0s
        # Shape: (seq_len, seq_len, score_dim + 1, batch_size), 2 for mention / non-mention
        inside_s = span_mention_score_matrix.new_zeros(seq_len, seq_len, score_dim + 1, batch_size)
        
        for width in range(0, seq_len):
            n = seq_len - width
            if width == 0:
                inside_s[:,:,:score_dim,:].diagonal(width).copy_(
                    span_mention_score_matrix.diagonal(width)
                )
                continue

            # [n, width, score_dim + 1, batch_size]
            split_1 = stripe(inside_s, n, width)
            split_2 = stripe(inside_s, n, width, (1, width), 0)

            # [n, width, batch_size]
            inside_s_span = logsumexp(split_1, 2) + logsumexp(split_2, 2)
            # [1, batch_size, n]
            inside_s_span = logsumexp(inside_s_span, 1, keepdim=True).permute(1, 2, 0)
            
            inside_s.diagonal(width).copy_(
                torch.cat(
                    [inside_s_span + span_mention_score_matrix.diagonal(width), # mention
                     inside_s_span],                                            # non-mention
                dim=0
                )
            )

        inside_s = inside_s.permute(0,1,3,2) # (seq_len, seq_len, batch_size, 2), 2 for mention / non-mention
        series_batchsize = torch.arange(0, batch_size, dtype=torch.long)
        
        Z = logsumexp(inside_s[0, sequence_lengths-1, series_batchsize], dim=-1) # (batch_size,)
        
        marginal = torch.autograd.grad(
            Z.sum(),
            span_mention_score_matrix,
            create_graph=True,
            only_inputs=True,
            allow_unused=False,
        )
        marginal = marginal[0].squeeze()
        
        return (Z.view(-1), marginal.permute(2,0,1)) # Shape: (batch_size, seq_len, seq_len, ) 
          
    def coolio(
        self, 
        span_mention_score_matrix: torch.FloatTensor, 
        sequence_lengths: torch.IntTensor,
    ) -> Tuple[torch.FloatTensor]:
        """
            Parameters:
                span_mention_score_matrix: shape (batch_size, sent_len, sent_len, score_dim)
                    Score of each span being a span of interest. There are batch_size number
                    of sentences in this document. And the maximum length of sentence is 
                    sent_len. 
                sequence_lengths: shape (batch_size, )
                    The actual length of each sentence. 
        """
        span_mention_score_matrix.requires_grad_(True)
        
        batch_size, _, _, score_dim = span_mention_score_matrix.size()
        seq_len = sequence_lengths.max()
        # Shape: (batch_size, )
        sequence_lengths = sequence_lengths.view(-1)
        
        rules = span_mention_score_matrix
        log1p_exp_rules = torch.log1p(rules.squeeze(-1).exp())
        
        zero_rules = (rules.new_ones(seq_len, seq_len).tril(diagonal=-1))*(-LARGENUMBER)
        zero_rules = zero_rules.unsqueeze(0).unsqueeze(-1).repeat(batch_size, 1,1,1)
        
        inside_s = torch.cat([rules.clone(), zero_rules], dim=3)
        inside_s = inside_s.logsumexp(dim=3)
            
        for width in range(0, seq_len-1):
            inside_s = logbmminside_rule(inside_s, log1p_exp_rules, width+1)
            
        series_batchsize = torch.arange(0, batch_size, dtype=torch.long)
        Z = inside_s[series_batchsize, 0, sequence_lengths-1] # (batch_size, )
        
        marginal = torch.autograd.grad(
            Z.sum(),
            span_mention_score_matrix,
            create_graph=True,
            only_inputs=True,
            allow_unused=False,
        )
        marginal = marginal[0].squeeze()
        return (Z.view(-1), marginal)  # Shape: (batch_size, seq_len, seq_len, )

In [5]:
mp = CKY()
l = 128
bs= 32
lengthvec = torch.tensor([l]*bs, device="cuda:0")
example = torch.randn(bs,l,l,1,device="cuda:0") + (-LARGENUMBER * (1-torch.ones(l,l, device="cuda:0").triu())).unsqueeze(0).unsqueeze(-1) +\
(-LARGENUMBER * (torch.ones(l,l, device="cuda:0").triu(31))).unsqueeze(0).unsqueeze(-1)

r1 = mp.coolio(example,lengthvec)
r2 = mp.io(example,lengthvec)
print(torch.norm(r1[0] - r2[0]))
print(torch.norm(r1[1] - r2[1]))

tensor(5.2858e-05, device='cuda:0', grad_fn=<CopyBackwards>)
tensor(0.0003, device='cuda:0', grad_fn=<CopyBackwards>)


In [6]:
%%timeit
with torch.autocast(enabled=True, device_type="cuda"):
    r2 = mp.io(example, lengthvec)

1 loop, best of 5: 161 ms per loop


In [7]:
%%timeit
with torch.autocast(enabled=True, device_type="cuda"):
    r1 = mp.coolio(example, lengthvec)

10 loops, best of 5: 21.9 ms per loop
