Based on the code and the question, I'll explain how to calculate the learning rate scaling factor to maintain constant RMS activation deltas across layers.

The mathematical formula for the learning rate scaling factor is:

$\text{lr\_scale} = \frac{1}{\sqrt{d_{in} \cdot d_{out}}}$

where:
- $d_{in}$ is the input dimension of the layer (number of input features)
- $d_{out}$ is the output dimension of the layer (number of output features)

This scaling is derived from the following reasoning:

1. For a layer with weight matrix $W$ of shape $[d_{out}, d_{in}]$:
   - The variance of weights is typically initialized as $\text{Var}(W_{ij}) \propto \frac{1}{d_{in}}$
   - The gradient update magnitude grows with $\sqrt{d_{in} \cdot d_{out}}$

2. To keep activation deltas constant across layers:
   - We need to counteract this growth in gradient magnitude
   - Therefore, we scale the learning rate by $\frac{1}{\sqrt{d_{in} \cdot d_{out}}}$

This scaling ensures that:
- Wider layers get smaller learning rates
- Narrower layers get larger learning rates
- The RMS of activation changes remains approximately constant regardless of layer width

For example:
- For a layer with $d_{in}=64, d_{out}=256$:
  $\text{lr\_scale} = \frac{1}{\sqrt{64 \cdot 256}} = \frac{1}{\sqrt{16384}} \approx 0.0078$

- For a layer with $d_{in}=256, d_{out}=1024$:
  $\text{lr\_scale} = \frac{1}{\sqrt{256 \cdot 1024}} = \frac{1}{\sqrt{262144}} \approx 0.002$

This scaling helps maintain stable training dynamics across different layer widths in deep neural networks.

This is a great question about the relationship between RMS norm and learning rate scaling! Let me explain why the scaling factor is $\frac{1}{\sqrt{d_{in} \cdot d_{out}}}$ rather than $\sqrt{\frac{d_{out}}{d_{in}}}$.

The key is to understand that:

1. The RMS norm formula you mentioned $\sqrt{\frac{d_{in}}{d_{out}}}||W||_2$ is about measuring the current state of weights

2. However, the learning rate scaling is about controlling the magnitude of weight updates during training

Here's why the scaling differs:

1. For a layer with weight matrix $W$ of shape $[d_{out}, d_{in}]$:
   - The gradient $\frac{\partial L}{\partial W}$ has the same shape as $W$
   - Each element of the gradient involves summing over $d_{in}$ terms (from backprop)
   - The gradient is averaged over $d_{out}$ output dimensions

2. This means:
   - The magnitude of gradients scales with $\sqrt{d_{in}}$ (from input summation)
   - The update also scales with $\sqrt{d_{out}}$ (from output dimension)
   - Total scaling of updates is proportional to $\sqrt{d_{in} \cdot d_{out}}$

3. Therefore:
   - To counteract this scaling and keep updates constant
   - We need to divide by $\sqrt{d_{in} \cdot d_{out}}$
   - Not multiply by $\sqrt{\frac{d_{out}}{d_{in}}}$

In other words:
- The RMS norm formula $\sqrt{\frac{d_{in}}{d_{out}}}||W||_2$ is about measuring static properties
- The learning rate scaling $\frac{1}{\sqrt{d_{in} \cdot d_{out}}}$ is about controlling dynamic updates

This is why we use $\frac{1}{\sqrt{d_{in} \cdot d_{out}}}$ as the scaling factor - it properly counteracts the natural scaling of gradient updates in neural networks.

Let me break down why the gradient scales with $\sqrt{d_{in} \cdot d_{out}}$:

1. Consider a layer with weight matrix $W$ of shape $[d_{out}, d_{in}]$:
   - Input $x$ has shape $[batch\_size, d_{in}]$
   - Output $y = Wx$ has shape $[batch\_size, d_{out}]$

2. During backpropagation:
   - Let $\frac{\partial L}{\partial y}$ be the gradient from the next layer (shape: $[batch\_size, d_{out}]$)
   - The gradient w.r.t weights is: $\frac{\partial L}{\partial W} = \frac{\partial L}{\partial y} x^T$

3. Let's analyze the magnitude:
   - Each element in $\frac{\partial L}{\partial W}_{ij}$ is computed as: $\sum_{b=1}^{batch\_size} \frac{\partial L}{\partial y_{bi}} x_{bj}$
   - Assuming elements in $\frac{\partial L}{\partial y}$ and $x$ have unit variance
   - Due to Central Limit Theorem, summing $batch\_size$ independent terms:
     - Each element in $\frac{\partial L}{\partial W}$ has variance $\propto batch\_size$

4. Now considering dimensions:
   - There are $d_{out}$ rows in $\frac{\partial L}{\partial W}$
   - Each row has $d_{in}$ elements
   - The total gradient norm scales with $\sqrt{d_{out}}$ (from number of rows)
   - Each element scales with $\sqrt{d_{in}}$ (from input dimension)

5. Therefore:
   - The total gradient magnitude scales with $\sqrt{d_{in}} \cdot \sqrt{d_{out}} = \sqrt{d_{in} \cdot d_{out}}$

This is why we need the learning rate scaling factor $\frac{1}{\sqrt{d_{in} \cdot d_{out}}}$ to counteract this natural scaling of gradients with layer dimensions.

The induced RMS norm is calculated as $\sqrt{\frac{1}{n}\sum_{j=1}^n\left(\frac{1}{m}\sum_{i=1}^m p_{ij}^2\right)}$, where m is the number of rows and n is the number of columns.