In [61]:
import random


def select_non_contiguous_indexes(n, available_indexes: list):
    """
    Select n non-contiguous indexes from a list of available indexes.
    """

    if len(available_indexes) != len(set(available_indexes)):
        raise ValueError("Available indexes must be unique.")

    if not available_indexes:
        raise ValueError("Available indexes must not be empty.")

    selected = []
    while len(selected) < n and available_indexes:
        selected_element = random.choice(available_indexes)
        selected.append(selected_element)

        # Update the available list to maintain the non-contiguity constraint
        new_available = []
        for element in available_indexes:
            if abs(selected_element - element) > 1:
                new_available.append(element)
        available_indexes = new_available

    if len(selected) < n:
        raise ValueError("Not enough non-contiguous indexes available.")

    return selected

In [71]:
select_non_contiguous_indexes(3, [1, 5, 6, 7])

ValueError: Not enough non-contiguous indexes available.

In [22]:
select_non_contiguous_indexes(3, [])

ValueError: Available indexes must not be empty.

In [21]:
select_non_contiguous_indexes(3, [1, 2, 4, 5, 6, 1])

ValueError: Available indexes must be unique.

In [79]:
import numpy as np

arr = np.array([[10, 20, 10], [30, 40, 50], [10, 60, 70], [80, 90, 10]])
print(np.where(arr == 10))

(array([0, 0, 2, 3]), array([0, 2, 0, 2]))


In [80]:
np.where(arr == 10)

(array([0, 0, 2, 3]), array([0, 2, 0, 2]))

In [84]:
(arr == 10).nonzero()

(array([0, 0, 2, 3]), array([0, 2, 0, 2]))

In [85]:
arr == 10

array([[ True, False,  True],
       [False, False, False],
       [ True, False, False],
       [False, False,  True]])

In [89]:
result = [np.where(sequence == 10)[0] for sequence in arr]

In [91]:
result

array([], dtype=int64)

In [93]:
np.where(np.array([10, 60, 70]) == 10)

(array([0]),)

In [95]:
import torch

ok = torch.tensor(
    [
        [
            3015,
            329,
            607,
            45581,
            290,
            7101,
            606,
            286,
            607,
            18410,
            11,
            2739,
            5537,
            4139,
            1423,
            8704,
            22081,
            13,
            4389,
            88,
            15927,
            318,
            11643,
            284,
            8414,
            422,
            37399,
            14225,
            11,
            3461,
            324,
            17325,
        ],
        [
            1327,
            640,
            27074,
            616,
            1182,
            1088,
            262,
            2126,
            326,
            11,
            379,
            617,
            966,
            11,
            661,
            655,
            2245,
            4673,
            13,
            198,
            10374,
            1806,
            617,
            1611,
            286,
            27886,
            393,
            584,
            5110,
            8967,
            11,
            340,
        ],
        [
            25,
            1578,
            7526,
            13,
            198,
            32,
            2092,
            6538,
            2648,
            3038,
            7791,
            7160,
            329,
            1402,
            2706,
            287,
            262,
            3482,
            11,
            1900,
            355,
            262,
            14973,
            8549,
            554,
            1087,
            425,
            1169,
            3620,
            40,
            13,
            383,
        ],
        [
            522,
            9382,
            319,
            1938,
            290,
            37924,
            11,
            355,
            880,
            355,
            257,
            23714,
            9552,
            326,
            1838,
            340,
            4622,
            7069,
            284,
            4776,
            4661,
            1377,
            262,
            6678,
            16608,
            319,
            262,
            24169,
            2650,
            13,
            3244,
            345,
        ],
        [
            33801,
            378,
            616,
            3159,
            475,
            4724,
            644,
            13,
            632,
            318,
            477,
            991,
            257,
            4724,
            13,
            921,
            423,
            645,
            2126,
            644,
            1243,
            481,
            804,
            588,
            1566,
            484,
            389,
            1682,
            319,
            3348,
            13,
            198,
        ],
        [
            517,
            6851,
            588,
            428,
            13,
            198,
            2025,
            8036,
            2648,
            11,
            314,
            2391,
            1813,
            428,
            4291,
            257,
            16008,
            508,
            373,
            1804,
            6454,
            3781,
            319,
            428,
            13,
            843,
            339,
            287,
            3872,
            8155,
            502,
            12607,
        ],
        [
            447,
            247,
            82,
            617,
            5986,
            286,
            1486,
            4213,
            329,
            534,
            1363,
            11087,
            1486,
            3519,
            284,
            1588,
            16610,
            17423,
            3084,
            13,
            775,
            7723,
            262,
            4263,
            422,
            2972,
            4237,
            284,
            2148,
            360,
            3191,
            10096,
        ],
        [
            9009,
            293,
            290,
            18340,
            351,
            9580,
            12983,
            13,
            16374,
            1257,
            1027,
            33173,
            351,
            326,
            698,
            278,
            4813,
            1842,
            13,
            2905,
            33173,
            1282,
            287,
            40003,
            2042,
            47750,
            2612,
            6979,
            3091,
            0,
            50256,
            464,
        ],
    ]
)

