In [60]:
# parameters from chinchilla paper
A = 406.4
B = 410.7
E = 1.62
alpha = 0.336
beta = 0.283

print((beta)/(alpha+beta))
print((alpha)/(alpha+beta))

0.45718901453957994
0.5428109854604201


In [61]:
N = 40e9
D = 1000e9

def loss(N, D):
    L = E + A/(N**alpha) + B/(D**beta)
    return L
loss(N, D)

1.8963581331216426

In [62]:
G = ((alpha*A)/(beta*B))**(1/(alpha+beta))
def optimal_N(C):
    a = (beta)/(alpha+beta)
#     a = 0.454
    N_opt = G*(C/6)**a
    return N_opt

def optimal_D(C):
    b = (alpha)/(alpha+beta)
#     b = 0.543
    D_opt = (1/G)*(C/6)**b
    return D_opt
    
C = 1.62e+20
N_opt = optimal_N(C)
D_opt = optimal_D(C)
print(N_opt/1e9, D_opt/1e9)

0.9928012883225557 27.19577453975649


In [63]:
N = 1e9
D = 20e9

def calc_compute_increase(C, parameter_fraction):
    """Starting from optimal N, D, multiply N by `parameter_fraction` and 
    calculate how much we need to increase D to achieve the same loss. 
    Return the compute budget increase for doing so. """
    
    N = optimal_N(C)
    D = optimal_D(C)
    
    factor = (parameter_fraction)**(-alpha)
    A_term = A/(N**alpha)
    B_term = B/(D**beta)
    
    # print((N**(-1.*alpha))/(D**(-1.*beta)), G**(-1.*alpha)*(G**(-1.*beta)))
    
    c_d = (1 - ((factor - 1)*A_term/B_term))**(1/-beta)
    new_C = 6*N*parameter_fraction*D*c_d
    C_increase = new_C/C
    return c_d, C_increase

print(calc_compute_increase(1e22, 0.5))

(2.4156451150039597, 1.2078225575019799)


In [64]:
import plotly.express as px
import pandas

C = 1.23e+18
data = list()

factors = [0.175, 0.25, 0.3052, 0.4, 0.5, 0.75, 1, 1.5, 2, 3, 4, 5]
for factor in factors:
    data_fraction, C_increase = calc_compute_increase(C, factor)
    data.append({'model_fraction': factor, 'compute_overhead': (C_increase - 1)*100})
    
df = pandas.DataFrame(data)
fig = px.line(df, x="model_fraction", y="compute_overhead")
fig.update_layout(yaxis_title="Compute overhead (%)", 
                  xaxis_title="Fraction of optimal model size k_N", 
                  xaxis_range=[0, 5.1],
                  yaxis_range=[-30, 850],
                  yaxis=dict(tickmode='linear', tick0=0, dtick=100))
fig.update_traces(mode='markers+lines')
fig.add_annotation({
            "x": 1,
            "y": 0,
            "ayref": "y",
            "ay": 80,
            "axref": "x",
            "ax": 1,
            "xanchor": "left",
            "text": "Chinchilla",
            "arrowhead": 1,
})
fig.add_annotation({
            "x": 0.6,
            "y": 11,
            "ayref": "y",
            "ay": 160,
            "axref": "x",
            "ax": 0.6,
            "xanchor": "left",
            "text": "LLaMA-7B",
            "arrowhead": 1,
})
fig.add_annotation({
            "x": 0.46,
            "y": 26,
            "ayref": "y",
            "ay": 240,
            "axref": "x",
            "ax": 0.46,
            "xanchor": "left",
            "text": "SantaCoder",
            "arrowhead": 1,
})

fig.add_annotation({
            "x": 0.31,
            "y": 100,
            "ayref": "y",
            "ay": 300,
            "axref": "x",
            "ax": 0.31,
            "xanchor": "left",
            "text": "Critical model size",
            "arrowhead": 1,
            "font": dict(size=16, color="red")
})
fig.show()

In [65]:
def get_N_D(C, fraction):
    N_opt = optimal_N(C)
    D_opt = optimal_D(C)
    data_increase, comp_increase = calc_compute_increase(C, fraction)
    new_N = fraction * N_opt
    new_D = data_increase*D_opt
    return new_N, new_D

# Santacoder analysis
N_santa = 1.1e9
D_santa = 236e9
C_santa = N_santa*D_santa*6

for i in range(50):
    fraction = 0.3 + i*0.01
    pred_N, pred_D = get_N_D(C_santa, fraction)
    print(fraction, pred_N/1e9, pred_D/1e9)

