### 1. Setup and Initialization

In [None]:
# Comment these commands if you are running this notebook locally and your environment is already setted up
!pip3 install -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ onnxruntime-gpu==1.17.0
!pip3 install ultralytics
!pip3 install ortools

### 2. Model Loading and IR conversion

In [None]:
from ultralytics import YOLO
import onnx


# Load the YOLO11 model
model = YOLO("yolo11n.pt")

# Print the detailed summary
print(model.model)

# Export the model to ONNX format
model.export(format="onnx")  # creates 'yolo11n.onnx'

In [None]:
# Load the ONNX model
model = onnx.load("yolo11n.onnx")

# Check the model's structure and ensure it's well-formed
onnx.checker.check_model(model)

# Print a human-readable representation of the model
print(onnx.helper.printable_graph(model.graph))

In [None]:
nodes = 0

for node in model.graph.node:
    nodes += 1

print("Number of nodes: ", nodes)

operation_types = set(node.op_type for node in model.graph.node)

print("\nOperations: \n")
print("\n".join(operation_types))

### 3. Device and Performance Modeling

In [None]:
DEVICE_PROFILE = {
    "nvidia_L40s": {
        "gflops": 91600,             # Peak GFLOPS from datasheet
        "tdp": 350,                  # TDP in Watts
        "efficiency": 91600 / 350,   # GFLOPS per Watt
        "static_power": 350 * 0.1,   # Watts per hour
        "pcie_bandwidth": 32,        # GB/s
        "memory_bandwidth": 864      # GB/s
    },
    "nvidia_tesla_T4": {
        "gflops": 8100,              # Peak GFLOPS from datasheet
        "tdp": 70,                   # TDP in Watts
        "efficiency": 8100 / 70,     # GFLOPS per Watt
        "static_power": 70 * 0.1,    # Watts per hour
        "pcie_bandwidth": 16,        # GB/s
        "memory_bandwidth": 300      # GB/s
    },
    "nvidia_A30": {
        "gflops": 10300,             # Peak GFLOPS from datasheet
        "tdp": 165,                  # TDP in Watts
        "efficiency": 10300 / 165,   # GFLOPS per Watt
        "static_power": 165 * 0.1,   # Watts per hour
        "pcie_bandwidth": 32,        # GB/s
        "memory_bandwidth": 933      # GB/s
    },
    "alveo_U55C_2_DPU": {
        "gflops": (8 / 4) * 1000,               # Peak GFLOPS from datasheet
        "tdp": 115,                             # TDP in Watts
        "efficiency": ((8 / 4) * 1000) / 115,   # GFLOPS per Watt
        "static_power": 115 * 0.1,              # Watts per hour
        "pcie_bandwidth": 16,                   # GB/s
        "memory_bandwidth": 460                 # GB/s
    },
    "alveo_U55C_1_DPU": {
        "gflops": (5 / 4) * 1000,               # Peak GFLOPS from datasheet
        "tdp": 115,                             # TDP in Watts
        "efficiency": ((5 / 4) * 1000) / 115,   # GFLOPS per Watt
        "static_power": 115 * 0.1,              # Watts per hour
        "pcie_bandwidth": 16,                   # GB/s
        "memory_bandwidth": 460                 # GB/s
    },
    "bluefield_3_16C": {
        "gflops": 2.1 * 6 * 16,                 # Frequency (in GHz) * ipc * cores
        "tdp": 150,                             # TDP in Watts
        "efficiency": (2.1 * 6 * 16) / 150,     # GFLOPS per Watt
        "static_power": 150 * 0.1,              # Watts per hour
        "pcie_bandwidth": 64,                   # GB/s
        "memory_bandwidth": 42                  # GB/s
    },
    "bluefield_2_8C": {
        "gflops": 2.5 * 3 * 8,                 # Frequency (in GHz) * ipc * cores
        "tdp": 75,                             # TDP in Watts
        "efficiency": (2.5 * 3 * 8) / 75,      # GFLOPS per Watt
        "static_power": 75 * 0.1,              # Watts per hour
        "pcie_bandwidth": 16,                  # GB/s
        "memory_bandwidth": 26                 # GB/s
    }
}

d1_model = "nvidia_L40s"
d2_model = "nvidia_tesla_T4"
d3_model = "nvidia_A30"
d4_model = "alveo_U55C_2_DPU"
d5_model = "alveo_U55C_1_DPU"
d6_model = "bluefield_3_16C"
d7_model = "bluefield_2_8C"

D1_GFLOPS = DEVICE_PROFILE[d1_model]["gflops"]
D1_TDP = DEVICE_PROFILE[d1_model]["tdp"]
D1_STATIC_POWER = DEVICE_PROFILE[d1_model]["static_power"]
D1_EFFICIENCY = DEVICE_PROFILE[d1_model]["efficiency"]
D1_PCIE_BWDTH = DEVICE_PROFILE[d1_model]["pcie_bandwidth"]
D1_MEMORY_BWDTH = DEVICE_PROFILE[d1_model]["memory_bandwidth"]

D2_GFLOPS = DEVICE_PROFILE[d2_model]["gflops"]
D2_TDP = DEVICE_PROFILE[d2_model]["tdp"]
D2_STATIC_POWER = DEVICE_PROFILE[d2_model]["static_power"]
D2_EFFICIENCY = DEVICE_PROFILE[d2_model]["efficiency"]
D2_PCIE_BWDTH = DEVICE_PROFILE[d2_model]["pcie_bandwidth"]
D2_MEMORY_BWDTH = DEVICE_PROFILE[d2_model]["memory_bandwidth"]

D3_GFLOPS = DEVICE_PROFILE[d3_model]["gflops"]
D3_TDP = DEVICE_PROFILE[d3_model]["tdp"]
D3_STATIC_POWER = DEVICE_PROFILE[d3_model]["static_power"]
D3_EFFICIENCY = DEVICE_PROFILE[d3_model]["efficiency"]
D3_PCIE_BWDTH = DEVICE_PROFILE[d3_model]["pcie_bandwidth"]
D3_MEMORY_BWDTH = DEVICE_PROFILE[d3_model]["memory_bandwidth"]

D4_GFLOPS = DEVICE_PROFILE[d4_model]["gflops"]
D4_TDP = DEVICE_PROFILE[d4_model]["tdp"]
D4_STATIC_POWER = DEVICE_PROFILE[d4_model]["static_power"]
D4_EFFICIENCY = DEVICE_PROFILE[d4_model]["efficiency"]
D4_PCIE_BWDTH = DEVICE_PROFILE[d4_model]["pcie_bandwidth"]
D4_MEMORY_BWDTH = DEVICE_PROFILE[d4_model]["memory_bandwidth"]

D5_GFLOPS = DEVICE_PROFILE[d5_model]["gflops"]
D5_TDP = DEVICE_PROFILE[d5_model]["tdp"]
D5_STATIC_POWER = DEVICE_PROFILE[d5_model]["static_power"]
D5_EFFICIENCY = DEVICE_PROFILE[d5_model]["efficiency"]
D5_PCIE_BWDTH = DEVICE_PROFILE[d5_model]["pcie_bandwidth"]
D5_MEMORY_BWDTH = DEVICE_PROFILE[d5_model]["memory_bandwidth"]

D6_GFLOPS = DEVICE_PROFILE[d6_model]["gflops"]
D6_TDP = DEVICE_PROFILE[d6_model]["tdp"]
D6_STATIC_POWER = DEVICE_PROFILE[d6_model]["static_power"]
D6_EFFICIENCY = DEVICE_PROFILE[d6_model]["efficiency"]
D6_PCIE_BWDTH = DEVICE_PROFILE[d6_model]["pcie_bandwidth"]
D6_MEMORY_BWDTH = DEVICE_PROFILE[d6_model]["memory_bandwidth"]

D7_GFLOPS = DEVICE_PROFILE[d7_model]["gflops"]
D7_TDP = DEVICE_PROFILE[d7_model]["tdp"]
D7_STATIC_POWER = DEVICE_PROFILE[d7_model]["static_power"]
D7_EFFICIENCY = DEVICE_PROFILE[d7_model]["efficiency"]
D7_PCIE_BWDTH = DEVICE_PROFILE[d7_model]["pcie_bandwidth"]
D7_MEMORY_BWDTH = DEVICE_PROFILE[d7_model]["memory_bandwidth"]


In [None]:
import numpy as np
import onnx.shape_inference


DAG_PROFILE = {
    "computing_time": {
        "D1": [],
        "D2": [],
        "D3": [],
        "D4": [],
        "D5": [],
        "D6": [],
        "D7": []
    },
    "energy_cost": {
        "D1": [],
        "D2": [],
        "D3": [],
        "D4": [],
        "D5": [],
        "D6": [],
        "D7": []
    },
    "output_size": []
}

def get_input_shapes(model):
    input_shapes = {}

    for node in model.graph.node:
        shape = [dim.dim_value for dim in model.graph.input[0].type.tensor_type.shape.dim]
        
        if node.name == "/model.0/conv/Conv":
            input_shapes['images'] = shape
        elif node.name == "/model.23/Sub":
            input_shapes['/model.23/Constant_9_output_0'] = shape
        elif node.name == "/model.23/Add_1":
            input_shapes['/model.23/Constant_10_output_0'] = shape

    for value_info in model.graph.value_info:
        shape = [dim.dim_value for dim in value_info.type.tensor_type.shape.dim]
        input_shapes[value_info.name] = shape

    for initializer in model.graph.initializer:
        shape = [dim for dim in initializer.dims]
        
        if initializer.name == '/model.23/Constant_12_output_0':
            input_shapes['output'] = shape
    
    return input_shapes


