# PyTorch Useful Methods 

In [1]:
import torch

### expand : copy the given tensor and concat those at desired dimension.

In [2]:
x = torch.FloatTensor([[[1,2]],
                       [[3,4]]])
print(x.size())

torch.Size([2, 1, 2])


In [4]:
y = x.expand(*[2,3,2])
print(y)
print(y.size())

tensor([[[1., 2.],
         [1., 2.],
         [1., 2.]],

        [[3., 4.],
         [3., 4.],
         [3., 4.]]])
torch.Size([2, 3, 2])


In [5]:
y = torch.cat([x,x,x], dim = 1)
print(y)
print(y.size())

tensor([[[1., 2.],
         [1., 2.],
         [1., 2.]],

        [[3., 4.],
         [3., 4.],
         [3., 4.]]])
torch.Size([2, 3, 2])


### randperm : Random Permutation 

In [6]:
x = torch.randperm(10)

print(x)
print(x.size())

tensor([6, 9, 0, 5, 2, 4, 3, 1, 7, 8])
torch.Size([10])


### argmax : Return index of maximum values 

In [7]:
x = torch.randperm(3**3).reshape(3,3,-1) # x : 3x3x3

print(x)
print(x.size())

tensor([[[ 7, 22, 16],
         [ 9, 10, 20],
         [ 5,  4,  3]],

        [[ 0,  2,  8],
         [15,  6, 13],
         [11, 24, 25]],

        [[19, 14, 17],
         [26, 21,  1],
         [12, 23, 18]]])
torch.Size([3, 3, 3])


In [8]:
y = x.argmax(dim = -1)

print(y)
print(y.size())

tensor([[1, 2, 0],
        [2, 0, 2],
        [0, 0, 1]])
torch.Size([3, 3])


### topk : Return tuple of top-k values and indices 

In [9]:
values, indices = torch.topk(x, k = 1, dim = -1)

print(values.size())
print(indices.size())

torch.Size([3, 3, 1])
torch.Size([3, 3, 1])


Note that topk didn't reduce the dimension, even in $k=1$ case.

In [10]:
print(values.squeeze(-1))
print(indices.squeeze(-1))

tensor([[22, 20,  5],
        [ 8, 15, 25],
        [19, 26, 23]])
tensor([[1, 2, 0],
        [2, 0, 2],
        [0, 0, 1]])


In [11]:
print(x.argmax(dim = -1) == indices.squeeze(-1))

tensor([[True, True, True],
        [True, True, True],
        [True, True, True]])


In [12]:
_, indices = torch.topk(x, k = 2, dim = -1)
print(indices.size())
print(x.argmax(dim = -1) == indices[:,:,0])

torch.Size([3, 3, 2])
tensor([[True, True, True],
        [True, True, True],
        [True, True, True]])


### Sort by using topk

In [13]:
target_dim = -1
values, indices = torch.topk(x,
                             k = x.size(target_dim),
                             largest = True)

print(values)

tensor([[[22, 16,  7],
         [20, 10,  9],
         [ 5,  4,  3]],

        [[ 8,  2,  0],
         [15, 13,  6],
         [25, 24, 11]],

        [[19, 17, 14],
         [26, 21,  1],
         [23, 18, 12]]])


### Topk by using sort

In [14]:
k = 1 
values, indices = torch.sort(x, dim = -1, descending = True)
values, indices = values[:,:,k], indices[:,:,k]

print(values.squeeze(-1))
print(indices.squeeze(-1))

tensor([[16, 10,  4],
        [ 2, 13, 24],
        [17, 21, 18]])
tensor([[2, 1, 1],
        [1, 2, 1],
        [2, 1, 2]])


In [19]:
values, indices = torch.sort(x, dim = -1, descending = True)
print(values)
print(values[:,:,1])

tensor([[[22, 16,  7],
         [20, 10,  9],
         [ 5,  4,  3]],

        [[ 8,  2,  0],
         [15, 13,  6],
         [25, 24, 11]],

        [[19, 17, 14],
         [26, 21,  1],
         [23, 18, 12]]])
tensor([[16, 10,  4],
        [ 2, 13, 24],
        [17, 21, 18]])


### masked_fill : fill the value if element of mask if True.

In [23]:
x = torch.FloatTensor([i for i in range(3**2)]).reshape(3,-1)

print(x)
print(x.size())

tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])
torch.Size([3, 3])


In [24]:
mask = x > 4 
print(mask)

tensor([[False, False, False],
        [False, False,  True],
        [ True,  True,  True]])


In [26]:
y = x.masked_fill(mask, value = -1)
print(y)

tensor([[ 0.,  1.,  2.],
        [ 3.,  4., -1.],
        [-1., -1., -1.]])


### Ones and Zeros 

In [27]:
print(torch.ones(2, 3))
print(torch.zeros(2, 3))

tensor([[1., 1., 1.],
        [1., 1., 1.]])
tensor([[0., 0., 0.],
        [0., 0., 0.]])


In [28]:
x = torch.FloatTensor([[1,2,3],
                       [4,5,6]])

print(x.size())

torch.Size([2, 3])


In [29]:
print(torch.ones_like(x))
print(torch.zeros_like(x))

tensor([[1., 1., 1.],
        [1., 1., 1.]])
tensor([[0., 0., 0.],
        [0., 0., 0.]])
