77#include " utils/archive_utils.h"
88#include " utils/system_info_utils.h"
99// clang-format on
10+ #include " utils/cuda_toolkit_utils.h"
1011#include " utils/engine_matcher_utils.h"
1112
1213namespace commands {
@@ -102,8 +103,8 @@ void EngineInitCmd::Exec() const {
102103 .type = DownloadType::Engine,
103104 .path = path,
104105 }}};
105-
106- DownloadService () .AddDownloadTask (
106+ DownloadService downloadService;
107+ downloadService .AddDownloadTask (
107108 downloadTask, [](const std::string& absolute_path) {
108109 // try to unzip the downloaded file
109110 std::filesystem::path downloadedEnginePath{absolute_path};
@@ -116,9 +117,62 @@ void EngineInitCmd::Exec() const {
116117 .parent_path ()
117118 .string ());
118119
119- // remove the downloaded file
120- std::filesystem::remove (absolute_path);
121- LOG_INFO << " Finished!" ;
120+ try {
121+ std::filesystem::remove (absolute_path);
122+ } catch (std::exception& e) {
123+ LOG_ERROR << " Error removing downloaded file: " << e.what ();
124+ }
125+ });
126+ if (system_info.os == " mac" || engineName_ == " cortex.onnx" ) {
127+ return ;
128+ }
129+ // download cuda toolkit
130+ const std::string jan_host = " https://catalog.jan.ai" ;
131+ const std::string cuda_toolkit_file_name = " cuda.tar.gz" ;
132+ const std::string download_id = " cuda" ;
133+
134+ auto gpu_driver_version = system_info_utils::GetDriverVersion ();
135+
136+ auto cuda_runtime_version =
137+ cuda_toolkit_utils::GetCompatibleCudaToolkitVersion (
138+ gpu_driver_version, system_info.os , engineName_);
139+
140+ std::ostringstream cuda_toolkit_path;
141+ cuda_toolkit_path << " dist/cuda-dependencies/" << 11.7 << " /"
142+ << system_info.os << " /"
143+ << cuda_toolkit_file_name;
144+
145+ LOG_DEBUG << " Cuda toolkit download url: " << jan_host
146+ << cuda_toolkit_path.str ();
147+
148+ auto downloadCudaToolkitTask = DownloadTask{
149+ .id = download_id,
150+ .type = DownloadType::CudaToolkit,
151+ .error = std::nullopt ,
152+ .items = {DownloadItem{
153+ .id = download_id,
154+ .host = jan_host,
155+ .fileName = cuda_toolkit_file_name,
156+ .type = DownloadType::CudaToolkit,
157+ .path = cuda_toolkit_path.str (),
158+ }},
159+ };
160+
161+ downloadService.AddDownloadTask (
162+ downloadCudaToolkitTask, [](const std::string& absolute_path) {
163+ LOG_DEBUG << " Downloaded cuda path: " << absolute_path;
164+ // try to unzip the downloaded file
165+ std::filesystem::path downloaded_path{absolute_path};
166+
167+ archive_utils::ExtractArchive (
168+ absolute_path,
169+ downloaded_path.parent_path ().parent_path ().string ());
170+
171+ try {
172+ std::filesystem::remove (absolute_path);
173+ } catch (std::exception& e) {
174+ LOG_ERROR << " Error removing downloaded file: " << e.what ();
175+ }
122176 });
123177
124178 return ;
0 commit comments