# 自己练习torch中tensor的相关操作

In [1]:
import torch
import numpy as np

## Tensor 创建

In [2]:
torch.Tensor(2,3)

tensor([[2.4758e-12, 1.5271e-04, 1.3567e-19],
        [7.8447e+17, 1.3556e-19, 1.3563e-19]])

In [3]:
torch.LongTensor(2,3)

tensor([[94384486530416, 94384486536624,             80],
        [            64,              0,              0]])

In [4]:
a = torch.randn(2,3)
a

tensor([[-1.5161, -0.2515,  0.4894],
        [-0.9457,  1.1382, -0.8518]])

In [5]:
torch.zeros_like(a)

tensor([[0., 0., 0.],
        [0., 0., 0.]])

In [6]:
torch.randint(0, 1, (2,3))

tensor([[0, 0, 0],
        [0, 0, 0]])

In [7]:
torch.empty((2,1))

tensor([[0.],
        [0.]])

## Tensor 维度转换、操作
### reshape, view, transpose, permute,

In [8]:
a = torch.randn((2,3))
a = a*a
a

tensor([[0.0091, 0.9968, 0.2921],
        [0.0026, 0.0402, 0.7185]])

In [9]:
b = torch.nn.Softmax(dim=0)
b(a)

tensor([[0.5016, 0.7224, 0.3950],
        [0.4984, 0.2776, 0.6050]])

In [10]:
w = torch.Tensor([[1],[0]])
w.size()

torch.Size([2, 1])

In [11]:
a.size()

torch.Size([2, 3])

In [12]:
temp = torch.Tensor(np.array(range(24))).view(2, 3, 4)
temp

tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])

In [13]:
c = torch.sum(temp, 1)
c

tensor([[12., 15., 18., 21.],
        [48., 51., 54., 57.]])

In [14]:
fm = torch.randn((2,1))
fm

tensor([[-1.5050],
        [ 0.1857]])

In [15]:
prev_hidden_state = torch.randn((3))
prev_hidden_state

tensor([-1.9094, -1.1655, -0.3349])

### repeat


In [16]:
prev_hidden_state = prev_hidden_state.repeat(2,1)
prev_hidden_state

tensor([[-1.9094, -1.1655, -0.3349],
        [-1.9094, -1.1655, -0.3349]])

In [17]:
temp = torch.cat((fm, prev_hidden_state), dim=1)
temp

tensor([[-1.5050, -1.9094, -1.1655, -0.3349],
        [ 0.1857, -1.9094, -1.1655, -0.3349]])

In [18]:
torch.sum(temp, 1)

tensor([-4.9150, -3.2242])

In [19]:
x = torch.Tensor(np.array(range(24)))
x = x.view(2,3,4)
x

tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])

In [20]:
x.permute(0,2,1)

tensor([[[ 0.,  4.,  8.],
         [ 1.,  5.,  9.],
         [ 2.,  6., 10.],
         [ 3.,  7., 11.]],

        [[12., 16., 20.],
         [13., 17., 21.],
         [14., 18., 22.],
         [15., 19., 23.]]])

In [21]:
x.permute(2,1,0)

tensor([[[ 0., 12.],
         [ 4., 16.],
         [ 8., 20.]],

        [[ 1., 13.],
         [ 5., 17.],
         [ 9., 21.]],

        [[ 2., 14.],
         [ 6., 18.],
         [10., 22.]],

        [[ 3., 15.],
         [ 7., 19.],
         [11., 23.]]])

In [22]:
t=[1,5,6,3,4,7,9]
c=torch.Tensor(t)
c

tensor([1., 5., 6., 3., 4., 7., 9.])

In [23]:
sort_c, index = torch.sort(c)
sort_c, index

(tensor([1., 3., 4., 5., 6., 7., 9.]), tensor([0, 3, 4, 1, 2, 5, 6]))

In [24]:
c

tensor([1., 5., 6., 3., 4., 7., 9.])

### index_select, gather

In [25]:
captions = torch.arange(0, 49).view(7, 7)
captions

tensor([[ 0,  1,  2,  3,  4,  5,  6],
        [ 7,  8,  9, 10, 11, 12, 13],
        [14, 15, 16, 17, 18, 19, 20],
        [21, 22, 23, 24, 25, 26, 27],
        [28, 29, 30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39, 40, 41],
        [42, 43, 44, 45, 46, 47, 48]])

In [26]:
torch.index_select(input=captions,dim=0,index=index)

tensor([[ 0,  1,  2,  3,  4,  5,  6],
        [21, 22, 23, 24, 25, 26, 27],
        [28, 29, 30, 31, 32, 33, 34],
        [ 7,  8,  9, 10, 11, 12, 13],
        [14, 15, 16, 17, 18, 19, 20],
        [35, 36, 37, 38, 39, 40, 41],
        [42, 43, 44, 45, 46, 47, 48]])

### narrow

In [27]:
a = torch.randn(3,4)
a

tensor([[-1.3634,  0.8036,  2.2136,  0.1945],
        [ 0.7901,  0.1566, -0.0502,  0.2545],
        [-1.1526,  0.7794, -0.6937, -0.2144]])

In [28]:
a.narrow(0, 1, 2)

tensor([[ 0.7901,  0.1566, -0.0502,  0.2545],
        [-1.1526,  0.7794, -0.6937, -0.2144]])

In [29]:
a.narrow(1, 1, 1)

tensor([[0.8036],
        [0.1566],
        [0.7794]])