# Pytorch Basic APIs

## 1. torch.max() 

> **torch.max(input) -> Tensor**

    Returns the maximum value of all elements in the input tensor.

**Example 1 :**

In [1]:
import torch

a = torch.randn(1, 3)
a

tensor([[ 0.1125,  0.7805, -1.0639]])

In [2]:
torch.max(a)

tensor(0.7805)

> **torch.max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)**

    Returns a namedtuple (values, indices) where values is the maximum value of each row of the input tensor in the given dimension dim. And indices is the index location of each maximum value found (argmax).

**Example 2 :**

In [5]:
import torch

a = torch.randn(3, 5)
a

tensor([[ 0.1885, -0.9210, -0.6735,  0.8115, -1.8839],
        [-0.8532, -0.2641, -0.0475,  0.6702, -1.1544],
        [-0.6395,  0.3680, -2.6685, -0.1931,  0.9921]])

In [6]:
# 分别从3行里面找最大值
torch.max(a, 1)

torch.return_types.max(
values=tensor([0.8115, 0.6702, 0.9921]),
indices=tensor([3, 3, 4]))

In [8]:
# 分别从5列里面找最大值
torch.max(a, 0)

torch.return_types.max(
values=tensor([ 0.1885,  0.3680, -0.0475,  0.8115,  0.9921]),
indices=tensor([0, 2, 1, 0, 2]))

In [10]:
# 分别从3行里找最大值，同 torch.max(a, 1)
torch.max(a, -1)

torch.return_types.max(
values=tensor([0.8115, 0.6702, 0.9921]),
indices=tensor([3, 3, 4]))

In [11]:
# 分别从5列里找最大值，同 torch.max(a, 0)
torch.max(a, -2)

torch.return_types.max(
values=tensor([ 0.1885,  0.3680, -0.0475,  0.8115,  0.9921]),
indices=tensor([0, 2, 1, 0, 2]))

**在 Neural Network 中常见的情况：**

In [15]:
_, pred = torch.max(a, 1)
# _ 代表值
print(_)
# pred 代表该值的列坐标
print(pred)

tensor([0.8115, 0.6702, 0.9921])
tensor([3, 3, 4])


## 2. torch.argmax()

> **torch.argmax(input) -> LongTensor**

    Returns the indices of the maximum value of all elemets in the input tensor. 
    This is the second value returned by torch.max().

In [19]:
import torch

a = torch.randn(3, 5)
a

tensor([[ 0.6611,  0.7773,  0.2094,  1.2663,  0.0294],
        [-2.0288,  0.1011,  0.6575,  1.0575,  1.9108],
        [ 1.4509, -1.5427,  0.7185, -0.4936, -0.3349]])

In [21]:
torch.max(a)

tensor(1.9108)

In [20]:
# 1.9108的index是9
torch.argmax(a)

tensor(9)

> **torch.argmax(input, dim, keepdim=False) -> LongTensor**

    Returns the indices of the maximum values of a tensor across a dimension.
    This is the second value returned by torch.max().

In [22]:
import torch

a = torch.randn(3, 5)
a

tensor([[-0.2378,  0.5017,  0.7894, -0.2582,  1.2080],
        [-1.5632,  0.2674,  1.6787, -0.1297, -1.1904],
        [ 0.2622, -1.8024, -0.8724,  1.1156, -0.2386]])

In [23]:
torch.max(a, 1)

torch.return_types.max(
values=tensor([1.2080, 1.6787, 1.1156]),
indices=tensor([4, 2, 3]))

In [24]:
# 从行来取max的index
torch.argmax(a, dim=1)

tensor([4, 2, 3])