In [None]:
import os
import json
import requests
import wandb
import pandas as pd
from time import time
from rouge import Rouge

In [None]:
%load_ext dotenv
%dotenv

In [None]:
YANDEX_ID_CATALOG = os.environ["YANDEX_ID_CATALOG"]
YANDEX_API_KEY = os.environ["YANDEX_API_KEY"]

In [None]:
# Подключения к API YandexGPT 3 Pro с возвратом сгенерированной истории
def prediction_history(message):
    prompt = {
        "modelUri": f"gpt://{YANDEX_ID_CATALOG}/yandexgpt/latest",  # YandexGPT 3 Pro
        "completionOptions": {"stream": False, "temperature": 0.6, "maxTokens": "2000"},
        "messages": [
            {
                "role": "system",
                "text": "Ты русский писатель детских рассказов. Всегда возвращаешь текст только на русском.",
            },
            {
                "role": "user",
                "text": message,
            },
        ],
    }

    url = "https://llm.api.cloud.yandex.net/foundationModels/v1/completion"
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Api-Key {YANDEX_API_KEY}",
    }

    response = requests.post(url, headers=headers, json=prompt)

    result = json.loads(response.text)

    return result["result"]["alternatives"][0]["message"]["text"]

In [None]:
# Расчет метрики ROUGE
def metric_rouge(generated_history, reference_history):
    rouge = Rouge()

    avg_rouge_scores = {
        "rouge-1": {"f": 0, "p": 0, "r": 0},
        "rouge-2": {"f": 0, "p": 0, "r": 0},
        "rouge-l": {"f": 0, "p": 0, "r": 0},
    }
    count_texts = len(reference_history)

    for ref_text, gen_text in zip(reference_history, generated_history):
        scores = rouge.get_scores(gen_text, ref_text)[0]
        for metric in avg_rouge_scores:
            for key in avg_rouge_scores[metric]:
                avg_rouge_scores[metric][key] += scores[metric][key]

    for metric in avg_rouge_scores:
        for key in avg_rouge_scores[metric]:
            avg_rouge_scores[metric][key] /= count_texts

    return avg_rouge_scores

In [None]:
# Загрузка синтетического датасета для оценки качества LLM
sin_ds = pd.read_csv("datasets/history_sin/history_sin.csv")

sin_ds.head(2)

In [None]:
# Предсказание LLM истории с логированием метрик в wandb (работа с моделью через API)
wandb.init(project="child_diary", group="yandexgpt-3-pro", job_type="base")

predicted_history = []
reference_history = []

count_captions = 0
start_time = time()

for i in range(sin_ds.shape[0]):
    photo_captions = sin_ds.img_input[i]
    len_photo_captions = len(
        sin_ds.img_input[i].translate(str.maketrans("", "", '[]"')).split(", ")
    )
    photo_description = sin_ds.text_input[i]
    setup_input = "В тексте не использовать слова: фотография, изображение, затем"
    message = f"""
        Напиши развернутое описание от первого лица происходящего на {len_photo_captions} фотографиях объединив в сюжет.
        Подписи к фотографиям: {photo_captions}.
        Описание событий, происходящих на фотографиях: {photo_description}.
        Дополнительные требования: {setup_input}.
        """

    predicted_history.append(prediction_history(message))
    reference_history.append(sin_ds.target[i])
    count_captions += len_photo_captions

end_time = time()

wandb.log(
    {
        "ROUGE-L": metric_rouge(predicted_history, reference_history)["rouge-l"]["f"],
        "Speed 1 image": (end_time - start_time) / count_captions,
        "Save conversion": 0.8,
    }
)

wandb.finish()