In [1]:
import numpy as np
import torch

print(torch.__version__)

1.1.0


## Tensors

In [2]:
print(torch.empty(2, 3))

tensor([[5.2064e+22, 2.1685e-04, 5.3690e-05],
        [3.4180e-06, 1.6806e-04, 2.1006e+20]])


In [3]:
print(torch.rand(2, 4))

tensor([[0.8809, 0.8158, 0.7985, 0.9070],
        [0.2123, 0.5240, 0.7613, 0.1632]])


In [4]:
print(torch.tensor([2.1, 3]))

tensor([2.1000, 3.0000])


### dtype

In [5]:
x = torch.tensor([2.1, 3])
print(x)
print(x.dtype)

tensor([2.1000, 3.0000])
torch.float32


In [6]:
x = torch.zeros(3, 2, dtype=torch.float)
print(x)
print(x.dtype)

tensor([[0., 0.],
        [0., 0.],
        [0., 0.]])
torch.float32


In [7]:
x = torch.zeros(3, 2, dtype=torch.double)
print(x)
print(x.dtype)

tensor([[0., 0.],
        [0., 0.],
        [0., 0.]], dtype=torch.float64)
torch.float64


In [8]:
x = torch.zeros(3, 2, dtype=torch.int32)
print(x)
print(x.dtype)

tensor([[0, 0],
        [0, 0],
        [0, 0]], dtype=torch.int32)
torch.int32


### size

In [9]:
x = torch.zeros(2)
print(x.size())

x = torch.zeros(3, 2)
print(x.size())

x = torch.tensor([[2, 1], [3, 1]])
print(x.size())

torch.Size([2])
torch.Size([3, 2])
torch.Size([2, 2])


## Operations

In [10]:
x = torch.tensor([[1, 0], [0, 1]])
y = torch.tensor([[0, 2], [3, 1]])

### addition

In [11]:
print(x + y)
print(torch.add(x, y))

tensor([[1, 2],
        [3, 2]])
tensor([[1, 2],
        [3, 2]])


In [12]:
result = torch.empty_like(x)

torch.add(x, y, out=result)
print(result)

tensor([[1, 2],
        [3, 2]])


In [13]:
# in-place addition
y.add_(x)
print(y)

tensor([[1, 2],
        [3, 2]])


### slice

In [14]:
x = torch.rand(3, 2)
print(x)

tensor([[0.2063, 0.0156],
        [0.8561, 0.5622],
        [0.5152, 0.9742]])


In [15]:
print(x[:, 1])

tensor([0.0156, 0.5622, 0.9742])


In [16]:
print(x[0:2, :])

tensor([[0.2063, 0.0156],
        [0.8561, 0.5622]])


### reshape

In [17]:
x = torch.randn(4, 3)
print(x)

tensor([[ 0.4647, -0.9568,  0.1959],
        [ 0.0976, -0.8237,  0.5047],
        [-0.7891, -1.6576,  0.4979],
        [ 0.5367,  2.0624,  1.7420]])


In [18]:
print(x.view(12))

tensor([ 0.4647, -0.9568,  0.1959,  0.0976, -0.8237,  0.5047, -0.7891, -1.6576,
         0.4979,  0.5367,  2.0624,  1.7420])


In [19]:
print(x.view(-1, 6))

tensor([[ 0.4647, -0.9568,  0.1959,  0.0976, -0.8237,  0.5047],
        [-0.7891, -1.6576,  0.4979,  0.5367,  2.0624,  1.7420]])


## Numpyとの関係

torch.tensorとnumpy.arrayを行き来できる。
片方に(in-placeな)変更を加えるともう片方も更新されるので注意。

### from torch.tensor to numpy.array

In [20]:
x_torch = torch.rand(2, 4)
print(x_torch)

tensor([[0.4377, 0.1137, 0.6769, 0.3122],
        [0.7977, 0.7243, 0.9831, 0.5149]])


In [21]:
x_numpy = x_torch.numpy()
x_numpy

array([[0.4376794 , 0.11373186, 0.6768678 , 0.31223094],
       [0.797726  , 0.72429544, 0.9831159 , 0.5148524 ]], dtype=float32)

In [22]:
x_torch.add_(3)

print(x_torch)
print(x_numpy)

tensor([[3.4377, 3.1137, 3.6769, 3.3122],
        [3.7977, 3.7243, 3.9831, 3.5149]])
[[3.4376793 3.1137319 3.6768677 3.312231 ]
 [3.797726  3.7242954 3.983116  3.5148525]]


In [23]:
x_numpy[0, 0] = 100.0

print(x_torch)
print(x_numpy)

tensor([[100.0000,   3.1137,   3.6769,   3.3122],
        [  3.7977,   3.7243,   3.9831,   3.5149]])
[[100.          3.1137319   3.6768677   3.312231 ]
 [  3.797726    3.7242954   3.983116    3.5148525]]


### from numpy.array to torch.tensor

In [24]:
x_numpy = np.random.rand(2, 4)
x_numpy

array([[0.92849014, 0.88663874, 0.24959687, 0.06786177],
       [0.18186754, 0.39224116, 0.02867458, 0.20462827]])

In [25]:
x_torch = torch.from_numpy(x_numpy)
print(x_torch)

tensor([[0.9285, 0.8866, 0.2496, 0.0679],
        [0.1819, 0.3922, 0.0287, 0.2046]], dtype=torch.float64)


In [26]:
x_torch.add_(3)

print(x_torch)
print(x_numpy)

tensor([[3.9285, 3.8866, 3.2496, 3.0679],
        [3.1819, 3.3922, 3.0287, 3.2046]], dtype=torch.float64)
[[3.92849014 3.88663874 3.24959687 3.06786177]
 [3.18186754 3.39224116 3.02867458 3.20462827]]


In [27]:
x_numpy[0, 0] = 100.0

print(x_torch)
print(x_numpy)

tensor([[100.0000,   3.8866,   3.2496,   3.0679],
        [  3.1819,   3.3922,   3.0287,   3.2046]], dtype=torch.float64)
[[100.           3.88663874   3.24959687   3.06786177]
 [  3.18186754   3.39224116   3.02867458   3.20462827]]
