In [1]:
import openai
import re

import pandas as pd

from langchain import PromptTemplate
from langchain.llms import OpenAI

from langchain.document_loaders.csv_loader import CSVLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.chains import RetrievalQA
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
import numpy as np

- 好きなアニメを取得したい


In [2]:
def get_prompt(template, input_text):
    prompt = PromptTemplate(template=template, input_variables=["input"])
    prompt_text = prompt.format(input=input_text)

    return prompt_text


def get_anime_name(anime_name):
    anime_list = anime_name.split("\n")
    anime_list = [anime for anime in anime_list if anime]

    return anime_list


def create_index_data():
    df = pd.read_csv("../data/anime_list.csv")
    df = df[["MAL_ID", "Name", "English name", "Japanese name", "Genres"]]

    # lowr
    df["Name"] = df["Name"].str.lower()
    df["English name"] = df["English name"].str.lower()
    df["Japanese name"] = df["Japanese name"].str.lower()
    df["Genres"] = df["Genres"].str.lower()

    df.rename(columns={"MAL_ID": "anime_id"}, inplace=True)
    df.to_csv("../data/index_data.csv", index=False)


def create_db():
    loader = CSVLoader(file_path="../data/index_data.csv")
    documents = loader.load()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
    docs = text_splitter.split_documents(documents)
    embeddings = OpenAIEmbeddings()
    db = FAISS.from_documents(docs, embeddings)
    db.save_local("../data/faiss_index")


def check_similarity_item(db, input):
    input = input.lower()
    sim = db.similarity_search_with_score(query=input, k=1)

    anime_id = sim[0][0].page_content.split("\n")[0]
    match = re.search(r"anime_id:\s*(\d+)", anime_id)
    if match:
        anime_id = int(match.group(1))
    else:
        anime_id = np.nan

    anime_name = sim[0][0].page_content.split("\n")[1]
    sim_value = sim[0][1]

    print(anime_id, anime_name, sim_value)


def get_similarity_item(db, input):
    input = input.lower()
    sim = db.similarity_search_with_score(query=input, k=1)

    anime_id = sim[0][0].page_content.split("\n")[0]
    match = re.search(r"anime_id:\s*(\d+)", anime_id)
    if match:
        anime_id = int(match.group(1))
    else:
        anime_id = np.nan

    anime_name = sim[0][0].page_content.split("\n")[1]
    sim_value = sim[0][1]

    if sim_value >= 0.3:
        return None
    else:
        return anime_id

### Create DB


In [3]:
create_index_data()

In [4]:
create_db()

In [5]:
embeddings = OpenAIEmbeddings()
db = FAISS.load_local("../data/faiss_index", embeddings)

In [6]:
llm = OpenAI()

In [7]:
chain = RetrievalQA.from_chain_type(
    llm, retriever=db.as_retriever(search_kwargs={"k": 1})
)

In [12]:
check_anime_list = [
    "Angel Beats",
    "魔法使いプリキュア",
    "Mahoutsukai Precure!",
    "狼と香辛料",
    "Ookami to Koushinryo",
    "鬼滅の刃",
    "Kimetsu no Yaiba",
]

In [13]:
for anime in check_anime_list:
    check_similarity_item(db, anime)

6547 Name: angel beats! 0.25655705
31884 Name: mahoutsukai precure! 0.31024
31884 Name: mahoutsukai precure! 0.20275107
2966 Name: ookami to koushinryou 0.34811774
6007 Name: ookami to koushinryou ii: ookami to kohakuiro no yuuutsu 0.26174277
30679 Name: queen's blade: grimoire 0.33600473
38000 Name: kimetsu no yaiba 0.1943137


### Prompt


In [14]:
user_input = """おすすめのアニメを教えてください。私が好きなアニメは以下の通りです。

Angel Beats!
Mahoutsukai Precure!
Ookami to Koushinryou
"""

In [15]:
template_user_input = """以下の文章は、ユーザが好きなアニメの名前を列挙しています。
アニメの名前を抽出し、出力しなさい。

{input}
"""

In [16]:
user_prompt = get_prompt(template_user_input, user_input)

In [17]:
anime_list = llm(user_prompt)

In [18]:
anime_list = get_anime_name(anime_list)

In [19]:
anime_list

['Angel Beats!', 'Mahoutsukai Precure!', 'Ookami to Koushinryou']

In [21]:
anime_ids = []
for item in anime_list:
    anime_ids.append(get_similarity_item(db, item))

In [22]:
anime_ids

[6547, 31884, 6007]

In [23]:
df = pd.read_csv("../data/index_data.csv")

In [24]:
df[df["anime_id"].isin(anime_ids)]

Unnamed: 0,anime_id,Name,English name,Japanese name,Genres
2467,6007,ookami to koushinryou ii: ookami to kohakuiro ...,unknown,狼と香辛料ii 狼と琥珀色の憂鬱,"adventure, fantasy, historical, romance"
2565,6547,angel beats!,angel beats!,angel beats!（エンジェルビーツ）,"action, comedy, drama, school, supernatural"
4144,31884,mahoutsukai precure!,maho girls precure!,魔法つかいプリキュア！,"action, slice of life, magic, fantasy, school,..."