def node_profilling(node, flops, data_size):
    # Estimating the intra-device communication overhead in milliseconds (read and write)
    d1_memory_overhead = ((data_size / D1_MEMORY_BWDTH) * 2) * 1000
    d2_memory_overhead = ((data_size / D2_MEMORY_BWDTH) * 2) * 1000
    d3_memory_overhead = ((data_size / D3_MEMORY_BWDTH) * 2) * 1000
    d4_memory_overhead = ((data_size / D4_MEMORY_BWDTH) * 2) * 1000
    d5_memory_overhead = ((data_size / D5_MEMORY_BWDTH) * 2) * 1000
    d6_memory_overhead = ((data_size / D6_MEMORY_BWDTH) * 2) * 1000
    d7_memory_overhead = ((data_size / D7_MEMORY_BWDTH) * 2) * 1000

    d1_pcie_overhead = (data_size / D1_PCIE_BWDTH) * 1000
    d2_pcie_overhead = (data_size / D2_PCIE_BWDTH) * 1000
    d3_pcie_overhead = (data_size / D3_PCIE_BWDTH) * 1000
    d4_pcie_overhead = (data_size / D4_PCIE_BWDTH) * 1000
    d5_pcie_overhead = (data_size / D5_PCIE_BWDTH) * 1000
    d6_pcie_overhead = (data_size / D6_PCIE_BWDTH) * 1000
    d7_pcie_overhead = (data_size / D7_PCIE_BWDTH) * 1000

    # Estimating the computing time in milliseconds
    d1_computing_time = (flops / D1_GFLOPS) * 1000
    d2_computing_time = (flops / D2_GFLOPS) * 1000
    d3_computing_time = (flops / D3_GFLOPS) * 1000
    d4_computing_time = (flops / D4_GFLOPS) * 1000
    d5_computing_time = (flops / D5_GFLOPS) * 1000
    d6_computing_time = (flops / D6_GFLOPS) * 1000
    d7_computing_time = (flops / D7_GFLOPS) * 1000
    
    if node == "/model.0/conv/Conv" or node == "/model.23/Concat_5":
        total_d1_computing_time = d1_computing_time + d1_memory_overhead + d1_pcie_overhead
        total_d2_computing_time = d2_computing_time + d2_memory_overhead + d2_pcie_overhead
        total_d3_computing_time = d3_computing_time + d3_memory_overhead + d3_pcie_overhead
        total_d4_computing_time = d4_computing_time + d4_memory_overhead + d4_pcie_overhead
        total_d5_computing_time = d5_computing_time + d5_memory_overhead + d5_pcie_overhead
        total_d6_computing_time = d6_computing_time + d6_memory_overhead + d6_pcie_overhead
        total_d7_computing_time = d7_computing_time + d7_memory_overhead + d7_pcie_overhead
    else:
        total_d1_computing_time = d1_computing_time + d1_memory_overhead
        total_d2_computing_time = d2_computing_time + d2_memory_overhead
        total_d3_computing_time = d3_computing_time + d3_memory_overhead
        total_d4_computing_time = d4_computing_time + d4_memory_overhead
        total_d5_computing_time = d5_computing_time + d5_memory_overhead
        total_d6_computing_time = d6_computing_time + d6_memory_overhead
        total_d7_computing_time = d7_computing_time + d7_memory_overhead

    DAG_PROFILE["computing_time"]["D1"].append(total_d1_computing_time)
    DAG_PROFILE["computing_time"]["D2"].append(total_d2_computing_time)
    DAG_PROFILE["computing_time"]["D3"].append(total_d3_computing_time)
    DAG_PROFILE["computing_time"]["D4"].append(total_d4_computing_time)
    DAG_PROFILE["computing_time"]["D5"].append(total_d5_computing_time)
    DAG_PROFILE["computing_time"]["D6"].append(total_d6_computing_time)
    DAG_PROFILE["computing_time"]["D7"].append(total_d7_computing_time)
    
    # Estimating the energy cost in Joules
    d1_energy_cost = (total_d1_computing_time / 1000) * D1_TDP
    d2_energy_cost = (total_d2_computing_time / 1000) * D2_TDP
    d3_energy_cost = (total_d3_computing_time / 1000) * D3_TDP
    d4_energy_cost = (total_d4_computing_time / 1000) * D4_TDP
    d5_energy_cost = (total_d5_computing_time / 1000) * D5_TDP
    d6_energy_cost = (total_d6_computing_time / 1000) * D6_TDP
    d7_energy_cost = (total_d7_computing_time / 1000) * D7_TDP

    DAG_PROFILE["energy_cost"]["D1"].append(d1_energy_cost)
    DAG_PROFILE["energy_cost"]["D2"].append(d2_energy_cost)
    DAG_PROFILE["energy_cost"]["D3"].append(d3_energy_cost)
    DAG_PROFILE["energy_cost"]["D4"].append(d4_energy_cost)
    DAG_PROFILE["energy_cost"]["D5"].append(d5_energy_cost)
    DAG_PROFILE["energy_cost"]["D6"].append(d6_energy_cost)
    DAG_PROFILE["energy_cost"]["D7"].append(d7_energy_cost)

    DAG_PROFILE["output_size"].append(data_size)
    
    print("\n=========================================================")
    
    print("\nNode {0} computing complexity: {1:.4f} GFLOPs".format(node, flops))

    print("Energy cost (D1 - {0}): {1:.4f} J".format(d1_model, d1_energy_cost))
    print("Energy cost (D2 - {0}): {1:.4f} J".format(d2_model, d2_energy_cost))
    print("Energy cost (D3 - {0}): {1:.4f} J".format(d3_model, d3_energy_cost))
    print("Energy cost (D4 - {0}): {1:.4f} J".format(d4_model, d4_energy_cost))
    print("Energy cost (D5 - {0}): {1:.4f} J".format(d5_model, d5_energy_cost))
    print("Energy cost (D6 - {0}): {1:.4f} J".format(d6_model, d6_energy_cost))
    print("Energy cost (D7 - {0}): {1:.4f} J".format(d7_model, d7_energy_cost))

    print("Computing time (D1 - {0}): {1:.4} ms".format(d1_model, total_d1_computing_time))
    print("Computing time (D2 - {0}): {1:.4} ms".format(d2_model, total_d2_computing_time))
    print("Computing time (D3 - {0}): {1:.4} ms".format(d3_model, total_d3_computing_time))
    print("Computing time (D4 - {0}): {1:.4} ms".format(d4_model, total_d4_computing_time))
    print("Computing time (D5 - {0}): {1:.4} ms".format(d5_model, total_d5_computing_time))
    print("Computing time (D6 - {0}): {1:.4} ms".format(d6_model, total_d6_computing_time))
    print("Computing time (D7 - {0}): {1:.4} ms".format(d7_model, total_d7_computing_time))
    
    print("\n=========================================================\n")


