Skip to content

Commit

Permalink
Add gigachat (#604)
Browse files Browse the repository at this point in the history
* Add gigachat

* Edited code style

* fixed comments

* Create generate goals api

* Add tests
  • Loading branch information
RafilGalimzyanov committed Jan 18, 2024
1 parent bc61e2e commit 141a8db
Show file tree
Hide file tree
Showing 18 changed files with 332 additions and 2 deletions.
6 changes: 6 additions & 0 deletions assistant_dists/universal_prompted_assistant/dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ services:
- "./common:/src/common"
ports:
- 8180:8180
gigachat-api:
volumes:
- "./services/gigachat_api_lm:/src"
- "./common:/src/common"
ports:
- 8187:8187
anthropic-api-claude-v1:
volumes:
- "./services/anthropic_api_lm:/src"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ services:
sentence-ranker:8128,
transformers-lm-gptjt:8161, openai-api-chatgpt:8145, openai-api-davinci3:8131,
openai-api-gpt4:8159, openai-api-gpt4-32k:8160, openai-api-chatgpt-16k:8167,
openai-api-gpt4-turbo:8180, dff-universal-prompted-skill:8147"
openai-api-gpt4-turbo:8180, gigachat-api:8187, dff-universal-prompted-skill:8147"
WAIT_HOSTS_TIMEOUT: ${WAIT_TIMEOUT:-1000}

sentseg:
Expand Down Expand Up @@ -219,6 +219,25 @@ services:
reservations:
memory: 100M

gigachat-api:
env_file: [ .env ]
build:
args:
SERVICE_PORT: 8187
SERVICE_NAME: gigachat_api
PRETRAINED_MODEL_NAME_OR_PATH: GigaChat:1.3.23.1
context: .
dockerfile: ./services/gigachat_api_lm/Dockerfile
command: flask run -h 0.0.0.0 -p 8187
environment:
- FLASK_APP=server
deploy:
resources:
limits:
memory: 500M
reservations:
memory: 100M

anthropic-api-claude-v1:
env_file: [ .env ]
build:
Expand Down
6 changes: 6 additions & 0 deletions assistant_dists/universal_ru_prompted_assistant/dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ services:
- "./common:/src/common"
ports:
- 8180:8180
gigachat-api:
volumes:
- "./services/gigachat_api_lm:/src"
- "./common:/src/common"
ports:
- 8187:8187
dff-universal-ru-prompted-skill:
volumes:
- "./skills/dff_universal_prompted_skill:/src"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,25 @@ services:
reservations:
memory: 100M

gigachat-api:
env_file: [ .env ]
build:
args:
SERVICE_PORT: 8187
SERVICE_NAME: gigachat_api
PRETRAINED_MODEL_NAME_OR_PATH: GigaChat:1.3.23.1
context: .
dockerfile: ./services/gigachat_api_lm/Dockerfile
command: flask run -h 0.0.0.0 -p 8187
environment:
- FLASK_APP=server
deploy:
resources:
limits:
memory: 500M
reservations:
memory: 100M

dff-universal-ru-prompted-skill:
env_file: [ .env_ru ]
build:
Expand Down
7 changes: 7 additions & 0 deletions common/generative_configs/gigachat.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"max_tokens": 256,
"temperature": 0.4,
"top_p": 1.0,
"frequency_penalty": 0,
"presence_penalty": 0
}
3 changes: 3 additions & 0 deletions common/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
with open("common/prompts/goals_for_prompts.json", "r") as f:
META_GOALS_PROMPT = json.load(f)["prompt"]

with open("common/prompts/goals_for_prompts_ru.json", "r") as f:
META_GOALS_PROMPT_RU = json.load(f)["prompt"]


