In [167]:
import torch
from torchhd.functional import bind,multibind,multiset,permute,random_hv,ngrams as ngrams_base

def ngrams_unbind(input, n=3):
    unbinded = torch.unbind(input, -2)
    permuted=[permute(torch.stack(unbinded[i:-(n-1 - i)]), shifts=n - i - 1) for i in range(n-1)]
    permuted.append(torch.stack(unbinded[n-1:]))
    permuted=multibind(torch.stack(permuted, -2))
    return multiset(permuted)

def ngrams_forloop(input, n=3):
    n_gram = permute(input[..., :-(n-1), :], shifts=n-1)
    for i in range(1, n-1):
        n_gram=bind(n_gram,permute(input[..., i:-(n-1 - i), :], shifts=n-1 - i))
    n_gram=bind(n_gram,input[...,n-1:,:])
    return multiset(n_gram)

def ngrams_in_place(input, n=3):
    n_gram = permute(input[..., :-(n-1), :], shifts=n-1)
    for i in range(1, n-1):
        n_gram.mul_(permute(input[..., i:-(n-1 - i), :], shifts=n-1 - i))
    n_gram.mul_(input[...,n-1:,:])
    return multiset(n_gram)

def ngrams_index_select(input, n=3):
    length = input.size(-2)
    n_gram = permute(torch.index_select(input,-2,torch.arange(length-(n-1))), shifts=n-1)
    for i in range(1, n-1):
        n_gram.mul_(permute(torch.index_select(input,-2,torch.arange(i,length-(n-1 - i))), shifts=n-1 - i))
    n_gram.mul_(torch.index_select(input,-2,torch.arange(n-1,length)))
    return multiset(n_gram)

In [168]:
x = random_hv(10, 10000)
%timeit ngrams_base(x)
%timeit ngrams_unbind(x)
%timeit ngrams_forloop(x)
%timeit ngrams_in_place(x)
%timeit ngrams_index_select(x)

134 µs ± 4.1 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
262 µs ± 9.69 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
132 µs ± 3.62 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
122 µs ± 1.74 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
162 µs ± 2.18 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