def estimate_flops(model):
    model = onnx.shape_inference.infer_shapes(model)
    input_shapes = get_input_shapes(model)
    
    total_flops = 0
    total_data = 0
    
    node_flops = {}

    for node in model.graph.node:
        print("Operation type: {0}".format(node.op_type))

        input_shape = input_shapes[node.input[0]]
        output_shape = input_shapes[node.output[0]]

        if node.op_type == "Conv":            
            c_in = input_shape[1]
            
            kernel_height = [attr.ints for attr in node.attribute if attr.name == 'kernel_shape'][0][0]
            kernel_width = [attr.ints for attr in node.attribute if attr.name == 'kernel_shape'][0][1]
            
            c_out = output_shape[1]
            h_out = output_shape[2]
            w_out = output_shape[3]

            macs = (c_out * h_out * w_out) * (c_in * kernel_height * kernel_width)
            flops = 2 * macs / (10 ** 9)

            total_flops += flops
            node_flops[node.name] = flops

            # Output elements multiplied by the numerical preicsion byte size (FP32)
            data_size = (c_out * h_out * w_out) * 4 / (10 ** 9) # GB
            total_data += data_size

            print("node: {0}, input: {1}, output: {2}, GFLOPs: {3:.4f}".format(node.name, node.input[0], node.output[0], flops))
            print("c_in: {0}, c_out: {1}, kernel_height: {2}".format(c_in, c_out, kernel_height))
            print(f"intermediate result data size: {data_size} GB")
            
            node_profilling(node.name, flops, data_size)
        elif node.op_type == "Sigmoid":
            if len(output_shape) > 3:                
                c_out = output_shape[1]
                h_out = output_shape[2]
                w_out = output_shape[3]

                flops = 4 * (c_out * h_out * w_out) / (10 ** 9)

                total_flops += flops
                node_flops[node.name] = flops

                # Output elements multiplied by the numerical preicsion byte size (FP32)
                data_size = (c_out * h_out * w_out) * 4 / (10 ** 9) # GB
                total_data += data_size

                print("node: {0}, input: {1}, output: {2}, GFLOPs: {3:.4f}".format(node.name, node.input[0], node.output[0], flops))
                print("c_out: {0}, H': {1}, W': {2}".format(c_out, h_out, w_out))
                print(f"intermediate result data size: {data_size} GB")
                
                node_profilling(node.name, flops, data_size)
            else:
                channels = output_shape[1]
                length = output_shape[2]

                flops = 4 * (channels * length) / (10 ** 9)

                total_flops += flops
                node_flops[node.name] = flops

                # Output elements multiplied by the numerical preicsion byte size (FP32)
                data_size = (channels * length) * 4 / (10 ** 9) # GB
                total_data += data_size
    
                print("node: {0}, input: {1}, output: {2}, GFLOPs: {3:.4f}".format(node.name, node.input[0], node.output[0], flops))
                print("c_out: {0}, W': {1}".format(channels, length))
                print(f"intermediate result data size: {data_size} GB")
                
                node_profilling(node.name, flops, data_size)
        elif node.op_type == "Mul":
            if len(output_shape) > 3:
                c_out = output_shape[1]
                h_out = output_shape[2]
                w_out = output_shape[3]

                flops = c_out * h_out * w_out / (10 ** 9)
                
                total_flops += flops
                node_flops[node.name] = flops

                # Output elements multiplied by the numerical preicsion byte size (FP32)
                data_size = (c_out * h_out * w_out) * 4 / (10 ** 9) # GB
                total_data += data_size
    
                print("node: {0}, input: {1}, output: {2}, GFLOPs: {3:.4f}".format(node.name, node.input[0], node.output[0], flops))
                print("c_out: {0}, H': {1}, W': {2}".format(c_out, h_out, w_out))
                print(f"intermediate result data size: {data_size} GB")
                
                node_profilling(node.name, flops, data_size)
            else:
                channels = output_shape[1]
                length = output_shape[2]

                flops = channels * length / (10 ** 9)

                total_flops += flops
                node_flops[node.name] = flops

                # Output elements multiplied by the numerical preicsion byte size (FP32)
                data_size = (channels * length) * 4 / (10 ** 9) # GB
                total_data += data_size
    
                print("node: {0}, input: {1}, output: {2}, GFLOPs: {3:.4f}".format(node.name, node.input[0], node.output[0], flops))
                print("c_out: {0}, W': {1}".format(channels, length))
                print(f"intermediate result data size: {data_size} GB")
                
                node_profilling(node.name, flops, data_size)
        elif node.op_type == "Add":
            if len(output_shape) > 3:
                c_out = output_shape[1]
                h_out = output_shape[2]
                w_out = output_shape[3]

                flops = c_out * h_out * w_out / (10 ** 9)
                
                total_flops += flops
                node_flops[node.name] = flops

                # Output elements multiplied by the numerical preicsion byte size (FP32)
                data_size = (c_out * h_out * w_out) * 4 / (10 ** 9) # GB
                total_data += data_size
    
                print("node: {0}, input: {1}, output: {2}, GFLOPs: {3:.4f}".format(node.name, node.input[0], node.output[0], flops))
                print("c_out: {0}, H': {1}, W': {2}".format(c_out, h_out, w_out))
                print(f"intermediate result data size: {data_size} GB")
                
                node_profilling(node.name, flops, data_size)
            else:
                channels = output_shape[1]
                length = output_shape[2]

                flops = channels * length / (10 ** 9)

                total_flops += flops
                node_flops[node.name] = flops

                # Output elements multiplied by the numerical preicsion byte size (FP32)
                data_size = (channels * length) * 4 / (10 ** 9) # GB
                total_data += data_size
    
                print("node: {0}, input: {1}, output: {2}, GFLOPs: {3:.4f}".format(node.name, node.input[0], node.output[0], flops))
                print("c_out: {0}, W': {1}".format(channels, length))
                print(f"intermediate result data size: {data_size} GB")

                node_profilling(node.name, flops, data_size)
        elif node.op_type == "Sub":
            if len(output_shape) > 3:
                c_out = output_shape[1]
                h_out = output_shape[2]
                w_out = output_shape[3]

                flops = c_out * h_out * w_out / (10 ** 9)
                
                total_flops += flops
                node_flops[node.name] = flops

                # Output elements multiplied by the numerical preicsion byte size (FP32)
                data_size = (c_out * h_out * w_out) * 4 / (10 ** 9) # GB
                total_data += data_size
    
                print("node: {0}, input: {1}, output: {2}, GFLOPs: {3:.4f}".format(node.name, node.input[0], node.output[0], flops))
                print("c_out: {0}, H': {1}, W': {2}".format(c_out, h_out, w_out))
                print(f"intermediate result data size: {data_size} GB")
                
                node_profilling(node.name, flops, data_size)
            else:
                channels = output_shape[1]
                length = output_shape[2]

                flops = channels * length / (10 ** 9)

                total_flops += flops
                node_flops[node.name] = flops

                # Output elements multiplied by the numerical preicsion byte size (FP32)
                data_size = (channels * length) * 4 / (10 ** 9) # GB
                total_data += data_size
    
                print("node: {0}, input: {1}, output: {2}, GFLOPs: {3:.4f}".format(node.name, node.input[0], node.output[0], flops))
                print("c_out: {0}, W': {1}".format(channels, length))
                print(f"intermediate result data size: {data_size} GB")
                
                node_profilling(node.name, flops, data_size)
        elif node.op_type == "Div":
            if len(output_shape) > 3:
                c_out = output_shape[1]
                h_out = output_shape[2]
                w_out = output_shape[3]

                flops = c_out * h_out * w_out / (10 ** 9)
                
                total_flops += flops
                node_flops[node.name] = flops

                # Output elements multiplied by the numerical preicsion byte size (FP32)
                data_size = (c_out * h_out * w_out) * 4 / (10 ** 9) # GB
                total_data += data_size
    
                print("node: {0}, input: {1}, output: {2}, GFLOPs: {3:.4f}".format(node.name, node.input[0], node.output[0], flops))
                print("c_out: {0}, H': {1}, W': {2}".format(c_out, h_out, w_out))
                print(f"intermediate result data size: {data_size} GB")
                
                node_profilling(node.name, flops, data_size)
            else:
                channels = output_shape[1]
                length = output_shape[2]

                flops = channels * length / (10 ** 9)

                total_flops += flops
                node_flops[node.name] = flops

                # Output elements multiplied by the numerical preicsion byte size (FP32)
                data_size = (channels * length) * 4 / (10 ** 9) # GB
                total_data += data_size
    
                print("node: {0}, input: {1}, output: {2}, GFLOPs: {3:.4f}".format(node.name, node.input[0], node.output[0], flops))
                print("c_out: {0}, W': {1}".format(channels, length))
                print(f"intermediate result data size: {data_size} GB")
                
                node_profilling(node.name, flops, data_size)
        elif node.op_type == "MatMul":
            c_out = output_shape[1]
            h_out = output_shape[2]
            w_out = output_shape[3]

            macs = c_out * h_out * w_out
            flops = 2 * macs / (10 ** 9)

            total_flops += flops
            node_flops[node.name] = flops

            # Output elements multiplied by the numerical preicsion byte size (FP32)
            data_size = (c_out * h_out * w_out) * 4 / (10 ** 9) # GB
            total_data += data_size

            print("node: {0}, input: {1}, output: {2}, GFLOPs: {3:.4f}".format(node.name, node.input, node.output[0], flops))
            print("c_out: {0}, H': {1}, W': {2}".format(c_out, h_out, w_out))
            print(f"intermediate result data size: {data_size} GB")
            
            node_profilling(node.name, flops, data_size)
        elif node.op_type == "MaxPool":
            c_out = output_shape[1]
            h_out = output_shape[2]
            w_out = output_shape[3]

            kernel_height = [attr.ints for attr in node.attribute if attr.name == 'kernel_shape'][0][0]
            kernel_width = [attr.ints for attr in node.attribute if attr.name == 'kernel_shape'][0][1]

            flops = (c_out * h_out * w_out) * (kernel_height * kernel_width - 1) / (10 ** 9)
            
            total_flops += flops
            node_flops[node.name] = flops

            # Output elements multiplied by the numerical preicsion byte size (FP32)
            data_size = (c_out * h_out * w_out) * 4 / (10 ** 9) # GB
            total_data += data_size

            print("node: {0}, input: {1}, output: {2}, GFLOPs: {3:.4f}".format(node.name, node.input[0], node.output[0], flops))
            print("c_out: {0}, H': {1}, W': {2}".format(c_out, h_out, w_out))
            print(f"intermediate result data size: {data_size} GB")
            
            node_profilling(node.name, flops, data_size)
        elif node.op_type == "Softmax":
            c_out = output_shape[1]
            h_out = output_shape[2]
            w_out = output_shape[3]

            output_elements = c_out * h_out * w_out
            elements_in_softmax_axis = output_shape[-1]
            num_slices = output_elements // elements_in_softmax_axis

            # 3 FLOPs per element in the softmax vector (exp, add, div)
            flops_per_slice = 3 * elements_in_softmax_axis
            
            flops = num_slices * flops_per_slice / (10 ** 9)

            total_flops += flops
            node_flops[node.name] = flops

            # Output elements multiplied by the numerical preicsion byte size (FP32)
            data_size = output_elements * 4 / (10 ** 9) # GB
            total_data += data_size

            print("node: {0}, input: {1}, output: {2}, GFLOPs: {3:.4f}".format(node.name, node.input[0], node.output[0], flops))
            print("c_out: {0}, H': {1}, W': {2}".format(c_out, h_out, w_out))
            print(f"intermediate result data size: {data_size} GB")
            
            node_profilling(node.name, flops, data_size)
        elif node.op_type == "Resize":
            for attr in node.attribute:
                if attr.name == "mode":
                    print(f"Resize mode: {attr.s.decode('utf-8')}")

            c_out = output_shape[1]
            h_out = output_shape[2]
            w_out = output_shape[3]

            output_elements = c_out * h_out * w_out

            # 1 FLOP per element (nearest neighbor interpolation mode)
            flops = 1 * output_elements / (10 ** 9)

            total_flops += flops
            node_flops[node.name] = flops

            # Output elements multiplied by the numerical preicsion byte size (FP32)
            data_size = output_elements * 4 / (10 ** 9) # GB
            total_data += data_size

            print("node: {0}, input: {1}, output: {2}, GFLOPs: {3:.4f}".format(node.name, node.input[0], node.output[0], flops))
            print("c_out: {0}, H': {1}, W': {2}".format(c_out, h_out, w_out))
            print(f"intermediate result data size: {data_size} GB")
            
            node_profilling(node.name, flops, data_size)
        elif node.op_type == "Split":            
            if len(output_shape) > 3:
                c_out = output_shape[1]
                h_out = output_shape[2]
                w_out = output_shape[3]
    
                flops = 0
    
                total_flops += flops
                node_flops[node.name] = flops

                # Output elements multiplied by the numerical preicsion byte size (FP32)
                data_size = (c_out * h_out * w_out) * 4 / (10 ** 9) # GB
                
                data_size = data_size * 2 # Considering the data size of both splits
                total_data += data_size
    
                print("node: {0}, input: {1}, output: {2}, GFLOPs: {3:.4f}".format(node.name, node.input[0], node.output[0], flops))
                print("c_out: {0}, H': {1}, W': {2}".format(c_out, h_out, w_out))
                print(f"intermediate result data size: {data_size} GB")
                
                node_profilling(node.name, flops, data_size)
            else:
                channels = output_shape[1]
                length = output_shape[2]

                flops = 0

                total_flops += flops
                node_flops[node.name] = flops

                # Output elements multiplied by the numerical preicsion byte size (FP32)
                data_size = (channels * length) * 4 / (10 ** 9) # GB
                
                data_size = data_size * 2 # Considering the data size of both splits
                total_data += data_size
    
                print("node: {0}, input: {1}, output: {2}, GFLOPs: {3:.4f}".format(node.name, node.input[0], node.output[0], flops))
                print("c_out: {0}, W': {1}".format(channels, length))
                print(f"intermediate result data size: {data_size} GB")
                
                node_profilling(node.name, flops, data_size)
        elif node.op_type == "Slice":
            channels = output_shape[1]
            length = output_shape[2]

            flops = 0

            total_flops += flops
            node_flops[node.name] = flops

            # Output elements multiplied by the numerical preicsion byte size (FP32)
            data_size = (channels * length) * 4 / (10 ** 9) # GB
            total_data += data_size

            print("node: {0}, input: {1}, output: {2}, GFLOPs: {3:.4f}".format(node.name, node.input[0], node.output[0], flops))
            print("c_out: {0}, W': {1}".format(channels, length))
            print(f"intermediate result data size: {data_size} GB")

            node_profilling(node.name, flops, data_size)
        elif node.op_type == "Concat":
            if len(output_shape) > 3:
                c_out = output_shape[1]
                h_out = output_shape[2]
                w_out = output_shape[3]
    
                flops = 0
    
                total_flops += flops
                node_flops[node.name] = flops

                # Output elements multiplied by the numerical preicsion byte size (FP32)
                data_size = (c_out * h_out * w_out) * 4 / (10 ** 9) # GB
                total_data += data_size
    
                print("node: {0}, input: {1}, output: {2}, GFLOPs: {3:.4f}".format(node.name, node.input[0], node.output[0], flops))
                print("c_out: {0}, H': {1}, W': {2}".format(c_out, h_out, w_out))
                print(f"intermediate result data size: {data_size} GB")
                
                node_profilling(node.name, flops, data_size)
            else:
                channels = output_shape[1]
                length = output_shape[2]

                flops = 0

                total_flops += flops
                node_flops[node.name] = flops

                # Output elements multiplied by the numerical preicsion byte size (FP32)
                data_size = (channels * length) * 4 / (10 ** 9) # GB
                total_data += data_size
    
                print("node: {0}, input: {1}, output: {2}, GFLOPs: {3:.4f}".format(node.name, node.input[0], node.output[0], flops))
                print("c_out: {0}, W': {1}".format(channels, length))
                print(f"intermediate result data size: {data_size} GB")
                
                node_profilling(node.name, flops, data_size)
        elif node.op_type == "Reshape":
            if len(output_shape) > 3:
                c_out = output_shape[1]
                h_out = output_shape[2]
                w_out = output_shape[3]
    
                flops = 0
    
                total_flops += flops
                node_flops[node.name] = flops

                # Output elements multiplied by the numerical preicsion byte size (FP32)
                data_size = (c_out * h_out * w_out) * 4 / (10 ** 9) # GB
                total_data += data_size
    
                print("node: {0}, input: {1}, output: {2}, GFLOPs: {3:.4f}".format(node.name, node.input[0], node.output[0], flops))
                print("c_out: {0}, H': {1}, W': {2}".format(c_out, h_out, w_out))
                print(f"intermediate result data size: {data_size} GB")
                
                node_profilling(node.name, flops, data_size)
            else:
                channels = output_shape[1]
                length = output_shape[2]

                flops = 0

                total_flops += flops
                node_flops[node.name] = flops

                # Output elements multiplied by the numerical preicsion byte size (FP32)
                data_size = (channels * length) * 4 / (10 ** 9) # GB
                total_data += data_size
    
                print("node: {0}, input: {1}, output: {2}, GFLOPs: {3:.4f}".format(node.name, node.input[0], node.output[0], flops))
                print("c_out: {0}, W': {1}".format(channels, length))
                print(f"intermediate result data size: {data_size} GB")
                
                node_profilling(node.name, flops, data_size)
        elif node.op_type == "Transpose":
            c_out = output_shape[1]
            h_out = output_shape[2]
            w_out = output_shape[3]

            flops = 0

            total_flops += flops
            node_flops[node.name] = flops

            # Output elements multiplied by the numerical preicsion byte size (FP32)
            data_size = (c_out * h_out * w_out) * 4 / (10 ** 9) # GB
            total_data += data_size

            print("node: {0}, input: {1}, output: {2}, GFLOPs: {3:.4f}".format(node.name, node.input[0], node.output[0], flops))
            print("c_out: {0}, H': {1}, W': {2}".format(c_out, h_out, w_out))
            print(f"intermediate result data size: {data_size} GB")
            
            node_profilling(node.name, flops, data_size)

    print("\n=========================================================")
    
    print("\nDAG computing complexity: {0:.4f} GFLOPs".format(total_flops))

    print("Energy cost (D1 - {0}): {1:.4} J".format(d1_model, sum(DAG_PROFILE["energy_cost"]["D1"])))
    print("Energy cost (D2 - {0}): {1:.4} J".format(d2_model, sum(DAG_PROFILE["energy_cost"]["D2"])))
    print("Energy cost (D3 - {0}): {1:.4} J".format(d3_model, sum(DAG_PROFILE["energy_cost"]["D3"])))
    print("Energy cost (D4 - {0}): {1:.4} J".format(d4_model, sum(DAG_PROFILE["energy_cost"]["D4"])))
    print("Energy cost (D5 - {0}): {1:.4} J".format(d5_model, sum(DAG_PROFILE["energy_cost"]["D5"])))
    print("Energy cost (D6 - {0}): {1:.4} J".format(d6_model, sum(DAG_PROFILE["energy_cost"]["D6"])))
    print("Energy cost (D7 - {0}): {1:.4} J".format(d7_model, sum(DAG_PROFILE["energy_cost"]["D7"])))

    print("Computing time (D1 - {0}): {1:.4} ms".format(d1_model, sum(DAG_PROFILE["computing_time"]["D1"])))
    print("Computing time (D2 - {0}): {1:.4} ms".format(d2_model, sum(DAG_PROFILE["computing_time"]["D2"])))
    print("Computing time (D3 - {0}): {1:.4} ms".format(d3_model, sum(DAG_PROFILE["computing_time"]["D3"])))
    print("Computing time (D4 - {0}): {1:.4} ms".format(d4_model, sum(DAG_PROFILE["computing_time"]["D4"])))
    print("Computing time (D5 - {0}): {1:.4} ms".format(d5_model, sum(DAG_PROFILE["computing_time"]["D5"])))
    print("Computing time (D6 - {0}): {1:.4} ms".format(d6_model, sum(DAG_PROFILE["computing_time"]["D6"])))
    print("Computing time (D7 - {0}): {1:.4} ms".format(d7_model, sum(DAG_PROFILE["computing_time"]["D7"])))
    
    print("\n=========================================================\n")

    return node_flops
        
        