def send_request_to_prompted_generative_service(dialog_context, prompt, url, config, timeout, sending_variables):
response = requests.post(
Expand Down
3 changes: 3 additions & 0 deletions common/prompts/goals_for_prompts_ru.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"prompt": "Сформулируй очень краткое описание целей ассистента, которые определены в данном запросе.\n\nПример:\nЗапрос: '''ЗАДАНИЕ: Твоё имя - Ассистент по Жизненному Коучингу. Тебя создала компания Rhoades & Co. Помоги человеку поставить цель в жизни и определить, как достичь её шаг за шагом. Не обсуждай другие темы. Отвечай с сочувствием. Задавай открытые вопросы, чтобы помочь человеку лучше понять себя.\nИНСТРУКЦИЯ: Человек входит в разговор. Представься кратко. Помоги ему поставить цель и достичь её. Ты можешь узнать о его жизненных приоритетах и предпочтительных областях концентрации и предложить полезные идеи. Ты должен задать ОДИН вопрос или НИ ОДНОГО вопроса, НЕ два или три. Остановись после того, как задашь первый вопрос.'''\nРезультат: Помогает пользователю поставить и достичь жизненных целей."
}
2 changes: 1 addition & 1 deletion components.tsv
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@
8184 external-fake-server
8185 transformers-mistral-7b-128k
8186 dff-document-qa-transformers-llm-skill
8187
8187 gigachat-api
8188
8189
8190
25 changes: 25 additions & 0 deletions components/0bBDINLSJDnjn1pzf8sdA.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: gigachat
display_name: GigaChat
component_type: Generative
model_type: NN-based
is_customizable: false
author: publisher@deeppavlov.ai
description: GigaChat is a service that can interact with the user in a dialogue format, write code,
create texts and pictures at the user’s request. At the same time,
GigaChat considers controversial issues or provocations.
ram_usage: 100M
gpu_usage: null
group: services
connector:
protocol: http
timeout: 120.0
url: http://gigachat-api:8187/respond
dialog_formatter: null
response_formatter: null
previous_services: null
required_previous_services: null
state_manager_method: null
tags: null
endpoint: respond
service: services/gigachat_api_lm/service_configs/gigachat-api
date_created: '2023-12-25T09:45:32'
14 changes: 14 additions & 0 deletions services/gigachat_api_lm/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
FROM python:3.10

WORKDIR /src

COPY ./services/gigachat_api_lm/requirements.txt /src/requirements.txt
RUN pip install -r /src/requirements.txt

ARG PRETRAINED_MODEL_NAME_OR_PATH
ENV PRETRAINED_MODEL_NAME_OR_PATH ${PRETRAINED_MODEL_NAME_OR_PATH}

COPY services/gigachat_api_lm /src
COPY common /src/common

CMD gunicorn --workers=1 server:app -b 0.0.0.0:${SERVICE_PORT} --timeout=300
Empty file.
9 changes: 9 additions & 0 deletions services/gigachat_api_lm/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
flask==1.1.1
itsdangerous==2.0.1
gunicorn==19.9.0
requests==2.22.0
sentry-sdk[flask]==0.14.1
healthcheck==1.3.3
jinja2<=3.0.3
Werkzeug<=2.0.3
gigachat==0.1.10
138 changes: 138 additions & 0 deletions services/gigachat_api_lm/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import json
import logging
import os
import time
import sentry_sdk

from gigachat import GigaChat
from gigachat.models import Chat
from flask import Flask, request, jsonify
from sentry_sdk.integrations.flask import FlaskIntegration
from common.prompts import META_GOALS_PROMPT_RU


sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()])


logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
logger = logging.getLogger(__name__)

PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH")
logger.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}")
GIGACHAT_ROLES = ["assistant", "user"]

app = Flask(__name__)
logging.getLogger("werkzeug").setLevel("WARNING")
DEFAULT_CONFIGS = {
"GigaChat:1.3.23.1": json.load(open("common/generative_configs/gigachat.json", "r")),
}


def generate_responses(
context,
gigachat_api_key,
gigachat_org,
prompt,
generation_params,
continue_last_uttr=False,
):
assert gigachat_api_key, logger.error("Error: GigaChat API key is not specified in env")
giga = GigaChat(credentials=gigachat_api_key, verify_ssl_certs=False)

s = len(context) % 2
messages = [
{"role": "system", "content": prompt},
]
messages += [
{
"role": f"{GIGACHAT_ROLES[(s + uttr_id) % 2]}",
"content": uttr,
}
for uttr_id, uttr in enumerate(context)
]
logger.info(f"context inside generate_responses seen as: {messages}")

payload = Chat(messages=messages, scope=gigachat_org, **generation_params)
response = giga.chat(payload)

outputs = [resp.message.content.strip() for resp in response.choices]

return outputs


@app.route("/ping", methods=["POST"])
def ping():
return "pong"


@app.route("/respond", methods=["POST"])
def respond():
st_time = time.time()
contexts = request.json.get("dialog_contexts", [])
prompts = request.json.get("prompts", [])
configs = request.json.get("configs", None)
configs = [None] * len(prompts) if configs is None else configs
configs = [DEFAULT_CONFIGS[PRETRAINED_MODEL_NAME_OR_PATH] if el is None else el for el in configs]

if len(contexts) > 0 and len(prompts) == 0:
prompts = [""] * len(contexts)

gigachat_api_keys = request.json.get("gigachat_credentials", [])
gigachat_orgs = request.json.get("gigachat_scopes", None)
gigachat_orgs = [None] * len(contexts) if gigachat_orgs is None else gigachat_orgs

try:
responses = []
for context, gigachat_api_key, gigachat_org, prompt, config in zip(
contexts, gigachat_api_keys, gigachat_orgs, prompts, configs
):
curr_responses = []
outputs = generate_responses(context, gigachat_api_key, gigachat_org, prompt, config)
for response in outputs:
if len(response) >= 2:
curr_responses += [response]
else:
curr_responses += [""]
responses += [curr_responses]

