## PyTorch Tensor Slicing and Concatenation

In [1]:
import torch

## Slicing and Concatenation

### Indexing and Slicing

Prepare target tensor.

In [2]:
x = torch.FloatTensor([[[1, 2],
                        [3, 4]],
                       [[5, 6],
                        [7, 8]],
                       [[9, 10],
                        [11, 12]]])
print(x.size())

torch.Size([3, 2, 2])


![image](https://user-images.githubusercontent.com/105966480/209548178-76d3b652-b9ba-4bbc-a18d-38cbcacdad45.png)

Access to certain dimension.

In [3]:
print(x[0])
print(x[0, :])
print(x[0, :, :])

tensor([[1., 2.],
        [3., 4.]])
tensor([[1., 2.],
        [3., 4.]])
tensor([[1., 2.],
        [3., 4.]])


![image](https://user-images.githubusercontent.com/105966480/209548221-1a556492-8d78-4d4a-ae23-b431f89af494.png)

In [4]:
print(x[-1])
print(x[-1, :])
print(x[-1, :, :])

tensor([[ 9., 10.],
        [11., 12.]])
tensor([[ 9., 10.],
        [11., 12.]])
tensor([[ 9., 10.],
        [11., 12.]])


![image](https://user-images.githubusercontent.com/105966480/209548315-6ba0e167-5fec-4ba5-894e-babdacb1fcc1.png)

In [5]:
print(x[:, 0, :])

tensor([[ 1.,  2.],
        [ 5.,  6.],
        [ 9., 10.]])


![image](https://user-images.githubusercontent.com/105966480/209548284-829e6e81-0177-48a8-baee-7ec420533c82.png)

Access by range. Note that the number of dimensions would not be changed.

range의 경우, dimension reduciton이 일어나지 않음

In [6]:
print(x[1:3, :, :].size())
print(x[:, :1, :].size())
print(x[:, :-1, :].size())

torch.Size([2, 2, 2])
torch.Size([3, 1, 2])
torch.Size([3, 1, 2])


![image](https://user-images.githubusercontent.com/105966480/209548528-a5a65016-e904-4233-835b-997e5eca437e.png)
![image](https://user-images.githubusercontent.com/105966480/209548549-38a6a2c9-5223-4253-a883-d9a18ee764be.png)

### split: Split tensor to desirable shapes.

In [7]:
x = torch.FloatTensor(10, 4)

![image](https://user-images.githubusercontent.com/105966480/209548777-2ddcbc28-439f-4534-838c-fa5dee74faa7.png)

In [8]:
splits = x.split(4, dim=0) # 0번 dimension이 4가 되도록 쪼개줘

# 10 = 4+4+2

for s in splits:
    print(s.size())

torch.Size([4, 4])
torch.Size([4, 4])
torch.Size([2, 4])


![image](https://user-images.githubusercontent.com/105966480/209548859-ba731658-97b6-4865-878d-fe5f8dc55d3c.png)

### chunk: Split tensor to number of chunks.

In [9]:
x = torch.FloatTensor(8, 4)

![image](https://user-images.githubusercontent.com/105966480/209548951-96595a29-8b53-4093-9be3-32085cbecb4a.png)

In [10]:
chunks = x.chunk(3, dim=0) # 8 = 3+3+2

for c in chunks:
    print(c.size())

torch.Size([3, 4])
torch.Size([3, 4])
torch.Size([2, 4])


![image](https://user-images.githubusercontent.com/105966480/209549043-7ba235fd-4511-4317-89f6-74f8bd9c1bb9.png)

### index_select: Select elements by using dimension index.

In [11]:
x = torch.FloatTensor([[[1, 1],
                        [2, 2]],
                       [[3, 3],
                        [4, 4]],
                       [[5, 5],
                        [6, 6]]]) #[3,2,2]
indice = torch.LongTensor([2, 1])

print(x.size())

torch.Size([3, 2, 2])


In [12]:
y = x.index_select(dim=0, index=indice)

print(y)
print(y.size())

tensor([[[5., 5.],
         [6., 6.]],

        [[3., 3.],
         [4., 4.]]])
torch.Size([2, 2, 2])


![image](https://user-images.githubusercontent.com/105966480/209549609-f4fbef31-d713-4bce-bcfb-d31a676bf535.png)

### cat: Concatenation of multiple tensors in the list.

In [13]:
x = torch.FloatTensor([[1, 2, 3],
                       [4, 5, 6],
                       [7, 8, 9]])
y = torch.FloatTensor([[10, 11, 12],
                       [13, 14, 15],
                       [16, 17, 18]])

print(x.size(), y.size())

torch.Size([3, 3]) torch.Size([3, 3])


In [14]:
z = torch.cat([x, y], dim=0)
print(z)
print(z.size())

tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.],
        [10., 11., 12.],
        [13., 14., 15.],
        [16., 17., 18.]])
torch.Size([6, 3])


![image](https://user-images.githubusercontent.com/105966480/209549818-657af460-5195-4b6c-b7b3-b68d9d4a1590.png)

In [15]:
z = torch.cat([x, y], dim=-1)
print(z)
print(z.size())

tensor([[ 1.,  2.,  3., 10., 11., 12.],
        [ 4.,  5.,  6., 13., 14., 15.],
        [ 7.,  8.,  9., 16., 17., 18.]])
torch.Size([3, 6])


![image](https://user-images.githubusercontent.com/105966480/209549853-5ebc5e57-883e-401b-bb2b-13b49cd67e7a.png)

### stack: Stacking of multiple tensors in the list.

차원이 증가

cat이 그냥 연결이면, stack은 내가 원하는 차원에 unsqueeze후 연결

In [16]:
z = torch.stack([x, y])
print(z)
print(z.size())

tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.],
         [ 7.,  8.,  9.]],

        [[10., 11., 12.],
         [13., 14., 15.],
         [16., 17., 18.]]])
torch.Size([2, 3, 3])


![image](https://user-images.githubusercontent.com/105966480/209550227-2eed9648-1444-444b-af65-79637d0c6f2d.png)

Or you can specify the dimension. Default is 0.

In [17]:
z = torch.stack([x, y], dim=-1)
print(z)
print(z.size())

tensor([[[ 1., 10.],
         [ 2., 11.],
         [ 3., 12.]],

        [[ 4., 13.],
         [ 5., 14.],
         [ 6., 15.]],

        [[ 7., 16.],
         [ 8., 17.],
         [ 9., 18.]]])
torch.Size([3, 3, 2])


![image](https://user-images.githubusercontent.com/105966480/209550316-479af5fc-b039-4646-8f3c-124a209a9d92.png)

### Implement 'stack' function by using 'cat'.

In [18]:
# z = torch.stack([x, y])
z = torch.cat([x.unsqueeze(0), y.unsqueeze(0)], dim=0)
print(z)
print(z.size())

tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.],
         [ 7.,  8.,  9.]],

        [[10., 11., 12.],
         [13., 14., 15.],
         [16., 17., 18.]]])
torch.Size([2, 3, 3])


### Useful Trick: Merge results from iterations

![image](https://user-images.githubusercontent.com/105966480/209550516-fb8a89f7-49d6-434c-850c-b6c3878c0957.png)

In [19]:
result = []
for i in range(5):
    x = torch.FloatTensor(2, 2)
    result += [x]

result = torch.stack(result)
result.size()

torch.Size([5, 2, 2])

![image](https://user-images.githubusercontent.com/105966480/209550594-c0892dab-f8de-4f70-b641-6a7a1c405b84.png)