In [58]:
import math

In [59]:
from thop import profile

## Setup

In [60]:
# # original
# model_size = 431080 * 4 * 8
# param_size = round(model_size / 32)
# add_p = 0.0014
# add_time_per_million = 0.012
# mul_time_per_million = 0

# nr_layer = 4
# nr_worker = 3
# recv_ratio = (98.65)/100

# bandwidth_dict = {}
# bandwidth_dict["s2c"] = 47.2 * 1000 * 1000
# bandwidth_dict["c2c"] = 48.97 * 1000 * 1000
# bandwidth_dict["c2w"] = 465 * 1000 * 1000

# time_cmd = 0.002

# gis_real_value = 0.32
# emp_real_value = 2.74
# agg_real_value = 0.1
# dec_real_value = 0.1774

In [61]:
# resnet 18
model_size = 46189135 * 8
param_size = round(model_size / 32)
add_p = 0.0014
add_time_per_million = 0.012
mul_time_per_million = 0

nr_layer = 5
nr_worker = 3
recv_ratio = 91.55/100

bandwidth_dict = {}
bandwidth_dict["s2c"] = 47.2 * 1000 * 1000
bandwidth_dict["c2c"] = 48.97 * 1000 * 1000
bandwidth_dict["c2w"] = 465 * 1000 * 1000

time_cmd = 0.002

gis_real_value = 8.66
emp_real_value = 58
agg_real_value = 0.23
dec_real_value = 5.166

In [62]:
# # resnet 34
# model_size = 87563522 * 8
# param_size = round(model_size / 32)
# add_p = 0.0014
# mul_time_per_million = 0
# # mul_time_per_million = 0.636

# nr_layer = 5
# nr_worker = 3
# recv_ratio = 91.55/100

# bandwidth_dict = {}
# bandwidth_dict["s2c"] = 47.2 * 1000 * 1000
# bandwidth_dict["c2c"] = 48.97 * 1000 * 1000
# bandwidth_dict["c2w"] = 465 * 1000 * 1000

# time_cmd = 0.002

# gis_real_value = 16.4
# emp_real_value = 114.9
# agg_real_value = 0.35
# dec_real_value = 10.04

In [63]:
# # resnet 50
# model_size = 94329249 * 8
# param_size = round(model_size / 32)
# add_p = 0.0014
# mul_time_per_million = 0

# nr_layer = 5
# nr_worker = 3
# recv_ratio = 91.55/100
# # recv_ratio = (100-8.33)/100

# bandwidth_dict = {}
# bandwidth_dict["s2c"] = 47.2 * 1000 * 1000
# bandwidth_dict["c2c"] = 48.97 * 1000 * 1000
# bandwidth_dict["c2w"] = 465 * 1000 * 1000

# time_cmd = 0.002

# gis_real_value = 17.6
# emp_real_value = 127.7
# agg_real_value = 0.364
# dec_real_value = 10.78

In [64]:
def ratio(pred_value, real_value):
    return pred_value / real_value

In [65]:
"""
data_kargs:
    * data.shape: List
    * nr_bits for the type of data: int
    * label.shape: List
    * nr_bits for the type of label: int
fb_kargs:
    * bandwidth (Mbps): float
    * constant: float
cmp_kargs:
    * gflops: float
    * constant: float
"""
def local_training_predictor(model_flops, data_kargs, fb_kargs, cmp_kargs):
    nr_data_element = 1
    for nr in data_kargs["data"]["shape"]:
        nr_data_element = nr * nr_data_element
        
    nr_data_bits = nr_data_element * data_kargs["data"]["type_bits"]
    
    nr_label_element = 1
    for nr in data_kargs["label"]["shape"]:
        nr_label_element = nr * nr_label_element
        
    nr_label_bits = nr_label_element * data_kargs["label"]["type_bits"]
    
    nr_batch_bits = data_kargs["batch_size"] * (nr_label_bits + nr_data_bits)
    nr_batch = math.ceil(data_kargs["data_size"] / data_kargs["batch_size"])
    fetch_batch_time = nr_batch_bits / (fb_kargs["mbps"] * 1024 * 1024) + fb_kargs["constant"]
    comp_batch_time = (data_kargs["batch_size"] * model_flops) / (cmp_kargs["gflops"] * 1024 * 1024 * 1024) + cmp_kargs["constant"]
    
    predict_value = nr_batch * (fetch_batch_time + comp_batch_time)
    
    return fetch_batch_time, comp_batch_time, predict_value

## LT prediction

