In [1]:
import os, sys
import torch as t
import einops

In [4]:
from eindex import eindex

# It was going into workspace settings.json (i.e. `.vscode/settings.json` in this folder) and
# adding the

## Short note on paths

I needed to go into `.vscode/settings.json` *in this directory*, and add the following:

```json
{
    "python.analysis.extraPaths": [
        "C:/Users/calsm/Documents/AI Alignment/my_eindex/eindex"
    ],
    "python.analysis.include": [
        "C:/Users/calsm/Documents/AI Alignment/my_eindex/eindex"
    ]
}
```

I neither know nor care which of these actually helped.

Also, this is the kind of structure I need:

```
.
├── demo.ipynb
└── utils/
    ├── __init__.py
    └── util.py
```

because then I can do things like `from utils import util` or `from utils.util import *` from within the demo notebook.



---

I've written a function `einops.index`, which has some sweet notation for indexing.

Not sure what the endgame here is - I assume have it on my own repo which I'll call `eindex`, rather than having it be a fork of einops? idk, ask Arthur.

# Test 1 - indexing correct logits

```
output[batch, seq] = logrobs[batch, seq, labels[batch, seq]]
```

In [6]:
BATCH_SIZE = 32
SEQ_LEN = 5
D_VOCAB = 100

logprobs = t.randn(BATCH_SIZE, SEQ_LEN, D_VOCAB).log_softmax(-1)
labels = t.randint(0, D_VOCAB, (BATCH_SIZE, SEQ_LEN))

# Normal method (at least the normal method that I use!)

batch_idx = einops.repeat(t.arange(BATCH_SIZE), "b -> b s", s=SEQ_LEN)
seq_idx = einops.repeat(t.arange(SEQ_LEN), "s -> b s", b=BATCH_SIZE)

output_1 = logprobs[batch_idx, seq_idx, labels]

# New method, using eindex

output_2 = eindex(logprobs, labels, "batch seq [batch seq]")

assert t.allclose(output_1, output_2)

# Test 2 - same, but two `d_vocab_out` dims

```
output[batch, seq] = logrobs[batch, seq, labels[batch, seq, 0], labels[batch, seq, 1]]
```

In [7]:
D_VOCAB_OUT_1 = 100
D_VOCAB_OUT_2 = 100
D_VOCAB_ALL = D_VOCAB_OUT_1 * D_VOCAB_OUT_2

logprobs_shape = BATCH_SIZE, SEQ_LEN, D_VOCAB_OUT_1, D_VOCAB_OUT_2
logprobs = t.randn(BATCH_SIZE, SEQ_LEN, D_VOCAB_ALL).log_softmax(-1).reshape(logprobs_shape)
labels = t.randint(0, D_VOCAB, (BATCH_SIZE, SEQ_LEN, 2))

# Normal method

batch_idx = einops.repeat(t.arange(BATCH_SIZE), "b -> b s", s=SEQ_LEN)
seq_idx = einops.repeat(t.arange(SEQ_LEN), "s -> b s", b=BATCH_SIZE)

output_1 = logprobs[batch_idx, seq_idx, labels[:, :, 0], labels[:, :, 1]]

# New method, using eindex

output_2 = eindex(logprobs, labels, "batch seq [batch seq 0] [batch seq 1]")

assert t.allclose(output_1, output_2)

# Test 3 - indexing logprobs along sequence dim

```
output[batch, d_vocab] = logprobs[batch, seq_indices[batch], d_vocab]
```

In [8]:
logprobs = t.randn(BATCH_SIZE, SEQ_LEN, D_VOCAB).log_softmax(-1)
seq_indices = t.randint(0, SEQ_LEN, (BATCH_SIZE,))

# Normal method

output_1 = logprobs[range(BATCH_SIZE), seq_indices]

# New method, using eindex

output_2 = eindex(logprobs, seq_indices, "batch [batch] d_vocab")

assert t.allclose(output_1, output_2)