diff --git a/benchmarks/base_classes.py b/benchmarks/base_classes.py index 5d328f62b904..38e7663fbf6e 100644 --- a/benchmarks/base_classes.py +++ b/benchmarks/base_classes.py @@ -162,6 +162,25 @@ def run_inference(self, pipe, args): guidance_scale=1.0, ) + def benchmark(self, args): + flush() + + print(f"[INFO] {self.pipe.__class__.__name__}: Running benchmark with: {vars(args)}\n") + + time = benchmark_fn(self.run_inference, self.pipe, args) # in seconds. + memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs. + benchmark_info = BenchmarkInfo(time=time, memory=memory) + + pipeline_class_name = str(self.pipe.__class__.__name__) + flush() + csv_dict = generate_csv_dict( + pipeline_cls=pipeline_class_name, ckpt=self.lora_id, args=args, benchmark_info=benchmark_info + ) + filepath = self.get_result_filepath(args) + write_to_csv(filepath, csv_dict) + print(f"Logs written to: {filepath}") + flush() + class ImageToImageBenchmark(TextToImageBenchmark): pipeline_class = AutoPipelineForImage2Image