Skip to content

Commit

Permalink
Custom providers (#122)
Browse files Browse the repository at this point in the history
* Removed Python backend files that are no longer used (everything in `promptengine`)

* Added `providers` subdomain, with `CustomProviderProtocol`, `provider` decorator, and global singleton `ProviderRegistry`

* Added a tab for custom providers, and a dropzone, in ChainForge global settings UI

* List custom providers in the Global Settings screen once added. 

* Added ability to remove custom providers by clicking X. 

* Make custom funcs sync but call them async'ly.

* Add Cohere custom provider example in examples/

*Cache the custom provider scripts and load them upon page load

* Rebuild react and update package version

* Bug fix when custom provider is deleted and settings screen is opened on the deleted custom provider
  • Loading branch information
ianarawjo committed Aug 27, 2023
1 parent f43861f commit 0134dbf
Show file tree
Hide file tree
Showing 39 changed files with 1,132 additions and 1,347 deletions.
55 changes: 55 additions & 0 deletions chainforge/examples/custom_provider_cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
A simple custom model provider to add to the ChainForge interface,
to support Cohere AI text completions through their Python API.
NOTE: You must have the `cohere` package installed and an API key.
"""
from chainforge.providers import provider
import cohere

# Init the Cohere client (replace with your Cohere API Key)
co = cohere.Client('<YOUR_API_KEY>')

# JSON schemas to pass react-jsonschema-form, one for this endpoints' settings and one to describe the settings UI.
COHERE_SETTINGS_SCHEMA = {
"settings": {
"temperature": {
"type": "number",
"title": "temperature",
"description": "Controls the 'creativity' or randomness of the response.",
"default": 0.75,
"minimum": 0,
"maximum": 5.0,
"multipleOf": 0.01,
},
"max_tokens": {
"type": "integer",
"title": "max_tokens",
"description": "Maximum number of tokens to generate in the response.",
"default": 100,
"minimum": 1,
"maximum": 1024,
},
},
"ui": {
"temperature": {
"ui:help": "Defaults to 1.0.",
"ui:widget": "range"
},
"max_tokens": {
"ui:help": "Defaults to 100.",
"ui:widget": "range"
},
}
}

# Our custom model provider for Cohere's text generation API.
@provider(name="Cohere",
emoji="🖇",
models=['command', 'command-nightly', 'command-light', 'command-light-nightly'],
rate_limit="sequential", # enter "sequential" for blocking; an integer N > 0 means N is the max mumber of requests per minute.
settings_schema=COHERE_SETTINGS_SCHEMA)
def CohereCompletion(prompt: str, model: str, temperature: float = 0.75, **kwargs) -> str:
print(f"Calling Cohere model {model} with prompt '{prompt}'...")
response = co.generate(model=model, prompt=prompt, temperature=temperature, **kwargs)
return response.generations[0].text
226 changes: 210 additions & 16 deletions chainforge/flask_app.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json, os, asyncio, sys, traceback
import json, os, sys, asyncio, time
from dataclasses import dataclass
from enum import Enum
from typing import List
from statistics import mean, median, stdev
from flask import Flask, request, jsonify, render_template
from flask_cors import CORS
from chainforge.promptengine.utils import LLM, call_dalai
from chainforge.providers.dalai import call_dalai
from chainforge.providers import ProviderRegistry
import requests as py_requests

""" =================
Expand All @@ -16,6 +17,7 @@
# Setup Flask app to serve static version of React front-end
HOSTNAME = "localhost"
PORT = 8000
# SESSION_TOKEN = secrets.token_hex(32)
BUILD_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'react-server', 'build')
STATIC_DIR = os.path.join(BUILD_DIR, 'static')
app = Flask(__name__, static_folder=STATIC_DIR, template_folder=BUILD_DIR)
Expand All @@ -27,10 +29,6 @@
CACHE_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'cache')
EXAMPLES_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'examples')

LLM_NAME_MAP = {}
for model in LLM:
LLM_NAME_MAP[model.value] = model

class MetricType(Enum):
KeyValue = 0
KeyValue_Numeric = 1
Expand Down Expand Up @@ -221,6 +219,22 @@ def run_over_responses(eval_func, responses: list, scope: str) -> list:
}
return responses

async def make_sync_call_async(sync_method, *args, **params):
"""
Makes a blocking synchronous call asynchronous, so that it can be awaited.
NOTE: This is necessary for LLM APIs that do not yet support async (e.g. Google PaLM).
"""
loop = asyncio.get_running_loop()
method = sync_method
if len(params) > 0:
def partial_sync_meth(*a):
return sync_method(*a, **params)
method = partial_sync_meth
return await loop.run_in_executor(None, method, *args)

def exclude_key(d, key_to_exclude):
return {k: v for k, v in d.items() if k != key_to_exclude}


