## **演示0403：矩阵的链式求导**

### **问题提出：**
在机器学习算法中，经常需要针对矩阵进行求导，例如：已知$(N,D)$维度矩阵$X$，$(D,M)$维度矩阵$W$以及$(N,M)$维度矩阵$Y=XW$  
假设标量 $L=\sum_{i=1}^{N}\sum_{j=1}^{M}f(y_{ij}) $，其中，$y_{ij}$表示矩阵$Y$中的第$(i,j)$个元素，$f(y_{ij})$是基于$y_{ij}$的某种函数定义。也就是说，$L$是由许多个$f(y_{ij})$函数值求和而得。  
要求计算：$\dfrac{\partial L}{\partial X}$和$\dfrac{\partial L}{\partial W}$  
注意：
* 因为$X$维度为$(N,D)$，而$L$是标量，因此$\dfrac{\partial L}{\partial X}$实际上可看成$L$针对$X$中的每一个元素分别求导，最终形成$(N,D)$矩阵
* 因为$W$维度为$(D,M)$，而$L$是标量，因此$\dfrac{\partial L}{\partial W}$形成$(D,M)$矩阵
* 但是由于$L$并非直接由$X$和$W$计算而来，而是经由中间变量$Y$，根据链式法则，必须要分别计算出$\dfrac{\partial L}{\partial Y}$和$\dfrac{\partial Y}{\partial X}$，才能使用链式法则求出$\dfrac{\partial L}{\partial X}$。对于$\dfrac{\partial L}{\partial W}$也类似
* $Y$维度为$(N,M)$，$X$维度为$(N,D)$，$\dfrac{\partial Y}{\partial X}$计算方法是：$Y$矩阵中的每个元素都需要分别针对$X$矩阵中的每个元素求导，最终将会有$N \times M \times N \times D$个导数值，这些导数值将构成**Jacobian矩阵**。
* 在机器学习中，$N$、$D$、$M$的值可能会比较大，例如：$N=64,D=4096,M=4096$，则$N \times M \times N \times D=64 \times 1024 \times 1024 \times 1024$个元素，如果每个元素用32位浮点数存储(4字节)，则共消耗$ 4 \times 64 \times 1024 \times 1024 \times 1024=256GB $内存，这是不现实的
* 因此，必须要找出一种办法，在不计算Jacobian矩阵的情况下，求解偏导数。而因为$Y$是$X$和$W$经过线性变换而得的，故存在着化简的可能性

### **实验1：推导矩阵链式求导公式**
已知$(N,D)$维度矩阵$X$，$(D,M)$维度矩阵$W$以及(N,M)维度矩阵$Y=XW$，同时，$L$是关于$Y$的函数，则：
* $ \dfrac{\partial L}{\partial X}=\dfrac{\partial L}{\partial Y}W^T $
* $ \dfrac{\partial L}{\partial W}=X^T\dfrac{\partial L}{\partial Y} $

针对上一结论的变形：已知$(N,D)$矩阵$W$，$(D,M)$矩阵$X$，以及矩阵$(N,M)$矩阵$Y=WX$，同时，$L$是关于$Y$的函数，则：
* $ \dfrac{\partial L}{\partial W}=\dfrac{\partial L}{\partial Y}X^T $
* $ \dfrac{\partial L}{\partial X}=W^T\dfrac{\partial L}{\partial Y} $

下面的步骤将对上述公式进行模拟推导，为了演示方便，仅考查$N=2, D=4, M=3$的特定情况：