In [97]:
ok == 0

tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, F

In [98]:
heh = ok == 0

In [99]:
heh

tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, F

In [117]:
mhym = heh.nonzero(as_tuple=True)

In [110]:
mhym

tensor([[ 7, 29]])

In [113]:
ok.index_select(0, mhym)

IndexError: index_select(): Index is supposed to be a vector

In [115]:
torch.where(ok == 0)

(tensor([7]), tensor([29]))

In [118]:
ok[mhym]

tensor([0])

In [123]:
torch.take(ok, torch.tensor([33, 35, 65]))

tensor([ 640,  616, 1578])

In [2]:
import torch
batch_size, seq_len, dm = 4,8,16
x = torch.rand((batch_size, seq_len, dm))

In [3]:
ids_to_save = torch.tensor([1,3,5,7])
ids_to_reduce = torch.tensor([0,2,4,6])

x = x.view(-1, dm)

In [10]:
ik1 = torch.index_select(x, 0, ids_to_save)

In [9]:
ik2 =x[ids_to_save]

In [11]:
torch.equal(ik1, ik2)

True

In [12]:
reduced = x[ids_to_reduce]
next = x[ids_to_reduce+1]

In [21]:
lel = torch.stack((reduced, next))

In [45]:
oksa = gate(lel)

In [50]:
meg = oksa.softmax(dim=0)

In [51]:
meg

tensor([[[0.6465],
         [0.5172],
         [0.4786],
         [0.5441]],

        [[0.3535],
         [0.4828],
         [0.5214],
         [0.4559]]], grad_fn=<SoftmaxBackward0>)

In [54]:
lel

