# PyTorch Useful Methods

In [1]:
import torch

### Expand

 => copy the given tensor and concat those at desired dimension.

In [2]:
x = torch.FloatTensor([[[1, 2]],
                       [[3, 4]]])
print(x.size())

torch.Size([2, 1, 2])


In [4]:
y = x.expand(*[2, 3, 2])

print(y)
print(y.size())

tensor([[[1., 2.],
         [1., 2.],
         [1., 2.]],

        [[3., 4.],
         [3., 4.],
         [3., 4.]]])
torch.Size([2, 3, 2])


In [9]:
## Implement expand with cat.
y = torch.cat([x, x, x], dim=1)

print(y)
print(y.size())

tensor([[[1., 2.],
         [1., 2.],
         [1., 2.]],

        [[3., 4.],
         [3., 4.],
         [3., 4.]]])
torch.Size([2, 3, 2])


---

## randperm

    => Random Permutation

In [11]:
x = torch.randperm(10) # 0~9까지 램덤한 수열.

print(x)
print(x.size())

tensor([7, 0, 4, 1, 8, 6, 2, 5, 3, 9])
torch.Size([10])


---
## argmax
    => Return index of maximum values

In [21]:
x = torch.randperm(3 ** 3).reshape(3, 3, -1)

print(x)
print(x.size())

tensor([[[10,  5, 14],
         [ 2, 25, 13],
         [23, 20,  3]],

        [[ 8, 22, 21],
         [ 7, 17, 12],
         [19,  0,  4]],

        [[18,  9, 15],
         [24,  1, 26],
         [11,  6, 16]]])
torch.Size([3, 3, 3])


In [23]:
y = x.argmax(dim=-1)

print(y)
print(y.size())

tensor([[2, 1, 0],
        [1, 1, 0],
        [0, 2, 2]])
torch.Size([3, 3])


---
## topk
    => Return tuple of top-k values and indices

In [29]:
values, indices = torch.topk(x, k=1, dim=-1)

print(values)
print(values.size())
print(indices)
print(indices.size())
## Note that topk didn't reduce the dimension, even in k=1 case.

tensor([[[14],
         [25],
         [23]],

        [[22],
         [17],
         [19]],

        [[18],
         [26],
         [16]]])
torch.Size([3, 3, 1])
tensor([[[2],
         [1],
         [0]],

        [[1],
         [1],
         [0]],

        [[0],
         [2],
         [2]]])
torch.Size([3, 3, 1])


In [31]:
print(values.squeeze(-1))
print(indices.squeeze(-1))

tensor([[14, 25, 23],
        [22, 17, 19],
        [18, 26, 16]])
tensor([[2, 1, 0],
        [1, 1, 0],
        [0, 2, 2]])


In [33]:
print(x.argmax(dim=-1) == indices.squeeze(-1))

tensor([[True, True, True],
        [True, True, True],
        [True, True, True]])


In [38]:
_, indices = torch.topk(x, k=2, dim=-1)

print(indices.size())
print(x.argmax(dim=-1) == indices[:, :, 0])

torch.Size([3, 3, 2])
tensor([[True, True, True],
        [True, True, True],
        [True, True, True]])


---
### Sort by using topk

In [39]:
x

tensor([[[10,  5, 14],
         [ 2, 25, 13],
         [23, 20,  3]],

        [[ 8, 22, 21],
         [ 7, 17, 12],
         [19,  0,  4]],

        [[18,  9, 15],
         [24,  1, 26],
         [11,  6, 16]]])

In [48]:
target_dim = -1
values, indices = torch.topk(x, 
                             k=x.size(target_dim),
                             largest=True)

values

tensor([[[14, 10,  5],
         [25, 13,  2],
         [23, 20,  3]],

        [[22, 21,  8],
         [17, 12,  7],
         [19,  4,  0]],

        [[18, 15,  9],
         [26, 24,  1],
         [16, 11,  6]]])

---
## Topk by using sort

In [54]:
k = 1

values, indices = torch.sort(x, dim=-1, descending=True)
values, indices = values[:, :, :1], indices[:, :, :1]

print(values.squeeze(-1))
print(indices.squeeze(-1))

tensor([[14, 25, 23],
        [22, 17, 19],
        [18, 26, 16]])
tensor([[2, 1, 0],
        [1, 1, 0],
        [0, 2, 2]])


---
## masked_fill
    => fill the value if element of mask is True.

In [58]:
x = torch.FloatTensor([i for i in range(3**2)]).reshape(3, -1)

print(x)
print(x.size())

tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])
torch.Size([3, 3])


In [60]:
mask = x > 4

print(mask)

tensor([[False, False, False],
        [False, False,  True],
        [ True,  True,  True]])


In [62]:
y = x.masked_fill(mask, value=-1)

print(y)

tensor([[ 0.,  1.,  2.],
        [ 3.,  4., -1.],
        [-1., -1., -1.]])


---
## Ones and Zeros

In [64]:
print(torch.ones(2, 3))
print(torch.zeros(2, 3))

tensor([[1., 1., 1.],
        [1., 1., 1.]])
tensor([[0., 0., 0.],
        [0., 0., 0.]])


In [65]:
x = torch.FloatTensor([[1, 2, 3, 4], 
                       [3, 4, 5, 6]])

print(x.size())

torch.Size([2, 4])


In [66]:
print(torch.ones_like(x))
print(torch.zeros_like(x))

tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.]])
tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.]])