""" ===================
FLASK SERVER ROUTES
Expand Down Expand Up @@ -261,7 +275,7 @@ def executepy():
data = request.get_json()

# Check that all required info is here:
if not set(data.keys()).issuperset({'id', 'code', 'responses', 'scope'}):
if not set(data.keys()).issuperset({'id', 'code', 'responses', 'scope', 'token'}):
return jsonify({'error': 'POST data is improper format.'})
if not isinstance(data['id'], str) or len(data['id']) == 0:
return jsonify({'error': 'POST data id is improper format (length 0 or not a string).'})
Expand All @@ -273,7 +287,7 @@ def executepy():
if (isinstance(responses, str) or not isinstance(responses, list)) or (len(responses) > 0 and any([not isinstance(r, dict) for r in responses])):
return jsonify({'error': 'POST data responses is improper format.'})

# add the path to any scripts to the path:
# Add the path to any scripts to the path:
try:
if 'script_paths' in data:
for script_path in data['script_paths']:
Expand Down Expand Up @@ -367,7 +381,7 @@ def fetchOpenAIEval():
POST'd data should be in form:
{
name: <str> # The name of the eval to grab (without .cforge extension)
'name': <str> # The name of the eval to grab (without .cforge extension)
}
"""
# Verify post'd data
Expand Down Expand Up @@ -404,9 +418,8 @@ def fetchOpenAIEval():
return jsonify({'error': f"Error creating a new directory 'oaievals' at filepath {oaievals_cache_dir}: {str(e)}"})

# Download the preconverted OpenAI eval from the GitHub main branch for ChainForge
import requests
_url = f"https://raw.githubusercontent.com/ianarawjo/ChainForge/main/chainforge/oaievals/{evalname}.cforge"
response = requests.get(_url)
response = py_requests.get(_url)

