Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 2 additions & 100 deletions src/backends/ezkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,22 +396,8 @@ def circuitization_pipeline(self, model_path, output_path, input_file_path=None,
logger.warning("Failed to calibrate settings")
circuitization_data["calibrate-settings_error"] = err
else:
# If no input file, create dummy calibration
try:
logger.info("No input file provided, creating dummy calibration data")
onnx_model = onnx.load(model_path)
calibration_path = os.path.join(output_path, f"{model_name}_calibration.json")
self._create_dummy_calibration(onnx_model, calibration_path, segment_details)
circuitization_data["calibration"] = calibration_path
ok, err = self.calibrate_settings(model_path=model_path, settings_path=settings_path, data_path=calibration_path)
if not ok:
logger.warning("Failed to calibrate settings")
circuitization_data["calibrate-settings_error"] = err
except Exception as e:
error_msg = f"Failed to create dummy calibration: {e}"
logger.warning(error_msg)
logger.warning("Skipping calibration step")
circuitization_data["calibration"] = error_msg
# If no input file, log and skip calibration
logger.info("No input file provided, skipping calibration step")

# Step 4: Compile circuit
logger.info(f"Compiling circuit for {model_path}")
Expand All @@ -436,90 +422,6 @@ def circuitization_pipeline(self, model_path, output_path, input_file_path=None,

return circuitization_data

@staticmethod
def _create_dummy_calibration(onnx_model, output_path, segment_details=None):
"""
Create a dummy calibration file for an ONNX model, handling multiple inputs if needed.

Args:
onnx_model: ONNX model
output_path: Path where to save the calibration file
segment_details: Details of the segment including shape information
"""
# Get input shapes from the ONNX model
input_shapes = []
input_names = []

# First, collect all input shapes from the ONNX model (excluding initializers)
initializers = {init.name for init in onnx_model.graph.initializer}
for input_info in onnx_model.graph.input:
if input_info.name not in initializers: # Skip weights and biases
input_name = input_info.name
input_names.append(input_name)

dim_shape = []
for dim in input_info.type.tensor_type.shape.dim:
if dim.dim_param:
dim_shape.append(1) # Replace named dimensions with 1
else:
dim_shape.append(dim.dim_value if dim.dim_value != 0 else 1) # Replace 0 with 1

if dim_shape: # Only add non-empty shapes
input_shapes.append((input_name, dim_shape))

logger.info(f"Found {len(input_shapes)} inputs in ONNX model: {input_shapes}")

# If we have metadata, use it to enhance our understanding of the shapes
if segment_details and "shape" in segment_details and "tensor_shape" in segment_details["shape"]:
tensor_shape = segment_details["shape"]["tensor_shape"]
if "input" in tensor_shape and len(tensor_shape["input"]) > 0:
# Try to map each ONNX input to the corresponding metadata shape
for i, (input_name, shape) in enumerate(input_shapes):
for meta_shape in tensor_shape["input"]:
# Check if this shape contains string dimensions (likely actual inputs, not weights)
if any(isinstance(dim, str) for dim in meta_shape):
# Found a shape with named dimensions, use it to enhance our understanding
enhanced_shape = [1 if isinstance(dim, str) else dim for dim in meta_shape]

# Only update if the rank matches or if we're reasonably sure this is the right shape
if len(enhanced_shape) == len(shape) or i == len(input_shapes) - 1:
input_shapes[i] = (input_name, enhanced_shape)
logger.info(f"Enhanced shape for {input_name}: {enhanced_shape}")
break

# Generate random data for each input and combine into a single flat array
all_flat_data = []

for input_name, shape in input_shapes:
# Calculate total elements for this input
total_elements = 1
for dim in shape:
total_elements *= dim

# Generate random data (consistent with model_circuitizer.py's approach)
input_data = [random.random() for _ in range(total_elements)]
all_flat_data.extend(input_data)

logger.info(f"Generated {len(input_data)} random values for input {input_name} with shape {shape}")

# If no inputs were found, create a default dummy input
if not all_flat_data:
logger.warning("No inputs found, creating default dummy input")
all_flat_data = [random.random() for _ in range(10)]

# Create the calibration data JSON structure that EZKL expects
calibration_data = {"input_data": [all_flat_data]}

# Write the calibration data to a JSON file
try:
with open(output_path, 'w') as f:
json.dump(calibration_data, f)
logger.info(f"Created dummy calibration file at {output_path} with {len(all_flat_data)} total values")
except Exception as e:
logger.error(f"Failed to create dummy calibration file: {str(e)}")
raise


@staticmethod
def process_witness_output(witness_data):
"""
Expand Down
51 changes: 48 additions & 3 deletions src/circuitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
import os
import json
import logging
from pathlib import Path
from typing import Optional, Dict, Any

from src.backends.ezkl import EZKL
from src.utils.utils import Utils
from src.runner import Runner

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -163,28 +165,71 @@ def _circuitize_slices(self, dir_path: str, input_file_path: Optional[str] = Non
segment_output_path = None
circuitized_count = 0
skipped_count = 0

# Phase 1: Run ONNX inference chain if we have input file
current_input = input_file_path
if current_input and os.path.exists(current_input):
logger.info("Running ONNX inference chain to generate calibration files")
for idx, segment in enumerate(segments):
segment_path = segment.get('path')
if not segment_path or not os.path.exists(segment_path):
logger.warning(f"Segment file not found for index {idx}: {segment_path}")
continue

segment_output_path = os.path.join(os.path.dirname(segment_path), "ezkl_circuitization")
os.makedirs(segment_output_path, exist_ok=True)

# Run ONNX inference to generate calibration data
output_tensor_path = os.path.join(segment_output_path, f"segment_{idx}_calibration.json")
logger.info(f"Running ONNX inference for segment {idx} with input file {current_input}")
success, tensor, exec_info = Runner._run_onnx_segment(
slice_info={"path": segment_path},
input_tensor_path=Path(current_input),
output_tensor_path=Path(output_tensor_path)
)

if not success:
logger.error(f"ONNX inference failed for segment {idx}: {exec_info.get('error', 'Unknown error')}")
return

current_input = output_tensor_path
logger.info(f"Generated calibration file: {output_tensor_path}")
else:
logger.warning("No input file provided, skipping ONNX inference chain")

# Phase 2: Circuitize selected layers
for idx, segment in enumerate(segments):
if layer_indices is not None and idx not in layer_indices:
logger.info(f"Skipping segment {idx} as it's not in the specified layers")
logger.info(f"Skipping circuitization for segment {idx} as it's not in the specified layers")
skipped_count += 1
continue

segment_path = segment.get('path')
if not segment_path or not os.path.exists(segment_path):
logger.warning(f"Segment file not found for index {idx}: {segment_path}")
continue

segment_output_path = os.path.join(os.path.dirname(segment_path), "ezkl_circuitization")
# Run pipeline and get data
os.makedirs(segment_output_path, exist_ok=True)

calibration_input = input_file_path if idx == 0 else os.path.join(
os.path.dirname(segments[idx-1].get('path')),
"ezkl_circuitization",
f"segment_{idx-1}_calibration.json"
)

logger.info(f"Circuitizing segment {idx} with calibration input file {calibration_input}")
circuitization_data = self.circuitizer_impl.circuitization_pipeline(
segment_path,
segment_output_path,
input_file_path=input_file_path,
input_file_path=calibration_input,
segment_details=segment
)
segment['ezkl_circuitization'] = circuitization_data
circuitized_count += 1
Utils.save_metadata_file(metadata, os.path.dirname(metadata_path), os.path.basename(metadata_path))


if segment_output_path:
output_dir = os.path.dirname(segment_output_path)
else:
Expand Down