Skip to content

Commit

Permalink
Added support for synchrous Reka using the OpenAI SDK format.
Browse files Browse the repository at this point in the history
  • Loading branch information
TootyFrooties committed Apr 23, 2024
1 parent 339c22e commit b1f4f1e
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 0 deletions.
46 changes: 46 additions & 0 deletions docs/examples/reka.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Structured Outputs using Reka
You can now also use Reka models for inference by using from_reka.

The examples are using reka core. For more detailed Reka documentation visit [Reka docs](https://docs.reka.ai/index.html)

## Reka API
To use Reka you need to obtain a Reka API key.
Goto [Reka AI](https://reka.ai/) click on API Access and login. Select API Keys from the left menu and then select
Create API key to create a new key. You need to fund your account before use.

Currently Reka does not support async

## Use example
Some pip packages need to be installed to use the example:
```
pip install instructor reka-api pydantic
```

```
An example:
```python
import os
from pydantic import BaseModel, Field
from typing import List
import reka
from instructor import from_reka, Mode
class UserDetails(BaseModel):
name: str
age: int
# enables `response_model` in chat call
client = from_reka(api_key=os.environ.get("REKA_API_KEY"))
resp = client.chat.completions.create(
response_model=UserDetails,
messages=[{"role": "user", "content": "Extract John Doe is 30 years old."}],
temperature=0,
)
print(user_info.name)
#> John Doe
print(user_info.age)
#> 30
22 changes: 22 additions & 0 deletions examples/reka/reka.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import instructor
from pydantic import BaseModel
import os


client = instructor.from_reka(api_key = os.environ.get("REKA_API_KEY"))
class UserInfo(BaseModel):
name: str
age: int


user_info = client.chat.completions.create(
model="reka-core",
temperature='0.2',
response_model=UserInfo,
messages=[{"role": "user", "content": "Extract John Doe is 30 years old."}],
)

print(user_info.name)
#> John Doe
print(user_info.age)
#> 30
5 changes: 5 additions & 0 deletions instructor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,8 @@
from .client_cohere import from_cohere

__all__ += ["from_cohere"]

if importlib.util.find_spec("reka") is not None:
from .client_reka import from_reka

__all__ += ["from_reka"]
2 changes: 2 additions & 0 deletions instructor/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def create_partial(
**kwargs,
) -> Generator[T, None, None]:
assert self.provider != Provider.ANTHROPIC, "Anthropic doesn't support partial"
assert self.provider != Provider.REKA, "Reka doesn't support partial"

kwargs["stream"] = True

Expand All @@ -111,6 +112,7 @@ def create_iterable(
**kwargs,
) -> Iterable[T]:
assert self.provider != Provider.ANTHROPIC, "Anthropic doesn't support iterable"
assert self.provider != Provider.REKA, "Reka doesn't support iterable"

kwargs["stream"] = True
kwargs = self.handle_kwargs(kwargs)
Expand Down
123 changes: 123 additions & 0 deletions instructor/client_reka.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Future imports to ensure compatibility with Python 3.9
from __future__ import annotations

import reka
import instructor
from typing import overload, Optional
import inspect
from datetime import datetime
from openai.types.chat import ChatCompletion
import logging


class RekaClient:
def __init__(self, api_key):
reka.API_KEY = api_key


@overload
def from_reka(

mode: instructor.Mode = instructor.Mode.MD_JSON,
**kwargs,
) -> instructor.Instructor: ...


def from_reka(api_key: Optional[str] =None,
mode: instructor.Mode = instructor.Mode.MD_JSON, model='reka-flash',
**kwargs,
) -> instructor.Instructor | instructor.AsyncInstructor:
client = RekaClient(api_key)
assert mode in {
instructor.Mode.MD_JSON,
}, "Mode be one of {instructor.Mode.MD_JSON}"

assert isinstance(
client, (RekaClient, reka.chat)
), "Client must be an instance of reka.chat or reka.completion"
client.default_model = model
return instructor.Instructor(
client=client,
create=instructor.patch(create=lambda **kw: reka_chat_wrapper(client.default_model, **kw), mode=mode),
provider=instructor.Provider.REKA,
mode=mode,
**kwargs,
)


def reka_chat_wrapper(default_model, **kwargs):
model = kwargs.pop('model',default_model)
kwargs['model_name']=model
kwargs = reformat_openai_request_as_reka(kwargs)

try:
response = reka.chat(**kwargs)
completion = reformat_reka_resp_as_chat_completion(response, kwargs.get('model_name', 'reka-flash'))
return completion
except TypeError as e:
logging.error(f"TypeError encountered: {e}")
raise


def reformat_openai_request_as_reka(kwargs):
messages = kwargs.pop('messages', [])
conversation_history = kwargs.pop('conversation_history', [])

# Process messages
new_messages = []
for msg in messages:
msg_type = 'model' if msg.get('role') == "assistant" else 'human'
text = " ".join(str(part) for part in msg['content']) if isinstance(msg['content'], list) else msg['content']
new_messages.append({'type': msg_type, 'text': text})

# Update conversation history
if new_messages:
conversation_history.append(new_messages[0])
for current_msg in new_messages[1:]:
last_msg = conversation_history[-1]
if last_msg['type'] == current_msg['type']:
last_msg['text'] += "\n" + current_msg['text']
else:
conversation_history.append(current_msg)
kwargs['conversation_history'] = conversation_history

# Adjust OpenAI-specific parameters for Reka API
param_mapping = {'max_tokens': 'request_output_length', 'seed': 'random_seed', 'stop': 'stop_words', 'top_p': 'runtime_top_p'}
for openai_arg, reka_arg in param_mapping.items():
if openai_arg in kwargs:
kwargs[reka_arg] = kwargs.pop(openai_arg)

# Validate against Reka's API signature
chat_params = inspect.signature(reka.chat).parameters
allowed_args = set(chat_params.keys())
kwargs = {key: value for key, value in kwargs.items() if key in allowed_args}
return kwargs
def reformat_reka_resp_as_chat_completion(response, model_name):
finish_reason = response.get("finish_reason", "unknown")
content = response.get("text", "")
generated_tokens = response.get("metadata", {}).get("generated_tokens", 0)
input_tokens = response.get("metadata", {}).get("input_tokens", 0)
total_tokens = input_tokens + generated_tokens

chat_completion_data = {
"choices": [{
"finish_reason": finish_reason,
"index": 0,
"message": {
"content": content,
"role": "assistant"
},
"logprobs": None
}],
"created": int(datetime.now().timestamp()),
"id": f"reka-{datetime.now().timestamp()}",
"model": model_name or "reka-flash",
"object": "chat.completion",
"usage": {
"completion_tokens": generated_tokens,
"prompt_tokens": input_tokens,
"total_tokens": total_tokens
}
}
completion_instance = ChatCompletion(**chat_completion_data)
return completion_instance
3 changes: 3 additions & 0 deletions instructor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class Provider(Enum):
GROQ = "groq"
MISTRAL = "mistral"
COHERE = "cohere"
REKA = "reka"
UNKNOWN = "unknown"


Expand All @@ -43,6 +44,8 @@ def get_provider(base_url: str) -> Provider:
return Provider.MISTRAL
elif "cohere" in str(base_url):
return Provider.COHERE
elif "reka" in str(base_url):
return Provider.REKA
return Provider.UNKNOWN


Expand Down
12 changes: 12 additions & 0 deletions tests/test_new_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,3 +372,15 @@ def test_client_mistral_response():
)
assert user.name == "Jason"
assert user.age == 10


def test_client_reka_response():
client = instructor.from_reka(api_key=os.getenv("REKA_API_KEY"))

user = client.chat.completions.create(
response_model=User,
messages=[{"role": "user", "content": "Jason is 10"}],
temperature=0,
)
assert user.name == "Jason"
assert user.age == 10

0 comments on commit b1f4f1e

Please sign in to comment.