# tensor indexing

In [1]:
import torch
batch_size = 10
features = 25
x = torch.rand((batch_size, features))

print(x[0].shape)

torch.Size([25])


In [2]:
print(x[:, 0].shape)

torch.Size([10])


In [3]:
print(x[2, 0:10])

tensor([4.3858e-01, 8.8003e-01, 1.9867e-01, 4.1061e-01, 6.8155e-01, 5.7185e-01,
        6.4947e-01, 8.7708e-04, 9.5885e-01, 2.8283e-01])


In [4]:
x[0, 0] = 100
print(x[0, 0])

tensor(100.)


### fancy indexing 

In [5]:
x = torch.arange(10)
print(x)
indices = [2, 5, 8]
print(x[indices])


tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([2, 5, 8])


In [6]:
x = torch.rand((3, 5))
print(x)
rows = torch.tensor([1, 0])
cols = torch.tensor([4, 0])
print(x[rows, cols])

tensor([[0.7268, 0.4033, 0.5628, 0.0978, 0.8711],
        [0.2708, 0.7218, 0.7818, 0.1173, 0.2985],
        [0.9839, 0.9996, 0.0686, 0.1075, 0.5857]])
tensor([0.2985, 0.7268])


In [7]:
rows = torch.tensor([1, 2, 0])
cols = torch.tensor([4, 0, 2])
print(x[rows, cols])

tensor([0.2985, 0.9839, 0.5628])


rows and cols pick the elements from the tensor and create a new tensor, in the above example, the new tensor is of size 3

In [8]:
print(x[rows, cols].shape)

torch.Size([3])


there is three elements

### more advanced indexing 

In [9]:
x = torch.arange (10)
print(x)
print(x[(x < 2) | (x > 8)])

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([0, 1, 9])


In [10]:
print(x[(x < 2) & (x > 8)])

tensor([], dtype=torch.int64)


In [11]:
print(x[(x < 8) & (x > 2)])

tensor([3, 4, 5, 6, 7])


In [12]:
print(x[x.remainder(2) == 0])

tensor([0, 2, 4, 6, 8])


### useful operations

In [13]:
print(torch.where(x > 5, x , x*2))

tensor([ 0,  2,  4,  6,  8, 10,  6,  7,  8,  9])


In [14]:
x = torch.tensor([1, 2, 3, 4, 5])
y = torch.tensor([6, 7, 8, 9, 10])

print(torch.tensor([x[i] if x[i] > y[i] else y[i] for i in range(len(x))]))

tensor([ 6,  7,  8,  9, 10])


In [15]:
x = torch.tensor([1, 8, 9, 4, 5])
y = torch.tensor([6, 7, 8, 9, 10])

condition = x > y
result = torch.where(condition, x, y)
print(result)

tensor([ 6,  8,  9,  9, 10])


In [17]:
x = torch.tensor([0, 2, 2, 2, 3, 4, 5, 5, 5, 6, 7, 8, 8, 9])
unique = x.unique()
print(unique)

tensor([0, 2, 3, 4, 5, 6, 7, 8, 9])


In [18]:
print(x.ndimension())

1


In [19]:
print(x.numel())

14


In [20]:
print(x.numel(2))

TypeError: TensorBase.numel() takes no arguments (1 given)

from above we can see `torch.numel()` takes no arguments, and `torch.numel(2)` throws an error, `numel()` returns the number of whole elements in the tensor