In [66]:
label_name = "resnet_18"

model_flops_dict = {
    "original": 2307720 * 2 * 3,
    "resnet_18": 13412720 * 2 * 3,
    "resnet_34": 23747720 * 2 * 3,
    "resnet_50": 25437720 * 2 * 3
}

lt_real_value_dict = {
    "original": 13.57,
    "resnet_18": 26.76,
    "resnet_34": 42.67,
    "resnet_50": 45.167
}

model_flops = model_flops_dict[label_name]

data_kargs = {}
data_kargs["data"] = {}
data_kargs["label"] = {}
data_kargs["data_size"] = 24754
data_kargs["batch_size"] = 64
data_kargs["data"]["shape"] = [28, 28]
data_kargs["data"]["type_bits"] = 32
data_kargs["label"]["shape"] = [1]
data_kargs["label"]["type_bits"] = 64

fb_kargs = {}
fb_kargs["mbps"] = 186.615
fb_kargs["constant"] = 0.0138

cmp_kargs = {}
cmp_kargs["gflops"] = 94.86
cmp_kargs["constant"] = 0.016

lt_real_value = lt_real_value_dict[label_name]

In [67]:
model_flops

80476320

In [68]:
_, _, pred_time = local_training_predictor(model_flops, data_kargs, fb_kargs, cmp_kargs)

In [69]:
pred_time

34.28552358073673

In [70]:
pred_time / lt_real_value

1.2812228542876205

## GIS Prediction

In [71]:
def gis_prediction():
    return model_size / bandwidth_dict["s2c"] + model_size / bandwidth_dict["s2c"] * (1-recv_ratio) * (nr_worker-1)

In [72]:
gis_prediction()

9.151711663559322

In [73]:
ratio(gis_prediction(), gis_real_value)

1.056779637824402

## Local Training Prediction

## EMP

In [74]:
def emp_prediction():
    time_fixed_model = (2 * model_size) / bandwidth_dict["c2c"]
    return (nr_worker - 1) * (1 + recv_ratio) * time_fixed_model + mul_time_per_million * (param_size/1000000)

In [75]:
emp_prediction()

57.81517702593424

In [76]:
emp_real_value

58

In [77]:
ratio(emp_prediction(), emp_real_value)

0.9968133969988662

## Aggregation prediction

In [78]:
def agg_prediction():
    return 2 * nr_layer * ((nr_worker - 1) * (nr_worker * time_cmd)) + (nr_worker - 1) * nr_worker *  (param_size/1000000) * add_p

In [79]:
# def agg_prediction():
#     return 2 * nr_layer * ((nr_worker - 1) * (nr_worker * time_cmd)) + (nr_worker - 1) *  (param_size/1000000) * add_p

In [80]:
# def agg_prediction():
#     return 2 * nr_layer * ((nr_worker - 1) * nr_worker * time_cmd) + (param_size/1000000) * add_time_per_million

In [81]:
agg_prediction()

0.21699718559999998

In [82]:
ratio(agg_prediction(), agg_real_value)

0.9434660243478259

In [83]:
ratio(agg_prediction(), agg_real_value)

0.9434660243478259

## Decryption prediction

In [84]:
def dec_prediction():
    return nr_worker * (model_size * 2 / bandwidth_dict["c2w"])

In [85]:
dec_prediction()

4.76791070967742

In [86]:
dec_real_value

5.166

In [87]:
ratio(dec_prediction(), dec_real_value)

0.9229405167784397

In [49]:
(13794560 * 2) / (47 * 1024 * 1024)

0.559809258643617

In [50]:
(0.5598 + 0.5598 * 0.9) * 2

2.1272399999999996

In [51]:
2 * 1.9 * 0.5598

2.1272399999999996

In [52]:
(2.127 - 2.6) / 2.6

-0.18192307692307705

In [53]:
8 * (2 * 3 * 0.002)

0.096

In [54]:
(0.096 - 0.1) / 0.1

-0.040000000000000036

In [56]:
3 * (2 * 13794560 / (441.39 * 1024 * 1024))

0.17882848607523957

In [57]:
(0.1788 - 0.18) / 0.18

-0.0066666666666667035

In [74]:
3 * (2 * 13794560 / (943 * 1024 * 1024))

0.08370424758085895

In [75]:
(0.0837 - 0.18) / 0.18

-0.535

In [77]:
2 * 1.9 * (2 * 13794560 / (943 * 1024 * 1024))

0.10602538026908802

In [78]:
(0.106 - 2.6) / 2.6

-0.9592307692307693