From b57c912ce733597de8d62990f0fdba808c7f32c0 Mon Sep 17 00:00:00 2001 From: rafilgalimzanov Date: Mon, 25 Dec 2023 02:08:07 +0300 Subject: [PATCH 1/5] Add gigachat --- .../universal_prompted_assistant/dev.yml | 6 ++ .../docker-compose.override.yml | 21 +++- .../universal_ru_prompted_assistant/dev.yml | 6 ++ .../docker-compose.override.yml | 19 ++++ common/generative_configs/gigachat.json | 7 ++ components.tsv | 2 +- components/0bBDINLSJDnjn1pzf8sdA.yml | 25 +++++ services/gigachat_api_lm/Dockerfile | 14 +++ services/gigachat_api_lm/README.md | 0 services/gigachat_api_lm/requirements.txt | 9 ++ services/gigachat_api_lm/server.py | 99 +++++++++++++++++++ .../gigachat-api/environment.yml | 4 + .../service_configs/gigachat-api/service.yml | 30 ++++++ .../scenario/response.py | 1 + 14 files changed, 241 insertions(+), 2 deletions(-) create mode 100644 common/generative_configs/gigachat.json create mode 100644 components/0bBDINLSJDnjn1pzf8sdA.yml create mode 100644 services/gigachat_api_lm/Dockerfile create mode 100644 services/gigachat_api_lm/README.md create mode 100644 services/gigachat_api_lm/requirements.txt create mode 100644 services/gigachat_api_lm/server.py create mode 100644 services/gigachat_api_lm/service_configs/gigachat-api/environment.yml create mode 100644 services/gigachat_api_lm/service_configs/gigachat-api/service.yml diff --git a/assistant_dists/universal_prompted_assistant/dev.yml b/assistant_dists/universal_prompted_assistant/dev.yml index ffd35ff272..37fb53b628 100644 --- a/assistant_dists/universal_prompted_assistant/dev.yml +++ b/assistant_dists/universal_prompted_assistant/dev.yml @@ -71,6 +71,12 @@ services: - "./common:/src/common" ports: - 8180:8180 + gigachat-api: + volumes: + - "./services/gigachat_api_lm:/src" + - "./common:/src/common" + ports: + - 8187:8187 anthropic-api-claude-v1: volumes: - "./services/anthropic_api_lm:/src" diff --git a/assistant_dists/universal_prompted_assistant/docker-compose.override.yml b/assistant_dists/universal_prompted_assistant/docker-compose.override.yml index 4040774ec8..ed0a649680 100644 --- a/assistant_dists/universal_prompted_assistant/docker-compose.override.yml +++ b/assistant_dists/universal_prompted_assistant/docker-compose.override.yml @@ -6,7 +6,7 @@ services: sentence-ranker:8128, transformers-lm-gptjt:8161, openai-api-chatgpt:8145, openai-api-davinci3:8131, openai-api-gpt4:8159, openai-api-gpt4-32k:8160, openai-api-chatgpt-16k:8167, - openai-api-gpt4-turbo:8180, dff-universal-prompted-skill:8147" + openai-api-gpt4-turbo:8180, gigachat-api:8187, dff-universal-prompted-skill:8147" WAIT_HOSTS_TIMEOUT: ${WAIT_TIMEOUT:-1000} sentseg: @@ -219,6 +219,25 @@ services: reservations: memory: 100M + gigachat-api: + env_file: [ .env ] + build: + args: + SERVICE_PORT: 8187 + SERVICE_NAME: gigachat_api + PRETRAINED_MODEL_NAME_OR_PATH: GigaChat:1.3.23.1 + context: . + dockerfile: ./services/gigachat_api_lm/Dockerfile + command: flask run -h 0.0.0.0 -p 8187 + environment: + - FLASK_APP=server + deploy: + resources: + limits: + memory: 500M + reservations: + memory: 100M + anthropic-api-claude-v1: env_file: [ .env ] build: diff --git a/assistant_dists/universal_ru_prompted_assistant/dev.yml b/assistant_dists/universal_ru_prompted_assistant/dev.yml index 9f451d8eac..ffd66d8ea1 100644 --- a/assistant_dists/universal_ru_prompted_assistant/dev.yml +++ b/assistant_dists/universal_ru_prompted_assistant/dev.yml @@ -67,6 +67,12 @@ services: - "./common:/src/common" ports: - 8180:8180 + gigachat-api: + volumes: + - "./services/gigachat_api_lm:/src" + - "./common:/src/common" + ports: + - 8187:8187 dff-universal-ru-prompted-skill: volumes: - "./skills/dff_universal_prompted_skill:/src" diff --git a/assistant_dists/universal_ru_prompted_assistant/docker-compose.override.yml b/assistant_dists/universal_ru_prompted_assistant/docker-compose.override.yml index 8b11265496..6eb49e5503 100644 --- a/assistant_dists/universal_ru_prompted_assistant/docker-compose.override.yml +++ b/assistant_dists/universal_ru_prompted_assistant/docker-compose.override.yml @@ -213,6 +213,25 @@ services: reservations: memory: 100M + gigachat-api: + env_file: [ .env ] + build: + args: + SERVICE_PORT: 8187 + SERVICE_NAME: gigachat_api + PRETRAINED_MODEL_NAME_OR_PATH: GigaChat:1.3.23.1 + context: . + dockerfile: ./services/gigachat_api_lm/Dockerfile + command: flask run -h 0.0.0.0 -p 8187 + environment: + - FLASK_APP=server + deploy: + resources: + limits: + memory: 500M + reservations: + memory: 100M + dff-universal-ru-prompted-skill: env_file: [ .env_ru ] build: diff --git a/common/generative_configs/gigachat.json b/common/generative_configs/gigachat.json new file mode 100644 index 0000000000..7e23841d5f --- /dev/null +++ b/common/generative_configs/gigachat.json @@ -0,0 +1,7 @@ +{ + "max_tokens": 256, + "temperature": 0.4, + "top_p": 1.0, + "frequency_penalty": 0, + "presence_penalty": 0 +} \ No newline at end of file diff --git a/components.tsv b/components.tsv index 8e0620f53d..4b27d3b41a 100644 --- a/components.tsv +++ b/components.tsv @@ -190,7 +190,7 @@ 8184 external-fake-server 8185 transformers-mistral-7b-128k 8186 dff-document-qa-transformers-llm-skill -8187 +8187 gigachat-api 8188 8189 8190 \ No newline at end of file diff --git a/components/0bBDINLSJDnjn1pzf8sdA.yml b/components/0bBDINLSJDnjn1pzf8sdA.yml new file mode 100644 index 0000000000..b5808bc50f --- /dev/null +++ b/components/0bBDINLSJDnjn1pzf8sdA.yml @@ -0,0 +1,25 @@ +name: gigachat +display_name: GigaChat +component_type: Generative +model_type: NN-based +is_customizable: false +author: publisher@deeppavlov.ai +description: GigaChat is a service that can interact with the user in a dialogue format, write code, + create texts and pictures at the user’s request. At the same time, + GigaChat considers controversial issues or provocations. +ram_usage: 100M +gpu_usage: null +group: services +connector: + protocol: http + timeout: 120.0 + url: http://gigachat-api:8187/respond +dialog_formatter: null +response_formatter: null +previous_services: null +required_previous_services: null +state_manager_method: null +tags: null +endpoint: respond +service: services/gigachat_api_lm/service_configs/gigachat-api +date_created: '2023-12-25T09:45:32' diff --git a/services/gigachat_api_lm/Dockerfile b/services/gigachat_api_lm/Dockerfile new file mode 100644 index 0000000000..387f3b4bae --- /dev/null +++ b/services/gigachat_api_lm/Dockerfile @@ -0,0 +1,14 @@ +FROM python:3.10 + +WORKDIR /src + +COPY ./services/gigachat_api_lm/requirements.txt /src/requirements.txt +RUN pip install -r /src/requirements.txt + +ARG PRETRAINED_MODEL_NAME_OR_PATH +ENV PRETRAINED_MODEL_NAME_OR_PATH ${PRETRAINED_MODEL_NAME_OR_PATH} + +COPY services/gigachat_api_lm /src +COPY common /src/common + +CMD gunicorn --workers=1 server:app -b 0.0.0.0:${SERVICE_PORT} --timeout=300 diff --git a/services/gigachat_api_lm/README.md b/services/gigachat_api_lm/README.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/services/gigachat_api_lm/requirements.txt b/services/gigachat_api_lm/requirements.txt new file mode 100644 index 0000000000..68e442d433 --- /dev/null +++ b/services/gigachat_api_lm/requirements.txt @@ -0,0 +1,9 @@ +flask==1.1.1 +itsdangerous==2.0.1 +gunicorn==19.9.0 +requests==2.22.0 +sentry-sdk[flask]==0.14.1 +healthcheck==1.3.3 +jinja2<=3.0.3 +Werkzeug<=2.0.3 +gigachat==0.1.10 diff --git a/services/gigachat_api_lm/server.py b/services/gigachat_api_lm/server.py new file mode 100644 index 0000000000..291cb50242 --- /dev/null +++ b/services/gigachat_api_lm/server.py @@ -0,0 +1,99 @@ +import json +import logging +import os +import time +import sentry_sdk + +from gigachat import GigaChat +from gigachat.models import Chat +from flask import Flask, request, jsonify +from sentry_sdk.integrations.flask import FlaskIntegration + + +sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()]) + + +logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO) +logger = logging.getLogger(__name__) + +PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH") +logger.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}") +GIGACHAT_ROLES = ["assistant", "user"] + +app = Flask(__name__) +logging.getLogger("werkzeug").setLevel("WARNING") +DEFAULT_CONFIGS = { + "GigaChat:1.3.23.1": json.load(open("common/generative_configs/gigachat.json", "r")), +} + + +def generate_responses(context, gigachat_api_key, gigachat_org, prompt, generation_params, continue_last_uttr=False): + + assert gigachat_api_key, logger.error("Error: GigaChat API key is not specified in env") + giga = GigaChat(credentials=gigachat_api_key, verify_ssl_certs=False) + + s = len(context) % 2 + messages = [ + {"role": "system", "content": prompt}, + ] + messages += [ + { + "role": f"{GIGACHAT_ROLES[(s + uttr_id) % 2]}", + "content": uttr, + } + for uttr_id, uttr in enumerate(context) + ] + logger.info(f"context inside generate_responses seen as: {messages}") + + payload = Chat(messages=messages, scope=gigachat_org, **generation_params) + response = giga.chat(payload) + + outputs = [resp.message.content.strip() for resp in response.choices] + + return outputs + + +@app.route("/ping", methods=["POST"]) +def ping(): + return "pong" + + +@app.route("/respond", methods=["POST"]) +def respond(): + st_time = time.time() + contexts = request.json.get("dialog_contexts", []) + prompts = request.json.get("prompts", []) + configs = request.json.get("configs", None) + configs = [None] * len(prompts) if configs is None else configs + configs = [DEFAULT_CONFIGS[PRETRAINED_MODEL_NAME_OR_PATH] if el is None else el for el in configs] + + if len(contexts) > 0 and len(prompts) == 0: + prompts = [""] * len(contexts) + + gigachat_api_keys = request.json.get("gigachat_credentials", []) + gigachat_orgs = request.json.get("gigachat_scopes", None) + gigachat_orgs = [None] * len(contexts) if gigachat_orgs is None else gigachat_orgs + + try: + responses = [] + for context, gigachat_api_key, gigachat_org, prompt, config in zip( + contexts, gigachat_api_keys, gigachat_orgs, prompts, configs + ): + curr_responses = [] + outputs = generate_responses(context, gigachat_api_key, gigachat_org, prompt, config) + for response in outputs: + if len(response) >= 2: + curr_responses += [response] + else: + curr_responses += [""] + responses += [curr_responses] + + except Exception as exc: + logger.exception(exc) + sentry_sdk.capture_exception(exc) + responses = [[""]] * len(contexts) + + logger.info(f"gigachat-api result: {responses}") + total_time = time.time() - st_time + logger.info(f"gigachat-api exec time: {total_time:.3f}s") + return jsonify(responses) diff --git a/services/gigachat_api_lm/service_configs/gigachat-api/environment.yml b/services/gigachat_api_lm/service_configs/gigachat-api/environment.yml new file mode 100644 index 0000000000..be0a791e3b --- /dev/null +++ b/services/gigachat_api_lm/service_configs/gigachat-api/environment.yml @@ -0,0 +1,4 @@ +SERVICE_PORT: 8187 +SERVICE_NAME: gigachat_api +PRETRAINED_MODEL_NAME_OR_PATH: GigaChat:1.3.23.1 +FLASK_APP: server diff --git a/services/gigachat_api_lm/service_configs/gigachat-api/service.yml b/services/gigachat_api_lm/service_configs/gigachat-api/service.yml new file mode 100644 index 0000000000..5fcb1240a4 --- /dev/null +++ b/services/gigachat_api_lm/service_configs/gigachat-api/service.yml @@ -0,0 +1,30 @@ +name: gigachat-api +endpoints: +- respond +- generate_goals +compose: + env_file: + - .env + build: + args: + SERVICE_PORT: 8187 + SERVICE_NAME: gigachat_api + PRETRAINED_MODEL_NAME_OR_PATH: GigaChat:1.3.23.1 + FLASK_APP: server + context: . + dockerfile: ./services/gigachat_api_lm/Dockerfile + command: flask run -h 0.0.0.0 -p 8187 + environment: + - FLASK_APP=server + deploy: + resources: + limits: + memory: 100M + reservations: + memory: 100M + volumes: + - ./services/gigachat_api_lm:/src + - ./common:/src/common + ports: + - 8187:8187 +proxy: null diff --git a/skills/dff_universal_prompted_skill/scenario/response.py b/skills/dff_universal_prompted_skill/scenario/response.py index be614b33f8..2c2dc84e1c 100644 --- a/skills/dff_universal_prompted_skill/scenario/response.py +++ b/skills/dff_universal_prompted_skill/scenario/response.py @@ -41,6 +41,7 @@ "http://transformers-lm-vicuna13b:8168/respond": [], "http://transformers-lm-ruxglm:8171/respond": [], "http://transformers-lm-rugpt35:8178/respond": [], + "http://gigachat-api:8187/respond": ["GIGACHAT_CREDENTIAL", "GIGACHAT_SCOPE"], } From c31a4353a9c8023769108256ddf0129cf0fbd5bd Mon Sep 17 00:00:00 2001 From: rafilgalimzanov Date: Mon, 25 Dec 2023 11:27:03 +0300 Subject: [PATCH 2/5] Edited code style --- services/gigachat_api_lm/server.py | 31 +++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/services/gigachat_api_lm/server.py b/services/gigachat_api_lm/server.py index 291cb50242..5d628a8030 100644 --- a/services/gigachat_api_lm/server.py +++ b/services/gigachat_api_lm/server.py @@ -13,7 +13,9 @@ sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()]) -logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO) +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +) logger = logging.getLogger(__name__) PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH") @@ -23,13 +25,23 @@ app = Flask(__name__) logging.getLogger("werkzeug").setLevel("WARNING") DEFAULT_CONFIGS = { - "GigaChat:1.3.23.1": json.load(open("common/generative_configs/gigachat.json", "r")), + "GigaChat:1.3.23.1": json.load( + open("common/generative_configs/gigachat.json", "r") + ), } -def generate_responses(context, gigachat_api_key, gigachat_org, prompt, generation_params, continue_last_uttr=False): - - assert gigachat_api_key, logger.error("Error: GigaChat API key is not specified in env") +def generate_responses( + context, + gigachat_api_key, + gigachat_org, + prompt, + generation_params, + continue_last_uttr=False, +): + assert gigachat_api_key, logger.error( + "Error: GigaChat API key is not specified in env" + ) giga = GigaChat(credentials=gigachat_api_key, verify_ssl_certs=False) s = len(context) % 2 @@ -65,7 +77,10 @@ def respond(): prompts = request.json.get("prompts", []) configs = request.json.get("configs", None) configs = [None] * len(prompts) if configs is None else configs - configs = [DEFAULT_CONFIGS[PRETRAINED_MODEL_NAME_OR_PATH] if el is None else el for el in configs] + configs = [ + DEFAULT_CONFIGS[PRETRAINED_MODEL_NAME_OR_PATH] if el is None else el + for el in configs + ] if len(contexts) > 0 and len(prompts) == 0: prompts = [""] * len(contexts) @@ -80,7 +95,9 @@ def respond(): contexts, gigachat_api_keys, gigachat_orgs, prompts, configs ): curr_responses = [] - outputs = generate_responses(context, gigachat_api_key, gigachat_org, prompt, config) + outputs = generate_responses( + context, gigachat_api_key, gigachat_org, prompt, config + ) for response in outputs: if len(response) >= 2: curr_responses += [response] From 1d8def216fb87061a2d743c79a123497e7acbf42 Mon Sep 17 00:00:00 2001 From: rafilgalimzanov Date: Fri, 29 Dec 2023 10:08:49 +0300 Subject: [PATCH 3/5] fixed comments --- services/gigachat_api_lm/server.py | 21 +++++-------------- .../service_configs/gigachat-api/service.yml | 1 - 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/services/gigachat_api_lm/server.py b/services/gigachat_api_lm/server.py index 5d628a8030..ec7c96b074 100644 --- a/services/gigachat_api_lm/server.py +++ b/services/gigachat_api_lm/server.py @@ -13,9 +13,7 @@ sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()]) -logging.basicConfig( - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO -) +logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO) logger = logging.getLogger(__name__) PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH") @@ -25,9 +23,7 @@ app = Flask(__name__) logging.getLogger("werkzeug").setLevel("WARNING") DEFAULT_CONFIGS = { - "GigaChat:1.3.23.1": json.load( - open("common/generative_configs/gigachat.json", "r") - ), + "GigaChat:1.3.23.1": json.load(open("common/generative_configs/gigachat.json", "r")), } @@ -39,9 +35,7 @@ def generate_responses( generation_params, continue_last_uttr=False, ): - assert gigachat_api_key, logger.error( - "Error: GigaChat API key is not specified in env" - ) + assert gigachat_api_key, logger.error("Error: GigaChat API key is not specified in env") giga = GigaChat(credentials=gigachat_api_key, verify_ssl_certs=False) s = len(context) % 2 @@ -77,10 +71,7 @@ def respond(): prompts = request.json.get("prompts", []) configs = request.json.get("configs", None) configs = [None] * len(prompts) if configs is None else configs - configs = [ - DEFAULT_CONFIGS[PRETRAINED_MODEL_NAME_OR_PATH] if el is None else el - for el in configs - ] + configs = [DEFAULT_CONFIGS[PRETRAINED_MODEL_NAME_OR_PATH] if el is None else el for el in configs] if len(contexts) > 0 and len(prompts) == 0: prompts = [""] * len(contexts) @@ -95,9 +86,7 @@ def respond(): contexts, gigachat_api_keys, gigachat_orgs, prompts, configs ): curr_responses = [] - outputs = generate_responses( - context, gigachat_api_key, gigachat_org, prompt, config - ) + outputs = generate_responses(context, gigachat_api_key, gigachat_org, prompt, config) for response in outputs: if len(response) >= 2: curr_responses += [response] diff --git a/services/gigachat_api_lm/service_configs/gigachat-api/service.yml b/services/gigachat_api_lm/service_configs/gigachat-api/service.yml index 5fcb1240a4..897241b96b 100644 --- a/services/gigachat_api_lm/service_configs/gigachat-api/service.yml +++ b/services/gigachat_api_lm/service_configs/gigachat-api/service.yml @@ -10,7 +10,6 @@ compose: SERVICE_PORT: 8187 SERVICE_NAME: gigachat_api PRETRAINED_MODEL_NAME_OR_PATH: GigaChat:1.3.23.1 - FLASK_APP: server context: . dockerfile: ./services/gigachat_api_lm/Dockerfile command: flask run -h 0.0.0.0 -p 8187 From 249fa60b15805a9279ebc60b199a9c2a9049b8b1 Mon Sep 17 00:00:00 2001 From: rafilgalimzanov Date: Thu, 18 Jan 2024 13:07:30 +0300 Subject: [PATCH 4/5] Create generate goals api --- common/prompts.py | 3 +++ common/prompts/goals_for_prompts_ru.json | 3 +++ services/gigachat_api_lm/server.py | 33 ++++++++++++++++++++++++ 3 files changed, 39 insertions(+) create mode 100644 common/prompts/goals_for_prompts_ru.json diff --git a/common/prompts.py b/common/prompts.py index 4eb1542004..f74cadd79e 100644 --- a/common/prompts.py +++ b/common/prompts.py @@ -10,6 +10,9 @@ with open("common/prompts/goals_for_prompts.json", "r") as f: META_GOALS_PROMPT = json.load(f)["prompt"] +with open("common/prompts/goals_for_prompts_ru.json", "r") as f: + META_GOALS_PROMPT_RU = json.load(f)["prompt"] + def send_request_to_prompted_generative_service(dialog_context, prompt, url, config, timeout, sending_variables): response = requests.post( diff --git a/common/prompts/goals_for_prompts_ru.json b/common/prompts/goals_for_prompts_ru.json new file mode 100644 index 0000000000..0f5ad0b927 --- /dev/null +++ b/common/prompts/goals_for_prompts_ru.json @@ -0,0 +1,3 @@ +{ + "prompt": "Сформулируй очень краткое описание целей ассистента, которые определены в данном запросе.\n\nПример:\nЗапрос: '''ЗАДАНИЕ: Твоё имя - Ассистент по Жизненному Коучингу. Тебя создала компания Rhoades & Co. Помоги человеку поставить цель в жизни и определить, как достичь её шаг за шагом. Не обсуждай другие темы. Отвечай с сочувствием. Задавай открытые вопросы, чтобы помочь человеку лучше понять себя.\nИНСТРУКЦИЯ: Человек входит в разговор. Представься кратко. Помоги ему поставить цель и достичь её. Ты можешь узнать о его жизненных приоритетах и предпочтительных областях концентрации и предложить полезные идеи. Ты должен задать ОДИН вопрос или НИ ОДНОГО вопроса, НЕ два или три. Остановись после того, как задашь первый вопрос.'''\nРезультат: Помогает пользователю поставить и достичь жизненных целей." +} diff --git a/services/gigachat_api_lm/server.py b/services/gigachat_api_lm/server.py index ec7c96b074..0783b426c3 100644 --- a/services/gigachat_api_lm/server.py +++ b/services/gigachat_api_lm/server.py @@ -8,6 +8,7 @@ from gigachat.models import Chat from flask import Flask, request, jsonify from sentry_sdk.integrations.flask import FlaskIntegration +from common.prompts import META_GOALS_PROMPT_RU sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()]) @@ -103,3 +104,35 @@ def respond(): total_time = time.time() - st_time logger.info(f"gigachat-api exec time: {total_time:.3f}s") return jsonify(responses) + + +@app.route("/generate_goals", methods=["POST"]) +def generate_goals(): + st_time = time.time() + + prompts = request.json.get("prompts", None) + prompts = [] if prompts is None else prompts + configs = request.json.get("configs", None) + configs = [None] * len(prompts) if configs is None else configs + configs = [DEFAULT_CONFIGS[PRETRAINED_MODEL_NAME_OR_PATH] if el is None else el for el in configs] + + gigachat_api_keys = request.json.get("gigachat_credentials", []) + gigachat_orgs = request.json.get("gigachat_scopes", None) + gigachat_orgs = [None] * len(prompts) if gigachat_orgs is None else gigachat_orgs + + try: + responses = [] + for gigachat_api_key, gigachat_org, prompt, config in zip(gigachat_api_keys, gigachat_orgs, prompts, configs): + context = ["Привет", META_GOALS_PROMPT_RU + f"\nПромпт: '''{prompt}'''\nРезультат:"] + goals_for_prompt = generate_responses(context, gigachat_api_key, gigachat_org, "", config)[0] + logger.info(f"Generated goals: `{goals_for_prompt}` for prompt: `{prompt}`") + responses += [goals_for_prompt] + + except Exception as exc: + logger.info(exc) + sentry_sdk.capture_exception(exc) + responses = [""] * len(prompts) + + total_time = time.time() - st_time + logger.info(f"gigachat-api generate_goals exec time: {total_time:.3f}s") + return jsonify(responses) From fc82a9febc60bd175584bb86714efa4712d8e727 Mon Sep 17 00:00:00 2001 From: rafilgalimzanov Date: Thu, 18 Jan 2024 17:09:34 +0300 Subject: [PATCH 5/5] Add tests --- services/gigachat_api_lm/test.py | 44 ++++++++++++++++++++++++++++++++ services/gigachat_api_lm/test.sh | 3 +++ 2 files changed, 47 insertions(+) create mode 100644 services/gigachat_api_lm/test.py create mode 100644 services/gigachat_api_lm/test.sh diff --git a/services/gigachat_api_lm/test.py b/services/gigachat_api_lm/test.py new file mode 100644 index 0000000000..93a8d8e4ec --- /dev/null +++ b/services/gigachat_api_lm/test.py @@ -0,0 +1,44 @@ +import requests +from os import getenv + + +# ATTENTION!!! This test is only working if you assign `GIGACHAT_CREDENTIALS` env variable +GIGACHAT_CREDENTIALS = getenv("GIGACHAT_CREDENTIALS", None) +GIGACHAT_SCOPE = getenv("GIGACHAT_SCOPE", None) +assert GIGACHAT_CREDENTIALS, print("No GigaChat credentials is given in env vars") +DEFAULT_CONFIG = {"max_tokens": 64, "temperature": 0.4, "top_p": 1.0, "frequency_penalty": 0, "presence_penalty": 0} +SERVICE_PORT = int(getenv("SERVICE_PORT")) + + +def test_respond(): + url = f"http://0.0.0.0:{SERVICE_PORT}/respond" + contexts = [ + [ + "Привет! Я Маркус. Как ты сегодня?", + "Привет, Маркус! Я в порядке. Как у тебя?", + "У меня все отлично. Какие у тебя планы на сегодня?", + ], + ["Привет, Маркус! Я в порядке. Как у тебя?", "У меня все отлично. Какие у тебя планы на сегодня?"], + ] + prompts = [ + "Отвечай как дружелюбный чатбот.", + "Отвечай как дружелюбный чатбот.", + ] + result = requests.post( + url, + json={ + "dialog_contexts": contexts, + "prompts": prompts, + "configs": [DEFAULT_CONFIG] * len(contexts), + "gigachat_credentials": [GIGACHAT_CREDENTIALS] * len(contexts), + "gigachat_scopes": [GIGACHAT_SCOPE] * len(contexts), + }, + ).json() + print(result) + + assert len(result) and [all(len(sample[0]) > 0 for sample in result)], f"Got\n{result}\n, something is wrong" + print("Success!") + + +if __name__ == "__main__": + test_respond() diff --git a/services/gigachat_api_lm/test.sh b/services/gigachat_api_lm/test.sh new file mode 100644 index 0000000000..468a5a38fc --- /dev/null +++ b/services/gigachat_api_lm/test.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +python test.py \ No newline at end of file