> **步骤1：给定样例**  
$ X=\left(\begin{matrix}
x_{11} & x_{12} & x_{13} & x_{14} \\
x_{21} & x_{22} & x_{23} & x_{24}
\end{matrix}\right) $  
$ W=\left(\begin{matrix}
w_{11} & w_{12} & w_{13} \\
w_{21} & w_{22} & w_{23} \\
w_{31} & w_{32} & w_{33} \\
w_{41} & w_{42} & w_{43}
\end{matrix}\right) $  
$ \begin{aligned}
Y & = XW \\
& = \left(\begin{matrix}
x_{11}w_{11}+x_{12}w_{21}+x_{13}w_{31}+x_{14}w_{41} &
x_{11}w_{12}+x_{12}w_{22}+x_{13}w_{32}+x_{14}w_{42} &
x_{11}w_{13}+x_{12}w_{23}+x_{13}w_{33}+x_{14}w_{43} \\
x_{21}w_{11}+x_{22}w_{21}+x_{23}w_{31}+x_{24}w_{41} &
x_{21}w_{12}+x_{22}w_{22}+x_{23}w_{32}+x_{24}w_{42} &
x_{21}w_{13}+x_{22}w_{23}+x_{23}w_{33}+x_{24}w_{43}
\end{matrix}\right) \\
& = \left(\begin{matrix}
y_{11} & y_{12} & y_{13} \\
y_{21} & y_{22} & y_{23}
\end{matrix}\right)
\end{aligned} $  
$ L=f(y_{11})+f(y_{12})+f(y_{13})+f(y_{21})+f(y_{22})+f(y_{23})=\sum_{i,j}f(y_{ij}) $  
其中，$f(y_{ij})$是关于$y_{ij}$的函数，这些函数值累加起来形成目标函数$L$  
请计算$\dfrac{\partial L}{\partial W}$和$\dfrac{\partial L}{\partial X}$

>**步骤2：考查$\dfrac{\partial L}{\partial Y}$**  
考虑到：$L=f(y_{11})+f(y_{12})+f(y_{13})+f(y_{21})+f(y_{22})+f(y_{23})$  
可知：$\dfrac{\partial L}{\partial f(y_{ij})}=1 $  
$L$是标量，$Y$是$(2,3)$矩阵，导数也应是$(2,3)$矩阵，表示为：  
$\begin{aligned}
\dfrac{\partial L}{\partial Y} & =
\left(\begin{matrix}
\dfrac{\partial L}{\partial y_{11}} &
\dfrac{\partial L}{\partial y_{12}} &
\dfrac{\partial L}{\partial y_{13}} \\
\dfrac{\partial L}{\partial y_{21}} &
\dfrac{\partial L}{\partial y_{22}} &
\dfrac{\partial L}{\partial y_{23}}
\end{matrix}\right) \\ \\
& =\left(\begin{matrix}
\dfrac{\partial L}{\partial f(y_{11})} \cdot \dfrac{\partial f(y_{11})}{\partial y_{11}} &
\dfrac{\partial L}{\partial f(y_{12})} \cdot \dfrac{\partial f(y_{12})}{\partial y_{12}} &
\dfrac{\partial L}{\partial f(y_{13})} \cdot \dfrac{\partial f(y_{13})}{\partial y_{13}} \\
\dfrac{\partial L}{\partial f(y_{21})} \cdot \dfrac{\partial f(y_{21})}{\partial y_{21}} &
\dfrac{\partial L}{\partial f(y_{22})} \cdot \dfrac{\partial f(y_{22})}{\partial y_{22}} &
\dfrac{\partial L}{\partial f(y_{23})} \cdot \dfrac{\partial f(y_{23})}{\partial y_{23}}
\end{matrix}\right) \\ \\
& =\left(\begin{matrix}
\dfrac{\partial f(y_{11})}{\partial y_{11}} &
\dfrac{\partial f(y_{12})}{\partial y_{12}} &
\dfrac{\partial f(y_{13})}{\partial y_{13}} \\
\dfrac{\partial f(y_{21})}{\partial y_{21}} &
\dfrac{\partial f(y_{22})}{\partial y_{22}} &
\dfrac{\partial f(y_{23})}{\partial y_{23}}
\end{matrix}\right)
\end{aligned}$