node_flops = estimate_flops(model)

### 4. Optimization Problem Formulation

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns


def verify_solution(nodes, devices, node_energy, node_energy_acc, node_latency, node_flops, x, overall_latency_requirement, node_assignments, plot_summary=False, exp=0):
    # Initialize dictionaries to store counts and summaries
    device_node_count = {device: 0 for device in devices}
    flops_per_device = {device: 0 for device in devices}
    total_latency = {device: 0 for device in devices}
    total_energy = {device: 0 for device in devices}
    total_acc_energy = {device: 0 for device in devices}

    device_type_to_model = {
        "D1": d1_model,
        "D2": d2_model,
        "D3": d3_model,
        "D4": d4_model,
        "D5": d5_model,
        "D6": d6_model,
        "D7": d7_model
    }

    # Calculate and summarize information
    for i, node in enumerate(nodes):
        assigned_device = node_assignments[node]
        device_node_count[assigned_device] += 1

        for j, device in enumerate(devices):
            if x[(i, j)].solution_value() > 0:
                total_latency[device] += node_latency[device][i]
                total_energy[device] += node_energy[device][i]
                total_acc_energy[device] += node_energy_acc[device][i]
    
    for node, device in node_assignments.items():
        if node in node_flops: 
            flops_per_device[device] += node_flops[node]

    # Check latency requirements
    meets_latency_requirements = all(latency <= overall_latency_requirement for latency in total_latency.values())

    # Print results
    print("Meets Latency Requirements:", meets_latency_requirements)

    print("\nSolution Verification Summary:\n")
    for device in devices:
        print(f" Device: {device_type_to_model[device]}")
        print(f" Number of Nodes Assigned: {device_node_count[device]}")
        print(f" GFLOPs computed: {flops_per_device[device]:.4f}")
        print(f" Total computing time: {total_latency[device]:.4f} ms")
        print(f" Total energy cost: {total_energy[device]:.4f} J")
        print("-" * 60)
    
    print("\nTotal Latency: {0:.4f} milliseconds".format(sum(total_latency.values())))
    print("Total Energy: {0:.4f} J".format(sum(total_energy.values())))

    if plot_summary == True:
        # Plot pie charts
        plot_chart(flops_per_device, "FLOPs computed per device", "Devices", "FLOPs", exp)
        plot_chart(total_latency, "Total computing time per device", "Devices", "Computing time (ms)", exp)
        plot_chart(total_acc_energy, "Total computing energy per device", "Devices", "Energy (J)", exp) 
    
    
def plot_chart(data, title, xlabel, ylabel, exp): 
    labels = list(data.keys()) 
    data = list(data.values())

    # Get four different grey colors.
    cmap = plt.get_cmap('magma')
    colors = list(cmap(np.linspace(0.1, 0.9, len(devices))))

    #data = [np.round(i, 4) for i in data]
    
    plt.figure(figsize=(12,6))

    rects = plt.bar(labels, data, color=colors)
    plt.bar_label(rects, padding=3)
    sns.despine()
    plt.xlabel(xlabel) 
    plt.ylabel(ylabel) 
    
    plt.title(title, fontsize=16, pad=30)

    #plt.savefig("{0}_e{1}.pdf".format(title, exp), dpi=600, bbox_inches='tight', pad_inches=0.1)
    plt.show()

In [None]:
from ortools.linear_solver import pywraplp
from collections import defaultdict


def extract_dag_edges(model):
    output_to_node = {}
    edges = []
    
    for idx, node in enumerate(model.graph.node):
        for output in node.output:
            output_to_node[output] = idx
    
    for idx, node in enumerate(model.graph.node):
        for input_name in node.input:
            if input_name in output_to_node:
                parent = output_to_node[input_name]
                child = idx
                edges.append((parent, child))

    return edges


def nested_dict():
    return defaultdict(dict)


def estimate_communication_latencies(devices, edges):
    comm_latency = defaultdict(lambda: defaultdict(nested_dict))

    device_type_to_model = {
        "D1": d1_model,
        "D2": d2_model,
        "D3": d3_model,
        "D4": d4_model,
        "D5": d5_model,
        "D6": d6_model,
        "D7": d7_model
    }

    for device_from in devices:
        for device_to in devices:
            model_from = device_type_to_model[device_from]
            model_to = device_type_to_model[device_to]

            bandwidth = min(
                DEVICE_PROFILE[model_from]["pcie_bandwidth"],
                DEVICE_PROFILE[model_to]["pcie_bandwidth"]
            )
    
            for (parent, child) in edges:
                data_size = DAG_PROFILE["output_size"][parent]
                latency = 0 if device_from == device_to else (data_size / bandwidth) * 1000
                comm_latency[device_from][device_to][parent][child] = latency

    return comm_latency