except Exception as exc:
logger.exception(exc)
sentry_sdk.capture_exception(exc)
responses = [[""]] * len(contexts)

logger.info(f"gigachat-api result: {responses}")
total_time = time.time() - st_time
logger.info(f"gigachat-api exec time: {total_time:.3f}s")
return jsonify(responses)


@app.route("/generate_goals", methods=["POST"])
def generate_goals():
st_time = time.time()

prompts = request.json.get("prompts", None)
prompts = [] if prompts is None else prompts
configs = request.json.get("configs", None)
configs = [None] * len(prompts) if configs is None else configs
configs = [DEFAULT_CONFIGS[PRETRAINED_MODEL_NAME_OR_PATH] if el is None else el for el in configs]

gigachat_api_keys = request.json.get("gigachat_credentials", [])
gigachat_orgs = request.json.get("gigachat_scopes", None)
gigachat_orgs = [None] * len(prompts) if gigachat_orgs is None else gigachat_orgs

try:
responses = []
for gigachat_api_key, gigachat_org, prompt, config in zip(gigachat_api_keys, gigachat_orgs, prompts, configs):
context = ["Привет", META_GOALS_PROMPT_RU + f"\nПромпт: '''{prompt}'''\nРезультат:"]
goals_for_prompt = generate_responses(context, gigachat_api_key, gigachat_org, "", config)[0]
logger.info(f"Generated goals: `{goals_for_prompt}` for prompt: `{prompt}`")
responses += [goals_for_prompt]

except Exception as exc:
logger.info(exc)
sentry_sdk.capture_exception(exc)
responses = [""] * len(prompts)

total_time = time.time() - st_time
logger.info(f"gigachat-api generate_goals exec time: {total_time:.3f}s")
return jsonify(responses)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
SERVICE_PORT: 8187
SERVICE_NAME: gigachat_api
PRETRAINED_MODEL_NAME_OR_PATH: GigaChat:1.3.23.1
FLASK_APP: server
29 changes: 29 additions & 0 deletions services/gigachat_api_lm/service_configs/gigachat-api/service.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: gigachat-api
endpoints:
- respond
- generate_goals
compose:
env_file:
- .env
build:
args:
SERVICE_PORT: 8187
SERVICE_NAME: gigachat_api
PRETRAINED_MODEL_NAME_OR_PATH: GigaChat:1.3.23.1
context: .
dockerfile: ./services/gigachat_api_lm/Dockerfile
command: flask run -h 0.0.0.0 -p 8187
environment:
- FLASK_APP=server
deploy:
resources:
limits:
memory: 100M
reservations:
memory: 100M
volumes:
- ./services/gigachat_api_lm:/src
- ./common:/src/common
ports:
- 8187:8187
proxy: null
44 changes: 44 additions & 0 deletions services/gigachat_api_lm/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import requests
from os import getenv


# ATTENTION!!! This test is only working if you assign `GIGACHAT_CREDENTIALS` env variable
GIGACHAT_CREDENTIALS = getenv("GIGACHAT_CREDENTIALS", None)
GIGACHAT_SCOPE = getenv("GIGACHAT_SCOPE", None)
assert GIGACHAT_CREDENTIALS, print("No GigaChat credentials is given in env vars")
DEFAULT_CONFIG = {"max_tokens": 64, "temperature": 0.4, "top_p": 1.0, "frequency_penalty": 0, "presence_penalty": 0}
SERVICE_PORT = int(getenv("SERVICE_PORT"))


def test_respond():
url = f"http://0.0.0.0:{SERVICE_PORT}/respond"
contexts = [
[
"Привет! Я Маркус. Как ты сегодня?",
"Привет, Маркус! Я в порядке. Как у тебя?",
"У меня все отлично. Какие у тебя планы на сегодня?",
],
["Привет, Маркус! Я в порядке. Как у тебя?", "У меня все отлично. Какие у тебя планы на сегодня?"],
]
prompts = [
"Отвечай как дружелюбный чатбот.",
"Отвечай как дружелюбный чатбот.",
]
result = requests.post(
url,
json={
"dialog_contexts": contexts,
"prompts": prompts,
"configs": [DEFAULT_CONFIG] * len(contexts),
"gigachat_credentials": [GIGACHAT_CREDENTIALS] * len(contexts),
"gigachat_scopes": [GIGACHAT_SCOPE] * len(contexts),
},
).json()
print(result)

assert len(result) and [all(len(sample[0]) > 0 for sample in result)], f"Got\n{result}\n, something is wrong"
print("Success!")


if __name__ == "__main__":
test_respond()
3 changes: 3 additions & 0 deletions services/gigachat_api_lm/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python test.py
Loading

0 comments on commit 141a8db

Please sign in to comment.