Skip to content

Commit

Permalink
RDISCROWD-5843 GIGWork GPT Helper Component MVP (v1) - Backend
Browse files Browse the repository at this point in the history
  • Loading branch information
XiChenn committed Apr 26, 2023
1 parent 1ec80f0 commit b8263a5
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 0 deletions.
72 changes: 72 additions & 0 deletions pybossa/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,3 +847,75 @@ def user_has_partial_answer(short_name=None):
task_id_map = get_user_saved_partial_tasks(sentinel, project_id, current_user.id, task_repo)
response = {"has_answer": bool(task_id_map)}
return Response(json.dumps(response), status=200, mimetype="application/json")


@jsonpify
@csrf.exempt
@blueprint.route('/llm', defaults={'model_name': None}, methods=['POST'])
@blueprint.route('/llm/<model_name>', methods=['POST'])
@ratelimit(limit=ratelimits.get('LIMIT'), per=ratelimits.get('PER'))
def large_language_model(model_name):
"""Large language model endpoint
The POST data format can be:
data = {
"instances": [
{
"context": "Identify the company name: Microsoft will release Windows 20 next year.",
"temperature": 1.0,
"seed": 12345,
"repetition_penalty": 1.05,
"num_beams": 1,
}
]
}
or
data = {
"prompts": "Identify the company name: Microsoft will release Windows 20 next year."
}
"""
if model_name is None:
model_name = 'flan-ul2'
endpoints = current_app.config.get('LLM_ENDPOINTS')
model_endpoint = endpoints.get(model_name.lower())

if not model_endpoint:
return abort(403, f'Model with name {model_name} does not exist.')

proxies = current_app.config.get('PROXIES')
cert = current_app.config.get('CA_CERT', False)

try:
data = request.get_json(force=True)
except:
return abort(400, "Invalid JSON data")

if "prompts" not in data and "instances" not in data:
return abort(400, "The JSON should have either 'prompts' or 'instances'")

if "prompts" in data:
prompts = data.get("prompts")
if not prompts:
return abort(400, 'prompts should not be empty')
if isinstance(prompts, list):
prompts = prompts[0] # Batch request temporarily NOT supported
if not isinstance(prompts, str):
return abort(400, f'prompts should be a string a list of strings')
data = {
"instances": [
{
"context": prompts + ' ',
"temperature": 1.0,
"seed": 12345,
"repetition_penalty": 1.05,
"num_beams": 1,
}
]
}
data = json.dumps(data)

r = requests.post(model_endpoint, data=data, proxies=proxies, verify=cert)
out = json.loads(r.text)
predictions = out["predictions"][0]["output"]
response = {"Model: ": model_name, "predictions: ": predictions}

return Response(json.dumps(response), status=r.status_code, mimetype="application/json")
129 changes: 129 additions & 0 deletions test/test_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,16 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with PYBOSSA. If not, see <http://www.gnu.org/licenses/>.
import json
import unittest
from datetime import datetime
from unittest.mock import patch, MagicMock

from dateutil.parser import parse
from werkzeug.http import parse_cookie

from pybossa.api import large_language_model
from pybossa.core import create_app
from test import Test


Expand Down Expand Up @@ -47,3 +52,127 @@ class TestAPI(Test):

endpoints = ['project', 'task', 'taskrun', 'user']

class TestLargeLanguageModel(unittest.TestCase):
def setUp(self):
self.app = create_app(run_as_server=False)
self.app.config['LLM_ENDPOINTS'] = {
'flan-ul2': 'http://localhost:5000/llm'
}
self.client = self.app.test_client()

@patch('requests.post')
def test_valid_request(self, mock_post):
response_data = {
"predictions": [{
"output": "Microsoft"
}]
}
mock_post.return_value = MagicMock(status_code=200, text=json.dumps(response_data))
with self.app.test_request_context('/', json={
"prompts": "Identify the company name: Microsoft will release Windows 20 next year."
}):
response = large_language_model('flan-ul2')
self.assertEqual(response.status_code, 200)
self.assertIn('Model: ', response.json)
self.assertIn('predictions: ', response.json)

@patch('requests.post')
def test_valid_request_with_list_of_prompts(self, mock_post):
response_data = {
"predictions": [{
"output": "Microsoft"
}]
}
mock_post.return_value = MagicMock(status_code=200,
text=json.dumps(response_data))
with self.app.test_request_context('/', json={
"prompts": ["Identify the company name: Microsoft will release Windows 20 next year.", "test"]
}):
response = large_language_model('flan-ul2')
self.assertEqual(response.status_code, 200)
self.assertIn('Model: ', response.json)
self.assertIn('predictions: ', response.json)

@patch('requests.post')
def test_valid_request_with_instances_key_in_json(self, mock_post):
response_data = {
"predictions": [{
"output": "Microsoft"
}]
}
mock_post.return_value = MagicMock(status_code=200,
text=json.dumps(response_data))
with self.app.test_request_context('/', json={
"instances": [
{
"context": "Identify the company name: Microsoft will release Windows 20 next year.",
"temperature": 1.0,
"seed": 12345,
"repetition_penalty": 1.05,
"num_beams": 1,
}
]
}):
response = large_language_model('flan-ul2')
self.assertEqual(response.status_code, 200)
self.assertIn('Model: ', response.json)
self.assertIn('predictions: ', response.json)

@patch('requests.post')
def test_invalid_model_name(self, mock_post):
mock_post.return_value = MagicMock(status_code=403, text='{"error": "Model not found"}')
with self.app.test_request_context('/', json={
"prompts": "Identify the company name: Microsoft will release Windows 20 next year."
}):
response = large_language_model('invalid-model')
self.assertEqual(response.status_code, 403)
self.assertIn('Model with name', response.json.get('exception_msg'))

@patch('requests.post')
def test_invalid_json(self, mock_post):
with self.app.test_request_context('/', data='invalid-json', content_type='application/json'):
response = large_language_model('flan-ul2')
self.assertEqual(response.status_code, 400)
self.assertIn('Invalid JSON', response.json.get('exception_msg'))

@patch('requests.post')
def test_invalid_post_data(self, mock_post):
response_data = {
"predictions": [{
"output": "Microsoft"
}]
}
mock_post.return_value = MagicMock(status_code=200,
text=json.dumps(response_data))
with self.app.test_request_context('/', json={
"invalid": [
{
"context": "Identify the company name: Microsoft will release Windows 20 next year.",
"temperature": 1.0,
"seed": 12345,
"repetition_penalty": 1.05,
"num_beams": 1,
}
]
}):
response = large_language_model('flan-ul2')
self.assertEqual(response.status_code, 400)
self.assertIn('The JSON should have', response.json.get('exception_msg'))

@patch('requests.post')
def test_empty_prompts(self, mock_post):
with self.app.test_request_context('/', json={
"prompts": ""
}):
response = large_language_model('flan-ul2')
self.assertEqual(response.status_code, 400)
self.assertIn('prompts should not be empty', response.json.get('exception_msg'))

@patch('requests.post')
def test_invalid_prompts_type(self, mock_post):
with self.app.test_request_context('/', json={
"prompts": 123
}):
response = large_language_model('flan-ul2')
self.assertEqual(response.status_code, 400)
self.assertIn('prompts should be a string', response.json.get('exception_msg'))

0 comments on commit b8263a5

Please sign in to comment.