Skip to content

Commit

Permalink
[Graph] Minor graph benchmark fix (#313)
Browse files Browse the repository at this point in the history
Co-authored-by: Allan Lin <allan.lin@centml.ai>
  • Loading branch information
Aalanli and Allan Lin committed Jul 16, 2023
1 parent 0c18446 commit 75dc607
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import os
import numpy as np

from hidet.runtime import CompiledModule
from hidet.runtime import CompiledTask
from hidet.graph.flow_graph import FlowGraph, Operator, Tensor, GraphForwardInstrument


Expand Down Expand Up @@ -57,10 +57,11 @@ def after_operator(self, op: Operator, inputs: List[Tensor], outputs: List[Tenso
if not self.benchmarking:
return

task_func: CompiledModule = op.compiled_task
task_func: CompiledTask = op.compiled_task
latency: List[float] = task_func.profile(
*inputs, *outputs, warmup=self.warmup, number=self.number, repeat=self.repeat
)

self.latency_list.append((op, float(np.median(latency)), float(np.std(latency))))

def after_graph(self, graph: FlowGraph, inputs: List[Tensor], outputs: List[Tensor]) -> None:
Expand Down
7 changes: 7 additions & 0 deletions python/hidet/runtime/compiled_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,13 @@ def run_async(self, inputs):

return outputs

def profile(self, *args, warmup=1, number=2, repeat=10):
num_outputs = len(self.meta_data.outputs)
inputs = args[:num_outputs]
outputs = args[num_outputs:]
candidate = self.candidates[self.pick_best_candidate(inputs, outputs)]
return candidate.profile(*args, warmup=warmup, number=number, repeat=repeat)


def load_compiled_task(compiled_task_dir: str) -> CompiledTask:
return CompiledTask(compiled_task_dir)
Expand Down

0 comments on commit 75dc607

Please sign in to comment.