### MLX参数初始化文档

在MLX框架中，参数初始化是神经网络训练中至关重要的一部分。本文档将总结MLX框架中常用的参数初始化方法，并结合用户提供的方法进行详细介绍。

In [1]:
import mlx.core as mx
import mlx.nn as nn
import mlx
from mlx.utils import tree_flatten

#### 参数初始化方法概述

在神经网络中，合适的权重初始化可以加速训练收敛并提高模型性能。MLX框架提供了一些常用的初始化方法，如Xavier初始化（Glorot初始化）等。
在绝大部分情况下，mlx可以使用两种不同的方法进行参数初始化

##### 1. 直接通过net遍历：
对于一个简单神经网络(无嵌套网络结构)，我们可以直接遍历网络到需要初始化的层，用内置的出实话函数进行参数初始化

In [2]:
# 定义神经网络
net = nn.Sequential(nn.Linear(4, 8),
                    nn.ReLU(),
                    nn.Linear(8, 1))

print(net)

# 通过索引进行参数初始化
uniform_fn = nn.init.glorot_uniform()
const_fn = nn.init.constant(42.0)
net.layers[0].weight = uniform_fn(net.layers[0].weight)
net.layers[2].weight = const_fn(net.layers[2].weight)
print(net.layers[0].weight)
print(net.layers[2].weight)

Sequential(
  (layers.0): Linear(input_dims=4, output_dims=8, bias=True)
  (layers.1): ReLU()
  (layers.2): Linear(input_dims=8, output_dims=1, bias=True)
)
array([[0.422567, 0.130482, 0.59768, -0.0286205],
       [-0.634305, -0.440786, -0.23919, -0.407694],
       [0.571524, 0.254098, -0.443719, -0.519819],
       ...,
       [-0.0761608, -0.38404, -0.225917, 0.321595],
       [-0.117003, -0.095284, -0.260199, -0.0796262],
       [-0.427804, -0.493432, -0.617128, -0.46343]], dtype=float32)
array([[42, 42, 42, ..., 42, 42, 42]], dtype=float32)


#### 参数初始化函数概述

在神经网络中，合适的权重初始化可以加速训练收敛并提高模型性能。MLX框架提供了一些常用的初始化方法，如Xavier初始化（Glorot初始化）等。

##### Constant初始化

In [3]:
weight_fn = nn.init.constant(1.0)
bias_fn = nn.init.constant(0.0)
for layer in net.layers:
    if type(layer) == nn.Linear:
        layer.weight = weight_fn(layer.weight)
        layer.bias = bias_fn(layer.bias)

net.layers[0].weight[0], net.layers[0].bias[0]

(array([1, 1, 1, 1], dtype=float32), array(0, dtype=float32))

##### Uniform/Normal初始化

In [4]:
# 定义神经网络
net = nn.Sequential(nn.Linear(4, 8),
                    nn.ReLU(),
                    nn.Linear(8, 1))

weight_fn = nn.init.normal(mean=0.0, std=0.01)
for layer in net.layers:
    if type(layer) == nn.Linear:
        layer.weight = weight_fn(layer.weight)

net.layers[0].weight[0], net.layers[0].bias[0]

(array([-0.0146943, 0.00767959, -0.000236169, 0.000649628], dtype=float32),
 array(0.440777, dtype=float32))

##### Xavier初始化
Xavier初始化所接收的参数必须dim > 1. 所以Xavier函数不能用于对bias的初始化

In [5]:
uniform_fn = nn.init.glorot_uniform()
const_fn = nn.init.constant(42.0)
net.layers[0].weight = uniform_fn(net.layers[0].weight)
net.layers[2].weight = const_fn(net.layers[2].weight)
print(net.layers[0].weight[0])
print(net.layers[2].weight)

array([0.0795656, 0.355053, -0.316831, 0.596846], dtype=float32)
array([[42, 42, 42, ..., 42, 42, 42]], dtype=float32)


#### 2.使用Modules函数遍历所有层：
对于一个复杂神经网络，我们可以使用内置的函数module.update方法来进行参数初始化

#### 相关函数介绍：
初始化函数大多使用module.update/module.update_modules

