In [1]:
### Tensor Math & Comparison Operations

In [2]:
import torch

In [3]:
x = torch.tensor([1, 2, 3])
y = torch.tensor([9, 8, 7])

In [5]:
# Addition
z1 = torch.empty(3)
torch.add(x, y, out=z1)
print(z1)

z2 = torch.add(x, y)
print(z2)

z = x + y
print(z)

tensor([10., 10., 10.])
tensor([10, 10, 10])
tensor([10, 10, 10])


In [6]:
# Subtraction
z = x - y
print(z)

tensor([-8, -6, -4])


In [7]:
# Division
z = torch.true_divide(x, y)
print(z)

tensor([0.1111, 0.2500, 0.4286])


In [10]:
# Inplace operations
t = torch.zeros(3)
t.add_(x)
print(t)
t += x # t = t + x
print(t)

tensor([1., 2., 3.])
tensor([2., 4., 6.])


In [12]:
# Exponentiation
z = x.pow(2)
print(z)
z = x ** 2
print(z)

tensor([1, 4, 9])
tensor([1, 4, 9])


In [14]:
# Simple comparion
z = x > 0
print(z)
z = x < 0
print(z)

tensor([True, True, True])
tensor([False, False, False])


In [15]:
# Matrix Multiplication
x1 = torch.rand((2, 5))
x2 = torch.rand((5, 3))
x3 = torch.mm(x1, x2)
print(x1), print(x2), print(x3)
x3 = x1.mm(x2)
print(x3)

tensor([[0.0461, 0.8025, 0.4515, 0.8971, 0.2115],
        [0.5237, 0.8451, 0.1771, 0.3182, 0.8047]])
tensor([[0.1598, 0.5945, 0.8470],
        [0.7692, 0.8331, 0.3614],
        [0.2823, 0.0902, 0.2565],
        [0.4686, 0.5653, 0.5304],
        [0.6478, 0.5833, 0.9962]])
tensor([[1.3095, 1.3672, 1.1314],
        [1.4541, 1.6807, 1.7648]])
tensor([[1.3095, 1.3672, 1.1314],
        [1.4541, 1.6807, 1.7648]])


In [18]:
# Matrix Exponentiation
matrix_exp = torch.arange(4).reshape(2,2)
print(matrix_exp)
matrix_exp_p = matrix_exp.matrix_power(3)
print(matrix_exp_p)

tensor([[0, 1],
        [2, 3]])
tensor([[ 6, 11],
        [22, 39]])


In [20]:
# Element wise mult
x = torch.tensor([1, 2, 3])
y = torch.tensor([9, 8, 7])
z = x * y 
print(z)

tensor([ 9, 16, 21])


In [21]:
# dot product
z = torch.dot(x, y)
print(z)

tensor(46)


In [22]:
# Batch Matrix Multiplication
batch = 32
n = 10
m = 20
p = 30

tensor1 = torch.rand((batch, n, m))
tensor2 = torch.rand((batch, m, p))
out_bmm = torch.bmm(tensor1, tensor2)
print(tensor1.shape), print(tensor2.shape), print(out_bmm.shape)

torch.Size([32, 10, 20])
torch.Size([32, 20, 30])
torch.Size([32, 10, 30])


(None, None, None)

In [23]:
# Example of Broadcasting
x1 = torch.rand((5, 5))
x2 = torch.rand((1, 5))
print(x1.shape), print(x2.shape)

z = x1 - x2
print(z.shape)

z = x1 ** x2  # ** element wise power
print(z.shape)

torch.Size([5, 5])
torch.Size([1, 5])
torch.Size([5, 5])
torch.Size([5, 5])


In [31]:
x1 = torch.tensor([[1, 2], [3, 4]])
x2 = torch.tensor([1, 0])
print(x1), print(x2)
print(x1 ** x2)

tensor([[1, 2],
        [3, 4]])
tensor([1, 0])
tensor([[1, 1],
        [3, 1]])


In [62]:
# Other useful tensor operations
x = torch.tensor([1, 2, 3])
y = torch.tensor([9, 8, 7])
sum_x = torch.sum(x, dim=0) # x.sum(dim=0)
print(sum_x)

values, indices = torch.max(x, dim=0) # x.max(dim=0)
print(values, indices)
values, indices = torch.min(x, dim=0)
print(values, indices)

abs_x = torch.abs(x)
print(abs_x)

z = torch.argmax(x, dim=0) # return index
print(z)
z = torch.argmin(x, dim=0)
print(z)

mean_x = torch.mean(x.float(), dim=0) # mean(): input dtype should be either floating point or complex dtypes. Got Long instead.
print(mean_x)

z = torch.eq(x, y)
print(z)

z_sort = torch.sort(y, dim=0, descending=False) # return sorted_y and indices
print(*z_sort)

z = torch.clamp(x, min=0, max=1) # 将小于min和大于max的数置为min和max
print(z)

tensor(6)
tensor(3) tensor(2)
tensor(1) tensor(0)
tensor([1, 2, 3])
tensor(2)
tensor(0)
tensor(2.)
tensor([False, False, False])
tensor([7, 8, 9]) tensor([2, 1, 0])
tensor([1, 1, 1])


In [60]:
x = torch.tensor([1, 0, 1, 1, 1], dtype=torch.bool)
print(x)

z = torch.any(x) # 是否存在True
print(z)

z = torch.all(x) # 是否全是True
print(z)

tensor([ True, False,  True,  True,  True])
tensor(True)
tensor(False)