>**步骤3：考查$\dfrac{\partial L}{\partial x_{11}}$**  
$ \begin{aligned} \\
\begin{aligned}
L = & f(y_{11})+f(y_{12})+f(y_{13})+f(y_{21})+f(y_{22})+f(y_{23}) \\
= & f(x_{11}w_{11}+x_{12}w_{21}+x_{13}w_{31}+x_{14}w_{41}) +
f(x_{11}w_{12}+x_{12}w_{22}+x_{13}w_{32}+x_{14}w_{42}) + \\
& f(x_{11}w_{13}+x_{12}w_{23}+x_{13}w_{33}+x_{14}w_{43}) +
f(x_{21}w_{11}+x_{22}w_{21}+x_{23}w_{31}+x_{24}w_{41}) + \\
& f(x_{21}w_{12}+x_{22}w_{22}+x_{23}w_{32}+x_{24}w_{42}) +
f(x_{21}w_{13}+x_{22}w_{23}+x_{23}w_{33}+x_{24}w_{43})
\end{aligned} \end{aligned}$  
$ \begin{aligned} \\
\dfrac{\partial L}{\partial x_{11}} = &
\dfrac{\partial{(f(y_{11})+f(y_{12})+f(y_{13})+f(y_{21})+f(y_{22})+f(y_{23}))}}{\partial x_{11}} \\
= & \dfrac{\partial{f(y_{11})}}{\partial x_{11}} + \dfrac{\partial{f(y_{12})}}{\partial x_{11}} + \dfrac{\partial{f(y_{13})}}{\partial x_{11}} + \dfrac{\partial{f(y_{21})}}{\partial x_{11}} + \dfrac{\partial{f(y_{22})}}{\partial x_{11}} + \dfrac{\partial{f(y_{23})}}{\partial x_{11}}
\end{aligned}$  
而根据链式法则(注意，此时$y_{11}$ 和$x_{11}$ 都已是一元变量，因此可以使用通常意义上的链式法则)：  
$ \dfrac{\partial f(y_{11})}{\partial x_{11}}=\dfrac{\partial f(y_{11})}{\partial y_{11}} \cdot \dfrac{\partial y_{11}}{\partial x_{11}}=\dfrac{\partial L}{\partial y_{11}} \cdot \dfrac{\partial y_{11}}{\partial x_{11}} $(注意：$\dfrac{\partial L}{\partial f(y_{ij})}=1 $  )  
因此：  
$\begin{aligned}
\dfrac{\partial L}{\partial x_{11}} = & \dfrac{\partial L}{\partial y_{11}} \cdot \dfrac{\partial y_{11}}{\partial x_{11}} + 
\dfrac{\partial L}{\partial y_{12}} \cdot \dfrac{\partial y_{12}}{\partial x_{11}} +
\dfrac{\partial L}{\partial y_{13}} \cdot \dfrac{\partial y_{13}}{\partial x_{11}} +
\dfrac{\partial L}{\partial y_{21}} \cdot \dfrac{\partial y_{21}}{\partial x_{11}} +
\dfrac{\partial L}{\partial y_{22}} \cdot \dfrac{\partial y_{22}}{\partial x_{11}} +
\dfrac{\partial L}{\partial y_{23}} \cdot \dfrac{\partial y_{23}}{\partial x_{11}} \\
= & \dfrac{\partial L}{\partial y_{11}} \cdot w_{11} +
\dfrac{\partial L}{\partial y_{12}} \cdot w_{12} +
\dfrac{\partial L}{\partial y_{13}} \cdot w_{13}
\end{aligned}$  
类似的：  
$ \begin{aligned}
\dfrac{\partial L}{\partial x_{12}} = & \dfrac{\partial L}{\partial y_{11}} \cdot \dfrac{\partial y_{11}}{\partial x_{12}} +
\dfrac{\partial L}{\partial y_{12}} \cdot \dfrac{\partial y_{12}}{\partial x_{12}} +
\dfrac{\partial L}{\partial y_{13}} \cdot \dfrac{\partial y_{13}}{\partial x_{12}} +
\dfrac{\partial L}{\partial y_{21}} \cdot \dfrac{\partial y_{21}}{\partial x_{12}} +
\dfrac{\partial L}{\partial y_{22}} \cdot \dfrac{\partial y_{22}}{\partial x_{12}} +
\dfrac{\partial L}{\partial y_{23}} \cdot \dfrac{\partial y_{23}}{\partial x_{12}} \\ = & \dfrac{\partial L}{\partial y_{11}} * w_{21} +
\dfrac{\partial L}{\partial y_{12}} \cdot w_{22} +
\dfrac{\partial L}{\partial y_{13}} \cdot w_{23}
\end{aligned}$  
$\begin{aligned}
\dfrac{\partial L}{\partial x_{21}} = & \dfrac{\partial L}{\partial y_{11}} \cdot \dfrac{\partial y_{11}}{\partial x_{21}} +
\dfrac{\partial L}{\partial y_{12}} \cdot \dfrac{\partial y_{12}}{\partial x_{21}} +
\dfrac{\partial L}{\partial y_{13}} \cdot \dfrac{\partial y_{13}}{\partial x_{21}} +
\dfrac{\partial L}{\partial y_{21}} \cdot \dfrac{\partial y_{21}}{\partial x_{21}} +
\dfrac{\partial L}{\partial y_{22}} \cdot \dfrac{\partial y_{22}}{\partial x_{21}} +
\dfrac{\partial L}{\partial y_{23}} \cdot \dfrac{\partial y_{23}}{\partial x_{21}} \\ = & \dfrac{\partial L}{\partial y_{21}} \cdot w_{11} +
\dfrac{\partial L}{\partial y_{22}} \cdot w_{12} +
\dfrac{\partial L}{\partial y_{23}} \cdot w_{13}
\end{aligned}$  
等等