module.update：
- 该函数需要传入一个dict类型的参数作为input，会将该dict中所有的参数更新到当前的module上
- 大多数情况下要和mlx.utils.tree_map配合使用
- mlx.utils.tree_map需要传入两个参数，一个为map function，另一个为map function的对象，类型为dictionary。该函数返回一个dictionary类型，内容为将map_fn的内容映射到map function的对象dict上后的结果。

module.update_modules：
- 与module.update及其类似
- 唯一区别是module.update会更新当前神经网络中的所有参数，但module.update_modules只会更新所提供的module的参数，例如module==nn.linear，则只会更新linear层中的参数

tree_flatten()：
- 接收输入为一个python tree（神经网络）或一个dictionary（model.parameters）
- 输出一个展平后的神经网络结构，类型为list，内容是key和value的tuple
- The keys are using the dot notation to define trees of arbitrary depth and complexity.

In [6]:
def block1():
    return nn.Sequential(nn.Linear(4, 8), nn.ReLU(),
                         nn.Linear(8, 4), nn.ReLU())

def block3():
    return nn.Sequential(nn.Linear(2, 3), nn.ReLU(),
                         nn.Linear(1, 4), nn.ReLU())

def block2():
    net = []
    for i in range(2):
        # 在这里嵌套
        net.append(block1())
    for i in range(2):
        # 在这里嵌套
        net.append(block3())
    return net

net = nn.Sequential(*block2(), nn.Linear(4, 1))


In [7]:
print(net)
#只打印最外层结构 无法访问每一层参数

Sequential(
  (layers.0): Sequential(
    (layers.0): Linear(input_dims=4, output_dims=8, bias=True)
    (layers.1): ReLU()
    (layers.2): Linear(input_dims=8, output_dims=4, bias=True)
    (layers.3): ReLU()
  )
  (layers.1): Sequential(
    (layers.0): Linear(input_dims=4, output_dims=8, bias=True)
    (layers.1): ReLU()
    (layers.2): Linear(input_dims=8, output_dims=4, bias=True)
    (layers.3): ReLU()
  )
  (layers.2): Sequential(
    (layers.0): Linear(input_dims=2, output_dims=3, bias=True)
    (layers.1): ReLU()
    (layers.2): Linear(input_dims=1, output_dims=4, bias=True)
    (layers.3): ReLU()
  )
  (layers.3): Sequential(
    (layers.0): Linear(input_dims=2, output_dims=3, bias=True)
    (layers.1): ReLU()
    (layers.2): Linear(input_dims=1, output_dims=4, bias=True)
    (layers.3): ReLU()
  )
  (layers.4): Linear(input_dims=4, output_dims=1, bias=True)
)


### 可以按照层级结构遍历访问

In [8]:
net.layers[1].layers[2].weight

net.layers[1].layers[2].weight = const_fn(net.layers[1].layers[2].weight)

net.layers[1].layers[2].weight

array([[42, 42, 42, ..., 42, 42, 42],
       [42, 42, 42, ..., 42, 42, 42],
       [42, 42, 42, ..., 42, 42, 42],
       [42, 42, 42, ..., 42, 42, 42]], dtype=float32)

In [9]:
print(net.modules())
#DFS遍历所有层

