In [1]:
#!/usr/bin/env python3

import pandas as pd


TRAINING_FLOP = 13675914923606016
LABELED_DATAPOINTS = 2000
FLOP_PER_REQUEST_T5 = round(205962234101.76)
TOKENS_PER_REQUEST_T5 = 23840 / 4000
FLOP_PER_TOKEN_GPT3 = 350e9
TOKENS_PER_REQUEST_GPT4 = 5.5065

# Starting at 10 TFLOP per forward pass
gpt4_flop_per_request_estimates = [
    (int)(TOKENS_PER_REQUEST_GPT4 * FLOP_PER_TOKEN_GPT3 * (2 ** x)) for x in range(4)
]
# gpt4_flop_per_request_estimates = [(int)(TOKENS_PER_REQUEST_GPT4*560e12), 0, 0, 0]

print("Estimates for GPT-4 FLOP per request:")
print(*[f"{x:.2e}" for x in gpt4_flop_per_request_estimates], sep='\n')
print("Total fine-tuning FLOP for T4:")
print(*[f"{TRAINING_FLOP + LABELED_DATAPOINTS * x:.2e}" for x in gpt4_flop_per_request_estimates], sep='\n')
print("FLOP per request for T5:")
print(f"{FLOP_PER_REQUEST_T5:.2e}")

num_requests = [(int)(10 ** (x / 2)) for x in range(4, 25)]


def total_flop_t5(requests: int, flop_per_request_gpt4) -> int:
    fine_tuning_flop = TRAINING_FLOP + LABELED_DATAPOINTS * flop_per_request_gpt4
    # print fine_tuning_flop in scientific notation
    inference_flop = requests * FLOP_PER_REQUEST_T5
    return fine_tuning_flop + inference_flop


def total_flop_gpt4(requests: int, flop_per_request_gpt4) -> int:
    return requests * flop_per_request_gpt4


labels = ["10 TFLOP", "100 TFLOP", "1 PFLOP", "10 PFLOP"]
df_t5 = pd.DataFrame(num_requests, columns=["Requests"])
df_gpt4 = pd.DataFrame(num_requests, columns=["Requests"])

for label, flop_per_request_gpt4 in zip(labels, gpt4_flop_per_request_estimates):
    df_t5[label] = df_t5["Requests"].apply(
        lambda x: total_flop_t5(x, flop_per_request_gpt4)
    )
    df_gpt4[label] = df_gpt4["Requests"].apply(
        lambda x: total_flop_gpt4(x, flop_per_request_gpt4)
    )

df_diff = df_t5 - df_gpt4
df_diff["Requests"] = df_t5["Requests"]

print(df_t5)
print(df_gpt4)
print(df_diff)

#plot for 100 TFLOP, df_t5 and df_gpt4
# df_t5.plot(x="Requests", y="100 TFLOP", title="T5", logx=True, logy=True)
# df_gpt4.plot(x="Requests", y="100 TFLOP", title="GPT-4", logx=True, logy=True)


Estimates for GPT-4 FLOP per request:
1.93e+12
3.85e+12
7.71e+12
1.54e+13
Total fine-tuning FLOP for T4:
1.75e+16
2.14e+16
2.91e+16
4.45e+16
FLOP per request for T5:
2.06e+11
         Requests                  10 TFLOP                 100 TFLOP  \
0             100         17551061147016216         21405611147016216   
1             316         17595548989582248         21450098989582248   
2            1000         17736427157708016         21590977157708016   
3            3162         18181717507836540         22036267507836540   
4           10000         19590087264626016         23444637264626016   
5           31622         24043402690379460         27897952690379460   
6          100000         38126688333806016         41981238333806016   
7          316227         82661284326979170         86515834326979170   
8         1000000        223492699025606016        227347249025606016   
9         3162277        668840100692976270        672694650692976270   
10       10000000     

In [2]:
import numpy as np
#compute break-even point for 10 TFLOP, 100 TFLOP, 1 PFLOP, 10 PFLOP


def cost_t5(gpt_4_estimate: int, num_requests) -> int:
    estimate_t5 = TRAINING_FLOP + (LABELED_DATAPOINTS * gpt_4_estimate) + (num_requests * FLOP_PER_REQUEST_T5)
    return estimate_t5

def cost_gpt4(gpt_4_estimate: int, num_requests) -> int:
    estimate_gpt4 = num_requests * gpt_4_estimate 
    return estimate_gpt4


def diff_t5_gpt_4(gpt_4_estimate: int, num_requests: int):
    return cost_t5(gpt_4_estimate, num_requests) - cost_gpt4(gpt_4_estimate, num_requests)

break_even_points = []
#loop until break-even point for all estimates
for label, flop_per_request_gpt4 in zip(labels, gpt4_flop_per_request_estimates):
    num_requests = 1
    while True:
        diff = diff_t5_gpt_4(flop_per_request_gpt4, num_requests)
        if diff < 0:
            break_even_points.append(num_requests)
            break
        num_requests += 1

#compute training cost for all estimates in PFLOP
for label, flop_per_request_gpt4 in zip(labels, gpt4_flop_per_request_estimates):
    #print(f"Training only cost for {label}: {TRAINING_FLOP / 10**15}")
    #print(f"Training only cost for {label} in PFLOP: {LABELED_DATAPOINTS * flop_per_request_gpt4 / 10**15}")
    #print(f"Training cost for {label}: {(TRAINING_FLOP + (LABELED_DATAPOINTS * flop_per_request_gpt4)) / 10**15}")
    #print(f"Inference cost T5 for {label}: {FLOP_PER_REQUEST_T5 / 10**15}")
    #print(f"Inference cost GPT 4 for {label}: {flop_per_request_gpt4 / 10**15}")
    print(f"Savings for {label} at 10k: {diff_t5_gpt_4(flop_per_request_gpt4, 10000) / 10**15}")
    print(f"Break-even point for {label}: {break_even_points.pop(0)}")


Savings for 10 TFLOP at 10k: 0.317337264626016
Break-even point for 10 TFLOP: 10185
Savings for 100 TFLOP at 10k: -15.100862735373983
Break-even point for 100 TFLOP: 5862
Savings for 1 PFLOP at 10k: -45.937262735373984
Break-even point for 1 PFLOP: 3878
Savings for 10 PFLOP at 10k: -107.61006273537399
Break-even point for 10 PFLOP: 2927


In [3]:
TRAINING_FLOP = 13675914923606016
FLOP_PER_REQUEST_T5 = round(205962234101.76)
TOKENS_PER_REQUEST_T5 = 23840 / 4000

# If we estimate that GPT-4 with 100 TFLOP per token
GPT4_per_request = 10**12 * TOKENS_PER_REQUEST_T5

#compute to TFLOP

#compute to PFLOP
pflop = TRAINING_FLOP / 10**15
pflop_per_request_t5 = FLOP_PER_REQUEST_T5 / 10**15
pflop_training_data = LABELED_DATAPOINTS * GPT4_per_request / 10**15
pflop_per_request_gpt4 = GPT4_per_request / 10**15
print(f"Request FLOP: {pflop_per_request_t5} PFLOP")
print(f"Training FLOP: {pflop} PFLOP")
print(f"Training Data FLOP: {pflop_training_data} PFLOP")
print(f"Request FLOP GPT-4: {pflop_per_request_gpt4} PFLOP")

def total_flop_t5(requests: int) -> int:
    fine_tuning_flop = TRAINING_FLOP + LABELED_DATAPOINTS * GPT4_per_request
    inference_flop = requests * FLOP_PER_REQUEST_T5
    return fine_tuning_flop + inference_flop

def total_flop_gpt4(requests: int) -> int:
    return requests * GPT4_per_request

num_requests = [(int)(10 ** (x / 2)) for x in range(4, 25)]

cost_t5 = [total_flop_t5(x) for x in num_requests]
cost_gpt4 = [total_flop_gpt4(x) for x in num_requests]

df = pd.DataFrame(num_requests, columns=["Requests"])
df["T5"] = cost_t5
df["GPT-4"] = cost_gpt4

print(df)


Request FLOP: 0.000205962234102 PFLOP
Training FLOP: 13.675914923606015 PFLOP
Training Data FLOP: 11.92 PFLOP
Request FLOP GPT-4: 0.00596 PFLOP
         Requests            T5         GPT-4
0             100  2.561651e+16  5.960000e+14
1             316  2.566100e+16  1.883360e+15
2            1000  2.580188e+16  5.960000e+15
3            3162  2.624717e+16  1.884552e+16
4           10000  2.765554e+16  5.960000e+16
5           31622  3.210885e+16  1.884671e+17
6          100000  4.619214e+16  5.960000e+17
7          316227  9.072673e+16  1.884713e+18
8         1000000  2.315581e+17  5.960000e+18
9         3162277  6.769056e+17  1.884717e+19
10       10000000  2.085218e+18  5.960000e+19
11       31622776  6.538694e+18  1.884717e+20
12      100000000  2.062182e+19  5.960000e+20
13      316227766  6.515657e+19  1.884717e+21
14     1000000000  2.059878e+20  5.960000e+21
15     3162277660  6.513354e+20  1.884717e+22
16    10000000000  2.059648e+21  5.960000e+22
17    31622776601  6.513123e