Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions src/models/debugging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,34 @@ void DumpTensor(const Model& model, std::ostream& stream, OrtValue* value, bool
stream << SGR::Fg_Green << " Location: " << SGR::Reset;

const auto& memory_info = value->GetTensorMemoryInfo();
switch (memory_info.GetDeviceType()) {
case OrtMemoryInfoDeviceType_CPU:
stream << "CPU\r\n";
if (memory_info.GetDeviceType() == OrtMemoryInfoDeviceType_CPU) {
stream << "CPU\r\n";
if (dump_value) {
DumpValues(stream, type_info->GetElementType(), value->GetTensorRawData(), element_count);
break;
case OrtMemoryInfoDeviceType_GPU: {
stream << "GPU\r\n";
auto type = type_info->GetElementType();
auto tensor_span = std::span<uint8_t>{const_cast<OrtValue*>(value)->GetTensorMutableData<uint8_t>(), SizeOf(type) * element_count};
auto device_span = model.p_device_->WrapMemory<uint8_t>(tensor_span);
DumpValues(stream, type, device_span.CopyDeviceToCpu().data(), element_count);
break;
}
default:
stream << "Unhandled device type: " << static_cast<int>(memory_info.GetDeviceType()) << "\r\n";
break;
// Internally there are 5 device types defined in onnxruntime but only 3 are exposed in the public API
// https://github.com/microsoft/onnxruntime/blob/9dbfee91ca9c2ba2074d19805bb6dedccedbcfe3/include/onnxruntime/core/framework/ortdevice.h#L15
} else if (memory_info.GetDeviceType() < 5) {
switch (model.p_device_->GetType()) {
case DeviceType::CUDA:
Copy link
Collaborator

@baijumeswani baijumeswani Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we try to define a model.p_device_->GetType<std::string>() for each device type? We can avoid the switch statement here then.

stream << model.p_device_->GetType<std::string>() << "\r\n";

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can fairly safely say that if the tensor is not on CPU, it's on the model.p_device_

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. We shouldn't use a switch here, use this:

std::string to_string(DeviceType device_type);

stream << "CUDA\r\n";
break;
case DeviceType::DML:
stream << "DML\r\n";
break;
case DeviceType::QNN:
stream << "QNN\r\n";
break;
default:
stream << "Unknown\r\n";
break;
}
auto type = type_info->GetElementType();
auto tensor_span = std::span<uint8_t>{const_cast<OrtValue*>(value)->GetTensorMutableData<uint8_t>(), SizeOf(type) * element_count};
auto device_span = model.p_device_->WrapMemory<uint8_t>(tensor_span);
DumpValues(stream, type, device_span.CopyDeviceToCpu().data(), element_count);
} else {
stream << "Unhandled device type: " << static_cast<int>(memory_info.GetDeviceType()) << "\r\n";
}
}

Expand Down
Loading