diff --git a/src/controllers/usecase.js b/src/controllers/usecase.js index 0acfc4a..7ec8df0 100644 --- a/src/controllers/usecase.js +++ b/src/controllers/usecase.js @@ -469,6 +469,38 @@ module.exports.getExplainerResponse = async (req, res) => { } } +module.exports.getModelPredictResponse = async (req, res) => { + try { + let data = new FormData(); + + data.append('id', req.params.id); + data.append('instance', JSON.stringify(req.body.instance)); + + const usecase = await Usecase.findById(req.params.id) + + let config = { + method: 'post', + url: MODELAPI_URL + 'predict', + headers: { + ...data.getHeaders() + }, + data: data + }; + + const response = await axios(config); + const model_attributes = JSON.parse(usecase.model.attributes) + let output = response.data.predictions[0]; + let target_values = model_attributes.target_values[0] + + let d = {} + for (var i = 0; i < target_values.length; i++){ + d[target_values[i]] = output[i]; + } + res.json(d); + } catch (error) { + res.status(500).json({ message: error }); + } +} function generateRandom(maxLimit = 100) { let rand = Math.random() * maxLimit; diff --git a/src/routes/usecases.js b/src/routes/usecases.js index be37ec6..9bdf07e 100644 --- a/src/routes/usecases.js +++ b/src/routes/usecases.js @@ -22,6 +22,8 @@ router.get('/:id', [isCompanyUsecase], usecasectrl.get); router.get('/:id/casestructure', [isCompanyUsecase], usecasectrl.getCaseStructure); router.get('/:id/sampleDataInstance', [isCompanyUsecase], usecasectrl.getRandomDataInstance); router.post('/:id/explainerResponse', [isCompanyUsecase], usecasectrl.getExplainerResponse); +router.post('/:id/predictResponse', [isCompanyUsecase], usecasectrl.getModelPredictResponse); + // Get all router.get('/', usecasectrl.list);