PyTorch is optimized to perform operations on large tensors. Doing many operations on small tensors is quite inefficient in PyTorch. So, whenever possible you should rewrite your computations in batch form to reduce overhead and improve performance. If there's no way you can manually batch your operations, using TorchScript may improve your code's performance. TorchScript is simply a subset of Python functions that are recognized by PyTorch. PyTorch can automatically optimize your TorchScript code using its just in time (jit) compiler and reduce some overheads.

Let's look at an example. A very common operation in ML applications is "batch gather". This operation can simply written as `output[i] = input[i, index[i]]`. 

Suppose you have a tensor `input` of shape `(3, 4)`:

```
input = [[ 1,  2,  3,  4],
         [ 5,  6,  7,  8],
         [ 9, 10, 11, 12]]
```

And you have an index tensor `index` of shape `(3, 2)`:

```
index = [[0, 2],
         [1, 3],
         [2, 0]]
```

Now, let's say you want to gather elements from `input` according to the indices specified in `index`. 

For example, for the first row of `input`, you want to gather elements at indices `0` and `2`, for the second row, you want to gather elements at indices `1` and `3`, and so on.

Here's how you can use `torch.gather` to achieve this:

```python
import torch

input = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
index = torch.tensor([[0, 2], [1, 3], [2, 0]])

output = torch.gather(input, 1, index)
print(output)
```

This will give you the following output:

```
tensor([[ 1,  3],
        [ 6,  8],
        [11,  9]])
```

Explanation:
- For the first row of `input`, `torch.gather` selects elements at indices `0` and `2`, which are `1` and `3`.
- For the second row of `input`, it selects elements at indices `1` and `3`, which are `6` and `8`.
- For the third row of `input`, it selects elements at indices `2` and `0`, which are `11` and `9`.

`torch.gather` allows you to gather specific elements from a tensor along a specified dimension according to the indices provided in another tensor.


In essence, batch gather enables you to gather elements from multiple batches of tensors based on corresponding indices. This can be useful in various scenarios, such as when you have a batch of sequences (e.g., sequences of words in natural language processing tasks) and you want to gather specific elements from each sequence according to indices provided for each batch.

This can be simply implemented in PyTorch as follows:

In [1]:
import torch
def batch_gather(tensor, indices):
    output = []
    for i in range(tensor.size(0)):
        output += [tensor[i][indices[i]]]
    return torch.stack(output)

To implement the same function using TorchScript simply use the `torch.jit.script` decorator:



In [2]:
@torch.jit.script
def batch_gather_jit(tensor, indices):
    output = []
    for i in range(tensor.size(0)):
        output += [tensor[i][indices[i]]]
    return torch.stack(output)

On my tests this is about 10% faster.

But nothing beats manually batching your operations. A vectorized implementation in my tests is 100 times faster:

In [3]:
def batch_gather_vec(tensor, indices):
    shape = list(tensor.shape)
    flat_first = torch.reshape(
        tensor, [shape[0] * shape[1]] + shape[2:])
    offset = torch.reshape(
        torch.arange(shape[0]).cuda() * shape[1],
        [shape[0]] + [1] * (len(indices.shape) - 1))
    output = flat_first[indices + offset]
    return output