## PyTorch Useful Methods

In [1]:
import torch

### expand: copy the given tensor and concat those at desired dimension.

In [2]:
x = torch.FloatTensor([[[1, 2]],
                       [[3, 4]]])
print(x.size())

torch.Size([2, 1, 2])


In [4]:
y = x.expand(*[2, 3, 2])

print(y)
print(y.size())

tensor([[[1., 2.],
         [1., 2.],
         [1., 2.]],

        [[3., 4.],
         [3., 4.],
         [3., 4.]]])
torch.Size([2, 3, 2])


#### Implement expand with cat.

In [7]:
y = torch.cat([x, x, x], dim=1)

print(y)
print(y.size())

tensor([[[1., 2.],
         [1., 2.],
         [1., 2.]],

        [[3., 4.],
         [3., 4.],
         [3., 4.]]])
torch.Size([2, 3, 2])


In [11]:
x = torch.randperm(10)

print(x)
print(x.size())

tensor([5, 3, 6, 0, 1, 7, 9, 4, 8, 2])
torch.Size([10])


### argmax: Return index of maximum values

In [15]:
x = torch.randperm(3**3).reshape(3, 3, -1)

print(x)
print(x.size())

tensor([[[20,  5, 18],
         [19, 26,  3],
         [ 0,  6, 23]],

        [[ 8, 22,  1],
         [12,  7, 25],
         [10,  2, 15]],

        [[ 4, 11, 14],
         [17, 16, 21],
         [13, 24,  9]]])
torch.Size([3, 3, 3])


In [21]:
y = x.argmax(dim=-1)

print(y)
print(y.size())

tensor([[0, 1, 2],
        [1, 2, 2],
        [2, 2, 1]])
torch.Size([3, 3])


### topk: Return tuple of top-k values and indices.

In [30]:
values, indices = torch.topk(x, k=1, dim=-1)

print(values.size())
print(indices.size())

torch.Size([3, 3, 1])
torch.Size([3, 3, 1])


In [31]:
print(values.squeeze(-1))

tensor([[20, 26, 23],
        [22, 25, 15],
        [14, 21, 24]])


In [32]:
print(indices.squeeze(-1))

tensor([[0, 1, 2],
        [1, 2, 2],
        [2, 2, 1]])


### Sort by using topk

In [33]:
target_dim = -1
values, indices = torch.topk(x, 
                             k=x.size(target_dim), 
                             largest=True)
print(values)

tensor([[[20, 18,  5],
         [26, 19,  3],
         [23,  6,  0]],

        [[22,  8,  1],
         [25, 12,  7],
         [15, 10,  2]],

        [[14, 11,  4],
         [21, 17, 16],
         [24, 13,  9]]])


### masked_fill: fill the value if element of mask is True.


In [37]:
x = torch.FloatTensor([i for i in range(3**2)]).reshape(3, -1)

print(x)
print(x.size())

tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])
torch.Size([3, 3])


In [38]:
mask = x > 4


print(mask)

tensor([[False, False, False],
        [False, False,  True],
        [ True,  True,  True]])


In [39]:
y = x.masked_fill(mask, value=-1)

In [40]:
y

tensor([[ 0.,  1.,  2.],
        [ 3.,  4., -1.],
        [-1., -1., -1.]])

### Ones and Zeros

In [42]:
print(torch.ones(2, 3))
print(torch.zeros(2, 3))

tensor([[1., 1., 1.],
        [1., 1., 1.]])
tensor([[0., 0., 0.],
        [0., 0., 0.]])


In [43]:
x = torch.FloatTensor([[1, 2, 3],
                       [4, 5, 6]])
print(x.size())

torch.Size([2, 3])


In [44]:
print(torch.ones_like(x))

tensor([[1., 1., 1.],
        [1., 1., 1.]])


In [45]:
print(torch.zeros_like(x))

tensor([[0., 0., 0.],
        [0., 0., 0.]])