def graph_partition_S1(model, nodes, devices, node_energy, node_flops, node_latency, overall_latency_requirement, device_capacity, device_utilization, static_energy, plot_summary=True, exp=0):
    solver = pywraplp.Solver.CreateSolver('SCIP')
    solver.EnableOutput()

    solver.set_time_limit(86400000)
    
    available_capacity = {device: device_capacity[device] * (1.0 - device_utilization[device]) for device in devices}
    adjusted_node_latency = {}

    for device in devices:
        adjusted_node_latency[device] = [node_latency[device][i] / (1.0 - device_utilization[device]) for i in range(len(node_latency[device]))]

    T_s = overall_latency_requirement / 1000
    
    # Variables: x[i][j] represents if node i is assigned to device j (binary)
    x = {}
    
    for i in range(len(nodes)):
        for j in range(len(devices)):
            x[(i, j)] = solver.BoolVar(f'x_{i}_{j}')
    
    # Objective: Minimize total energy spent (dynamic energy + static energy)
    objective = solver.Objective()
    
    for i in range(len(nodes)):
        for j in range(len(devices)):
            # Scaling both dynamic and static energy by the device utilization
            dynamic_energy = node_energy[devices[j]][i] * (adjusted_node_latency[devices[j]][i] / node_latency[devices[j]][i])
            static_energy_per_node = (adjusted_node_latency[devices[j]][i] / 1000) * static_energy[devices[j]]
            
            objective.SetCoefficient(x[(i, j)], dynamic_energy + static_energy_per_node)
    
    objective.SetMinimization()
    
    # Consistency constraint: each node must be assigned to exactly one device
    for i in range(len(nodes)):
        solver.Add(sum(x[(i, j)] for j in range(len(devices))) == 1)
    
    # Device capacity constraint: the workload assigned to each device (in GFLOPs) should
    # not surpass its computational capacity (in GFLOPS)
    for j, device in enumerate(devices):        
        solver.Add(sum(x[(i, j)] * node_flops[node] for i, node in enumerate(nodes)) <= (available_capacity[device] * T_s))
 
    # Node precedence constraint: if there are dependencies between nodes, the child node
    # processing must not start before its parent node finishes
    
    # Defining start time variables
    s = [solver.NumVar(0, overall_latency_requirement, f"s_{i}") for i in range(len(nodes))]

    # Maximum possible finish time (e.g. sum of all node latencies on the slowest device)
    M = sum(max(adjusted_node_latency[device][i] for device in devices) for i in range(len(nodes)))

    # Node dependencies
    edges = extract_dag_edges(model)

    # Inter-device communication overhead
    comm_latency = estimate_communication_latencies(devices, edges)

    for (parent, child) in edges:
        for j, device_p in enumerate(devices):
            for k, device_c in enumerate(devices):
                # transfer penalty if parent -> child crosses devices
                transfer = (0 if device_p == device_c else comm_latency[device_p][device_c][parent][child])
          
                # activate only when parent -> device_p AND child -> device_c
                solver.Add(s[child] >= s[parent] + adjusted_node_latency[device_p][parent] + transfer - M * (2 - x[(parent, j)] - x[(child, k)]))

    # Global latency constraint: the total completion time (including intra- and inter-device comm. overhead)
    # should meet the overall latency requirement
    for i in range(len(nodes)):
        # completion = s[i] + its compute time on whichever device it's assigned
        completion_time = s[i] + sum(x[(i, j)] * adjusted_node_latency[devices[j]][i] for j in range(len(devices)))
        solver.Add(completion_time <= overall_latency_requirement)

    # Non-Parallelism constraint: each device can only process one node at time
    
    # Building a representation of the DAG reachability
    adj = {i: [] for i in range(len(nodes))}
    for parent, child in edges:
        adj[parent].append(child)
    
    # For each node, find all other nodes it can reach (its descendants).
    reachable = {i: set() for i in range(len(nodes))}
    for i in range(len(nodes)):
            q = [i]
            visited = {i}
            while q:
                u = q.pop(0)
                for v in adj.get(u, []):
                    if v not in visited:
                        visited.add(v)
                        reachable[i].add(v)
                        q.append(v)
    
    # Identifying pairs (i, j) that are independent
    independent_pairs = []
    for i in range(len(nodes)):
        for j in range(i + 1, len(nodes)):
            # If i cannot reach j, AND j cannot reach i, they are independent.
            if j not in reachable[i] and i not in reachable[j]:
                independent_pairs.append((i, j))
    
    # Only loop through the independent pairs you found.
    for i, j in independent_pairs:
        for d_idx, device in enumerate(devices):
            li = adjusted_node_latency[device][i]
            lj = adjusted_node_latency[device][j]

            # This binary variable decides the order ONLY for this independent pair.
            b = solver.BoolVar(f"order_{i}_{j}_on_{device}")
    
            # This is the "Big M" formulation. It correctly activates the constraints 
            # only when BOTH nodes are on the SAME device.
            EPS = 1e-6
            
            # Constraint 1: If i is before j (b=1)
            # The term (3 - x_i - x_j - b) becomes 0 only if x_i=1, x_j=1, and b=1.
            # Otherwise, the term is >= 1, making the constraint trivial.
            solver.Add(s[i] + li + EPS <= s[j] + M * (3 - x[(i, d_idx)] - x[(j, d_idx)] - b))
            
            # Constraint 2: If j is before i (b=0)
            # The term (2 - x_i - x_j + b) becomes 0 only if x_i=1, x_j=1, and b=0.
            # Otherwise, the term is >= 1, making the constraint trivial.
            solver.Add(s[j] + lj + EPS <= s[i] + M * (2 - x[(i, d_idx)] - x[(j, d_idx)] + b))

    # Solve the problem
    status = solver.Solve()

    START_TIME = {}
    END_TIME = {}
    MAX_TIME = 0
    
    for i in range(len(nodes)):
        completion_time = s[i].solution_value() + sum(x[(i, j)].solution_value() * adjusted_node_latency[devices[j]][i] for j in range(len(devices)))
        START_TIME[i] = s[i].solution_value()
        END_TIME[i] = completion_time
        
        if (completion_time > MAX_TIME):
            MAX_TIME = completion_time
    
    print(f"\nNode processing MAX TIME:{MAX_TIME}:.3f")

    node_assignments = {}
    x_list = []
    
    if status == solver.OPTIMAL:
        print('Solution:')
        
        for i in range(len(nodes)):
            for j in range(len(devices)):
                if x[(i, j)].solution_value() > 0:
                    node_assignments[nodes[i]] = devices[j]
                    x_list.append((i, j))
                    print(f'Node {nodes[i]} assigned to {devices[j]}')
        
        print('Objective value =', objective.Value())

        print("\nPrecedence constraint sanity check:")
        for (parent, child) in edges:
            for j, device_p in enumerate(devices):
                for k, device_c in enumerate(devices):
                    if x[(parent, j)].solution_value() == 1 and x[(child, k)].solution_value() == 1:
                        start_p = s[parent].solution_value()
                        start_c = s[child].solution_value()
                        
                        latency = adjusted_node_latency[device_p][parent]
                        
                        transfer = 0 if device_p == device_c else comm_latency[device_p][device_c][parent][child]
                        expected_min_start = start_p + latency + transfer
                        
                        print(f"\nEdge {parent} → {child} across {devices[j]} → {devices[k]}")
                        print(f"  s[{parent}] = {start_p:.3f}, s[{child}] = {start_c:.3f}, must be ≥ {expected_min_start:.3f}")
                        
                        assert start_c + 1e-4 >= expected_min_start, "Constraint violated!"

        print("\nNon-parallelism constraint sanity check:")
        for i in range(len(nodes)):
            for j in range(i + 1, len(nodes)):
                # Skip if there's a dependency in either direction
                if j in reachable[i] or i in reachable[j]:
                    continue 
        
                for d_idx, device in enumerate(devices):
                    if x[(i, d_idx)].solution_value() == 1 and x[(j, d_idx)].solution_value() == 1:
                        s_i = s[i].solution_value()
                        s_j = s[j].solution_value()
                        li = adjusted_node_latency[device][i]
                        lj = adjusted_node_latency[device][j]
        
                        finish_i = s_i + li
                        finish_j = s_j + lj
        
                        overlap = not (finish_i <= s_j or finish_j <= s_i)
        
                        print(f"\nDevice {device}: Node {i} [{s_i:.3f}, {finish_i:.3f}) "
                              f"vs Node {j} [{s_j:.3f}, {finish_j:.3f})")
                        
                        if overlap:
                            print("Overlap detected!")
                        else:
                            print("No overlap — constraint respected")
        
                        assert not overlap, f"Nodes {i} and {j} overlap on device {device}"
        
        node_energy_acc = {
            "D1": [],
            "D2": [],
            "D3": [],
            "D4": [],
            "D5": [],
            "D6": [],
            "D7": []
        }

        for i in range(len(nodes)):
            for j in range(len(devices)):
                dynamic_energy = node_energy[devices[j]][i]
                static_energy_per_node = (node_latency[devices[j]][i] / 1000) * static_energy[devices[j]]
                
                node_energy_acc[devices[j]].append(dynamic_energy + static_energy_per_node)
    
        print("\n=========================================================\n")
        verify_solution(nodes, devices, node_energy, node_energy_acc, node_latency,
                        node_flops, x, overall_latency_requirement, node_assignments, plot_summary, exp)
        print("\n=========================================================\n")

        return objective.Value(), MAX_TIME, START_TIME, END_TIME, [nodes, devices, node_energy, node_energy_acc, node_latency,
                                                                   node_flops, x_list, overall_latency_requirement, node_assignments]
    else:
        print('The problem does not have an optimal solution.')