[Sequential(
  (layers.0): Sequential(
    (layers.0): Linear(input_dims=4, output_dims=8, bias=True)
    (layers.1): ReLU()
    (layers.2): Linear(input_dims=8, output_dims=4, bias=True)
    (layers.3): ReLU()
  )
  (layers.1): Sequential(
    (layers.0): Linear(input_dims=4, output_dims=8, bias=True)
    (layers.1): ReLU()
    (layers.2): Linear(input_dims=8, output_dims=4, bias=True)
    (layers.3): ReLU()
  )
  (layers.2): Sequential(
    (layers.0): Linear(input_dims=2, output_dims=3, bias=True)
    (layers.1): ReLU()
    (layers.2): Linear(input_dims=1, output_dims=4, bias=True)
    (layers.3): ReLU()
  )
  (layers.3): Sequential(
    (layers.0): Linear(input_dims=2, output_dims=3, bias=True)
    (layers.1): ReLU()
    (layers.2): Linear(input_dims=1, output_dims=4, bias=True)
    (layers.3): ReLU()
  )
  (layers.4): Linear(input_dims=4, output_dims=1, bias=True)
), Linear(input_dims=4, output_dims=1, bias=True), Sequential(
  (layers.0): Linear(input_dims=2, output_dims=3, bias=

In [10]:
#直接遍历只能访问最外层结构 无法初始化内部层
weight_fn = nn.init.constant(1.0)
bias_fn = nn.init.constant(0.0)
for layer in net.layers:
    if type(layer) == nn.Linear:
        layer.weight = weight_fn(layer.weight)
        layer.bias = bias_fn(layer.bias)

net.layers[0].weight[0], net.layers[0].bias[0]

AttributeError: 'Sequential' object has no attribute 'weight'

In [11]:
# 定义一个初始化参数方法
def init_weights(array):
    if array.ndim > 1:
        weight_fn = nn.init.constant(10)
        array = weight_fn(array)
    else:
        bias_fn = nn.init.constant(1)
        array = bias_fn(array)
    return array

#通过modules DFS遍历 所有层级，并通过module.update()函数来更新对应参数
def apply_initialization(models, init_fn):
    for module in models.modules():
        if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
            module.update(mlx.utils.tree_map(lambda x: init_fn(x), module.parameters()))

print("before_init:", tree_flatten(net.parameters()) )
apply_initialization(net, init_weights)
print("after init:", tree_flatten(net.parameters()))

before_init: [('layers.0.layers.0.weight', array([[0.211076, -0.0322273, -0.171056, 0.462113],
       [-0.462399, 0.399972, -0.42979, -0.42679],
       [0.337287, -0.0173589, 0.442565, -0.209292],
       ...,
       [0.384921, 0.381924, 0.303233, 0.486555],
       [0.0861257, -0.25014, 0.278124, -0.434768],
       [-0.243735, 0.183608, -0.0420352, -0.391986]], dtype=float32)), ('layers.0.layers.0.bias', array([-0.252933, -0.285083, 0.299472, ..., -0.246823, -0.460196, 0.431308], dtype=float32)), ('layers.0.layers.2.weight', array([[0.286649, -0.279438, -0.0777029, ..., 0.104014, -0.26662, 0.342589],
       [-0.262362, 0.301246, -0.0589984, ..., -0.0613111, -0.123374, -0.263305],
       [-0.191364, 0.0707401, 0.317304, ..., -0.262486, -0.270887, -0.315851],
       [0.215008, 0.346379, 0.146354, ..., -0.0982642, -0.0868911, -0.188508]], dtype=float32)), ('layers.0.layers.2.bias', array([-0.140886, 0.053999, 0.313642, -0.300036], dtype=float32)), ('layers.1.layers.0.weight', array([[-0.26

In [15]:
tree_flatten(net.parameters())

[('layers.0.layers.0.weight',
  array([[10, 10, 10, 10],
         [10, 10, 10, 10],
         [10, 10, 10, 10],
         ...,
         [10, 10, 10, 10],
         [10, 10, 10, 10],
         [10, 10, 10, 10]], dtype=float32)),
 ('layers.0.layers.0.bias', array([1, 1, 1, ..., 1, 1, 1], dtype=float32)),
 ('layers.0.layers.2.weight',
  array([[10, 10, 10, ..., 10, 10, 10],
         [10, 10, 10, ..., 10, 10, 10],
         [10, 10, 10, ..., 10, 10, 10],
         [10, 10, 10, ..., 10, 10, 10]], dtype=float32)),
 ('layers.0.layers.2.bias', array([1, 1, 1, 1], dtype=float32)),
 ('layers.1.layers.0.weight',
  array([[10, 10, 10, 10],
         [10, 10, 10, 10],
         [10, 10, 10, 10],
         ...,
         [10, 10, 10, 10],
         [10, 10, 10, 10],
         [10, 10, 10, 10]], dtype=float32)),
 ('layers.1.layers.0.bias', array([1, 1, 1, ..., 1, 1, 1], dtype=float32)),
 ('layers.1.layers.2.weight',
  array([[10, 10, 10, ..., 10, 10, 10],
         [10, 10, 10, ..., 10, 10, 10],
         [10, 10,

### 结论

参数初始化在神经网络训练中扮演着重要角色，合适的初始化方法可以显著提高模型性能。MLX框架提供了多种初始化工具，用户可以结合具体需求选择合适的方法，并通过自定义函数进一步优化初始化过程。通过本文档的介绍，希望用户能更好地理解并应用MLX框架中的参数初始化技术。
