From b8263a5d9d6ac8c12cfaaed5f30234a9f08e84de Mon Sep 17 00:00:00 2001 From: Xi Chen Date: Wed, 26 Apr 2023 10:13:54 -0400 Subject: [PATCH 1/3] RDISCROWD-5843 GIGWork GPT Helper Component MVP (v1) - Backend --- pybossa/api/__init__.py | 72 +++++++++++++++++++++ test/test_api/__init__.py | 129 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 201 insertions(+) diff --git a/pybossa/api/__init__.py b/pybossa/api/__init__.py index d2f27b55a8..5eedf257b6 100644 --- a/pybossa/api/__init__.py +++ b/pybossa/api/__init__.py @@ -847,3 +847,75 @@ def user_has_partial_answer(short_name=None): task_id_map = get_user_saved_partial_tasks(sentinel, project_id, current_user.id, task_repo) response = {"has_answer": bool(task_id_map)} return Response(json.dumps(response), status=200, mimetype="application/json") + + +@jsonpify +@csrf.exempt +@blueprint.route('/llm', defaults={'model_name': None}, methods=['POST']) +@blueprint.route('/llm/', methods=['POST']) +@ratelimit(limit=ratelimits.get('LIMIT'), per=ratelimits.get('PER')) +def large_language_model(model_name): + """Large language model endpoint + The POST data format can be: + data = { + "instances": [ + { + "context": "Identify the company name: Microsoft will release Windows 20 next year.", + "temperature": 1.0, + "seed": 12345, + "repetition_penalty": 1.05, + "num_beams": 1, + } + ] + } + or + data = { + "prompts": "Identify the company name: Microsoft will release Windows 20 next year." + } + """ + if model_name is None: + model_name = 'flan-ul2' + endpoints = current_app.config.get('LLM_ENDPOINTS') + model_endpoint = endpoints.get(model_name.lower()) + + if not model_endpoint: + return abort(403, f'Model with name {model_name} does not exist.') + + proxies = current_app.config.get('PROXIES') + cert = current_app.config.get('CA_CERT', False) + + try: + data = request.get_json(force=True) + except: + return abort(400, "Invalid JSON data") + + if "prompts" not in data and "instances" not in data: + return abort(400, "The JSON should have either 'prompts' or 'instances'") + + if "prompts" in data: + prompts = data.get("prompts") + if not prompts: + return abort(400, 'prompts should not be empty') + if isinstance(prompts, list): + prompts = prompts[0] # Batch request temporarily NOT supported + if not isinstance(prompts, str): + return abort(400, f'prompts should be a string a list of strings') + data = { + "instances": [ + { + "context": prompts + ' ', + "temperature": 1.0, + "seed": 12345, + "repetition_penalty": 1.05, + "num_beams": 1, + } + ] + } + data = json.dumps(data) + + r = requests.post(model_endpoint, data=data, proxies=proxies, verify=cert) + out = json.loads(r.text) + predictions = out["predictions"][0]["output"] + response = {"Model: ": model_name, "predictions: ": predictions} + + return Response(json.dumps(response), status=r.status_code, mimetype="application/json") diff --git a/test/test_api/__init__.py b/test/test_api/__init__.py index 6bae0ab383..22f78007a3 100644 --- a/test/test_api/__init__.py +++ b/test/test_api/__init__.py @@ -15,11 +15,16 @@ # # You should have received a copy of the GNU Affero General Public License # along with PYBOSSA. If not, see . +import json +import unittest from datetime import datetime +from unittest.mock import patch, MagicMock from dateutil.parser import parse from werkzeug.http import parse_cookie +from pybossa.api import large_language_model +from pybossa.core import create_app from test import Test @@ -47,3 +52,127 @@ class TestAPI(Test): endpoints = ['project', 'task', 'taskrun', 'user'] +class TestLargeLanguageModel(unittest.TestCase): + def setUp(self): + self.app = create_app(run_as_server=False) + self.app.config['LLM_ENDPOINTS'] = { + 'flan-ul2': 'http://localhost:5000/llm' + } + self.client = self.app.test_client() + + @patch('requests.post') + def test_valid_request(self, mock_post): + response_data = { + "predictions": [{ + "output": "Microsoft" + }] + } + mock_post.return_value = MagicMock(status_code=200, text=json.dumps(response_data)) + with self.app.test_request_context('/', json={ + "prompts": "Identify the company name: Microsoft will release Windows 20 next year." + }): + response = large_language_model('flan-ul2') + self.assertEqual(response.status_code, 200) + self.assertIn('Model: ', response.json) + self.assertIn('predictions: ', response.json) + + @patch('requests.post') + def test_valid_request_with_list_of_prompts(self, mock_post): + response_data = { + "predictions": [{ + "output": "Microsoft" + }] + } + mock_post.return_value = MagicMock(status_code=200, + text=json.dumps(response_data)) + with self.app.test_request_context('/', json={ + "prompts": ["Identify the company name: Microsoft will release Windows 20 next year.", "test"] + }): + response = large_language_model('flan-ul2') + self.assertEqual(response.status_code, 200) + self.assertIn('Model: ', response.json) + self.assertIn('predictions: ', response.json) + + @patch('requests.post') + def test_valid_request_with_instances_key_in_json(self, mock_post): + response_data = { + "predictions": [{ + "output": "Microsoft" + }] + } + mock_post.return_value = MagicMock(status_code=200, + text=json.dumps(response_data)) + with self.app.test_request_context('/', json={ + "instances": [ + { + "context": "Identify the company name: Microsoft will release Windows 20 next year.", + "temperature": 1.0, + "seed": 12345, + "repetition_penalty": 1.05, + "num_beams": 1, + } + ] + }): + response = large_language_model('flan-ul2') + self.assertEqual(response.status_code, 200) + self.assertIn('Model: ', response.json) + self.assertIn('predictions: ', response.json) + + @patch('requests.post') + def test_invalid_model_name(self, mock_post): + mock_post.return_value = MagicMock(status_code=403, text='{"error": "Model not found"}') + with self.app.test_request_context('/', json={ + "prompts": "Identify the company name: Microsoft will release Windows 20 next year." + }): + response = large_language_model('invalid-model') + self.assertEqual(response.status_code, 403) + self.assertIn('Model with name', response.json.get('exception_msg')) + + @patch('requests.post') + def test_invalid_json(self, mock_post): + with self.app.test_request_context('/', data='invalid-json', content_type='application/json'): + response = large_language_model('flan-ul2') + self.assertEqual(response.status_code, 400) + self.assertIn('Invalid JSON', response.json.get('exception_msg')) + + @patch('requests.post') + def test_invalid_post_data(self, mock_post): + response_data = { + "predictions": [{ + "output": "Microsoft" + }] + } + mock_post.return_value = MagicMock(status_code=200, + text=json.dumps(response_data)) + with self.app.test_request_context('/', json={ + "invalid": [ + { + "context": "Identify the company name: Microsoft will release Windows 20 next year.", + "temperature": 1.0, + "seed": 12345, + "repetition_penalty": 1.05, + "num_beams": 1, + } + ] + }): + response = large_language_model('flan-ul2') + self.assertEqual(response.status_code, 400) + self.assertIn('The JSON should have', response.json.get('exception_msg')) + + @patch('requests.post') + def test_empty_prompts(self, mock_post): + with self.app.test_request_context('/', json={ + "prompts": "" + }): + response = large_language_model('flan-ul2') + self.assertEqual(response.status_code, 400) + self.assertIn('prompts should not be empty', response.json.get('exception_msg')) + + @patch('requests.post') + def test_invalid_prompts_type(self, mock_post): + with self.app.test_request_context('/', json={ + "prompts": 123 + }): + response = large_language_model('flan-ul2') + self.assertEqual(response.status_code, 400) + self.assertIn('prompts should be a string', response.json.get('exception_msg')) From 4c1fc67d9e242403949e75d83fc0ef06282d509d Mon Sep 17 00:00:00 2001 From: Xi Chen Date: Wed, 26 Apr 2023 11:41:38 -0400 Subject: [PATCH 2/3] update comment --- pybossa/api/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pybossa/api/__init__.py b/pybossa/api/__init__.py index 5eedf257b6..9b2f104a71 100644 --- a/pybossa/api/__init__.py +++ b/pybossa/api/__init__.py @@ -856,8 +856,8 @@ def user_has_partial_answer(short_name=None): @ratelimit(limit=ratelimits.get('LIMIT'), per=ratelimits.get('PER')) def large_language_model(model_name): """Large language model endpoint - The POST data format can be: - data = { + The JSON data in the POST request can be one of the following: + { "instances": [ { "context": "Identify the company name: Microsoft will release Windows 20 next year.", @@ -869,7 +869,7 @@ def large_language_model(model_name): ] } or - data = { + { "prompts": "Identify the company name: Microsoft will release Windows 20 next year." } """ From 17a99c75f7a1044cd2445080e041a624a4b7fe12 Mon Sep 17 00:00:00 2001 From: Xi Chen Date: Wed, 26 Apr 2023 12:06:05 -0400 Subject: [PATCH 3/3] Address code review comments --- pybossa/api/__init__.py | 4 ++-- test/test_api/__init__.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pybossa/api/__init__.py b/pybossa/api/__init__.py index 9b2f104a71..07d612ff56 100644 --- a/pybossa/api/__init__.py +++ b/pybossa/api/__init__.py @@ -879,7 +879,7 @@ def large_language_model(model_name): model_endpoint = endpoints.get(model_name.lower()) if not model_endpoint: - return abort(403, f'Model with name {model_name} does not exist.') + return abort(400, f'{model_name} LLM is unsupported on this platform.') proxies = current_app.config.get('PROXIES') cert = current_app.config.get('CA_CERT', False) @@ -899,7 +899,7 @@ def large_language_model(model_name): if isinstance(prompts, list): prompts = prompts[0] # Batch request temporarily NOT supported if not isinstance(prompts, str): - return abort(400, f'prompts should be a string a list of strings') + return abort(400, f'prompts should be a string or a list of strings') data = { "instances": [ { diff --git a/test/test_api/__init__.py b/test/test_api/__init__.py index 22f78007a3..d939b2d6e2 100644 --- a/test/test_api/__init__.py +++ b/test/test_api/__init__.py @@ -125,8 +125,8 @@ def test_invalid_model_name(self, mock_post): "prompts": "Identify the company name: Microsoft will release Windows 20 next year." }): response = large_language_model('invalid-model') - self.assertEqual(response.status_code, 403) - self.assertIn('Model with name', response.json.get('exception_msg')) + self.assertEqual(response.status_code, 400) + self.assertIn('LLM is unsupported', response.json.get('exception_msg')) @patch('requests.post') def test_invalid_json(self, mock_post):