tensor([[[0.8170, 0.1776, 0.6311, 0.7077, 0.9408, 0.5163, 0.1086, 0.3799,
          0.3005, 0.1174, 0.5974, 0.8472, 0.8531, 0.9711, 0.5063, 0.9746],
         [0.7152, 0.7709, 0.4917, 0.1187, 0.0327, 0.2721, 0.8203, 0.6136,
          0.9451, 0.8556, 0.3962, 0.5608, 0.8932, 0.7803, 0.9784, 0.7406],
         [0.1707, 0.5920, 0.7726, 0.0346, 0.3443, 0.9938, 0.8569, 0.0538,
          0.0449, 0.9547, 0.9321, 0.9746, 0.5815, 0.9907, 0.2923, 0.5910],
         [0.9318, 0.1974, 0.9067, 0.1529, 0.8520, 0.9041, 0.7743, 0.2024,
          0.8877, 0.5271, 0.3772, 0.9178, 0.7792, 0.2886, 0.4124, 0.4090]],

        [[0.3197, 0.0022, 0.9051, 0.1985, 0.6937, 0.3038, 0.3414, 0.7831,
          0.9357, 0.6333, 0.1919, 0.0446, 0.1083, 0.6792, 0.3022, 0.7335],
         [0.3046, 0.2010, 0.6455, 0.8924, 0.7246, 0.7259, 0.5433, 0.8912,
          0.6476, 0.5763, 0.1685, 0.6228, 0.9594, 0.9984, 0.1007, 0.4396],
         [0.8872, 0.9353, 0.4492, 0.7079, 0.6097, 0.0914, 0.1573, 0.7766,
          0.5397, 0.2917, 0.15

In [56]:
yeah = lel * meg

In [57]:
yeah.shape

torch.Size([2, 4, 16])

In [59]:
yeah

tensor([[[0.5282, 0.1148, 0.4080, 0.4576, 0.6083, 0.3338, 0.0702, 0.2456,
          0.1943, 0.0759, 0.3862, 0.5477, 0.5516, 0.6279, 0.3273, 0.6301],
         [0.3699, 0.3987, 0.2543, 0.0614, 0.0169, 0.1407, 0.4243, 0.3174,
          0.4888, 0.4425, 0.2049, 0.2900, 0.4619, 0.4036, 0.5060, 0.3831],
         [0.0817, 0.2833, 0.3698, 0.0166, 0.1648, 0.4756, 0.4101, 0.0258,
          0.0215, 0.4569, 0.4461, 0.4664, 0.2783, 0.4741, 0.1399, 0.2829],
         [0.5069, 0.1074, 0.4933, 0.0832, 0.4636, 0.4919, 0.4213, 0.1101,
          0.4830, 0.2868, 0.2052, 0.4994, 0.4240, 0.1570, 0.2244, 0.2226]],

        [[0.1130, 0.0008, 0.3199, 0.0702, 0.2452, 0.1074, 0.1207, 0.2768,
          0.3308, 0.2238, 0.0678, 0.0158, 0.0383, 0.2401, 0.1068, 0.2593],
         [0.1471, 0.0970, 0.3117, 0.4309, 0.3498, 0.3505, 0.2623, 0.4303,
          0.3127, 0.2782, 0.0813, 0.3007, 0.4632, 0.4820, 0.0486, 0.2122],
         [0.4626, 0.4877, 0.2342, 0.3691, 0.3179, 0.0477, 0.0820, 0.4049,
          0.2814, 0.1521, 0.08

In [58]:
yeah.sum(dim=0)

tensor([[0.6412, 0.1156, 0.7280, 0.5277, 0.8535, 0.4412, 0.1909, 0.5224, 0.5251,
         0.2997, 0.4541, 0.5635, 0.5899, 0.8680, 0.4342, 0.8894],
        [0.5170, 0.4957, 0.5660, 0.4923, 0.3667, 0.4912, 0.6866, 0.7476, 0.8015,
         0.7207, 0.2863, 0.5907, 0.9251, 0.8856, 0.5546, 0.5953],
        [0.5443, 0.7710, 0.6040, 0.3857, 0.4827, 0.5233, 0.4921, 0.4307, 0.3029,
         0.6090, 0.5270, 0.8686, 0.5156, 0.6368, 0.4704, 0.3870],
        [0.6492, 0.2383, 0.9161, 0.5161, 0.8211, 0.6530, 0.4768, 0.5235, 0.8553,
         0.5456, 0.6423, 0.5264, 0.5993, 0.1865, 0.4746, 0.5544]],
       grad_fn=<SumBackward1>)

In [46]:
oks = gate(lel).squeeze()

In [28]:
oks.shape

torch.Size([2, 4])

In [32]:
oks

tensor([[ 0.4018,  0.2643,  0.0500,  0.0631],
        [-0.2020,  0.1955,  0.1357, -0.1137]], grad_fn=<SqueezeBackward0>)

In [34]:
dobra = oks.softmax(dim=0)

In [35]:
dobra

tensor([[0.6465, 0.5172, 0.4786, 0.5441],
        [0.3535, 0.4828, 0.5214, 0.4559]], grad_fn=<SoftmaxBackward0>)

In [41]:
lel.shape

torch.Size([2, 4, 16])

In [42]:
dobra.shape

torch.Size([2, 4])

In [43]:
dobra * lel

RuntimeError: The size of tensor a (4) must match the size of tensor b (16) at non-singleton dimension 2

In [36]:
dobra.shape

torch.Size([2, 4])

In [38]:
lel.shape

torch.Size([2, 4, 16])

In [29]:
softmax(oks, dim=0)

NameError: name 'softmax' is not defined

In [16]:
reduced.shape

torch.Size([4, 16])

In [17]:
gate = torch.nn.Linear(dm, 1)

gate(reduced)


tensor([[0.4018],
        [0.2643],
        [0.0500],
        [0.0631]], grad_fn=<AddmmBackward0>)