# PyTorch Tensor

In [1]:
import torch

## Tensor Allocation

In [3]:
ft = torch.FloatTensor([[1, 2], 
                        [3, 4]])
ft

tensor([[1., 2.],
        [3., 4.]])

In [4]:
ft.shape

torch.Size([2, 2])

In [5]:
lt = torch.LongTensor([[1, 2], 
                       [3, 4]])
lt

tensor([[1, 2],
        [3, 4]])

In [7]:
lt.size()

torch.Size([2, 2])

In [8]:
bt = torch.ByteTensor([[1, 0], 
                       [0, 1]])

bt

tensor([[1, 0],
        [0, 1]], dtype=torch.uint8)

In [9]:
bt.size()

torch.Size([2, 2])

In [14]:
x = torch.FloatTensor(4, 2)
x

tensor([[0.0000e+00, 0.0000e+00],
        [5.6052e-45, 0.0000e+00],
        [1.4013e-45, 0.0000e+00],
        [2.3136e-36, 1.4013e-45]])

---

## Numpy compatibility

In [15]:
import numpy as np

## Define numpy array

x = np.array([[1, 2], 
              [3, 4]])

print(x, type(x))

[[1 2]
 [3 4]] <class 'numpy.ndarray'>


In [16]:
x = torch.from_numpy(x)
print(x, type(x))

## numpy array => PyTorch Tensor

tensor([[1, 2],
        [3, 4]]) <class 'torch.Tensor'>


In [17]:
x = x.numpy()
print(x, type(x))

## PyTorch Tensor => numpy array

[[1 2]
 [3 4]] <class 'numpy.ndarray'>


---
## Tensor Type-casting

In [19]:
# Float Tensor => Long Tensor

ft.long()

tensor([[1, 2],
        [3, 4]])

In [21]:
# Long Tensor => Float Tensor

lt.float()

tensor([[1., 2.],
        [3., 4.]])

In [22]:
# Float Tensor => Byte Tensor

torch.FloatTensor([[1, 0], 
                   [1, 0]]).byte()

tensor([[1, 0],
        [1, 0]], dtype=torch.uint8)

---
## Get Shape

In [23]:
x = torch.FloatTensor([[[1, 2],
                        [3, 4]],
                       [[5, 6],
                        [7, 8]],
                       [[9, 10],
                        [11, 12]]])

torch.Size([3, 2, 2])

In [24]:
## Get tensor shape

print(x.size())
print(x.shape)

torch.Size([3, 2, 2])
torch.Size([3, 2, 2])


In [25]:
## Get number of dimensions in the tensor.

print(x.dim())
print(len(x.size()))

3
3


In [26]:
## Get number of elements in certain dimension of the tensor.

print(x.size(1))
print(x.shape[1])

2
2


In [27]:
## Get number of elements in the last dimension.

print(x.size(-1))
print(x.shape[-1])

2
2