In [None]:
def graph_partition_S2(model, nodes, devices, node_energy, node_flops, node_latency, overall_latency_requirement, device_capacity, device_utilization, static_energy, opt_obj, plot_summary=True, exp=0):
    solver = pywraplp.Solver.CreateSolver('SCIP')
    solver.EnableOutput()

    solver.set_time_limit(86400000)
    
    available_capacity = {device: device_capacity[device] * (1.0 - device_utilization[device]) for device in devices}
    adjusted_node_latency = {}

    for device in devices:
        adjusted_node_latency[device] = [node_latency[device][i] / (1.0 - device_utilization[device]) for i in range(len(node_latency[device]))]

    T_s = overall_latency_requirement / 1000
    
    # Variables: x[i][j] represents if node i is assigned to device j (binary)
    x = {}
    
    for i in range(len(nodes)):
        for j in range(len(devices)):
            x[(i, j)] = solver.BoolVar(f'x_{i}_{j}')

    # Defining start time variables
    s = [solver.NumVar(0, overall_latency_requirement, f"s_{i}") for i in range(len(nodes))]
    
    # Objective: Minimize completion time
    objective = solver.Objective()

    completion_time = solver.NumVar(0, solver.infinity(), 'completion_time')

    for i in range(len(nodes)):
        for j in range(len(devices)):
            expr = s[i] + x[(i, j)] * adjusted_node_latency[devices[j]][i]
            solver.Add(completion_time >= expr)

    objective.SetCoefficient(completion_time, 1)
    objective.SetMinimization()

    # Energy expression: total energy consumed across all node-device assignments
    energy_terms = []

    for i in range(len(nodes)):
        for j in range(len(devices)):
            # Scaling both dynamic and static energy by the device utilization
            dynamic = node_energy[devices[j]][i] * (adjusted_node_latency[devices[j]][i] / node_latency[devices[j]][i])
            static_per_node = (adjusted_node_latency[devices[j]][i] / 1000) * static_energy[devices[j]] # Convert ms to seconds

            energy = x[(i, j)] * (dynamic + static_per_node)
            energy_terms.append(energy)

    # Constraint: total energy must match optimal value from Stage 1
    energy_expr = solver.Sum(energy_terms)
    solver.Add(energy_expr == opt_obj)

    # Consistency constraint: each node must be assigned to exactly one device
    for i in range(len(nodes)):
        solver.Add(sum(x[(i, j)] for j in range(len(devices))) == 1)
    
    # Device capacity constraint: the workload assigned to each device (in GFLOPs) should
    # not surpass its computational capacity (in GFLOPS)
    for j, device in enumerate(devices):        
        solver.Add(sum(x[(i, j)] * node_flops[node] for i, node in enumerate(nodes)) <= (available_capacity[device] * T_s))
 
    # Node precedence constraint: if there are dependencies between nodes, the child node
    # processing must not start before its parent node finishes

    # Maximum possible finish time (e.g. sum of all node latencies on the slowest device)
    M = sum(max(adjusted_node_latency[device][i] for device in devices) for i in range(len(nodes)))

    # Node dependencies
    edges = extract_dag_edges(model)

    # Inter-device communication overhead
    comm_latency = estimate_communication_latencies(devices, edges)

    for (parent, child) in edges:
        for j, device_p in enumerate(devices):
            for k, device_c in enumerate(devices):
                # transfer penalty if parent -> child crosses devices
                transfer = (0 if device_p == device_c else comm_latency[device_p][device_c][parent][child])
          
                # activate only when parent -> device_p AND child -> device_c
                solver.Add(s[child] >= s[parent] + adjusted_node_latency[device_p][parent] + transfer - M * (2 - x[(parent, j)] - x[(child, k)]))

    # Non-Parallelism constraint: each device can only process one node at time
    
    # Building a representation of the DAG reachability
    adj = {i: [] for i in range(len(nodes))}
    for parent, child in edges:
        adj[parent].append(child)
    
    # For each node, find all other nodes it can reach (its descendants).
    reachable = {i: set() for i in range(len(nodes))}
    for i in range(len(nodes)):
            q = [i]
            visited = {i}
            while q:
                u = q.pop(0)
                for v in adj.get(u, []):
                    if v not in visited:
                        visited.add(v)
                        reachable[i].add(v)
                        q.append(v)
    
    # Identifying pairs (i, j) that are independent
    independent_pairs = []
    for i in range(len(nodes)):
        for j in range(i + 1, len(nodes)):
            # If i cannot reach j, AND j cannot reach i, they are independent.
            if j not in reachable[i] and i not in reachable[j]:
                independent_pairs.append((i, j))
    
    # Only loop through the independent pairs you found.
    for i, j in independent_pairs:
        for d_idx, device in enumerate(devices):
            li = adjusted_node_latency[device][i]
            lj = adjusted_node_latency[device][j]

            # This binary variable decides the order ONLY for this independent pair.
            b = solver.BoolVar(f"order_{i}_{j}_on_{device}")
    
            # This is the "Big M" formulation. It correctly activates the constraints 
            # only when BOTH nodes are on the SAME device.
            EPS = 1e-6
            
            # Constraint 1: If i is before j (b=1)
            # The term (3 - x_i - x_j - b) becomes 0 only if x_i=1, x_j=1, and b=1.
            # Otherwise, the term is >= 1, making the constraint trivial.
            solver.Add(s[i] + li + EPS <= s[j] + M * (3 - x[(i, d_idx)] - x[(j, d_idx)] - b))
            
            # Constraint 2: If j is before i (b=0)
            # The term (2 - x_i - x_j + b) becomes 0 only if x_i=1, x_j=1, and b=0.
            # Otherwise, the term is >= 1, making the constraint trivial.
            solver.Add(s[j] + lj + EPS <= s[i] + M * (2 - x[(i, d_idx)] - x[(j, d_idx)] + b))

    # Solve the problem
    status = solver.Solve()

    START_TIME = {}
    END_TIME = {}
    MAX_TIME = 0
    
    for i in range(len(nodes)):
        completion_time = s[i].solution_value() + sum(x[(i, j)].solution_value() * adjusted_node_latency[devices[j]][i] for j in range(len(devices)))
        START_TIME[i] = s[i].solution_value()
        END_TIME[i] = completion_time
        
        if (completion_time > MAX_TIME):
            MAX_TIME = completion_time
    
    print(f"\nNode processing MAX TIME:{MAX_TIME}:.3f")    
    
    node_assignments = {}
    x_list = []
    
    if status == solver.OPTIMAL or status == pywraplp.Solver.FEASIBLE:
        print('Solution:')
        
        for i in range(len(nodes)):
            for j in range(len(devices)):
                if x[(i, j)].solution_value() > 0:
                    x_list.append((i, j))
                    node_assignments[nodes[i]] = devices[j]
                    print(f'Node {nodes[i]} assigned to {devices[j]}')
        
        print('Objective value =', objective.Value())

        print("\nPrecedence constraint sanity check:")
        for (parent, child) in edges:
            for j, device_p in enumerate(devices):
                for k, device_c in enumerate(devices):
                    if x[(parent, j)].solution_value() == 1 and x[(child, k)].solution_value() == 1:
                        start_p = s[parent].solution_value()
                        start_c = s[child].solution_value()
                        
                        latency = adjusted_node_latency[device_p][parent]
                        
                        transfer = 0 if device_p == device_c else comm_latency[device_p][device_c][parent][child]
                        expected_min_start = start_p + latency + transfer
                        
                        print(f"\nEdge {parent} → {child} across {devices[j]} → {devices[k]}")
                        print(f"  s[{parent}] = {start_p:.3f}, s[{child}] = {start_c:.3f}, must be ≥ {expected_min_start:.3f}")
                        
                        assert start_c + 1e-4 >= expected_min_start, "Constraint violated!"

        print("\nNon-parallelism constraint sanity check:")
        for i in range(len(nodes)):
            for j in range(i + 1, len(nodes)):
                # Skip if there's a dependency in either direction
                if j in reachable[i] or i in reachable[j]:
                    continue
        
                for d_idx, device in enumerate(devices):
                    if x[(i, d_idx)].solution_value() == 1 and x[(j, d_idx)].solution_value() == 1:
                        s_i = s[i].solution_value()
                        s_j = s[j].solution_value()
                        li = adjusted_node_latency[device][i]
                        lj = adjusted_node_latency[device][j]
        
                        finish_i = s_i + li
                        finish_j = s_j + lj
        
                        overlap = not (finish_i <= s_j or finish_j <= s_i)
        
                        print(f"\nDevice {device}: Node {i} [{s_i:.3f}, {finish_i:.3f}) "
                              f"vs Node {j} [{s_j:.3f}, {finish_j:.3f})")
                        
                        if overlap:
                            print("Overlap detected!")
                        else:
                            print("No overlap — constraint respected")
        
                        assert not overlap, f"Nodes {i} and {j} overlap on device {device}"
        
        node_energy_acc = {
            "D1": [],
            "D2": [],
            "D3": [],
            "D4": [],
            "D5": [],
            "D6": [],
            "D7": []

        }

        for i in range(len(nodes)):
            for j in range(len(devices)):
                dynamic_energy = node_energy[devices[j]][i]
                static_energy_per_node = (node_latency[devices[j]][i] / 1000) * static_energy[devices[j]]
                
                node_energy_acc[devices[j]].append(dynamic_energy + static_energy_per_node)
    
        print("\n=========================================================\n")
        verify_solution(nodes, devices, node_energy, node_energy_acc, node_latency,
                        node_flops, x, overall_latency_requirement, node_assignments, plot_summary, exp)
        print("\n=========================================================\n")
        
        return objective.Value(), MAX_TIME, START_TIME, END_TIME, [nodes, devices, node_energy, node_energy_acc, node_latency,
                                                                   node_flops, x_list, overall_latency_requirement, node_assignments]
    else:
        print('The problem does not have an optimal solution.')


In [None]:
import copy
import random
import time

def heuristic1(model, nodes, devices, node_energy, node_flops, node_latency, overall_latency_requirement, device_capacity, device_utilization, static_energy):
    energy = 0
    time = 0

    available_capacity = {device: device_capacity[device] * (1.0 - device_utilization[device]) for device in devices}

    adjusted_node_latency = {}

    for device in devices:
        adjusted_node_latency[device] = [node_latency[device][i] / (1.0 - device_utilization[device]) for i in range(len(node_latency[device]))]

    T_s = overall_latency_requirement / 1000
        
    edges = extract_dag_edges(model)
    comm_latency = estimate_communication_latencies(devices, edges)

    # Contains all the fixed dependencies for each node
    dependencies = {i: [] for i in range(len(nodes))}
    for item in edges:
        if item[0] not in dependencies[item[1]]:
            dependencies[item[1]].append(item[0])

    # Containts the start of execution for each node
    time_of_execution_start = {i: None for i in range(len(nodes))}
    
    # Containts the end of execution for each node
    time_of_execution_end = {i: None for i in range(len(nodes))}

    # Containts the earliest time a device is available
    device_disponibility_time = {i: 0 for i in devices}

    # Containts the updated dependencies
    updated_depenencies = copy.deepcopy(dependencies)

    # List of associations
    association = {}
    x_list = []
    
    while updated_depenencies:
        # Get all nodes without dependencies in this iteration
        nodes_with_zero = [k for k, v in updated_depenencies.items() if len(v) == 0]

        # Iterate through this nodes, since they are available for allocation
        for i in nodes_with_zero:
            energy_list = []
            device_energy = {}
            
            #For each device, get the amount of energy spent for the node i
            for j, device in enumerate(devices):
                dynamic_energy = node_energy[devices[j]][i] * (adjusted_node_latency[devices[j]][i] / node_latency[devices[j]][i])
                static_energy_per_node = (adjusted_node_latency[devices[j]][i] / 1000) * static_energy[devices[j]]
                
                total_energy = dynamic_energy + static_energy_per_node
                energy_list.append((j, device, total_energy))
                
                device_energy[device] = total_energy

            # Create a list of sorted devices, by the amount of energy spent
            sorted_j = [j for _, j, _ in sorted(energy_list, key=lambda x: x[2])]
            sorted_j_ind = [j for j, _, _ in sorted(energy_list, key=lambda x: x[2])]

            chosen_device = sorted_j[0]
            chosen_device_ind = sorted_j_ind[0]
            max_inter_device_comm = 0
            max_start_time_of_depenencies = 0
            
            # If exists dependencies (on the persisted/original dictionary)
            if dependencies[i]:
    
                # Get the latest time of execution end among the dependencies
                max_start_time_of_depenencies = max(time_of_execution_end[x] for x in dependencies[i])

                # Find the longest communication delay from any parent of node i, if assigned to a different device
                # Iterate over all parent nodes (dependencies) of node i
                for dep in dependencies[i]:
                     # Check if the parent node (dep) and the current node (i) are assigned to different devices
                    if association[nodes[dep]] != chosen_device:
                        # If they are on different devices, retrieve the communication latency between them
                        inter_dev_time = comm_latency[association[nodes[dep]]][chosen_device][dep][i]
                    else:
                        # Otherwise, no communication delay is incurred
                        inter_dev_time = 0
                    
                    # Update the maximum inter-device communication time if this one is greater
                    if inter_dev_time > max_inter_device_comm:
                        max_inter_device_comm = inter_dev_time

            # Execution start of node i is the max between the device disponibity and the latest dependency
            time_of_execution_start[i] = max(device_disponibility_time[chosen_device], max_start_time_of_depenencies) 

            # Execution end of node i is equal to the start time + the processing time + the inter device communication time
            time_of_execution_end[i] = time_of_execution_start[i] + adjusted_node_latency[chosen_device][i] + max_inter_device_comm
            
            device_disponibility_time[chosen_device] = max(device_disponibility_time[chosen_device], time_of_execution_end[i])

            print("Node: ", i, " -- Device: ", chosen_device, " -- Start: ", time_of_execution_start[i], " -- End: ", time_of_execution_end[i], "\n\n")
                
            association[nodes[i]] = chosen_device
            x_list.append((i, chosen_device_ind))
            energy = energy + device_energy[chosen_device]
            time = max(time, time_of_execution_end[i])
            
            del updated_depenencies[i]
    
            for child, dep in updated_depenencies.items():
                if i in dep:
                    updated_depenencies[child].remove(i)
        
    node_energy_acc = {
        "D1": [],
        "D2": [],
        "D3": [],
        "D4": [],
        "D5": [],
        "D6": [],
        "D7": [],
    }

    for i in range(len(nodes)):
        for j in range(len(devices)):
            dynamic_energy = node_energy[devices[j]][i]
            static_energy_per_node = (node_latency[devices[j]][i] / 1000) * static_energy[devices[j]]
            
            node_energy_acc[devices[j]].append(dynamic_energy + static_energy_per_node)

    print(f"ENERGY = {energy}:.3f")
    print(f"TIME = {time}:.3f")

    return energy, time, time_of_execution_start, time_of_execution_end, [nodes, devices, node_energy, node_energy_acc, node_latency,
                                                                          node_flops, x_list, overall_latency_requirement, association]