0.3 0.8382496468967309 636.6199603958831
0.31 0.8661913017932886 585.7602207937458
0.32 0.8941329566898464 541.8340046725356
0.32999999999999996 0.9220746115864039 503.59823900048104
0.33999999999999997 0.9500162664829617 470.0789242212514
0.35 0.9779579213795193 440.50398036067816
0.36 1.0058995762760772 414.25486393369357
0.37 1.0338412311726348 390.8311957450166
0.38 1.0617828860691925 369.82455184741184
0.39 1.08972454096575 350.8988028984032
0.4 1.117666195862308 333.7751968187406
0.41 1.1456078507588656 318.22092034849936
0.42 1.1735495056554233 304.0402418228724
0.43 1.201491160551981 291.0675898301918
0.44 1.2294328154485388 279.16209838745715
0.44999999999999996 1.2573744703450964 268.2032735261029
0.45999999999999996 1.2853161252416538 258.08752494988335
0.47 1.3132577801382117 248.72537053912953
0.48 1.3411994350347696 240.03916825664368
0.49 1.3691410899313272 231.96126447555747
0.5 1.3970827448278849 224.43247337232143
0.51 1.4250243997244427 217.40082124081263
0.52 1.4529

In [58]:
# LLaMA 7B analysis
N_LLaMA_7B = 6.9e9
D_LLaMA_7B = 1000e9
C = 6*N_LLaMA_7B*D_LLaMA_7B
for i in range(50):
    fraction = 0.3 + i*0.01
    pred_N, pred_D = get_N_D(C, fraction)
    print(fraction, pred_N/1e9, pred_D/1e9)

0.3 3.7554279201144802 3776.9270985919716
0.31 3.8806088507849625 3475.187378380889
0.32 4.005789781455446 3214.5827377352557
0.32999999999999996 4.130970712125928 2987.7382960177965
0.33999999999999997 4.256151642796411 2788.875526717914
0.35 4.381332573466893 2613.413848078631
0.36 4.506513504137375 2457.683576779059
0.37 4.631694434807859 2318.7160724060786
0.38 4.756875365478342 2194.0882449374017
0.39 4.882056296148825 2081.805912441561
0.4 5.007237226819307 1980.2152997506334
0.41 5.132418157489789 1887.935176672317
0.42 5.257599088160272 1803.8043100143573
0.43 5.382780018830755 1726.8404007751772
0.44 5.507960949501238 1656.207721862381
0.44999999999999996 5.63314188017172 1591.1914088931283
0.45999999999999996 5.758322810842202 1531.176883278336
0.47 5.883503741512686 1475.633267157362
0.48 6.008684672183168 1424.0999272913434
0.49 6.133865602853651 1376.1754894970422
0.5 6.259046533524134 1331.5088172177634
0.51 6.384227464194616 1289.7915618133738
0.52 6.509408394865099 1250

In [67]:
import pandas
Cs = [2.21e+19, 1.62e+20, 2.46e+22, 1.71e+23]

data = list()
for C in Cs:
    N_opt = optimal_N(C)
    D_opt = optimal_D(C)
    D_inc, _ = calc_compute_increase(C, 0.5)
    new_D = D_inc*D_opt
    new_N = 0.5*optimal_N(C)
    D_inc2, _ = calc_compute_increase(C, 0.305)
    new_D2 = D_inc2*D_opt
    new_N2 = 0.305*optimal_N(C)
    data.append(
    {'C': round(C, 2),
     'N_opt': round(optimal_N(C)/1e9, 2),
     'D_opt': round(optimal_D(C)/1e9, 2),
     'N_{k=0.5}': round(new_N/1e9, 2),
     'D_{k=0.5}': round(new_D/1e9, 2),
     'N_{k=0.305}': round(new_N2/1e9, 2),
     'D_{k=0.305}': round(new_D2/1e9, 2)})
    
print(pandas.DataFrame(data))

              C  N_opt    D_opt  N_{k=0.5}  D_{k=0.5}  N_{k=0.305}  \
0  2.210000e+19   0.40     9.22       0.20      22.28         0.12   
1  1.620000e+20   0.99    27.20       0.50      65.70         0.30   
2  2.460000e+22   9.87   415.53       4.93    1003.77         3.01   
3  1.710000e+23  23.94  1190.37      11.97    2875.50         7.30   

   D_{k=0.305}  
0        60.58  
1       178.63  
2      2729.25  
3      7818.50  