# Check if the request was successful (status code 200)
if response.status_code == 200:
Expand Down Expand Up @@ -449,9 +462,9 @@ def makeFetchCall():
POST'd data should be in form:
{
url: <str> # the url to fetch from
headers: <dict> # a JSON object of the headers
body: <dict> # the request payload, as JSON
'url': <str> # the url to fetch from
'headers': <dict> # a JSON object of the headers
'body': <dict> # the request payload, as JSON
}
"""
# Verify post'd data
Expand Down Expand Up @@ -486,8 +499,6 @@ async def callDalai():
if not set(data.keys()).issuperset({'prompt', 'model', 'server', 'n', 'temperature'}):
return jsonify({'error': 'POST data is improper format.'})

data['model'] = LLM_NAME_MAP[data['model']]

try:
query, response = await call_dalai(**data)
except Exception as e:
Expand All @@ -498,6 +509,189 @@ async def callDalai():
return ret


@app.route('/app/initCustomProvider', methods=['POST'])
def initCustomProvider():
"""
Initalizes custom model provider(s) defined in a Python script,
and returns specs for the front-end UI provider dropdown and the providers' settings window.
POST'd data should be in form:
{
'code': <str> # the Python script to save + execute,
}
"""
# Verify post'd data
data = request.get_json()
if 'code' not in data:
return jsonify({'error': 'POST data is improper format.'})

# Sanity check that the code actually registers a provider
if '@provider' not in data['code']:
return jsonify({'error': """Did not detect a @provider decorator. Custom provider scripts should register at least one @provider.
Do `from chainforge.providers import provider` and decorate your provider completion function with @provider."""})

# Establish the custom provider script cache directory
provider_scripts_dir = os.path.join(CACHE_DIR, "provider_scripts")
if not os.path.isdir(provider_scripts_dir):
# Create the directory
try:
os.mkdir(provider_scripts_dir)
except Exception as e:
return jsonify({'error': f"Error creating a new directory 'provider_scripts' at filepath {provider_scripts_dir}: {str(e)}"})

# For keeping track of what script registered providers came from
script_id = str(round(time.time()*1000))
ProviderRegistry.set_curr_script_id(script_id)
ProviderRegistry.watch_next_registered()

# Attempt to run the Python script, in context
try:
exec(data['code'], globals(), None)

# This should have registered one or more new CustomModelProviders.
except Exception as e:
return jsonify({'error': f'Error while executing custom provider code:\n{str(e)}'})

# Check whether anything was updated, and what
new_registries = ProviderRegistry.last_registered()
if len(new_registries) == 0: # Determine whether there's at least one custom provider.
return jsonify({'error': 'Did not detect any custom providers added to the registry. Make sure you are registering your provider with @provider correctly.'})

# At least one provider was registered; detect if it had a past script id and remove those file(s) from the cache
if any((v is not None for v in new_registries.values())):
# For every registered provider that was overwritten, remove the cache'd script(s) associated with it:
past_script_ids = [v for v in new_registries.values() if v is not None]
for sid in past_script_ids:
past_script_path = os.path.join(provider_scripts_dir, f"{sid}.py")
try:
if os.path.isfile(past_script_path):
os.remove(past_script_path)
except Exception as e:
return jsonify({'error': f"Error removing cache'd custom provider script at filepath {past_script_path}: {str(e)}"})

# Get the names and specs of all currently registered CustomModelProviders,
# and pass that info to the front-end (excluding the func):
registered_providers = [exclude_key(d, 'func') for d in ProviderRegistry.get_all()]

# Copy the passed Python script to a local file in the package directory
try:
with open(os.path.join(provider_scripts_dir, f"{script_id}.py"), 'w') as f:
f.write(data['code'])
except Exception as e:
return jsonify({'error': f"Error saving script 'provider_scripts' at filepath {provider_scripts_dir}: {str(e)}"})

# Return all loaded providers
return jsonify({'providers': registered_providers})


@app.route('/app/loadCachedCustomProviders', methods=['POST'])
def loadCachedCustomProviders():
"""
Initalizes all custom model provider(s) in the local provider_scripts directory.
"""
provider_scripts_dir = os.path.join(CACHE_DIR, "provider_scripts")
if not os.path.isdir(provider_scripts_dir):
# No providers to load.
return jsonify({'providers': []})

try:
for file_name in os.listdir(provider_scripts_dir):
file_path = os.path.join(provider_scripts_dir, file_name)
if os.path.isfile(file_path) and os.path.splitext(file_path)[1] == '.py':
# For keeping track of what script registered providers came from
ProviderRegistry.set_curr_script_id(os.path.splitext(file_name)[0])

# Read the Python script
with open(file_path, 'r') as f:
code = f.read()

# Try to execute it in the global context
try:
exec(code, globals(), None)
except Exception as code_exc:
# Remove the script file associated w the failed execution
os.remove(file_path)
raise code_exc
except Exception as e:
return jsonify({'error': f'Error while loading custom providers from cache: \n{str(e)}'})

# Get the names and specs of all currently registered CustomModelProviders,
# and pass that info to the front-end (excluding the func):
registered_providers = [exclude_key(d, 'func') for d in ProviderRegistry.get_all()]

return jsonify({'providers': registered_providers})


@app.route('/app/removeCustomProvider', methods=['POST'])
def removeCustomProvider():
"""
Initalizes custom model provider(s) defined in a Python script,
and returns specs for the front-end UI provider dropdown and the providers' settings window.
POST'd data should be in form:
{
'name': <str> # a name that refers to the registered custom provider in the `ProviderRegistry`
}
"""
# Verify post'd data
data = request.get_json()
name = data.get('name')
if name is None:
return jsonify({'error': 'POST data is improper format.'})

if not ProviderRegistry.has(name):
return jsonify({'error': f'Could not find a custom provider named "{name}"'})

# Get the script id associated with the provider we're about to remove
script_id = ProviderRegistry.get(name).get('script_id')

# Remove the custom provider from the registry
ProviderRegistry.remove(name)

# Attempt to delete associated script from cache
if script_id:
script_path = os.path.join(CACHE_DIR, "provider_scripts", f"{script_id}.py")
if os.path.isfile(script_path):
os.remove(script_path)

return jsonify({'success': True})


@app.route('/app/callCustomProvider', methods=['POST'])
async def callCustomProvider():
"""
Calls a custom model provider and returns the response.
POST'd data should be in form:
{
'name': <str> # the name of the provider in the `ProviderRegistry`
'params': <dict> # the params (prompt, model, etc) to pass to the provider function.
}
"""
# Verify post'd data
data = request.get_json()
if not set(data.keys()).issuperset({'name', 'params'}):
return jsonify({'error': 'POST data is improper format.'})

# Load the name of the provider
name = data['name']
params = data['params']

# Double-check that the custom provider exists in the registry, and (if passed) a model with that name exists
provider_spec = ProviderRegistry.get(name)
if provider_spec is None:
return jsonify({'error': f'Could not find provider named {name}. Perhaps you need to import a custom provider script?'})

# Call + await the custom provider function, passing in the JSON payload as kwargs
try:
response = await make_sync_call_async(provider_spec.get('func'), **params)
except Exception as e:
return jsonify({'error': f'Error encountered while calling custom provider function: {str(e)}'})

# Return the response
return jsonify({'response': response})


def run_server(host="", port=8000, cmd_args=None):
global HOSTNAME, PORT
HOSTNAME = host
Expand Down
1 change: 0 additions & 1 deletion chainforge/promptengine/__init__.py

This file was deleted.

Loading

0 comments on commit 0134dbf

Please sign in to comment.