Skip to content

Commit

Permalink
Add Azure support for transcription (#7)
Browse files Browse the repository at this point in the history
* Azure support

* Refactor settings module to include new audio and chat model names

* Fix syntax

* add setting

* fix settings

* add config validations

* bump version and add release notes

---------

Co-authored-by: Aakash Singh <mail@singhaakash.dev>
  • Loading branch information
Ashesh3 and sainak committed May 17, 2024
1 parent 2d833a4 commit 5d1ef9b
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 17 deletions.
6 changes: 6 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
History
=======

0.2.0 (2024-05-18)
------------------

* Add support for Azure OpenAI API.


0.1.1 (2024-04-16)
------------------

Expand Down
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ scribe_plug = Plug(
version="@master",
configs={
"TRANSCRIBE_SERVICE_PROVIDER_API_KEY": "secret",
"API_PROVIDER": "openai", # or "azure"
"AZURE_API_VERSION": "", # required if API_PROVIDER is "azure"
"AZURE_ENDPOINT": "", # required if API_PROVIDER is "azure"
"AUDIO_MODEL_NAME": "", # model name for OpenAI or custom deployment name for Azure
"CHAT_MODEL_NAME": "", # model name for OpenAI or custom deployment name for Azure
},
)
plugs = [scribe_plug]
Expand All @@ -40,6 +45,11 @@ plugs = [scribe_plug]
The following configurations variables are available for Care Scribe:

- `TRANSCRIBE_SERVICE_PROVIDER_API_KEY`: API key for the transcribe service provider (OpenAI whisper or Google Speech to Text)
- `API_PROVIDER`: The API provider to use for transcription. Can be either "openai" or "azure".
- `AZURE_API_VERSION`: The version of the Azure API to use. This is required if `API_PROVIDER` is set to "azure".
- `AZURE_ENDPOINT`: The endpoint for the Azure API. This is required if `API_PROVIDER` is set to "azure".
- `AUDIO_MODEL_NAME`: The model name for OpenAI or the custom deployment name for Azure.
- `CHAT_MODEL_NAME`: The model name for OpenAI or the custom deployment name for Azure.

The plugin will try to find the API key from the config first and then from the environment variable.

Expand Down
2 changes: 1 addition & 1 deletion care_scribe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__author__ = """Open Healthcare Network"""
__email__ = "info@ohc.network"
__version__ = "0.1.1"
__version__ = "0.2.0"
31 changes: 27 additions & 4 deletions care_scribe/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,20 @@ def validate(self) -> None:
f'Please set the "{setting}" in the environment or the {PLUGIN_NAME} plugin config.'
)

if getattr(self, "API_PROVIDER") not in ("openai", "azure"):
raise ImproperlyConfigured(
'Invalid value for "API_PROVIDER". '
'Please set the "API_PROVIDER" to "openai" or "azure".'
)

if getattr(self, "API_PROVIDER") == "azure":
for setting in ("AZURE_API_VERSION", "AZURE_ENDPOINT"):
if not getattr(self, setting):
raise ImproperlyConfigured(
f'The "{setting}" setting is required when using Azure API. '
f'Please set the "{setting}" in the environment or the {PLUGIN_NAME} plugin config.'
)

def reload(self) -> None:
"""
Deletes the cached attributes so they will be recomputed next time they are accessed.
Expand All @@ -97,18 +111,27 @@ def reload(self) -> None:
delattr(self, "_user_settings")


TSP_API_KEY = "TRANSCRIBE_SERVICE_PROVIDER_API_KEY"

REQUIRED_SETTINGS = {
TSP_API_KEY,
"TRANSCRIBE_SERVICE_PROVIDER_API_KEY",
"AUDIO_MODEL_NAME",
"CHAT_MODEL_NAME",
"API_PROVIDER",
}

DEFAULTS = {TSP_API_KEY: "test"}
DEFAULTS = {
"TRANSCRIBE_SERVICE_PROVIDER_API_KEY": "",
"AUDIO_MODEL_NAME": "whisper-1",
"CHAT_MODEL_NAME": "gpt-4-turbo",
"API_PROVIDER": "openai",
"AZURE_API_VERSION": "",
"AZURE_ENDPOINT": "",
}

plugin_settings = PluginSettings(
PLUGIN_NAME, defaults=DEFAULTS, required_settings=REQUIRED_SETTINGS
)


@receiver(setting_changed)
def reload_plugin_settings(*args, **kwargs) -> None:
setting = kwargs["setting"]
Expand Down
29 changes: 19 additions & 10 deletions care_scribe/tasks/scribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,33 @@

import requests
from celery import shared_task
from openai import OpenAI
from openai import OpenAI, AzureOpenAI

from care_scribe.models.scribe import Scribe
from care_scribe.models.scribe_file import ScribeFile
from care_scribe.settings import plugin_settings

logger = logging.getLogger(__name__)

OpenAIClient = None
AiClient = None


def get_openai_client():
global OpenAIClient
if OpenAIClient is None:
OpenAIClient = OpenAI(
api_key=plugin_settings.TRANSCRIBE_SERVICE_PROVIDER_API_KEY
)
return OpenAIClient
global AiClient
if AiClient is None:
if plugin_settings.API_PROVIDER == 'azure':
AiClient = AzureOpenAI(
api_key=plugin_settings.TRANSCRIBE_SERVICE_PROVIDER_API_KEY,
api_version=plugin_settings.AZURE_API_VERSION,
azure_endpoint=plugin_settings.AZURE_ENDPOINT
)
elif plugin_settings.API_PROVIDER == 'openai':
AiClient = OpenAI(
api_key=plugin_settings.TRANSCRIBE_SERVICE_PROVIDER_API_KEY
)
else:
raise Exception('Invalid API_PROVIDER in plugin_settings')
return AiClient


prompt_1 = """
Expand Down Expand Up @@ -86,7 +95,7 @@ def process_ai_form_fill(external_id):
buffer.name = "file.mp3"

transcription = get_openai_client().audio.transcriptions.create(
model="whisper-1", file=buffer
model=plugin_settings.AUDIO_MODEL_NAME, file=buffer # This can be the model name (OPENAI) or the custom deployment name (AZURE)
)
transcript += transcription.text
logger.info(f"Transcript: {transcript}")
Expand All @@ -103,7 +112,7 @@ def process_ai_form_fill(external_id):

# Process the transcript with Ayushma
ai_response = get_openai_client().chat.completions.create(
model="gpt-4-turbo-preview",
model=plugin_settings.CHAT_MODEL_NAME, # This can be the model name (OPENAI) or the custom deployment name (AZURE)
response_format={"type": "json_object"},
max_tokens=4096,
temperature=0,
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.1.1
current_version = 0.2.0
commit = True
tag = True

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,6 @@
test_suite="tests",
tests_require=test_requirements,
url="https://github.com/coronasafe/care_scribe",
version="0.1.1",
version="0.2.0",
zip_safe=False,
)

0 comments on commit 5d1ef9b

Please sign in to comment.