diff --git a/third_party/aipu/backend/aipu_torch_dev.cpp b/third_party/aipu/backend/aipu_torch_dev.cpp index 31036a944..5090b8264 100644 --- a/third_party/aipu/backend/aipu_torch_dev.cpp +++ b/third_party/aipu/backend/aipu_torch_dev.cpp @@ -398,12 +398,16 @@ struct _Device { int prev_idx = -1; }; +static std::unordered_map default_generators = { + {0, at::detail::getDefaultCPUGenerator()}}; + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("device_count", &aipu::device_count, "aipu device count"); m.def("is_available", &aipu::is_available, "aipu is available"); m.def("current_device", &aipu::current_device, "aipu current device"); m.def("_is_in_bad_fork", []() { return py::bool_(false); }); m.def("manual_seed_all", [](int seed) { std::srand(seed); }); + m.attr("default_generators") = &default_generators; py::class_<_DeviceGuard>(m, "_DeviceGuard", py::module_local()) .def(py::init(