In [3]:
import os, json
from dotenv import load_dotenv, find_dotenv; _ = load_dotenv(find_dotenv())

In [4]:
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema.runnable import RunnablePassthrough

In [5]:
topic_type = 'foods'
topic_file = f'{topic_type}.txt'

In [6]:
with open(f'config/{topic_file}', 'r') as f:
    topic_list = f.read().splitlines()
topic_list

['じゃがいも', '玉ねぎ', 'にんじん', 'なす', 'トマト', 'ピーマン']

## Summarizing chain

In [7]:
from pydantic import BaseModel, Field
from langchain_core.utils.function_calling import convert_to_openai_function

class GenerateSummary(BaseModel):
    """Generate long and short summary texts."""
    long_summary: str = Field(...)
    short_summary: str = Field(...)

summarize_func = convert_to_openai_function(function=GenerateSummary)
summarize_func

{'name': 'GenerateSummary',
 'description': 'Generate long and short summary texts.',
 'parameters': {'type': 'object',
  'properties': {'long_summary': {'type': 'string'},
   'short_summary': {'type': 'string'}},
  'required': ['long_summary', 'short_summary']}}

In [8]:
template_str = """\
ユーザーが入力したテキストについて要約を生成してください。
要約は元のテキストの情報を全て保持するようにしてください。
"""

In [9]:
prompt_template = ChatPromptTemplate.from_messages([
    ('system', template_str),
    ('user', '{text}')
])

chat = ChatOpenAI(model='gpt-3.5-turbo-0125', temperature=0., max_tokens=4096)

chat_with_func = chat.bind(
    functions=[summarize_func],
    function_call={'name':'GenerateSummary'}
)

summarizing_chain = (
    {'text': RunnablePassthrough()}
    | prompt_template
    | chat_with_func
)

### Test

In [10]:
# topic = topic_list[0]
# with open(f'./data/{topic_type}/{topic}/ja-prose-full.txt', 'r') as f:
#     text = f.read()
# # print(text)

# res = summarizing_chain.invoke(text)
# print(res)

In [11]:
# args = json.loads(res.additional_kwargs['function_call']['arguments'])
# print(len(text), text)
# print(len(args['long_summary']), args['long_summary'])
# print(len(args['short_summary']), args['short_summary'])

### Batch

In [12]:
texts = [
    open(f'./data/{topic_type}/{topic}/ja-prose-full.txt', 'r').read()
    for topic in topic_list
]
res = summarizing_chain.batch(texts)

In [13]:
short_summaries = []
long_summaries = []
for r in res:
    args = json.loads(r.additional_kwargs['function_call']['arguments'])
    short_summaries.append(args['short_summary'])
    long_summaries.append(args['long_summary'])

In [14]:
for (topic,full), (short, long) in zip(zip(topic_list, texts), zip(short_summaries, long_summaries)):
    print(topic, len(full), len(long), len(short))
    with open(f'./data/{topic_type}/{topic}/ja-prose-short_sum.txt', 'w') as f:
        f.write(short)
    with open(f'./data/{topic_type}/{topic}/ja-prose-long_sum.txt', 'w') as f:
        f.write(long)

じゃがいも 606 327 155
玉ねぎ 689 677 152
にんじん 605 347 125
なす 457 319 108
トマト 532 275 102
ピーマン 592 396 159
