In [None]:
import pickle
import pandas as pd
from glob import glob

lap_df = pd.read_pickle('/media/sj-archimedes/data/share/OddAI_Library_practice/data08/lap_trainval_20221001-20231031.pkl')
lap_df = lap_df.query('creative_type == "image"')
lap_df['creative_media_url'] = lap_df['creative_media_url'].map(lambda x: x[0])
lap_df['creative_media_hash'] = lap_df['creative_media_hash'].map(lambda x: x[0])

fba_df = pd.read_pickle('/media/sj-archimedes/data/share/OddAI_Library_practice/data08/fba_trainval_20221001-20231031.pkl')
fba_df = fba_df.query('creative_type == "image"')
fba_df['creative_media_url'] = fba_df['creative_media_url'].map(lambda x: x[0])
fba_df['creative_media_hash'] = fba_df['creative_media_hash'].map(lambda x: x[0])

df = pd.concat([lap_df, fba_df])

In [None]:
all_caption_dict = {}

caption_dict_list = glob('./output/*.pkl')
caption_dict = {}
for fname in caption_dict_list:
    with open(fname, 'rb') as r:
        caption_dict |= pickle.load(r)
all_caption_dict['caption'] = caption_dict

appeal_dict_list = glob('./output_appeal_caption/*.pkl')

appeal_dict = {}
for fname in appeal_dict_list:
    with open(fname, 'rb') as r:
        appeal_dict |= pickle.load(r)

all_caption_dict['appeal'] = appeal_dict

In [None]:
df['caption'] = df['creative_media_hash'].map(
    lambda x: all_caption_dict['caption'][x] if x in all_caption_dict['caption'] else None
)
df['appeal'] = df['creative_media_hash'].map(
    lambda x: all_caption_dict['appeal'][x] if x in all_caption_dict['appeal'] else None
)
df = df[~df['appeal'].isna()]

In [None]:
df = df[['creative_media_hash', 'creative_media_url', 'caption', 'appeal']]
df = df.drop_duplicates()
df.shape

In [None]:
from datasets import Dataset

dataset = Dataset.from_pandas(df)
dataset

In [None]:
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
import torch
torch.set_grad_enabled(False)
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
ctx_encoder.to('cuda:2')
print()

In [None]:
dataset = dataset.map(
    lambda example: {'embeddings_caption': ctx_encoder(**ctx_tokenizer(example["caption"], return_tensors="pt", max_length=512).to('cuda:2'))[0][0].cpu().numpy()}
)

dataset = dataset.map(
    lambda example: {'embeddings_appeal': ctx_encoder(**ctx_tokenizer(example["appeal"], return_tensors="pt", max_length=512).to('cuda:2'))[0][0].cpu().numpy()}
)

In [None]:
import faiss
dataset.add_faiss_index(column='embeddings_caption')
dataset.add_faiss_index(column='embeddings_appeal')

In [None]:
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")

# 広告画像をqueryとする
実際にはその広告画像のcaptionやappeal captionをqueryとしている

In [None]:
import random
from src.utils import download_image_from_s3
import matplotlib.pyplot as plt

idx = random.randint(0, 6241)

query_ds = dataset[idx]
img = download_image_from_s3(query_ds['creative_media_url'])
print('===== caption =====')
print(query_ds['caption'])
print('\n===== appeal =====')
print(query_ds['appeal'])
plt.figure(figsize=(5,5))
plt.imshow(img)
plt.show()

In [None]:
import openai
from openai import OpenAI
from src.utils import download_image_from_s3
import matplotlib.pyplot as plt
# openai.api_key = ''

# query = f"次の文章を英語に翻訳してください。\n{query}"


# client = OpenAI()

# query = client.chat.completions.create(
#   model="gpt-3.5-turbo",
#   messages=[
#     {"role": "system", "content": "You are a helpful assistant."},
#     {"role": "user", "content": query},
#   ]
# )

# query = query.choices[0].message.content

query = query_ds['caption']
print(query)
question_embedding = q_encoder(**q_tokenizer(query, return_tensors="pt"))[0][0].numpy()
caption_scores, caption_retrieved_examples = dataset.get_nearest_examples('embeddings_caption', question_embedding, k=10)

query = query_ds['appeal']
print(query)
question_embedding = q_encoder(**q_tokenizer(query, return_tensors="pt"))[0][0].numpy()
appeal_scores, appeal_retrieved_examples = dataset.get_nearest_examples('embeddings_appeal', question_embedding, k=10)

In [None]:
media_hashes = caption_retrieved_examples['creative_media_hash']
s3_urls = caption_retrieved_examples['creative_media_url']
captions = caption_retrieved_examples['caption']

for url, caption in zip(s3_urls, captions):
    print(caption)
    img = download_image_from_s3(url)
    plt.figure(figsize=(5,5))
    plt.imshow(img)
    plt.show()
    print('\n==============\n')

In [None]:
media_hashes = appeal_retrieved_examples['creative_media_hash']
s3_urls = appeal_retrieved_examples['creative_media_url']
captions = appeal_retrieved_examples['caption']

for url, caption in zip(s3_urls, captions):
    print(caption)
    img = download_image_from_s3(url)
    plt.figure(figsize=(5,5))
    plt.imshow(img)
    plt.show()
    print('\n==============\n')

# ユーザが自然言語で検索する場合

In [None]:
import openai
from openai import OpenAI
from src.utils import download_image_from_s3
import matplotlib.pyplot as plt
openai.api_key = ''


query = "声優、限定"
query = f"次の文章を英語に翻訳してください。\n{query}"


client = OpenAI()

query = client.chat.completions.create(
  model="gpt-3.5-turbo",
  messages=[
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": query},
  ]
)

query = query.choices[0].message.content

print(query)
question_embedding = q_encoder(**q_tokenizer(query, return_tensors="pt"))[0][0].numpy()
caption_scores, caption_retrieved_examples = dataset.get_nearest_examples('embeddings_caption', question_embedding, k=10)
appeal_scores, appeal_retrieved_examples = dataset.get_nearest_examples('embeddings_appeal', question_embedding, k=10)

In [None]:
media_hashes = caption_retrieved_examples['creative_media_hash']
s3_urls = caption_retrieved_examples['creative_media_url']
captions = caption_retrieved_examples['caption']

for url, caption in zip(s3_urls, captions):
    print(caption)
    img = download_image_from_s3(url)
    plt.figure(figsize=(5,5))
    plt.imshow(img)
    plt.show()
    print('\n==============\n')

In [None]:
media_hashes = appeal_retrieved_examples['creative_media_hash']
s3_urls = appeal_retrieved_examples['creative_media_url']
captions = appeal_retrieved_examples['caption']

for url, caption in zip(s3_urls, captions):
    print(caption)
    img = download_image_from_s3(url)
    plt.figure(figsize=(5,5))
    plt.imshow(img)
    plt.show()
    print('\n==============\n')