11#include " model_upd_cmd.h"
22#include " httplib.h"
3- # include " json/json.h "
3+
44#include " server_start_cmd.h"
55#include " utils/file_manager_utils.h"
66#include " utils/logging_utils.h"
@@ -26,8 +26,7 @@ void ModelUpdCmd::Exec(
2626 Json::Value json_data;
2727 for (const auto & [key, value] : options) {
2828 if (!value.empty ()) {
29- json_data[key] = value;
30- CLI_LOG (" Updated " << key << " to: " << value);
29+ UpdateConfig (json_data, key, value);
3130 }
3231 }
3332 auto data_str = json_data.toStyledString ();
@@ -47,4 +46,287 @@ void ModelUpdCmd::Exec(
4746 return ;
4847 }
4948}
49+
50+ void ModelUpdCmd::UpdateConfig (Json::Value& data, const std::string& key,
51+ const std::string& value) {
52+ static const std::unordered_map<
53+ std::string,
54+ std::function<void (Json::Value &, const std::string&, const std::string&)>>
55+ updaters = {
56+ {" name" ,
57+ [](Json::Value &data, const std::string&, const std::string& v) {
58+ data[" name" ] = v;
59+ }},
60+ {" model" ,
61+ [](Json::Value &data, const std::string&, const std::string& v) {
62+ data[" model" ] = v;
63+ }},
64+ {" version" ,
65+ [](Json::Value &data, const std::string&, const std::string& v) {
66+ data[" version" ] = v;
67+ }},
68+ {" engine" ,
69+ [](Json::Value &data, const std::string&, const std::string& v) {
70+ data[" engine" ] = v;
71+ }},
72+ {" prompt_template" ,
73+ [](Json::Value &data, const std::string&, const std::string& v) {
74+ data[" prompt_template" ] = v;
75+ }},
76+ {" system_template" ,
77+ [](Json::Value &data, const std::string&, const std::string& v) {
78+ data[" system_template" ] = v;
79+ }},
80+ {" user_template" ,
81+ [](Json::Value &data, const std::string&, const std::string& v) {
82+ data[" user_template" ] = v;
83+ }},
84+ {" ai_template" ,
85+ [](Json::Value &data, const std::string&, const std::string& v) {
86+ data[" ai_template" ] = v;
87+ }},
88+ {" os" ,
89+ [](Json::Value &data, const std::string&, const std::string& v) {
90+ data[" os" ] = v;
91+ }},
92+ {" gpu_arch" ,
93+ [](Json::Value &data, const std::string&, const std::string& v) {
94+ data[" gpu_arch" ] = v;
95+ }},
96+ {" quantization_method" ,
97+ [](Json::Value &data, const std::string&, const std::string& v) {
98+ data[" quantization_method" ] = v;
99+ }},
100+ {" precision" ,
101+ [](Json::Value &data, const std::string&, const std::string& v) {
102+ data[" precision" ] = v;
103+ }},
104+ {" trtllm_version" ,
105+ [](Json::Value &data, const std::string&, const std::string& v) {
106+ data[" trtllm_version" ] = v;
107+ }},
108+ {" object" ,
109+ [](Json::Value &data, const std::string&, const std::string& v) {
110+ data[" object" ] = v;
111+ }},
112+ {" owned_by" ,
113+ [](Json::Value &data, const std::string&, const std::string& v) {
114+ data[" owned_by" ] = v;
115+ }},
116+ {" grammar" ,
117+ [](Json::Value &data, const std::string&, const std::string& v) {
118+ data[" grammar" ] = v;
119+ }},
120+ {" stop" , [this ](Json::Value &data, const std::string& k, const std::string& v) {
121+ UpdateVectorField (
122+ k, v, [&data](const std::vector<std::string>& stops) {
123+ Json::Value d (Json::arrayValue);
124+ for (auto const & s: stops) {
125+ d.append (s);
126+ }
127+ data[" stop" ] = d;
128+ });
129+ }},
130+ {" files" , [this ](Json::Value &data, const std::string& k, const std::string& v) {
131+ UpdateVectorField (
132+ k, v, [&data](const std::vector<std::string>& fs) {
133+ Json::Value d (Json::arrayValue);
134+ for (auto const & f: fs) {
135+ d.append (f);
136+ }
137+ data[" files" ] = d;
138+ });
139+ }},
140+ {" top_p" ,
141+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
142+ UpdateNumericField (
143+ k, v, [&data](float f) { data[" top_p" ] = f; });
144+ }},
145+ {" temperature" ,
146+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
147+ UpdateNumericField (k, v, [&data](float f) {
148+ data[" temperature" ] = f;
149+ });
150+ }},
151+ {" frequency_penalty" ,
152+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
153+ UpdateNumericField (k, v, [&data](float f) {
154+ data[" frequency_penalty" ] = f;
155+ });
156+ }},
157+ {" presence_penalty" ,
158+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
159+ UpdateNumericField (k, v, [&data](float f) {
160+ data[" presence_penalty" ] = f;
161+ });
162+ }},
163+ {" dynatemp_range" ,
164+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
165+ UpdateNumericField (k, v, [&data](float f) {
166+ data[" dynatemp_range" ] = f;
167+ });
168+ }},
169+ {" dynatemp_exponent" ,
170+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
171+ UpdateNumericField (k, v, [&data](float f) {
172+ data[" dynatemp_exponent" ] = f;
173+ });
174+ }},
175+ {" min_p" ,
176+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
177+ UpdateNumericField (
178+ k, v, [&data](float f) { data[" min_p" ] = f; });
179+ }},
180+ {" tfs_z" ,
181+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
182+ UpdateNumericField (
183+ k, v, [&data](float f) { data[" tfs_z" ] = f; });
184+ }},
185+ {" typ_p" ,
186+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
187+ UpdateNumericField (
188+ k, v, [&data](float f) { data[" typ_p" ] = f; });
189+ }},
190+ {" repeat_penalty" ,
191+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
192+ UpdateNumericField (k, v, [&data](float f) {
193+ data[" repeat_penalty" ] = f;
194+ });
195+ }},
196+ {" mirostat_tau" ,
197+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
198+ UpdateNumericField (k, v, [&data](float f) {
199+ data[" mirostat_tau" ] = f;
200+ });
201+ }},
202+ {" mirostat_eta" ,
203+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
204+ UpdateNumericField (k, v, [&data](float f) {
205+ data[" mirostat_eta" ] = f;
206+ });
207+ }},
208+ {" max_tokens" ,
209+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
210+ UpdateNumericField (k, v, [&data](float f) {
211+ data[" max_tokens" ] = static_cast <int >(f);
212+ });
213+ }},
214+ {" ngl" ,
215+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
216+ UpdateNumericField (k, v, [&data](float f) {
217+ data[" ngl" ] = static_cast <int >(f);
218+ });
219+ }},
220+ {" ctx_len" ,
221+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
222+ UpdateNumericField (k, v, [&data](float f) {
223+ data[" ctx_len" ] = static_cast <int >(f);
224+ });
225+ }},
226+ {" tp" ,
227+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
228+ UpdateNumericField (k, v, [&data](float f) {
229+ data[" tp" ] = static_cast <int >(f);
230+ });
231+ }},
232+ {" seed" ,
233+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
234+ UpdateNumericField (k, v, [&data](float f) {
235+ data[" seed" ] = static_cast <int >(f);
236+ });
237+ }},
238+ {" top_k" ,
239+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
240+ UpdateNumericField (k, v, [&data](float f) {
241+ data[" top_k" ] = static_cast <int >(f);
242+ });
243+ }},
244+ {" repeat_last_n" ,
245+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
246+ UpdateNumericField (k, v, [&data](float f) {
247+ data[" repeat_last_n" ] = static_cast <int >(f);
248+ });
249+ }},
250+ {" n_probs" ,
251+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
252+ UpdateNumericField (k, v, [&data](float f) {
253+ data[" n_probs" ] = static_cast <int >(f);
254+ });
255+ }},
256+ {" min_keep" ,
257+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
258+ UpdateNumericField (k, v, [&data](float f) {
259+ data[" min_keep" ] = static_cast <int >(f);
260+ });
261+ }},
262+ {" stream" ,
263+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
264+ UpdateBooleanField (
265+ k, v, [&data](bool b) { data[" stream" ] = b; });
266+ }},
267+ {" text_model" ,
268+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
269+ UpdateBooleanField (
270+ k, v, [&data](bool b) { data[" text_model" ] = b; });
271+ }},
272+ {" mirostat" ,
273+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
274+ UpdateBooleanField (
275+ k, v, [&data](bool b) { data[" mirostat" ] = b; });
276+ }},
277+ {" penalize_nl" ,
278+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
279+ UpdateBooleanField (
280+ k, v, [&data](bool b) { data[" penalize_nl" ] = b; });
281+ }},
282+ {" ignore_eos" ,
283+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
284+ UpdateBooleanField (
285+ k, v, [&data](bool b) { data[" ignore_eos" ] = b; });
286+ }},
287+ {" created" ,
288+ [this ](Json::Value &data, const std::string& k, const std::string& v) {
289+ UpdateNumericField (k, v, [&data](float f) {
290+ data[" created" ] = static_cast <std::size_t >(f);
291+ });
292+ }},
293+ };
294+
295+ if (auto it = updaters.find (key); it != updaters.end ()) {
296+ it->second (data, key, value);
297+ CLI_LOG (" Updated " << key << " to: " << value);
298+ } else {
299+ CLI_LOG (" Warning: Unknown configuration key '" << key << " ' ignored." );
300+ }
301+ }
302+
303+ void ModelUpdCmd::UpdateVectorField (
304+ const std::string& key, const std::string& value,
305+ std::function<void (const std::vector<std::string>&)> setter) {
306+ std::vector<std::string> tokens;
307+ std::istringstream iss (value);
308+ std::string token;
309+ while (std::getline (iss, token, ' ,' )) {
310+ tokens.push_back (token);
311+ }
312+ setter (tokens);
313+ }
314+
315+ void ModelUpdCmd::UpdateNumericField (const std::string& key,
316+ const std::string& value,
317+ std::function<void (float )> setter) {
318+ try {
319+ float numericValue = std::stof (value);
320+ setter (numericValue);
321+ } catch (const std::exception& e) {
322+ CLI_LOG (" Failed to parse numeric value for " << key << " : " << e.what ());
323+ }
324+ }
325+
326+ void ModelUpdCmd::UpdateBooleanField (const std::string& key,
327+ const std::string& value,
328+ std::function<void (bool )> setter) {
329+ bool boolValue = (value == " true" || value == " 1" );
330+ setter (boolValue);
331+ }
50332} // namespace commands
0 commit comments