Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

20190309-理解numpy中的广播(Broadcasting) #23

Open
iMchxx opened this issue Mar 9, 2019 · 0 comments
Open

20190309-理解numpy中的广播(Broadcasting) #23

iMchxx opened this issue Mar 9, 2019 · 0 comments

Comments

@iMchxx
Copy link
Owner

iMchxx commented Mar 9, 2019

在学习numpy的过程中知道两个不同形状的数组进行运算时,可能会对某些数组进行广播(Broadcasting)。但是文档里的解释理解起来有点困难,看了别人的文章后大致有点理解。
以下引用别人文章里对广播规则的翻译:

三、Broadcast(广播)的规则
All input arrays with ndim smaller than the input array of largest ndim, have 1’s prepended to their shapes.
The size in each dimension of the output shape is the maximum of all the input sizes in that dimension.
An input can be used in the calculation if its size in a particular dimension either matches the output size in that dimension, or has value exactly 1.
If an input has a dimension size of 1 in its shape, the first data entry in that dimension will be used for all calculations along that dimension. In other words, the stepping machinery of the ufunc will simply not step along that dimension (the stride will be 0 for that dimension).

使用以下的代码来辅助解释

# 
x = np.arange(3).reshape(3, 1)
x
Out[2]: 
array([[0],
       [1],
       [2]])
x.shape
Out[3]: (3, 1)

# 
y = np.ones(4)
y
Out[7]: array([1., 1., 1., 1.])
y.shape
Out[8]: (4,)

# 
z = x + y
z
Out[13]: 
array([[1., 1., 1., 1.],
       [2., 2., 2., 2.],
       [3., 3., 3., 3.]])
z.shape
Out[14]: (3, 4)
  1. 让所有输入数组都向其中shape最长的数组看齐,shape中不足的部分都通过在前面加1补齐

输入数组中shape长度最长的是x.shape = ( 3, 1 ),这时y.shape会在前面补1,即y.shape = ( 1, 4 )

  1. 输出数组的shape是输入数组shape的各个轴上的最大值

x + y的输出数组的shape会是各个数组的各个轴中的最大值,即(3, 4)

  1. 如果输入数组的某个轴和输出数组的对应轴的长度相同或者其长度为1时,这个数组能够用来计算,否则出错

例如

# 以下正确
Image (3d array):  256 x 256 x 3
Scale (1d array):              3   # 相当于是 1 x 1 x3
Result (3d array): 256 x 256 x 3

A      (4d array):  8 x 1 x 6 x 1
B      (3d array):      7 x 1 x 5 # 相当于是1 x 7 x 1 x 5
Result (4d array):  8 x 7 x 6 x 5

A      (2d array):  5 x 4
B      (1d array):      1 # 相当于是1 x 1
Result (2d array):  5 x 4

A      (2d array):  15 x 3 x 5
B      (1d array):  15 x 1 x 5
Result (2d array):  15 x 3 x 5

# 以下会报错
A  (1d array): 3        # 相当于是(1, 3)
B  (1d array): 4        # 相当于是(1 ,4), 最后一维(trailing dimension)不匹配

A  (2d array):      2 x 1   # (1, 2, 1)
B  (3d array):  8 x 4 x 3  # (8, 4, 3)(倒数第二维不匹配)

  1. 当输入数组的某个轴的长度为1时,沿着此轴运算时都用此轴上的第一组值

运算时,x, y在内部分别被扩展成

x     (2d array):      3 x 1
y      (1d array):          4 # 1 x 4
Result (2d array):      3 x 4

# x会由
array([[0],
       [1],
       [2]])

# 扩展成以下的样子
array([[0., 0., 0., 0.],
       [1., 1., 1., 1.],
       [2., 2., 2., 2.]])

# y会由
array([1., 1., 1., 1.])

# 扩展成以下的样子
array([[1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.]])

# 所以相加会得到
x + y
Out[21]: 
array([[1., 1., 1., 1.],
       [2., 2., 2., 2.],
       [3., 3., 3., 3.]])

总结

其实广播简单的总结就是以下两个规则:

来看更为一般的broadcasting rules

当操作两个array时,numpy会逐个比较它们的shape(构成的元组tuple),只有在下述情况下,两arrays才算兼容:

  1. 相等
  2. 其中一个为1,(进而可进行拷贝拓展已至,shape匹配)

via

@iMchxx iMchxx changed the title 20190309-理解numpy中广播(Broadcasting) 20190309-理解numpy中的广播(Broadcasting) Mar 9, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant