In [1]:
from flopper import Flopper

flops_list = {}

flops_list[('untrained')] = {}
flops = Flopper(512, num_steps_training=0, batch_size=1, use_lora=False)
flops_list[('untrained')]['inference'] = 4*flops.compute_inference()
flops_list[('untrained')]['total'] = flops_list[('untrained')]['inference']


flops_list[('default')] = {}
flops = Flopper(512, num_steps_training=3000, batch_size=4, use_lora=True, lora_rank=4)
flops_list[('default')]['training'] = flops.compute_flops()
flops_list[('default')]['validation'] = (3000//50) * flops.compute_validation()
flops_list[('default')]['inference'] = 4*flops.compute_inference()
flops_list[('default')]['total'] = flops_list[('default')]['training'] + flops_list[('default')]['validation'] + flops_list[('default')]['inference']

for rank in [2, 4, 8]:
    for lr in [1e-5, 5e-5, 1e-4]:
        flops_list[(rank, lr)] = {}
        flops = Flopper(512, num_steps_training=300, batch_size=4, use_lora=True, lora_rank=rank)
        flops_list[(rank, lr)]['training'] = flops.compute_flops()
        flops_list[(rank, lr)]['validation'] = (300//50) * flops.compute_validation()
        flops_list[(rank, lr)]['total'] = flops_list[(rank, lr)]['training'] + flops_list[(rank, lr)]['validation']

for ctx in [128, 512, 768]:
    flops_list[(ctx)] = {}
    flops = Flopper(ctx, num_steps_training=300, batch_size=4, use_lora=True, lora_rank=8)
    flops_list[(ctx)]['training'] = flops.compute_flops()
    flops_list[(ctx)]['validation'] = (300//50) * flops.compute_validation()
    flops_list[(ctx)]['total'] = flops_list[(ctx)]['training'] + flops_list[(ctx)]['validation']

flops_list[('final_run')] = {}
flops = Flopper(768, num_steps_training=15000, batch_size=4, use_lora=True, lora_rank=8)
flops_list[('final_run')]['training'] = flops.compute_flops()
flops_list[('final_run')]['validation'] = (15000//50) * flops.compute_validation()
flops_list[('final_run')]['inference'] = 4*flops.compute_inference()
flops_list[('final_run')]['total'] = flops_list[('final_run')]['training'] + flops_list[('final_run')]['validation'] + flops_list[('final_run')]['inference']

In [2]:
for key in ['inference', 'total']:
    print(f"Default {key}: {flops_list[('untrained')][key]:.3e}")

Default inference: 5.416e+12
Default total: 5.416e+12


In [3]:
for key in ['training', 'validation', 'inference', 'total']:
    print(f"Default {key}: {flops_list[('default')][key]:.3e}")

Default training: 9.812e+15
Default validation: 1.308e+14
Default inference: 5.416e+12
Default total: 9.948e+15


In [4]:
training = 0
validation = 0
total = 0
for ctx in [128, 512, 768]:
    training += flops_list[(ctx)]['training']
    validation += flops_list[(ctx)]['validation']
    total += flops_list[(ctx)]['total']

print(f"Context length training: {training:.3e}")
print(f"Context length validation: {validation:.3e}")
print(f"Context length total: {total:.3e}")

Context length training: 2.723e+15
Context length validation: 3.631e+13
Context length total: 2.759e+15


In [5]:
for key in ['training', 'validation', 'inference', 'total']:
    print(f"Final run {key}: {flops_list[('default')][key]:.3e}")

Final run training: 9.812e+15
Final run validation: 1.308e+14
Final run inference: 5.416e+12
Final run total: 9.948e+15


In [6]:
total_training = 0
total_validation = 0
total_inference = 0
total_total = 0
for stage in flops_list:
    total_training += flops_list[stage].get('training', 0)
    total_validation += flops_list[stage].get('validation', 0)
    total_inference += flops_list[stage].get('inference', 0)
    total_total += flops_list[stage].get('total')

print(f"Total training: {total_training:.3e}")
print(f"Total validation: {total_validation:.3e}")
print(f"Total inference: {total_inference:.3e}")
print(f"Total total: {total_total:.3e}")
print(f"Within budget: {total_total < 1e17}")

Total training: 9.659e+16
Total validation: 1.288e+15
Total inference: 1.625e+13
Total total: 9.790e+16
Within budget: True
