Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RDISCROWD-5843 GIGWork GPT Helper Component MVP (v1) - Backend #840

Merged
merged 3 commits into from
Apr 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions pybossa/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,3 +847,75 @@ def user_has_partial_answer(short_name=None):
task_id_map = get_user_saved_partial_tasks(sentinel, project_id, current_user.id, task_repo)
response = {"has_answer": bool(task_id_map)}
return Response(json.dumps(response), status=200, mimetype="application/json")


@jsonpify
@csrf.exempt
@blueprint.route('/llm', defaults={'model_name': None}, methods=['POST'])
@blueprint.route('/llm/<model_name>', methods=['POST'])
@ratelimit(limit=ratelimits.get('LIMIT'), per=ratelimits.get('PER'))
def large_language_model(model_name):
"""Large language model endpoint
The POST data format can be:
data = {
"instances": [
{
"context": "Identify the company name: Microsoft will release Windows 20 next year.",
"temperature": 1.0,
"seed": 12345,
"repetition_penalty": 1.05,
"num_beams": 1,
}
]
}
or
data = {
"prompts": "Identify the company name: Microsoft will release Windows 20 next year."
}
"""
if model_name is None:
model_name = 'flan-ul2'
XiChenn marked this conversation as resolved.
Show resolved Hide resolved
endpoints = current_app.config.get('LLM_ENDPOINTS')
model_endpoint = endpoints.get(model_name.lower())

if not model_endpoint:
return abort(403, f'Model with name {model_name} does not exist.')
XiChenn marked this conversation as resolved.
Show resolved Hide resolved

proxies = current_app.config.get('PROXIES')
cert = current_app.config.get('CA_CERT', False)

try:
data = request.get_json(force=True)
except:
return abort(400, "Invalid JSON data")

if "prompts" not in data and "instances" not in data:
return abort(400, "The JSON should have either 'prompts' or 'instances'")

if "prompts" in data:
prompts = data.get("prompts")
if not prompts:
return abort(400, 'prompts should not be empty')
if isinstance(prompts, list):
prompts = prompts[0] # Batch request temporarily NOT supported
if not isinstance(prompts, str):
return abort(400, f'prompts should be a string a list of strings')
data = {
"instances": [
{
"context": prompts + ' ',
"temperature": 1.0,
"seed": 12345,
"repetition_penalty": 1.05,
"num_beams": 1,
}
]
}
XiChenn marked this conversation as resolved.
Show resolved Hide resolved
data = json.dumps(data)
XiChenn marked this conversation as resolved.
Show resolved Hide resolved

r = requests.post(model_endpoint, data=data, proxies=proxies, verify=cert)
out = json.loads(r.text)
predictions = out["predictions"][0]["output"]
response = {"Model: ": model_name, "predictions: ": predictions}

return Response(json.dumps(response), status=r.status_code, mimetype="application/json")
129 changes: 129 additions & 0 deletions test/test_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,16 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with PYBOSSA. If not, see <http://www.gnu.org/licenses/>.
import json
import unittest
from datetime import datetime
from unittest.mock import patch, MagicMock

from dateutil.parser import parse
from werkzeug.http import parse_cookie

from pybossa.api import large_language_model
from pybossa.core import create_app
from test import Test


Expand Down Expand Up @@ -47,3 +52,127 @@ class TestAPI(Test):

endpoints = ['project', 'task', 'taskrun', 'user']

class TestLargeLanguageModel(unittest.TestCase):
def setUp(self):
self.app = create_app(run_as_server=False)
self.app.config['LLM_ENDPOINTS'] = {
'flan-ul2': 'http://localhost:5000/llm'
}
self.client = self.app.test_client()

@patch('requests.post')
def test_valid_request(self, mock_post):
response_data = {
"predictions": [{
"output": "Microsoft"
}]
}
mock_post.return_value = MagicMock(status_code=200, text=json.dumps(response_data))
with self.app.test_request_context('/', json={
"prompts": "Identify the company name: Microsoft will release Windows 20 next year."
}):
response = large_language_model('flan-ul2')
self.assertEqual(response.status_code, 200)
self.assertIn('Model: ', response.json)
self.assertIn('predictions: ', response.json)

@patch('requests.post')
def test_valid_request_with_list_of_prompts(self, mock_post):
response_data = {
"predictions": [{
"output": "Microsoft"
}]
}
mock_post.return_value = MagicMock(status_code=200,
text=json.dumps(response_data))
with self.app.test_request_context('/', json={
"prompts": ["Identify the company name: Microsoft will release Windows 20 next year.", "test"]
}):
response = large_language_model('flan-ul2')
self.assertEqual(response.status_code, 200)
self.assertIn('Model: ', response.json)
self.assertIn('predictions: ', response.json)

@patch('requests.post')
def test_valid_request_with_instances_key_in_json(self, mock_post):
response_data = {
"predictions": [{
"output": "Microsoft"
}]
}
mock_post.return_value = MagicMock(status_code=200,
text=json.dumps(response_data))
with self.app.test_request_context('/', json={
"instances": [
{
"context": "Identify the company name: Microsoft will release Windows 20 next year.",
"temperature": 1.0,
"seed": 12345,
"repetition_penalty": 1.05,
"num_beams": 1,
}
]
}):
response = large_language_model('flan-ul2')
self.assertEqual(response.status_code, 200)
self.assertIn('Model: ', response.json)
self.assertIn('predictions: ', response.json)

@patch('requests.post')
def test_invalid_model_name(self, mock_post):
mock_post.return_value = MagicMock(status_code=403, text='{"error": "Model not found"}')
with self.app.test_request_context('/', json={
"prompts": "Identify the company name: Microsoft will release Windows 20 next year."
}):
response = large_language_model('invalid-model')
self.assertEqual(response.status_code, 403)
self.assertIn('Model with name', response.json.get('exception_msg'))

@patch('requests.post')
def test_invalid_json(self, mock_post):
with self.app.test_request_context('/', data='invalid-json', content_type='application/json'):
response = large_language_model('flan-ul2')
self.assertEqual(response.status_code, 400)
self.assertIn('Invalid JSON', response.json.get('exception_msg'))

@patch('requests.post')
def test_invalid_post_data(self, mock_post):
response_data = {
"predictions": [{
"output": "Microsoft"
}]
}
mock_post.return_value = MagicMock(status_code=200,
text=json.dumps(response_data))
with self.app.test_request_context('/', json={
"invalid": [
{
"context": "Identify the company name: Microsoft will release Windows 20 next year.",
"temperature": 1.0,
"seed": 12345,
"repetition_penalty": 1.05,
"num_beams": 1,
}
]
}):
response = large_language_model('flan-ul2')
self.assertEqual(response.status_code, 400)
self.assertIn('The JSON should have', response.json.get('exception_msg'))

@patch('requests.post')
def test_empty_prompts(self, mock_post):
with self.app.test_request_context('/', json={
"prompts": ""
}):
response = large_language_model('flan-ul2')
self.assertEqual(response.status_code, 400)
self.assertIn('prompts should not be empty', response.json.get('exception_msg'))

@patch('requests.post')
def test_invalid_prompts_type(self, mock_post):
with self.app.test_request_context('/', json={
"prompts": 123
}):
response = large_language_model('flan-ul2')
self.assertEqual(response.status_code, 400)
self.assertIn('prompts should be a string', response.json.get('exception_msg'))