In [None]:
def heuristic2(model, nodes, devices, node_energy, node_flops, node_latency, overall_latency_requirement, device_capacity, device_utilization, static_energy):
    energy = 0
    time = 0

    available_capacity = {device: device_capacity[device] * (1.0 - device_utilization[device]) for device in devices}

    adjusted_node_latency = {}

    for device in devices:
        adjusted_node_latency[device] = [node_latency[device][i] / (1.0 - device_utilization[device]) for i in range(len(node_latency[device]))]

    T_s = overall_latency_requirement / 1000
        
    edges = extract_dag_edges(model)
    comm_latency = estimate_communication_latencies(devices, edges)

    # Contains all the fixed dependencies for each node
    dependencies = {i: [] for i in range(len(nodes))}
    
    for item in edges:
        if item[0] not in dependencies[item[1]]:
            dependencies[item[1]].append(item[0])

    # Contains the start of execution for each node
    time_of_execution_start = {i: None for i in range(len(nodes))}
    
    # Contains the end of execution for each node
    time_of_execution_end = {i: None for i in range(len(nodes))}

    # Contains the earliest time a device is available
    device_disponibility_time = {i: 0 for i in devices}

    # Containts the updated dependencies
    updated_depenencies = copy.deepcopy(dependencies)

    # List of associations
    association = {}
    x_list = []
    
    while updated_depenencies:
        # Get all nodes without dependencies in this iteration
        nodes_with_zero = [k for k, v in updated_depenencies.items() if len(v) == 0]

        # Iterate through this nodes, since they are available for allocation
        for i in nodes_with_zero:
            energy_list = []
            end_list = []
            device_energy = {}
            start_time_aux = {}
            end_time_aux = {}
            
            #For each device we get the amount of energy spent for the node i
            for j, device in enumerate(devices):
                dynamic_energy = node_energy[devices[j]][i] * (adjusted_node_latency[devices[j]][i] / node_latency[devices[j]][i])
                static_energy_per_node = (adjusted_node_latency[devices[j]][i] / 1000) * static_energy[devices[j]]
                
                total_energy = dynamic_energy + static_energy_per_node
                energy_list.append((device, total_energy))
                
                device_energy[device] = total_energy

                max_inter_device_comm = 0
                max_start_time_of_depenencies = 0
                
                # If exists dependencies (on the persisted/original dictionary)
                if dependencies[i]:
        
                    # Get the latest time of execution end among the dependencies
                    max_start_time_of_depenencies = max(time_of_execution_end[x] for x in dependencies[i])
    
                    # Find the longest communication delay from any parent of node i, if assigned to a different device
                    # Iterate over all parent nodes (dependencies) of node i
                    for dep in dependencies[i]:
                        # Check if the parent node (dep) and the current node (i) are assigned to different devices
                        if association[nodes[dep]] != chosen_device:
                            # If they are on different devices, retrieve the communication latency between them
                            inter_dev_time = comm_latency[association[nodes[dep]]][chosen_device][dep][i]
                        else:
                            # Otherwise, no communication delay is incurred
                            inter_dev_time = 0
                        
                        # Update the maximum inter-device communication time if this one is greater
                        if inter_dev_time > max_inter_device_comm:
                            max_inter_device_comm = inter_dev_time

                start_time_aux[device] = max(device_disponibility_time[device], max_start_time_of_depenencies) 

                end_list.append((j, device, start_time_aux[device] + adjusted_node_latency[device][i] + max_inter_device_comm))
                end_time_aux[device] = start_time_aux[device] + adjusted_node_latency[device][i] + max_inter_device_comm


            # Create a list of sorted devices, by the amount of energy spent
            sorted_j = [j for _, j, _ in sorted(end_list, key=lambda x: x[2])]
            sorted_j_ind = [j for j, _, _ in sorted(end_list, key=lambda x: x[2])]

            chosen_device = sorted_j[0]
            chosen_device_ind = sorted_j_ind[0]

            # Execution start of node i is the max between the device disponibity and the latest dependency
            time_of_execution_start[i] = start_time_aux[chosen_device] 

            # Execution end of node i is equal to the start time + the processing time + the inter device communication delay
            time_of_execution_end[i] = end_time_aux[chosen_device]
            
            device_disponibility_time[chosen_device] = max(device_disponibility_time[chosen_device], time_of_execution_end[i])

            print("Node: ", i, " -- Device: ", chosen_device, " -- Start: ", time_of_execution_start[i], " -- End: ", time_of_execution_end[i], "\n\n")
                
            association[nodes[i]] = chosen_device
            x_list.append((i, chosen_device_ind))
            energy = energy + device_energy[chosen_device]
            time = max(time, time_of_execution_end[i])
            
            del updated_depenencies[i]
    
            for child, dep in updated_depenencies.items():
                if i in dep:
                    updated_depenencies[child].remove(i)

    node_energy_acc = {
        "D1": [],
        "D2": [],
        "D3": [],
        "D4": [],
        "D5": [],
        "D6": [],
        "D7": [],
    }

    for i in range(len(nodes)):
        for j in range(len(devices)):
            dynamic_energy = node_energy[devices[j]][i]
            static_energy_per_node = (node_latency[devices[j]][i] / 1000) * static_energy[devices[j]]
            
            node_energy_acc[devices[j]].append(dynamic_energy + static_energy_per_node)

    print(f"ENERGY = {energy}:.3f")
    print(f"TIME = {time}:.3f")
    
    return energy, time, time_of_execution_start, time_of_execution_end, [nodes, devices, node_energy, node_energy_acc, node_latency,
                                                                          node_flops, x_list, overall_latency_requirement, association]


In [None]:
device_capabilities = {
    'D1': D1_GFLOPS,
    'D2': D2_GFLOPS,
    'D3': D3_GFLOPS,
    'D4': D4_GFLOPS,
    'D5': D5_GFLOPS,
    'D6': D6_GFLOPS,
    'D7': D7_GFLOPS
}

device_efficiency = {
    'D1': D1_EFFICIENCY,
    'D2': D2_EFFICIENCY,
    'D3': D3_EFFICIENCY,
    'D4': D4_EFFICIENCY,
    'D5': D5_EFFICIENCY,
    'D6': D6_EFFICIENCY,
    'D7': D7_EFFICIENCY
}

static_energy = {
    'D1': D1_STATIC_POWER,
    'D2': D2_STATIC_POWER,
    'D3': D3_STATIC_POWER,
    'D4': D4_STATIC_POWER,
    'D5': D5_STATIC_POWER,
    'D6': D6_STATIC_POWER,
    'D7': D7_STATIC_POWER
}

device_utilization = {
    'D1': 0.7,
    'D2': 0.7,
    'D3': 0.7,
    'D4': 0.0,
    'D5': 0.0,
    'D6': 0.0,
    'D7': 0.0
}

computing_time = DAG_PROFILE['computing_time']
energy_cost = DAG_PROFILE['energy_cost']

nodes = []

for node in node_flops:
    nodes.append(node)

overall_latency_requirement = 20

devices = ["D1", "D2", "D3", "D4", "D5", "D6", "D7"]


### 5. Solving and Analysis

In [None]:
import time

start = time.time()
F1_opt_obj, F1_MAX_TIME, F1_START_TIME, F1_END_TIME, F1_sol = graph_partition_S1(model, nodes, 
                                            devices, energy_cost, node_flops, computing_time,
                                            overall_latency_requirement, device_capabilities, device_utilization, 
                                            static_energy, plot_summary=True, exp=5)
end = time.time()
F1_exe_time = end - start

In [None]:
start = time.time()
F2_opt_obj, F2_MAX_TIME, F2_START_TIME, F2_END_TIME, F2_sol = graph_partition_S2(model, nodes, 
                                                devices, energy_cost, node_flops, computing_time,
                                                overall_latency_requirement, device_capabilities, device_utilization, 
                                                static_energy, F1_opt_obj, plot_summary=True, exp=5)
end = time.time()
F2_exe_time = end - start

In [None]:
start = time.time()
HE1_opt_obj, HE1_MAX_TIME, HE1_START_TIME, HE1_END_TIME, HE1_sol = heuristic1(model, nodes, devices, energy_cost, node_flops, computing_time,
                overall_latency_requirement, device_capabilities, device_utilization, static_energy)
end = time.time()
HE1_exe_time = end - start

In [None]:
start = time.time()
HE2_opt_obj, HE2_MAX_TIME, HE2_START_TIME, HE2_END_TIME, HE2_sol = heuristic2(model, nodes, devices, energy_cost, node_flops, computing_time,
            overall_latency_requirement, device_capabilities, device_utilization, static_energy)
end = time.time()
HE2_exe_time = end - start

### 6. Figure Generation

In [None]:
import pickle


# Uncomment to export results to a .pkl file
'''
Solution_Yolo = {}
Solution_Yolo['F1'] = [F1_opt_obj, F1_MAX_TIME, F1_START_TIME, F1_END_TIME, F1_sol, F1_exe_time]
Solution_Yolo['F2'] = [F2_opt_obj, F2_MAX_TIME, F2_START_TIME, F2_END_TIME, F2_sol, F2_exe_time]
Solution_Yolo['H1'] = [HE1_opt_obj, HE1_MAX_TIME, HE1_START_TIME, HE1_END_TIME, HE1_sol, HE1_exe_time]
Solution_Yolo['H2'] = [HE2_opt_obj, HE2_MAX_TIME, HE2_START_TIME, HE2_END_TIME, HE2_sol, HE2_exe_time]

dl_model = yolo11n

# Save to file
with open(dl_model + "_solution.pkl", "wb") as f:
    pickle.dump(Solution_Yolo, f)
'''

In [None]:
import pickle


# Uncomment to import results from a .pkl file
'''
dl_model = yolo11n

# Load from file
with open(dl_model + "_solution.pkl", "rb") as f:
    Solution_Yolo = pickle.load(f)

[F1_opt_obj, F1_MAX_TIME, F1_START_TIME, F1_END_TIME, F1_sol, F1_exe_time] = Solution_Yolo['F1']
[F2_opt_obj, F2_MAX_TIME, F2_START_TIME, F2_END_TIME, F2_sol, F2_exe_time] = Solution_Yolo['F2']
[HE1_opt_obj, HE1_MAX_TIME, HE1_START_TIME, HE1_END_TIME, HE1_sol, HE1_exe_time] = Solution_Yolo['H1']
[HE2_opt_obj, HE2_MAX_TIME, HE2_START_TIME, HE2_END_TIME, HE2_sol, HE2_exe_time]  = Solution_Yolo['H2']
'''

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from matplotlib.patches import Patch


