In [84]:
import random
import torch

In [85]:
def generate_random_tokens(num_lists, min_length=10, max_length=256, min_value=0, max_value=32000):
    results = []
    for _ in range(num_lists):
        # Generate a random length
        list_length = random.randint(min_length, max_length)

        # Generate a list of random integers with the generated length
        random_tokens = [random.randint(min_value, max_value) for _ in range(list_length)]
        results.append(random_tokens)

    return results

In [86]:
max_gen_len = 10

# randomly sample a prompt to start a new episode
prompt_tokens = generate_random_tokens(2, 2, 5)

bsz = len(prompt_tokens)

min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
assert max_prompt_len <= max_gen_len
total_len = min(max_gen_len, max_gen_len + max_prompt_len)

pad_id = -1
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
for k, t in enumerate(prompt_tokens):
    tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)

prev_pos = 0
eos_reached = torch.tensor([False] * bsz)
input_text_mask = tokens != pad_id

In [87]:
print(input_text_mask)

tensor([[ True,  True, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True, False, False, False, False, False]])


In [88]:
for i, toks in enumerate(prompt_tokens):
    rand_len = random.randint(1, max_gen_len - len(toks))

    for j in range(rand_len):
        idx = len(toks) + j
        tokens[i, idx] = random.randint(1, 32000)

In [89]:
pad_mask = tokens == pad_id
print(pad_mask)

tensor([[False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False, False, False,  True,  True]])


In [90]:
completion_mask = input_text_mask.logical_or(pad_mask)
print(pad_mask)

tensor([[False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False, False, False,  True,  True]])


In [91]:
mask1 = torch.tensor([[ True,  True, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False, False, False]])

mask2 = torch.tensor([[False, False, False, False, False, False, False, False, False, True],
        [False, False, False, False, False, False, False, True,  True,  True]])

mask3 = mask1 | mask2  # Element-wise OR operation
print(mask3)

tensor([[ True,  True, False, False, False, False, False, False, False,  True],
        [ True,  True,  True,  True, False, False, False,  True,  True,  True]])


In [92]:
inverse_mask3 = ~mask3
print(inverse_mask3)

tensor([[False, False,  True,  True,  True,  True,  True,  True,  True, False],
        [False, False, False, False,  True,  True,  True, False, False, False]])


In [93]:
mask = torch.tensor([[False, False,  True,  True,  True,  True,  True,  True,  True, False],
        [False, False, False, False,  True,  True,  True, False, False, False],
        [False, False,  False,  False,  False,  True,  True,  True,  True, True]])

In [94]:
# Reverse the mask along the columns
reversed_mask = torch.flip(mask.float(), [1])

# Find the indices of the first True value along each row (in the reversed mask)
first_true_indices = torch.argmax(reversed_mask, dim=1)

# Calculate the indices of the last True value in the original mask
last_true_indices = mask.shape[1] - 1 - first_true_indices

print(last_true_indices)

tensor([8, 6, 9])


In [95]:
terminal_steps = torch.gather(mask, dim=1, index=last_true_indices.unsqueeze(1))
print(terminal_steps)

tensor([[True],
        [True],
        [True]])


In [113]:
tokens = torch.randint(0, 100, size=(3, 10))
print(tokens)

tensor([[81, 16, 59, 45, 79, 41, 39, 66, 91, 52],
        [ 6,  5, 99, 18, 59, 91, 28, 21, 99, 20],
        [50, 88, 85, 59, 59, 26, 28, 66, 23, 20]])


In [117]:
tokens = torch.tensor([[ 0, 97, 93, 45, 23, -1, 96, 95, -1, 61],
        [20, 10, 87, 64, 93, 11, 22, 75, 24, 75],
        [20, 10, 87, 64, 93, 11, 22, 75, 24, -1],
        [ 2,  1,  9, 59, 41,  1, -1, -1, 95, 98]])

In [118]:
value_to_find = -1

# Create a boolean mask where the tensor is equal to the value to find
mask = (tokens == value_to_find)


# Find the indices of the first occurrence of the value in each row
# indices = torch.argmax(mask.float(), dim=1) * mask.max(dim=1).values
indices = torch.argmax(mask.float(), dim=1)


print(indices)

tensor([5, 0, 9, 6])


In [123]:
completion_mask = torch.ones_like(tokens).to(dtype=bool)

print(completion_mask)

tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])


In [124]:
for i, eos_idx in enumerate(indices.tolist()):
    if eos_idx == 0:
        continue
    completion_mask[i, :eos_idx] = False

print(completion_mask)

tensor([[False, False, False, False, False,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True]])
