## Scaling Laws

In [1]:
import pandas as pd
import numpy as np
import math

In [2]:
scaling_data = []
# from chincilla 
d_model = 8192           # initialized from chinchilla n_dim
while True:
    n_layer = d_model // 100 # roughly the same as chinchilla: exactly the same aspect ratio
    d_attn  = d_model // 75  # in the aspect ratio  
    d_ff    = d_model * 10   # roughly the right aspect ratio 
    #width   = 12 * n_layer * np.exp(2*5.039) * np.exp(2*5.553e-2*n_layer)
    N       = 2*d_model * n_layer * (2*d_attn + d_ff)
    n_heads =  d_model // 20
    #print(f'{N/1e12:,} T params')
    scaling_data.append([d_model, 
                         n_layer, 
                         d_attn, 
                         d_ff, 
                         #width,
                         n_heads,
                         N, 
                         N/1e12
                        ])
    if N > 1e14: 
        break

    d_model *= 1.1
    
scaling_data = pd.DataFrame(scaling_data, 
                            columns=[
                                'd_model',
                                'n_layer', 
                                'd_attn',
                                'd_ff', 
                                #'width',
                                'n_heads',
                                'N', 
                                'N (T)'
                            ],
                           )

In [3]:
scaling_data

Unnamed: 0,d_model,n_layer,d_attn,d_ff,n_heads,N,N (T)
0,8192.0,81.0,109.0,81920.0,409.0,109005700000.0,0.109006
1,9011.2,90.0,120.0,90112.0,450.0,146552400000.0,0.146552
2,9912.32,99.0,132.0,99123.2,495.0,195061200000.0,0.195061
3,10903.552,109.0,145.0,109035.52,545.0,259864000000.0,0.259864
4,11993.9072,119.0,159.0,119939.072,599.0,343279800000.0,0.34328
5,13193.29792,131.0,175.0,131932.9792,659.0,457255200000.0,0.457255
6,14512.627712,145.0,193.0,145126.27712,725.0,612412000000.0,0.612412
7,15963.890483,159.0,212.0,159638.904832,798.0,812562100000.0,0.812562
8,17560.279532,175.0,234.0,175602.795315,878.0,1082148000000.0,1.082148
9,19316.307485,193.0,257.0,193163.074847,965.0,1444075000000.0,1.444075


In [4]:
def get_closest_to_order(scale): 
    return scaling_data.loc[np.argmin(np.abs(scaling_data['N (T)'] - scale)), :]

In [5]:
scales = pd.DataFrame([get_closest_to_order(s) for s in [1, 10, 100]])

In [6]:
scales

Unnamed: 0,d_model,n_layer,d_attn,d_ff,n_heads,N,N (T)
8,17560.279532,175.0,234.0,175602.795315,878.0,1082148000000.0,1.082148
16,37642.018704,376.0,501.0,376420.187042,1882.0,10683610000000.0,10.683614
24,80689.01008,806.0,1075.0,806890.100802,4034.0,105232400000000.0,105.232399


From here we can apply the scaling laws in section 1.1 of the deep mind paper to get the data and compute requirements given N. 

We first apply the first scaling law in `loss_from_size`, and then rearrange the second two scaling laws to get the data and cost requirements from this loss. 

The derivations are here: 

To get `data_from_loss`
\begin{align}
L(D) &= \frac{D_c}{D}^{a_D} \\
\ln L(D) &= a_D \ln \frac{D_c}{D} \\
\frac{\ln L(D)}{a_d} &= \ln D_c - \ln D \\
\ln D &=  \ln D_c - \frac{\ln L(D)}{a_d} \\
D &= \exp\left[\ln D_c - \frac{\ln L(D)}{a_d}\right]
\end{align}

