From 52154b8d6796464b7dfeb90f85b6b50a664a6843 Mon Sep 17 00:00:00 2001 From: Xi Chen Date: Wed, 26 Apr 2023 14:09:54 -0400 Subject: [PATCH] RDISCROWD-5843 GIGWork GPT Helper Component MVP (v1) - Backend (#840) * RDISCROWD-5843 GIGWork GPT Helper Component MVP (v1) - Backend * update comment * Address code review comments --- pybossa/api/__init__.py | 72 +++++++++++++++++++++ test/test_api/__init__.py | 129 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 201 insertions(+) diff --git a/pybossa/api/__init__.py b/pybossa/api/__init__.py index d2f27b55a8..07d612ff56 100644 --- a/pybossa/api/__init__.py +++ b/pybossa/api/__init__.py @@ -847,3 +847,75 @@ def user_has_partial_answer(short_name=None): task_id_map = get_user_saved_partial_tasks(sentinel, project_id, current_user.id, task_repo) response = {"has_answer": bool(task_id_map)} return Response(json.dumps(response), status=200, mimetype="application/json") + + +@jsonpify +@csrf.exempt +@blueprint.route('/llm', defaults={'model_name': None}, methods=['POST']) +@blueprint.route('/llm/', methods=['POST']) +@ratelimit(limit=ratelimits.get('LIMIT'), per=ratelimits.get('PER')) +def large_language_model(model_name): + """Large language model endpoint + The JSON data in the POST request can be one of the following: + { + "instances": [ + { + "context": "Identify the company name: Microsoft will release Windows 20 next year.", + "temperature": 1.0, + "seed": 12345, + "repetition_penalty": 1.05, + "num_beams": 1, + } + ] + } + or + { + "prompts": "Identify the company name: Microsoft will release Windows 20 next year." + } + """ + if model_name is None: + model_name = 'flan-ul2' + endpoints = current_app.config.get('LLM_ENDPOINTS') + model_endpoint = endpoints.get(model_name.lower()) + + if not model_endpoint: + return abort(400, f'{model_name} LLM is unsupported on this platform.') + + proxies = current_app.config.get('PROXIES') + cert = current_app.config.get('CA_CERT', False) + + try: + data = request.get_json(force=True) + except: + return abort(400, "Invalid JSON data") + + if "prompts" not in data and "instances" not in data: + return abort(400, "The JSON should have either 'prompts' or 'instances'") + + if "prompts" in data: + prompts = data.get("prompts") + if not prompts: + return abort(400, 'prompts should not be empty') + if isinstance(prompts, list): + prompts = prompts[0] # Batch request temporarily NOT supported + if not isinstance(prompts, str): + return abort(400, f'prompts should be a string or a list of strings') + data = { + "instances": [ + { + "context": prompts + ' ', + "temperature": 1.0, + "seed": 12345, + "repetition_penalty": 1.05, + "num_beams": 1, + } + ] + } + data = json.dumps(data) + + r = requests.post(model_endpoint, data=data, proxies=proxies, verify=cert) + out = json.loads(r.text) + predictions = out["predictions"][0]["output"] + response = {"Model: ": model_name, "predictions: ": predictions} + + return Response(json.dumps(response), status=r.status_code, mimetype="application/json") diff --git a/test/test_api/__init__.py b/test/test_api/__init__.py index 6bae0ab383..d939b2d6e2 100644 --- a/test/test_api/__init__.py +++ b/test/test_api/__init__.py @@ -15,11 +15,16 @@ # # You should have received a copy of the GNU Affero General Public License # along with PYBOSSA. If not, see . +import json +import unittest from datetime import datetime +from unittest.mock import patch, MagicMock from dateutil.parser import parse from werkzeug.http import parse_cookie +from pybossa.api import large_language_model +from pybossa.core import create_app from test import Test @@ -47,3 +52,127 @@ class TestAPI(Test): endpoints = ['project', 'task', 'taskrun', 'user'] +class TestLargeLanguageModel(unittest.TestCase): + def setUp(self): + self.app = create_app(run_as_server=False) + self.app.config['LLM_ENDPOINTS'] = { + 'flan-ul2': 'http://localhost:5000/llm' + } + self.client = self.app.test_client() + + @patch('requests.post') + def test_valid_request(self, mock_post): + response_data = { + "predictions": [{ + "output": "Microsoft" + }] + } + mock_post.return_value = MagicMock(status_code=200, text=json.dumps(response_data)) + with self.app.test_request_context('/', json={ + "prompts": "Identify the company name: Microsoft will release Windows 20 next year." + }): + response = large_language_model('flan-ul2') + self.assertEqual(response.status_code, 200) + self.assertIn('Model: ', response.json) + self.assertIn('predictions: ', response.json) + + @patch('requests.post') + def test_valid_request_with_list_of_prompts(self, mock_post): + response_data = { + "predictions": [{ + "output": "Microsoft" + }] + } + mock_post.return_value = MagicMock(status_code=200, + text=json.dumps(response_data)) + with self.app.test_request_context('/', json={ + "prompts": ["Identify the company name: Microsoft will release Windows 20 next year.", "test"] + }): + response = large_language_model('flan-ul2') + self.assertEqual(response.status_code, 200) + self.assertIn('Model: ', response.json) + self.assertIn('predictions: ', response.json) + + @patch('requests.post') + def test_valid_request_with_instances_key_in_json(self, mock_post): + response_data = { + "predictions": [{ + "output": "Microsoft" + }] + } + mock_post.return_value = MagicMock(status_code=200, + text=json.dumps(response_data)) + with self.app.test_request_context('/', json={ + "instances": [ + { + "context": "Identify the company name: Microsoft will release Windows 20 next year.", + "temperature": 1.0, + "seed": 12345, + "repetition_penalty": 1.05, + "num_beams": 1, + } + ] + }): + response = large_language_model('flan-ul2') + self.assertEqual(response.status_code, 200) + self.assertIn('Model: ', response.json) + self.assertIn('predictions: ', response.json) + + @patch('requests.post') + def test_invalid_model_name(self, mock_post): + mock_post.return_value = MagicMock(status_code=403, text='{"error": "Model not found"}') + with self.app.test_request_context('/', json={ + "prompts": "Identify the company name: Microsoft will release Windows 20 next year." + }): + response = large_language_model('invalid-model') + self.assertEqual(response.status_code, 400) + self.assertIn('LLM is unsupported', response.json.get('exception_msg')) + + @patch('requests.post') + def test_invalid_json(self, mock_post): + with self.app.test_request_context('/', data='invalid-json', content_type='application/json'): + response = large_language_model('flan-ul2') + self.assertEqual(response.status_code, 400) + self.assertIn('Invalid JSON', response.json.get('exception_msg')) + + @patch('requests.post') + def test_invalid_post_data(self, mock_post): + response_data = { + "predictions": [{ + "output": "Microsoft" + }] + } + mock_post.return_value = MagicMock(status_code=200, + text=json.dumps(response_data)) + with self.app.test_request_context('/', json={ + "invalid": [ + { + "context": "Identify the company name: Microsoft will release Windows 20 next year.", + "temperature": 1.0, + "seed": 12345, + "repetition_penalty": 1.05, + "num_beams": 1, + } + ] + }): + response = large_language_model('flan-ul2') + self.assertEqual(response.status_code, 400) + self.assertIn('The JSON should have', response.json.get('exception_msg')) + + @patch('requests.post') + def test_empty_prompts(self, mock_post): + with self.app.test_request_context('/', json={ + "prompts": "" + }): + response = large_language_model('flan-ul2') + self.assertEqual(response.status_code, 400) + self.assertIn('prompts should not be empty', response.json.get('exception_msg')) + + @patch('requests.post') + def test_invalid_prompts_type(self, mock_post): + with self.app.test_request_context('/', json={ + "prompts": 123 + }): + response = large_language_model('flan-ul2') + self.assertEqual(response.status_code, 400) + self.assertIn('prompts should be a string', response.json.get('exception_msg'))