Discussion on discord:

_Did anyone else find that indexing into the weight matrix was slower than using the "redundant" one-hot encoded vectors? The perf difference was around 8x for me on GPU.
I'm guessing it's faster to let the GPU rip through the matmul than have to do many random memory lookups._

https://discord.com/channels/1020383067459821711/1029849849765564528/1056756317777305710

In [12]:
import torch
import torch.nn.functional as F

In [2]:
row_dimensions = 703 # Trigram word model input: (.., .a, [...], .z, aa, [...], az)
col_dimensions = 27 # Trigram word model output: (., a, [...], z)

In [3]:
# Simulated weight matrix
W = torch.randn([row_dimensions, col_dimensions])

In [4]:
# simulated X matrix, consisting of 1000 random integers between 0 and 703
X = torch.randint(low=0, high = row_dimensions, size=(1000,))

In [5]:
# One hot encoded x_enc matrix
x_enc = F.one_hot(X, num_classes=row_dimensions).float()

In [6]:
# Alternative1: using matrix multiplication between one-hot encoded matrix and W
a1 = x_enc @ W

In [7]:
# Alternative2: directly indexing into W with original X vector
a2 = W[X]

In [8]:
# Make sure the results are the same
assert torch.allclose(a1, a2)

In [9]:
%%timeit
x_enc @ W

416 µs ± 3.94 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [10]:
%%timeit
W[X]

33.3 µs ± 156 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


### Directly indexing into W is roughly 10 times faster then using matrix multiplication

In [11]:
%%timeit
x_enc = F.one_hot(X, num_classes=row_dimensions).float()
x_enc @ W

1.35 ms ± 132 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### If we also consider the duration of creating the one_hot encoded matrix, the difference is even larger, it's 1300 vs 35 $\mu$s, a factor 40!