To get `loss_from_data`
\begin{align}
L(C_{min}) &= \left[\frac{C_c^{min}}{C_{min}}\right]^{a_C^{min}} \\
\ln L(C_{min}) &= a_C^{min} \ln \left[\frac{C_c^{min}}{C_{min}}\right] \\
\frac{\ln L(C_{min})}{a_C^{min}} &= \ln C_c^{min} - \ln C_{min} \\
C_{min} &= \exp \left[ \ln C^{min}_c - \frac{\ln L(C_{min})}{a_C^{min}} \right]
\end{align}

I implement these in the math library to avoid overflows in numpy. If needed, one could implement these in a more numerically efficient manner. 

In [7]:
def loss_from_size(N): 
    return (8.8*1e13/N)**0.076

In [8]:
scales['loss'] = scales['N'].map(loss_from_size)

In [9]:
scales

Unnamed: 0,d_model,n_layer,d_attn,d_ff,n_heads,N,N (T),loss
8,17560.279532,175.0,234.0,175602.795315,878.0,1082148000000.0,1.082148,1.396931
16,37642.018704,376.0,501.0,376420.187042,1882.0,10683610000000.0,10.683614,1.173811
24,80689.01008,806.0,1075.0,806890.100802,4034.0,105232400000000.0,105.232399,0.986501


In [10]:
def data_from_loss(loss):
    """Data requirement (in tokens) given a transformer loss
    
    Rearranged from the second scaling law in the openAI paper.
    """
    a_d = 0.095
    D_c = 5.4e13
    L = math.log(D_c)
    R = math.log(loss) / a_d
    return math.exp(L - R)

In [11]:
scales['data'] = scales['loss'].map(data_from_loss)

In [12]:
scales

Unnamed: 0,d_model,n_layer,d_attn,d_ff,n_heads,N,N (T),loss,data
8,17560.279532,175.0,234.0,175602.795315,878.0,1082148000000.0,1.082148,1.396931,1600431000000.0
16,37642.018704,376.0,501.0,376420.187042,1882.0,10683610000000.0,10.683614,1.173811,9994985000000.0
24,80689.01008,806.0,1075.0,806890.100802,4034.0,105232400000000.0,105.232399,0.986501,62305620000000.0


In [13]:
def cost_from_loss(loss):
    """Computational cost (in petaflop-days) given a transformer loss
    
    Rearranged from the second scaling law in the openAI paper.
    """
    c_min_c = 3.1e8
    a_min_c = 0.05
    #left and right side of subtraction inside exp
    L = math.log(c_min_c) 
    R = math.log(loss) / a_min_c
    return math.exp(L-R)

In [14]:
scales['cost'] = scales['loss'].map(cost_from_loss)

In [15]:
scales

Unnamed: 0,d_model,n_layer,d_attn,d_ff,n_heads,N,N (T),loss,data,cost
8,17560.279532,175.0,234.0,175602.795315,878.0,1082148000000.0,1.082148,1.396931,1600431000000.0,387136.4
16,37642.018704,376.0,501.0,376420.187042,1882.0,10683610000000.0,10.683614,1.173811,9994985000000.0,12571860.0
24,80689.01008,806.0,1075.0,806890.100802,4034.0,105232400000000.0,105.232399,0.986501,62305620000000.0,406832100.0


In [16]:
def pf_days_to_20EF_time(c): 
    """return a tuple of time in days,weeks,months,years"""
    c = c / 20_000  # peta to 20exa
    out = [c]
    for factor in [7, 4, 52]: # to weeks, months, years
        c /= factor
        out.append(c)
    return out
    

In [17]:
cost = scales['cost'].map(pf_days_to_20EF_time).to_list()

Time on 20EF machine: 

In [18]:
pd.DataFrame(cost, columns=['days', 'weeks', 'months', 'years'])

Unnamed: 0,days,weeks,months,years
0,19.356818,2.76526,0.691315,0.013295
1,628.593174,89.799025,22.449756,0.431726
2,20341.605892,2905.943699,726.485925,13.970883


Wow, thats a lot of time! Better start prepping for the Zettascale computing project.