## PyTorch Useful Methods

In [1]:
import torch

### expand: copy the given tensor and concat those at desired dimension

In [3]:
x = torch.FloatTensor([[[1, 2]], [[3, 4]]])
print(x.size())

torch.Size([2, 1, 2])


In [4]:
y = x.expand(*(2, 3, 2))

print(y)
print(y.size())

tensor([[[1., 2.],
         [1., 2.],
         [1., 2.]],

        [[3., 4.],
         [3., 4.],
         [3., 4.]]])
torch.Size([2, 3, 2])


In [5]:
y = torch.cat([x, x, x], dim=1)

print(y)
print(y.size())

tensor([[[1., 2.],
         [1., 2.],
         [1., 2.]],

        [[3., 4.],
         [3., 4.],
         [3., 4.]]])
torch.Size([2, 3, 2])


### randperm: Random Permutation

In [6]:
x = torch.randperm(10)

print(x)
print(x.size())

tensor([3, 6, 7, 9, 5, 4, 2, 0, 1, 8])
torch.Size([10])


### argmax: return index of maximum values

In [8]:
x = torch.randperm(3**3).reshape(3, 3, -1)

print(x)
print(x.size())

tensor([[[22, 20,  1],
         [10, 19,  9],
         [ 4, 21, 13]],

        [[ 0,  8, 17],
         [16,  3,  5],
         [24, 14,  2]],

        [[ 7, 25, 11],
         [12, 23, 15],
         [ 6, 18, 26]]])
torch.Size([3, 3, 3])


In [9]:
y = x.argmax(dim=-1)

print(y)
print(y.size())

tensor([[0, 1, 1],
        [2, 0, 0],
        [1, 1, 2]])
torch.Size([3, 3])


### topk: return tuple of top-k values and indices

In [10]:
values, indices = torch.topk(x, k=1, dim=-1)

print(values.size())
print(indices.size())

torch.Size([3, 3, 1])
torch.Size([3, 3, 1])


In [11]:
print(values.squeeze())
print(indices.squeeze())

tensor([[22, 19, 21],
        [17, 16, 24],
        [25, 23, 26]])
tensor([[0, 1, 1],
        [2, 0, 0],
        [1, 1, 2]])


In [12]:
print(x.argmax(dim=-1) == indices.squeeze())

tensor([[True, True, True],
        [True, True, True],
        [True, True, True]])


In [13]:
_, indices = torch.topk(x, k=2, dim=-1)
print(indices.size())

print(x.argmax(dim=-1) == indices[:, :, 0])

torch.Size([3, 3, 2])
tensor([[True, True, True],
        [True, True, True],
        [True, True, True]])


### sort by using topk

In [14]:
target_dim = -1 
values, indices = torch.topk(x, k=x.size(target_dim), largest=True)
print(values)

tensor([[[22, 20,  1],
         [19, 10,  9],
         [21, 13,  4]],

        [[17,  8,  0],
         [16,  5,  3],
         [24, 14,  2]],

        [[25, 11,  7],
         [23, 15, 12],
         [26, 18,  6]]])


### topk by using sort

In [16]:
k=1
values, indices = torch.sort(x, dim=-1, descending=True)
values, indices = values[:, :, :k], indices[:, :, :k]

print(values.squeeze(-1))
print(indices.squeeze(-1))

tensor([[22, 19, 21],
        [17, 16, 24],
        [25, 23, 26]])
tensor([[0, 1, 1],
        [2, 0, 0],
        [1, 1, 2]])


### masked fill: fill the value if element of maks is True

In [17]:
x = torch.FloatTensor([i for i in range(3**2)]).reshape(3, -1)

In [18]:
print(x)
print(x.size())

tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])
torch.Size([3, 3])


In [19]:
mask = x > 4 
print(mask)

tensor([[False, False, False],
        [False, False,  True],
        [ True,  True,  True]])


In [20]:
y = x.masked_fill(mask, value=-1)

In [21]:
print(y)

tensor([[ 0.,  1.,  2.],
        [ 3.,  4., -1.],
        [-1., -1., -1.]])


### Ones and Zeros

In [22]:
print(torch.ones(2, 3))
print(torch.zeros(2, 3))

tensor([[1., 1., 1.],
        [1., 1., 1.]])
tensor([[0., 0., 0.],
        [0., 0., 0.]])


In [23]:
x = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])

print(x.size())

torch.Size([2, 3])


In [24]:
print(torch.ones_like(x))
print(torch.zeros_like(x))

tensor([[1., 1., 1.],
        [1., 1., 1.]])
tensor([[0., 0., 0.],
        [0., 0., 0.]])
