In [1]:
import pandas as pd

In [2]:
# 全画像のラベルを csv から取得
df = pd.read_csv('labels.csv')
len(df)

1445

In [3]:
# ランダムに5件のデータを表示
df.sample(5)

Unnamed: 0,image_path,label
1327,datasets\floodnet\FloodNet-Supervised_v1.0\tra...,non-flooded
334,datasets\floodnet\FloodNet-Supervised_v1.0\tra...,non-flooded
1166,datasets\floodnet\FloodNet-Supervised_v1.0\tra...,non-flooded
767,datasets\floodnet\FloodNet-Supervised_v1.0\tra...,non-flooded
1059,datasets\floodnet\FloodNet-Supervised_v1.0\tra...,non-flooded


In [4]:
# ラベルの分布を確認
df['label'].value_counts()

label
non-flooded    1263
flooded         182
Name: count, dtype: int64

In [5]:
import cohere
import os
from dotenv import load_dotenv, find_dotenv
from PIL import Image
from io import BytesIO
import base64
from tqdm import tqdm

In [6]:
_= load_dotenv(find_dotenv())

In [7]:
api_key = os.getenv("COHERE_API_KEY")
model_id = os.getenv("COHERE_EMBED_MODEL_ID")
model_id

'embed-english-v3.0'

In [8]:
# 画像を base64 に変換してデータ URL を返す関数
def image_to_base64_data_url(image_path):
    with Image.open(image_path) as img:
        buffered = BytesIO()
        img.save(buffered, format="JPEG")
        img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
    data_url = f"data:image/jpeg;base64,{img_base64}"
    return data_url

In [9]:
co = cohere.Client(api_key=api_key)

In [10]:
# CSVファイルのヘッダーを初期化（最初の1回のみ）
pd.DataFrame(columns=['image_path', 'label', 'embedding']).to_csv('image_embeddings.csv', index=False)

# DataFrameをイテレーション
for index, row in tqdm(df.iterrows(), total=len(df)):
    image_path = row['image_path']
    label = row['label']
    
    data_url = image_to_base64_data_url(image_path)
    ret = co.embed(
        input_type="image",
        images=[data_url],
        model=model_id,
        embedding_types=["float"],                
    )
    
    # 1件分のデータをDataFrameとして作成
    embedding_df = pd.DataFrame([{
        'image_path': image_path,
        'label': label,
        'embedding': ret.embeddings.float[0]
    }])
    
    # mode='a'（append）とheader=Falseで追記モードで保存
    embedding_df.to_csv('image_embeddings.csv', mode='a', header=False, index=False)

100%|██████████| 1445/1445 [26:21<00:00,  1.09s/it]
