# PyTorch Tensor SLicing and Concatenation

In [1]:
import torch

## Slicing and Concatenation

### Indexing and Slicing

Prepare target tensor.

In [2]:
x = torch.FloatTensor([[[1,2],
                        [3,4]],
                       [[5,6],
                        [7,8]],
                       [[9,10],
                        [11,12]]])
print(x.size())

torch.Size([3, 2, 2])


Access to certain dimension

In [3]:
print(x[0])
print(x[0, :])
print(x[0, :, :])

tensor([[1., 2.],
        [3., 4.]])
tensor([[1., 2.],
        [3., 4.]])
tensor([[1., 2.],
        [3., 4.]])


In [4]:
print(x[-1])
print(x[-1, :])
print(x[-1, :, :])

tensor([[ 9., 10.],
        [11., 12.]])
tensor([[ 9., 10.],
        [11., 12.]])
tensor([[ 9., 10.],
        [11., 12.]])


In [5]:
print(x[:,0,:])

tensor([[ 1.,  2.],
        [ 5.,  6.],
        [ 9., 10.]])


In [6]:
print(x[2,1,:])

tensor([11., 12.])


Access by range. Note that the number of dimensions would not ve changed.

In [13]:
print(x[1:3, :, :].size())
print(x[:,:1,:].size())
print(x[:,:-1,:].size())

torch.Size([2, 2, 2])
torch.Size([3, 1, 2])
torch.Size([3, 1, 2])


### split tensor to desirable shapes

In [18]:
x = torch.FloatTensor(10,4)

In [21]:
splits= x.split(4, dim = 0)

for s in splits :
    print(s.size())

torch.Size([4, 4])
torch.Size([4, 4])
torch.Size([2, 4])


### chunk : split tensor to number of chunks 

In [25]:
x = torch.FloatTensor(8, 4)
chunks = x.chunk(4, dim = 0)
for c in chunks :
    print(c.size())

torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4])


### index_select : Select elements by using dimension index

In [31]:
x = torch.FloatTensor([[[1,1],
                        [2,2]],
                       [[3,3],
                        [4,4]],
                       [[5,5],
                        [6,6]]])
indices = torch.LongTensor([2,1])

print(x.size())

torch.Size([3, 2, 2])


In [32]:
y = x.index_select(dim = 1, index = indices)

print(y)
print(y.size())

tensor([[[2., 2.],
         [1., 1.]],

        [[4., 4.],
         [3., 3.]],

        [[6., 6.],
         [5., 5.]]])
torch.Size([3, 2, 2])


In [29]:
torch.LongTensor([2,1])

tensor([2, 1])

### cat : Concatenation of multiple tensors in the list.

In [33]:
x = torch.FloatTensor([[1,2,3],
                       [4,5,6],
                       [7,8,9]])
y = torch.FloatTensor([[10,11,12],
                       [13,14,15], 
                       [16,17,18]])
print(x.size(), y.size())

torch.Size([3, 3]) torch.Size([3, 3])


In [34]:
z = torch.cat([x,y], dim = 0)
print(z)
print(z.size())

tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.],
        [10., 11., 12.],
        [13., 14., 15.],
        [16., 17., 18.]])
torch.Size([6, 3])


In [35]:
z = torch.cat([x,y], dim = -1)
print(z)
print(z.size())

tensor([[ 1.,  2.,  3., 10., 11., 12.],
        [ 4.,  5.,  6., 13., 14., 15.],
        [ 7.,  8.,  9., 16., 17., 18.]])
torch.Size([3, 6])


### stack : Stacking of multiple tensors in the list.

In [36]:
z = torch.stack([x, y])
print(z)
print(z.size())

tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.],
         [ 7.,  8.,  9.]],

        [[10., 11., 12.],
         [13., 14., 15.],
         [16., 17., 18.]]])
torch.Size([2, 3, 3])


In [37]:
z = torch.stack([x, y], dim = -1)
print(z)
print(z.size())

tensor([[[ 1., 10.],
         [ 2., 11.],
         [ 3., 12.]],

        [[ 4., 13.],
         [ 5., 14.],
         [ 6., 15.]],

        [[ 7., 16.],
         [ 8., 17.],
         [ 9., 18.]]])
torch.Size([3, 3, 2])


### Implement 'stack' function by using 'cat'

In [38]:
# z = torch.stack([x,y])
z = torch.cat([x.unsqueeze(0), y.unsqueeze(0)], dim = 0)
print(z)
print(z.size())

tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.],
         [ 7.,  8.,  9.]],

        [[10., 11., 12.],
         [13., 14., 15.],
         [16., 17., 18.]]])
torch.Size([2, 3, 3])


### Useful Trick : Merge results from iterations

In [39]:
result = []
for i in range(5) : 
    x = torch.FloatTensor(2, 2)
    result += [x]
    
result = torch.stack(result)
result.size()

torch.Size([5, 2, 2])