### Pytorch Useful Methods

In [2]:
import torch

### expand : copy the given tensor and concat those at desired dimension

In [3]:
x = torch.FloatTensor([[[1, 2]],
                      [[3, 4]]])

In [4]:
print(x.size())

torch.Size([2, 1, 2])


In [5]:
y = x.expand(*[2, 3, 2])

In [6]:
print(y)

tensor([[[1., 2.],
         [1., 2.],
         [1., 2.]],

        [[3., 4.],
         [3., 4.],
         [3., 4.]]])


#### Implement expand with cat.

In [7]:
y = torch.cat([x, x, x], dim=1)              
print(y)
print(y.size())

tensor([[[1., 2.],
         [1., 2.],
         [1., 2.]],

        [[3., 4.],
         [3., 4.],
         [3., 4.]]])
torch.Size([2, 3, 2])


### randperm : Random Permutation

index_select함수에 indice를 넣어서, shuffling을 수행한다.

In [8]:
x = torch.randperm(10) #임의의 순열을 생성해내기

print(x)
print(x.size())

tensor([9, 2, 8, 7, 6, 3, 0, 5, 1, 4])
torch.Size([10])


### argmax : Return index of maximum values
값의 최대를 만드는 '인덱스'를 return

In [9]:
x = torch.randperm(3**3).reshape(3, 3, -1)

print(x)
print(x.size())

tensor([[[17, 15, 10],
         [ 4,  9, 26],
         [ 1, 13, 19]],

        [[11, 22, 20],
         [14,  5, 21],
         [16,  7,  0]],

        [[ 8,  6, 23],
         [ 2,  3, 18],
         [24, 25, 12]]])
torch.Size([3, 3, 3])


In [11]:
# 차원이 -1에 해당하는 인덱스 중 가장 큰 값의 인덱스
y = x.argmax(dim=-1)

print(y)

# 차원이 축소됨을 확인할 수 있음
print(y.size())

tensor([[0, 2, 2],
        [1, 2, 0],
        [2, 2, 1]])
torch.Size([3, 3])


### topk : Return tuple of top-k values and indices
최고의 '값'과 '인덱스' 모두를 리턴 / 추출할 때 차원이 살아있음

In [19]:
values, indices = torch.topk(x, k=1, dim=-1)

print(values.size())
print(indices.size())

torch.Size([3, 3, 1])
torch.Size([3, 3, 1])


Note that topk didn't reduce the dimension, even in k=1 case

In [20]:
print(values.squeeze(-1))
print(indices.squeeze(-1))

tensor([[24, 25, 23],
        [15,  8, 20],
        [16, 21, 26]])
tensor([[0, 0, 2],
        [1, 2, 2],
        [1, 0, 1]])


In [21]:
print(x.argmax(dim=-1) == indices.squeeze(-1))

tensor([[True, True, True],
        [True, True, True],
        [True, True, True]])


In [14]:
_, indices = torch.topk(x, k=2, dim=-1)

print(indices.size())
print(x.argmax(dim=-1) == indices[:, :, 0])

torch.Size([3, 3, 2])
tensor([[True, True, True],
        [True, True, True],
        [True, True, True]])


### Sort by using topk


In [16]:
target_dim = -1
values, indices = torch.topk(x, k=x.size(target_dim), largest = True)
print(values)

tensor([[[17, 15, 10],
         [26,  9,  4],
         [19, 13,  1]],

        [[22, 20, 11],
         [21, 14,  5],
         [16,  7,  0]],

        [[23,  8,  6],
         [18,  3,  2],
         [25, 24, 12]]])


### Topk by using sort

In [17]:
k=1
values, indices = torch.sort(x, dim=-1, descending = True)
values, indices = values[:, :, :k], indices[:, :, :k]

print(values.squeeze(-1))
print(indices.squeeze(-1))

tensor([[17, 26, 19],
        [22, 21, 16],
        [23, 18, 25]])
tensor([[0, 2, 2],
        [1, 2, 0],
        [2, 2, 1]])


### Masked_fill : fill the value if element of mask is True

- 마스킹이 된 곳에 채워넣기

In [18]:
x = torch.FloatTensor([i  for i in range (3**2)]).reshape(3, -1)

In [20]:
print(x)
print(x.size())

tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])
torch.Size([3, 3])


In [23]:
# over 4 is true
mask = x > 4
print(mask)

tensor([[False, False, False],
        [False, False,  True],
        [ True,  True,  True]])


In [27]:
# mask is true, fill '-1'
y = x.masked_fill(mask, value = -1)

print(y)

tensor([[ 0.,  1.,  2.],
        [ 3.,  4., -1.],
        [-1., -1., -1.]])


### Ones and Zeros

In [29]:
# 0/1 tensor 만들고 싶을 때 사용
print(torch.ones(2, 3))
print(torch.zeros(2, 3))

tensor([[1., 1., 1.],
        [1., 1., 1.]])
tensor([[0., 0., 0.],
        [0., 0., 0.]])


In [30]:
# 연산하기 위해서 타입과 디바이스를 같게 하기

x = torch.FloatTensor([[1, 2, 3],
                      [4, 5, 6]])

print(x.size())

torch.Size([2, 3])


In [32]:
print(torch.ones_like(x))
print(torch.zeros_like(x))

tensor([[1., 1., 1.],
        [1., 1., 1.]])
tensor([[0., 0., 0.],
        [0., 0., 0.]])
