# SHO Fitter Statistics Benchmarking

In [None]:
import sys
sys.path.append('../../')

In [None]:
import numpy as np
import statistics

## Statistics Calculations

In [None]:
fnames = ['ADAM_SHO_speed.txt', 'ADAHESSIAN_SHO_speed.txt']

In [None]:
def get_mins(arr, mins, name):
   return np.array([arr[name][x] for x in mins])

def print_stats(arr, name, bs):
   print(f"Standard Deviation of {name}s of {bs}: %.1f "% (statistics.stdev(arr)))
   print(f"Mean of of {name}s of {bs}: %.1f " % (statistics.mean(arr)))

In [None]:
for fname in fnames:
  with open(f'./{fname}') as file_in:
    bs_dict = {}
    batch_size = 64

    bs_dict[64] = {}
    bs_dict[128] = {}
    bs_dict[256] = {}
    bs_dict[512] = {}
    bs_dict[1024] = {}
    
    for line in file_in:
      if 'Training' in line and 'batch' in line:
        batch_size = int(line.split(' ')[3][5:])

      if 'Training' in line and 'seconds' in line:
        time = float(line.split(' ')[-2])
        
        time_arr = bs_dict.get(batch_size, {}).get('time', [])
        time_arr.append(time)
        bs_dict[batch_size]['time'] = time_arr
      
      elif 'Avg' in line:
        inf_time = float(line.split(' ')[-1])
        inf_time *= 1382400 / 1000
        inf_time = 1382400 / inf_time
        
        inf_time_arr = bs_dict.get(batch_size, {}).get('inf_time', [])
        inf_time_arr.append(inf_time)
        bs_dict[batch_size]['inf_time'] = inf_time_arr
      
      elif 'Reconstruction' in line:
        loops_mse = float(line.split(' ')[-1])
        
        loops_arr = bs_dict.get(batch_size, {}).get('loops_mse', [])
        loops_arr.append(loops_mse)
        bs_dict[batch_size]['loops_mse'] = loops_arr
      
      elif 'Total' in line:
        params_mse = float(line.split(' ')[-1])

        params_arr = bs_dict.get(batch_size, {}).get('params_mse', [])
        params_arr.append(params_mse)
        bs_dict[batch_size]['params_mse'] = params_arr

  bs_64_mins = sorted(range(len(bs_dict[64]['params_mse'])), key = lambda sub: bs_dict[64]['params_mse'][sub])[:5]
  bs_128_mins = sorted(range(len(bs_dict[128]['params_mse'])), key = lambda sub: bs_dict[128]['params_mse'][sub])[:5]
  bs_256_mins = sorted(range(len(bs_dict[256]['params_mse'])), key = lambda sub: bs_dict[256]['params_mse'][sub])[:5]
  bs_512_mins = sorted(range(len(bs_dict[512]['params_mse'])), key = lambda sub: bs_dict[512]['params_mse'][sub])[:5]
  bs_1024_mins = sorted(range(len(bs_dict[1024]['params_mse'])), key = lambda sub: bs_dict[1024]['params_mse'][sub])[:5]

  bs_mins = [bs_64_mins, bs_128_mins, bs_256_mins, bs_512_mins, bs_1024_mins]
  bs = [64, 128, 256, 512, 1024]

  print(fname)
  for name in ['time', 'inf_time', 'loops_mse', 'params_mse']:
   for i in range(len(bs_mins)):
      print_stats(get_mins(bs_dict[2**(i+6)], bs_mins[i], name), name, bs[i])
      print()