>**步骤4：表达成矩阵形式**  
$\begin{aligned}
\dfrac{\partial L}{\partial X} \\ = & \left(\begin{matrix}
\dfrac{\partial L}{\partial x_{11}}&\dfrac{\partial L}{\partial x_{12}}&
\dfrac{\partial L}{\partial x_{13}}&\dfrac{\partial L}{\partial x_{14}} \\
\dfrac{\partial L}{\partial x_{21}}&\dfrac{\partial L}{\partial x_{22}}&
\dfrac{\partial L}{\partial x_{23}}&\dfrac{\partial L}{\partial x_{24}}
\end{matrix}\right) \\ \\ = & \left(\begin{matrix}
\dfrac{\partial L}{\partial y_{11}} \cdot w_{11} +
\dfrac{\partial L}{\partial y_{12}} \cdot w_{12} +
\dfrac{\partial L}{\partial y_{13}} \cdot w_{13} &
\dfrac{\partial L}{\partial y_{11}} \cdot w_{21} +
\dfrac{\partial L}{\partial y_{12}} \cdot w_{22} +
\dfrac{\partial L}{\partial y_{13}} \cdot w_{23} \to \\
\to \dfrac{\partial L}{\partial y_{11}} \cdot w_{31} +
\dfrac{\partial L}{\partial y_{12}} \cdot w_{32} +
\dfrac{\partial L}{\partial y_{13}} \cdot w_{33} &
\dfrac{\partial L}{\partial y_{11}} \cdot w_{41} +
\dfrac{\partial L}{\partial y_{12}} \cdot w_{42} +
\dfrac{\partial L}{\partial y_{13}} \cdot w_{43} \\
\dfrac{\partial L}{\partial y_{21}} \cdot w_{11} +
\dfrac{\partial L}{\partial y_{22}} \cdot w_{12} +
\dfrac{\partial L}{\partial y_{23}} \cdot w_{13} &
\dfrac{\partial L}{\partial y_{21}} \cdot w_{21} +
\dfrac{\partial L}{\partial y_{22}} \cdot w_{22} +
\dfrac{\partial L}{\partial y_{23}} \cdot w_{23} \to \\
\to \dfrac{\partial L}{\partial y_{21}} \cdot w_{31} +
\dfrac{\partial L}{\partial y_{22}} \cdot w_{32} +
\dfrac{\partial L}{\partial y_{23}} \cdot w_{33} &
\dfrac{\partial L}{\partial y_{21}} \cdot w_{41} +
\dfrac{\partial L}{\partial y_{22}} \cdot w_{42} +
\dfrac{\partial L}{\partial y_{23}} \cdot w_{43}
\end{matrix}\right) \\ \\ = & \left(\begin{matrix}
\dfrac{\partial L}{\partial y_{11}} &
\dfrac{\partial L}{\partial y_{12}} &
\dfrac{\partial L}{\partial y_{13}} \\
\dfrac{\partial L}{\partial y_{21}} &
\dfrac{\partial L}{\partial y_{22}} &
\dfrac{\partial L}{\partial y_{23}}
\end{matrix}\right) \cdot
\left(\begin{matrix}
w_{11} & w_{21} & w_{31} & w_{41} \\
w_{12} & w_{22} & w_{32} & w_{42} \\
w_{13} & w_{23} & w_{33} & w_{43}
\end{matrix}\right) \\ \\ = & \dfrac{\partial L}{\partial Y}W^T
\end{aligned}$

>**步骤5：结论**  
* 类似的可得：$ \dfrac{\partial L}{\partial W}=X^T\dfrac{\partial L}{\partial Y} $
* 有了上述推论，在不计算Jacobian矩阵的情况下，也能比较方便的求得偏导数。上述结论虽然是在$N=2, D=4, M=3$情况下得出的，但一般也具有普遍性