-
Notifications
You must be signed in to change notification settings - Fork 75
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add gigachat * Edited code style * fixed comments * Create generate goals api * Add tests
- Loading branch information
1 parent
bc61e2e
commit 141a8db
Showing
18 changed files
with
332 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
{ | ||
"max_tokens": 256, | ||
"temperature": 0.4, | ||
"top_p": 1.0, | ||
"frequency_penalty": 0, | ||
"presence_penalty": 0 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
{ | ||
"prompt": "Сформулируй очень краткое описание целей ассистента, которые определены в данном запросе.\n\nПример:\nЗапрос: '''ЗАДАНИЕ: Твоё имя - Ассистент по Жизненному Коучингу. Тебя создала компания Rhoades & Co. Помоги человеку поставить цель в жизни и определить, как достичь её шаг за шагом. Не обсуждай другие темы. Отвечай с сочувствием. Задавай открытые вопросы, чтобы помочь человеку лучше понять себя.\nИНСТРУКЦИЯ: Человек входит в разговор. Представься кратко. Помоги ему поставить цель и достичь её. Ты можешь узнать о его жизненных приоритетах и предпочтительных областях концентрации и предложить полезные идеи. Ты должен задать ОДИН вопрос или НИ ОДНОГО вопроса, НЕ два или три. Остановись после того, как задашь первый вопрос.'''\nРезультат: Помогает пользователю поставить и достичь жизненных целей." | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
name: gigachat | ||
display_name: GigaChat | ||
component_type: Generative | ||
model_type: NN-based | ||
is_customizable: false | ||
author: publisher@deeppavlov.ai | ||
description: GigaChat is a service that can interact with the user in a dialogue format, write code, | ||
create texts and pictures at the user’s request. At the same time, | ||
GigaChat considers controversial issues or provocations. | ||
ram_usage: 100M | ||
gpu_usage: null | ||
group: services | ||
connector: | ||
protocol: http | ||
timeout: 120.0 | ||
url: http://gigachat-api:8187/respond | ||
dialog_formatter: null | ||
response_formatter: null | ||
previous_services: null | ||
required_previous_services: null | ||
state_manager_method: null | ||
tags: null | ||
endpoint: respond | ||
service: services/gigachat_api_lm/service_configs/gigachat-api | ||
date_created: '2023-12-25T09:45:32' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
FROM python:3.10 | ||
|
||
WORKDIR /src | ||
|
||
COPY ./services/gigachat_api_lm/requirements.txt /src/requirements.txt | ||
RUN pip install -r /src/requirements.txt | ||
|
||
ARG PRETRAINED_MODEL_NAME_OR_PATH | ||
ENV PRETRAINED_MODEL_NAME_OR_PATH ${PRETRAINED_MODEL_NAME_OR_PATH} | ||
|
||
COPY services/gigachat_api_lm /src | ||
COPY common /src/common | ||
|
||
CMD gunicorn --workers=1 server:app -b 0.0.0.0:${SERVICE_PORT} --timeout=300 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
flask==1.1.1 | ||
itsdangerous==2.0.1 | ||
gunicorn==19.9.0 | ||
requests==2.22.0 | ||
sentry-sdk[flask]==0.14.1 | ||
healthcheck==1.3.3 | ||
jinja2<=3.0.3 | ||
Werkzeug<=2.0.3 | ||
gigachat==0.1.10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import json | ||
import logging | ||
import os | ||
import time | ||
import sentry_sdk | ||
|
||
from gigachat import GigaChat | ||
from gigachat.models import Chat | ||
from flask import Flask, request, jsonify | ||
from sentry_sdk.integrations.flask import FlaskIntegration | ||
from common.prompts import META_GOALS_PROMPT_RU | ||
|
||
|
||
sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()]) | ||
|
||
|
||
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH") | ||
logger.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}") | ||
GIGACHAT_ROLES = ["assistant", "user"] | ||
|
||
app = Flask(__name__) | ||
logging.getLogger("werkzeug").setLevel("WARNING") | ||
DEFAULT_CONFIGS = { | ||
"GigaChat:1.3.23.1": json.load(open("common/generative_configs/gigachat.json", "r")), | ||
} | ||
|
||
|
||
def generate_responses( | ||
context, | ||
gigachat_api_key, | ||
gigachat_org, | ||
prompt, | ||
generation_params, | ||
continue_last_uttr=False, | ||
): | ||
assert gigachat_api_key, logger.error("Error: GigaChat API key is not specified in env") | ||
giga = GigaChat(credentials=gigachat_api_key, verify_ssl_certs=False) | ||
|
||
s = len(context) % 2 | ||
messages = [ | ||
{"role": "system", "content": prompt}, | ||
] | ||
messages += [ | ||
{ | ||
"role": f"{GIGACHAT_ROLES[(s + uttr_id) % 2]}", | ||
"content": uttr, | ||
} | ||
for uttr_id, uttr in enumerate(context) | ||
] | ||
logger.info(f"context inside generate_responses seen as: {messages}") | ||
|
||
payload = Chat(messages=messages, scope=gigachat_org, **generation_params) | ||
response = giga.chat(payload) | ||
|
||
outputs = [resp.message.content.strip() for resp in response.choices] | ||
|
||
return outputs | ||
|
||
|
||
@app.route("/ping", methods=["POST"]) | ||
def ping(): | ||
return "pong" | ||
|
||
|
||
@app.route("/respond", methods=["POST"]) | ||
def respond(): | ||
st_time = time.time() | ||
contexts = request.json.get("dialog_contexts", []) | ||
prompts = request.json.get("prompts", []) | ||
configs = request.json.get("configs", None) | ||
configs = [None] * len(prompts) if configs is None else configs | ||
configs = [DEFAULT_CONFIGS[PRETRAINED_MODEL_NAME_OR_PATH] if el is None else el for el in configs] | ||
|
||
if len(contexts) > 0 and len(prompts) == 0: | ||
prompts = [""] * len(contexts) | ||
|
||
gigachat_api_keys = request.json.get("gigachat_credentials", []) | ||
gigachat_orgs = request.json.get("gigachat_scopes", None) | ||
gigachat_orgs = [None] * len(contexts) if gigachat_orgs is None else gigachat_orgs | ||
|
||
try: | ||
responses = [] | ||
for context, gigachat_api_key, gigachat_org, prompt, config in zip( | ||
contexts, gigachat_api_keys, gigachat_orgs, prompts, configs | ||
): | ||
curr_responses = [] | ||
outputs = generate_responses(context, gigachat_api_key, gigachat_org, prompt, config) | ||
for response in outputs: | ||
if len(response) >= 2: | ||
curr_responses += [response] | ||
else: | ||
curr_responses += [""] | ||
responses += [curr_responses] | ||
|
||
except Exception as exc: | ||
logger.exception(exc) | ||
sentry_sdk.capture_exception(exc) | ||
responses = [[""]] * len(contexts) | ||
|
||
logger.info(f"gigachat-api result: {responses}") | ||
total_time = time.time() - st_time | ||
logger.info(f"gigachat-api exec time: {total_time:.3f}s") | ||
return jsonify(responses) | ||
|
||
|
||
@app.route("/generate_goals", methods=["POST"]) | ||
def generate_goals(): | ||
st_time = time.time() | ||
|
||
prompts = request.json.get("prompts", None) | ||
prompts = [] if prompts is None else prompts | ||
configs = request.json.get("configs", None) | ||
configs = [None] * len(prompts) if configs is None else configs | ||
configs = [DEFAULT_CONFIGS[PRETRAINED_MODEL_NAME_OR_PATH] if el is None else el for el in configs] | ||
|
||
gigachat_api_keys = request.json.get("gigachat_credentials", []) | ||
gigachat_orgs = request.json.get("gigachat_scopes", None) | ||
gigachat_orgs = [None] * len(prompts) if gigachat_orgs is None else gigachat_orgs | ||
|
||
try: | ||
responses = [] | ||
for gigachat_api_key, gigachat_org, prompt, config in zip(gigachat_api_keys, gigachat_orgs, prompts, configs): | ||
context = ["Привет", META_GOALS_PROMPT_RU + f"\nПромпт: '''{prompt}'''\nРезультат:"] | ||
goals_for_prompt = generate_responses(context, gigachat_api_key, gigachat_org, "", config)[0] | ||
logger.info(f"Generated goals: `{goals_for_prompt}` for prompt: `{prompt}`") | ||
responses += [goals_for_prompt] | ||
|
||
except Exception as exc: | ||
logger.info(exc) | ||
sentry_sdk.capture_exception(exc) | ||
responses = [""] * len(prompts) | ||
|
||
total_time = time.time() - st_time | ||
logger.info(f"gigachat-api generate_goals exec time: {total_time:.3f}s") | ||
return jsonify(responses) |
4 changes: 4 additions & 0 deletions
4
services/gigachat_api_lm/service_configs/gigachat-api/environment.yml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
SERVICE_PORT: 8187 | ||
SERVICE_NAME: gigachat_api | ||
PRETRAINED_MODEL_NAME_OR_PATH: GigaChat:1.3.23.1 | ||
FLASK_APP: server |
29 changes: 29 additions & 0 deletions
29
services/gigachat_api_lm/service_configs/gigachat-api/service.yml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
name: gigachat-api | ||
endpoints: | ||
- respond | ||
- generate_goals | ||
compose: | ||
env_file: | ||
- .env | ||
build: | ||
args: | ||
SERVICE_PORT: 8187 | ||
SERVICE_NAME: gigachat_api | ||
PRETRAINED_MODEL_NAME_OR_PATH: GigaChat:1.3.23.1 | ||
context: . | ||
dockerfile: ./services/gigachat_api_lm/Dockerfile | ||
command: flask run -h 0.0.0.0 -p 8187 | ||
environment: | ||
- FLASK_APP=server | ||
deploy: | ||
resources: | ||
limits: | ||
memory: 100M | ||
reservations: | ||
memory: 100M | ||
volumes: | ||
- ./services/gigachat_api_lm:/src | ||
- ./common:/src/common | ||
ports: | ||
- 8187:8187 | ||
proxy: null |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import requests | ||
from os import getenv | ||
|
||
|
||
# ATTENTION!!! This test is only working if you assign `GIGACHAT_CREDENTIALS` env variable | ||
GIGACHAT_CREDENTIALS = getenv("GIGACHAT_CREDENTIALS", None) | ||
GIGACHAT_SCOPE = getenv("GIGACHAT_SCOPE", None) | ||
assert GIGACHAT_CREDENTIALS, print("No GigaChat credentials is given in env vars") | ||
DEFAULT_CONFIG = {"max_tokens": 64, "temperature": 0.4, "top_p": 1.0, "frequency_penalty": 0, "presence_penalty": 0} | ||
SERVICE_PORT = int(getenv("SERVICE_PORT")) | ||
|
||
|
||
def test_respond(): | ||
url = f"http://0.0.0.0:{SERVICE_PORT}/respond" | ||
contexts = [ | ||
[ | ||
"Привет! Я Маркус. Как ты сегодня?", | ||
"Привет, Маркус! Я в порядке. Как у тебя?", | ||
"У меня все отлично. Какие у тебя планы на сегодня?", | ||
], | ||
["Привет, Маркус! Я в порядке. Как у тебя?", "У меня все отлично. Какие у тебя планы на сегодня?"], | ||
] | ||
prompts = [ | ||
"Отвечай как дружелюбный чатбот.", | ||
"Отвечай как дружелюбный чатбот.", | ||
] | ||
result = requests.post( | ||
url, | ||
json={ | ||
"dialog_contexts": contexts, | ||
"prompts": prompts, | ||
"configs": [DEFAULT_CONFIG] * len(contexts), | ||
"gigachat_credentials": [GIGACHAT_CREDENTIALS] * len(contexts), | ||
"gigachat_scopes": [GIGACHAT_SCOPE] * len(contexts), | ||
}, | ||
).json() | ||
print(result) | ||
|
||
assert len(result) and [all(len(sample[0]) > 0 for sample in result)], f"Got\n{result}\n, something is wrong" | ||
print("Success!") | ||
|
||
|
||
if __name__ == "__main__": | ||
test_respond() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#!/bin/bash | ||
|
||
python test.py |
Oops, something went wrong.