In [343]:
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
from eight_mile.pytorch.layers import sequence_mask

In [344]:
B = 6
T = 4
H = 5

In [345]:
def make_batch(B, T, H):
    b = np.arange(1, B + 1).reshape(-1, 1)
    t = np.arange(1, T + 1).reshape(1, -1) / 10
    h = (np.arange(1, H + 1) / 100)
    batch = (np.expand_dims(b + t, -1) + h).astype(np.float32)
    lengths = make_lengths(B, T)
    #for i, l in enumerate(lengths):
    #    batch[i, l:] = 0
    return batch, lengths

def make_lengths(B, T):
    lengths = np.full((B,), T)
    idx = np.random.randint(0, B, size=(B // 2 + 1))
    lengths[idx] = np.random.randint(1, T, size=(idx.shape))
    return lengths

In [346]:
data, lengths = make_batch(B, T, H)

data = torch.from_numpy(data)
lengths = torch.from_numpy(lengths)

mask = torch.from_numpy(np.array(
    [
        [0, 1, 1, 0],
        [0, 1, 1, 1],
        [1, 1, 0, 0],
        [1, 1, 1, 1],
        [0, 1, 0, 0],
        [0, 1, 1, 1],
    ], dtype=np.bool
))

In [347]:
print(data)
print(mask)

tensor([[[1.1100, 1.1200, 1.1300, 1.1400, 1.1500],
         [1.2100, 1.2200, 1.2300, 1.2400, 1.2500],
         [1.3100, 1.3200, 1.3300, 1.3400, 1.3500],
         [1.4100, 1.4200, 1.4300, 1.4400, 1.4500]],

        [[2.1100, 2.1200, 2.1300, 2.1400, 2.1500],
         [2.2100, 2.2200, 2.2300, 2.2400, 2.2500],
         [2.3100, 2.3200, 2.3300, 2.3400, 2.3500],
         [2.4100, 2.4200, 2.4300, 2.4400, 2.4500]],

        [[3.1100, 3.1200, 3.1300, 3.1400, 3.1500],
         [3.2100, 3.2200, 3.2300, 3.2400, 3.2500],
         [3.3100, 3.3200, 3.3300, 3.3400, 3.3500],
         [3.4100, 3.4200, 3.4300, 3.4400, 3.4500]],

        [[4.1100, 4.1200, 4.1300, 4.1400, 4.1500],
         [4.2100, 4.2200, 4.2300, 4.2400, 4.2500],
         [4.3100, 4.3200, 4.3300, 4.3400, 4.3500],
         [4.4100, 4.4200, 4.4300, 4.4400, 4.4500]],

        [[5.1100, 5.1200, 5.1300, 5.1400, 5.1500],
         [5.2100, 5.2200, 5.2300, 5.2400, 5.2500],
         [5.3100, 5.3200, 5.3300, 5.3400, 5.3500],
         [5.4100, 5.420

In [389]:
@torch.jit.script
def to_dense_index(idx: torch.Tensor, T: int) -> torch.Tensor:
    out: torch.Tensor = torch.zeros((idx.size(0),)).to(torch.long)
    prev: torch.Tensor = torch.full((), -1).to(torch.long)
    j: int = 0
    for i in range(out.size(0)):
        if idx[i, 0] == prev:
            j += 1
        else:
            prev = idx[i, 0]
            j = 0
        out[i] = prev * T + j
    return out

@torch.jit.script
def span_select(tensor: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    B: int = tensor.size(0)
    T: int = tensor.size(1)
    H: int = tensor.size(2)
    
    src = torch.nonzero(mask)
    dense = to_dense_index(src, T)
    
    indices = torch.zeros((B * T, 2)).to(torch.long)
    dense = dense.view(src.size(0), 1).expand_as(src)
    indices = torch.scatter(indices, 0, dense, src)
    offset_mask = torch.arange(2) >= 1
    offset = torch.full(indices.shape, T).to(indices.dtype).masked_fill(offset_mask.unsqueeze(0), 1)
    indices = torch.sum(indices * offset, dim=1)
    
    tensor = tensor.view(B * T, -1)
    tensor = torch.gather(tensor.view(B * T, -1), 0, indices.view(-1, 1).expand_as(tensor))
    tensor = tensor.view(B, T, -1)
    
    span_lengths = torch.sum(mask, dim=1)
    mask = (torch.arange(0, T).to(span_lengths.dtype).view(1, T) < span_lengths.view(B, 1)).unsqueeze(2)
    
    tensor = tensor.masked_fill(mask == 0, 0)
    
    return tensor, span_lengths

In [390]:
t, l = span_select(data, mask)
print(t)
print(l)

tensor([[[1.2100, 1.2200, 1.2300, 1.2400, 1.2500],
         [1.3100, 1.3200, 1.3300, 1.3400, 1.3500],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[2.2100, 2.2200, 2.2300, 2.2400, 2.2500],
         [2.3100, 2.3200, 2.3300, 2.3400, 2.3500],
         [2.4100, 2.4200, 2.4300, 2.4400, 2.4500],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[3.1100, 3.1200, 3.1300, 3.1400, 3.1500],
         [3.2100, 3.2200, 3.2300, 3.2400, 3.2500],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[4.1100, 4.1200, 4.1300, 4.1400, 4.1500],
         [4.2100, 4.2200, 4.2300, 4.2400, 4.2500],
         [4.3100, 4.3200, 4.3300, 4.3400, 4.3500],
         [4.4100, 4.4200, 4.4300, 4.4400, 4.4500]],

        [[5.2100, 5.2200, 5.2300, 5.2400, 5.2500],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.000

In [391]:
class Example(nn.Module):
    def forward(self, input):
        tensor, mask = input
        return span_select(tensor, mask)

In [392]:
dummy_input = torch.rand(10, 20, 30)
dummy_mask = torch.randint(0, 2, (10, 20)).to(torch.bool)
input_names = ['tensor', 'mask']
output_names = ['tensor', 'lengths']
dynamic_axes = {
    'tensor': [0, 1],
    'mask': [0, 1]
}

model = Example()

In [393]:
model([data, mask])

(tensor([[[1.2100, 1.2200, 1.2300, 1.2400, 1.2500],
          [1.3100, 1.3200, 1.3300, 1.3400, 1.3500],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
 
         [[2.2100, 2.2200, 2.2300, 2.2400, 2.2500],
          [2.3100, 2.3200, 2.3300, 2.3400, 2.3500],
          [2.4100, 2.4200, 2.4300, 2.4400, 2.4500],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
 
         [[3.1100, 3.1200, 3.1300, 3.1400, 3.1500],
          [3.2100, 3.2200, 3.2300, 3.2400, 3.2500],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
 
         [[4.1100, 4.1200, 4.1300, 4.1400, 4.1500],
          [4.2100, 4.2200, 4.2300, 4.2400, 4.2500],
          [4.3100, 4.3200, 4.3300, 4.3400, 4.3500],
          [4.4100, 4.4200, 4.4300, 4.4400, 4.4500]],
 
         [[5.2100, 5.2200, 5.2300, 5.2400, 5.2500],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],


In [394]:
torch.onnx.export(
    model,
    [dummy_input, dummy_mask],
    'example.onnx',
    verbose=True,
    input_names=input_names,
    output_names=output_names
)

KeyError: 'prim_dtype'

In [395]:
import onnxruntime as ort

ort_session = ort.InferenceSession('example.onnx')

outputs = ort.sess.run(None, {'tensor': data, 'mask': mask})

Fail: [ONNXRuntimeError] : 1 : FAIL : Exception during loading: /onnxruntime_src/onnxruntime/core/graph/graph.cc:2486 onnxruntime::common::Status onnxruntime::Graph::SetGraphInputsOutputs() node_arg was false. Graph ctor should have created NodeArg for initializer.