def format_data(SOLUTION):
    # Initialize dictionaries to store counts and summaries
    device_node_count = {device: 0 for device in SOLUTION[1]}
    flops_per_device = {device: 0 for device in SOLUTION[1]}
    total_latency = {device: 0 for device in SOLUTION[1]}
    total_energy = {device: 0 for device in SOLUTION[1]}
    total_acc_energy = {device: 0 for device in SOLUTION[1]}

    # Calculate and summarize information
    for i, node in enumerate(SOLUTION[0]):
        assigned_device = SOLUTION[8][node]
        device_node_count[assigned_device] += 1

        for j, device in enumerate(SOLUTION[1]):
            if (i, j) in SOLUTION[6]:
                total_latency[device] += SOLUTION[4][device][i]
                total_energy[device] += SOLUTION[2][device][i]
                total_acc_energy[device] += SOLUTION[3][device][i]
    
    for node, device in SOLUTION[8].items():
        if node in SOLUTION[5]: 

            flops_per_device[device] += SOLUTION[5][node]

    # Check latency requirements
    meets_latency_requirements = all(latency <= SOLUTION[7] for latency in total_latency.values())
    
    # Print results
    print("Meets Latency Requirements:", meets_latency_requirements)

    print("\nSolution Verification Summary:\n")
    for device in SOLUTION[1]:
        print(f" Device: {device}")
        print(f" Number of Nodes Assigned: {device_node_count[device]}")
        print(f" GFLOPs computed: {flops_per_device[device]:.4f}")
        print(f" Total computing time: {total_latency[device]:.4f} ms")
        print(f" Total energy cost: {total_energy[device]:.4f} J")
        print("-" * 60)
    
    #print("\nTotal Latency: {0:.4f} milliseconds".format(sum(total_latency.values())))
    print("Total Energy: {0:.4f} J".format(sum(total_energy.values())))

    flops_per_device["Total"] = sum(flops_per_device.values())
    total_latency["Total"] = sum(total_latency.values())
    total_acc_energy["Total"] = sum(total_acc_energy.values())

    device_node_count["Total"] = sum(device_node_count.values())
    
    return flops_per_device, total_latency, total_acc_energy, device_node_count


def result_analysis(F1, F2, H1, H2, F1_COMP_TIME, F2_COMP_TIME, H1_COMP_TIME, H2_COMP_TIME, F1_EXEC_TIME, F2_EXEC_TIME, H1_EXEC_TIME, H2_EXEC_TIME):
    F2_flops_per_device, F2_total_latency, F2_total_acc_energy, F2_device_node_count = format_data(F2)
    F1_flops_per_device, F1_total_latency, F1_total_acc_energy, F1_device_node_count = format_data(F1)
    H1_flops_per_device, H1_total_latency, H1_total_acc_energy, H1_device_node_count = format_data(H1)
    H2_flops_per_device, H2_total_latency, H2_total_acc_energy, H2_device_node_count = format_data(H2)

    plot_results(F1_flops_per_device, F2_flops_per_device, H1_flops_per_device, H2_flops_per_device, "GFLOPs computed per device", "Devices", "GFLOPs", F1, "1")
    plot_results(F1_total_latency, F2_total_latency, H1_total_latency, H2_total_latency, "Total computing time per device", "Devices", "Computing time (ms)", F1, "2")
    plot_results(F1_total_acc_energy, F2_total_acc_energy, H1_total_acc_energy, H2_total_acc_energy, "Total computing energy per device", "Devices", "Energy (J)", F1, "3") 
    plot_results(F1_device_node_count, F2_device_node_count, H1_device_node_count, H2_device_node_count, "Total device node count", "Devices", "Number of nodes", F1, "4")

    completion_time_comparison(F1_COMP_TIME, F2_COMP_TIME, H1_COMP_TIME, H2_COMP_TIME)
    execution_time_comparison(F1_EXEC_TIME, F2_EXEC_TIME, H1_EXEC_TIME, H2_EXEC_TIME)


def plot_results(F1_data, F2_data, H1_data, H2_data, title, xlabel, ylabel, F1, plot_ind):
    DEVICES_NAME = {"D1": "NV-L40",
                    "D2": "NV-T4",
                    "D3": "NV-A30",
                    "D4": "AL-U2",
                    "D5": "AL-U1",
                    "D6": "BL3",
                    "D7": "BL2"}
    
    labels = [DEVICES_NAME[x] if x != "Total" else "TOTAL" for x in F1_data.keys()]

    x_labels = np.arange(len(labels))

    F1_data = list(F1_data.values())
    F2_data = list(F2_data.values())
    H1_data = list(H1_data.values())
    H2_data = list(H2_data.values())

    plt.figure(figsize=(4,2))
    plt.rcParams["font.family"] = "Ubuntu"
    plt.xticks(x_labels, labels, rotation=30)

    F1_rects = plt.bar(x_labels-0.30, F1_data, 0.2, color="#b0d4ec", edgecolor='black', label='1-stage', hatch="/")
    F2_rects = plt.bar(x_labels-0.10, F2_data, 0.2, color='#006bb3', edgecolor='black', label='2-stage', hatch="\\")
    H1_rects = plt.bar(x_labels+0.10, H1_data, 0.2, color='#c0b1ec', edgecolor='black', label='heu_es', hatch="-")
    H2_rects = plt.bar(x_labels+0.30, H2_data, 0.2, color='#ecb1dd', edgecolor='black', label='heu_ct', hatch="|||||")

    plt.xlabel(xlabel) 
    plt.ylabel(ylabel) 
    plt.grid(axis = 'y', linestyle='--', linewidth=0.6)

    plt.legend()
    plt.ylim(0, max(F1_data)+(0.4*max(F1_data)))
    plt.savefig("figs/" + plot_ind + ".pdf", dpi=600, bbox_inches='tight', pad_inches=0.1)
    plt.savefig("figs/PNG/" + plot_ind + ".png", dpi=600, bbox_inches='tight', pad_inches=0.1)


def completion_time_comparison(F1_COMP_TIME, F2_COMP_TIME, H1_COMP_TIME, H2_COMP_TIME): 
    F1_MAX_TIME = [round(x, 3) for x in F1_COMP_TIME]
    F2_MAX_TIME = [round(x, 3) for x in F2_COMP_TIME]
    H1_MAX_TIME = [round(x, 3) for x in H1_COMP_TIME]
    H2_MAX_TIME = [round(x, 3) for x in H2_COMP_TIME]

    x_labels = np.arange(len(F1_MAX_TIME))

    plt.figure(figsize=(4,1.2))
    plt.rcParams["font.family"] = "Ubuntu"

    F1_rects = plt.bar(x_labels-0.60, F1_MAX_TIME, 0.40, color = "#b0d4ec", edgecolor='black', label='1-stage', hatch="/")
    F2_rects = plt.bar(x_labels-0.20, F2_MAX_TIME, 0.40, color = '#006bb3', edgecolor='black', label='2-stage', hatch="\\")
    H1_rects = plt.bar(x_labels+0.20, H1_MAX_TIME, 0.40, color = '#c0b1ec', edgecolor='black', label='heu_es', hatch="-")
    H2_rects = plt.bar(x_labels+0.60, H2_MAX_TIME, 0.40, color = '#ecb1dd', edgecolor='black', label='heu_ct', hatch="|")
    
    plt.bar_label(F1_rects, fontsize=10, padding=3)
    plt.bar_label(F2_rects, fontsize=10, padding=3)
    plt.bar_label(H1_rects, fontsize=10, padding=3)
    plt.bar_label(H2_rects, fontsize=10, padding=3)

    plt.ylabel("Completion time (ms)      ")
    plt.xticks([-0.6, -0.2, +0.2, +0.6], ["1-stage", "2-stage", "heu_es", "heu_ct"])
    plt.grid(axis = 'y', linestyle='--', linewidth=0.6)

    plt.ylim(0, max(F1_MAX_TIME[0], F2_MAX_TIME[0])+(0.25*max(F1_MAX_TIME[0], F2_MAX_TIME[0])))
    plt.savefig("figs/5.pdf", dpi=600, bbox_inches='tight', pad_inches=0.1)
    plt.savefig("figs/PNG/5.png", dpi=600, bbox_inches='tight', pad_inches=0.1)


def execution_time_comparison(F1_EXEC_TIME, F2_EXEC_TIME, H1_EXEC_TIME, H2_EXEC_TIME): 
    F1_MAX_TIME = [round(x, 3) for x in F1_EXEC_TIME]
    F2_MAX_TIME = [round(x, 3) for x in F2_EXEC_TIME]
    H1_MAX_TIME = [round(x, 3) for x in H1_EXEC_TIME]
    H2_MAX_TIME = [round(x, 3) for x in H2_EXEC_TIME]
    
    x_labels = np.arange(len(F1_MAX_TIME))

    plt.figure(figsize=(4,1.2))
    plt.rcParams["font.family"] = "Ubuntu"

    F1_rects = plt.bar(x_labels-0.60, F1_MAX_TIME, 0.40, color = "#b0d4ec", edgecolor='black', label='stage 1', hatch="/")
    F2_rects = plt.bar(x_labels-0.20, F2_MAX_TIME, 0.40, color = '#006bb3', edgecolor='black', label='stage 2', hatch="\\")
    H1_rects = plt.bar(x_labels+0.20, H1_MAX_TIME, 0.40, color = '#c0b1ec', edgecolor='black', label='heu_es', hatch="-")
    H2_rects = plt.bar(x_labels+0.60, H2_MAX_TIME, 0.40, color = '#ecb1dd', edgecolor='black', label='heu_ct', hatch="|")
    
    plt.bar_label(F1_rects, fontsize=10, padding=3)
    plt.bar_label(F2_rects, fontsize=10, padding=3)
    plt.bar_label(H1_rects, fontsize=10, padding=3)
    plt.bar_label(H2_rects, fontsize=10, padding=3)
    
    plt.ylabel("Execution time (s)         ")
    plt.xticks([-0.6, -0.2, +0.2, +0.6], ["stage 1", "stage 2", "heu_es", "heu_ct"])
    plt.grid(axis = 'y', linestyle='--', linewidth=0.6)
    plt.yscale('log')
    plt.ylim(0, max(F1_MAX_TIME[0], F2_MAX_TIME[0])+(40*max(F1_MAX_TIME[0], F2_MAX_TIME[0])))
    plt.savefig("figs/6.pdf", dpi=600, bbox_inches='tight', pad_inches=0.1)
    plt.savefig("figs/PNG/6.png", dpi=600, bbox_inches='tight', pad_inches=0.1)


result_analysis(F1_sol, F2_sol, HE1_sol, HE2_sol, [F1_MAX_TIME], [F2_MAX_TIME], [HE1_MAX_TIME], [HE2_MAX_TIME],
               [F1_exe_time], [F2_exe_time], [HE1_exe_time], [HE2_exe_time])
