In [17]:
import os

if os.path.exists( ".env" ):
    from dotenv import load_dotenv
    load_dotenv()

In [18]:
from google import genai
from google.genai import types
from chromadb import Documents, EmbeddingFunction, Embeddings

class GenAIEmbeddingFunction( EmbeddingFunction[ Documents ] ):

    def __init__( self, api_key: str = None, 
                  model_name: str = "gemini-embedding-exp-03-07", 
                  task_type = "SEMANTIC_SIMILARITY" ) -> None:
        self.api_key = api_key if api_key is not None else os.getenv( "GEMINI_API_KEY" )
        self.client = genai.Client( api_key = self.api_key )
        self.model_name = model_name
        self.task_type = task_type

    def __call__( self, input: Documents ) -> Embeddings:
       
        result = self.client.models.embed_content( model = self.model_name,
                                                   contents = input,
                                                   config = types.EmbedContentConfig( task_type = self.task_type )
                                    )
       
        return [ embedding.values for embedding in result.embeddings ]

dataset

In [22]:
import pandas as pd

qa_text_df = pd.read_csv( "./dataset/qa_texts.csv" )

qa_texts = qa_text_df[ "document" ].tolist()
qa_ids = qa_text_df[ "id" ].tolist()

assert len( qa_texts ) == len( qa_ids ), "The length of dataset and ids must be the same."
qa_text_df.head( 2 )

Unnamed: 0,id,document,category
0,qa_1,使用者: 你們有哪些線上服務\n心肝寶貝健康諮詢小助手: 主要提供的線上服務為心臟病和肝病的...,qa
1,qa_2,使用者: 預測結果代表什麼?\n心肝寶貝健康諮詢小助手: 僅代表是否有潛在風險，如有進一步醫...,qa


embeddings

In [None]:
import time

embedding_function = GenAIEmbeddingFunction( api_key = os.getenv( "GEMINI_API_KEY" ), task_type = "QUESTION_ANSWERING" )

current = 0
offset = 5
# total = 10
total = len( qa_texts )
qa_embeddings = []
while len( qa_embeddings ) < total:
    
    try:
        result = embedding_function( qa_texts[ current:( current + offset ) ] )
        if isinstance( result, list ):
            qa_embeddings.extend( result )

        if current % 10 == 0:
            print( f"current progress: { current }" )
        current = current + offset
        
    except Exception as e:
        print( f"Error: { e }" )
    
    time.sleep( 3 )

print( "len:", len( qa_embeddings ) )
print( "last index:", current )

assert len( qa_embeddings ) == total

qa_data to file

In [24]:
import numpy as np
import json

n = len( qa_embeddings )
qa_df = pd.DataFrame( {
    "id": qa_ids[ :n ],
    "document": qa_texts[ :n ],
    "embedding": qa_embeddings[ :n ],
    "category": "qa",
} )
qa_df.to_parquet( "qa_data.parquet", engine = "pyarrow", compression = "snappy" )
print( qa_df.shape )
qa_df.head( 2)

(71, 4)


Unnamed: 0,id,document,embedding,category
0,qa_1,使用者: 你們有哪些線上服務\n心肝寶貝健康諮詢小助手: 主要提供的線上服務為心臟病和肝病的...,"[0.010236396, -0.000117234405, 0.007943524, -0...",qa
1,qa_2,使用者: 預測結果代表什麼?\n心肝寶貝健康諮詢小助手: 僅代表是否有潛在風險，如有進一步醫...,"[0.011000515, 0.014900229, -0.0049431445, -0.0...",qa


### chromaDb

In [42]:
import chromadb

genai_embedding_function = GenAIEmbeddingFunction()

CHROMADB_COLLECTION_NAME = "gad245-g1-chromadb-embedding"
# chroma_client = chromadb.Client() # Ephemeral Client
chroma_client = chromadb.PersistentClient( path = "./chroma" )
collection = chroma_client.get_or_create_collection( 
                                name = CHROMADB_COLLECTION_NAME,
                                embedding_function = genai_embedding_function, #  Chroma will use sentence transformer as a default. 
                           )

In [43]:
import hashlib

# def hash_ids( sentences ):
#     hash_ids = []
#     for sentence in sentences:
#         bytes = sentence.encode()
#         hashed = hashlib.sha256( bytes ).hexdigest()
    
#         hash_ids.append( hashed )

#     return hash_ids

qa_data_df = pd.read_parquet( "qa_data.parquet", engine = "pyarrow" )
n = qa_data_df.shape[ 0 ]
print( "n =", n )

embeddings = qa_data_df[ "embedding" ].tolist()[ :n ]
documents = qa_data_df[ "document" ].tolist()[ :n ]
ids = qa_data_df[ "id" ].tolist()[ :n ]
metadatas = qa_data_df.drop( columns = [ "id", "document", "embedding" ] ).to_dict( orient = "records" )[  :n ]

collection.add(
    documents = documents,
    embeddings = embeddings,
    ids = ids,
    metadatas = metadatas,
)
collection.peek()

n = 71


{'ids': ['qa_1',
  'qa_2',
  'qa_3',
  'qa_4',
  'qa_5',
  'qa_6',
  'qa_7',
  'qa_8',
  'qa_9',
  'qa_10'],
 'embeddings': array([[ 0.0102364 , -0.00011723,  0.00794352, ..., -0.00188833,
         -0.00200063, -0.01164419],
        [ 0.01100051,  0.01490023, -0.00494314, ..., -0.00420213,
         -0.00407544, -0.01837455],
        [-0.00413251,  0.00743793,  0.00969152, ...,  0.00238361,
          0.00298788, -0.01452713],
        ...,
        [-0.00623671,  0.04030967,  0.03029461, ..., -0.02568531,
         -0.00308101, -0.03016951],
        [ 0.00953058,  0.01997827,  0.00854848, ..., -0.00755414,
         -0.00787648,  0.0063565 ],
        [ 0.00675113,  0.01490591,  0.00011731, ..., -0.00375697,
          0.01389408, -0.02812044]], shape=(10, 3072)),
 'documents': ['使用者: 你們有哪些線上服務\n心肝寶貝健康諮詢小助手: 主要提供的線上服務為心臟病和肝病的風險預測',
  '使用者: 預測結果代表什麼?\n心肝寶貝健康諮詢小助手: 僅代表是否有潛在風險，如有進一步醫療問題，請務必諮詢專業醫師的建議並遵照醫囑。',
  '使用者: 你們網站叫什麼名字\n心肝寶貝健康諮詢小助手: 心肝寶貝疾病預測線上服務',
  '使用者: 用怎樣方式做預測\n心肝寶貝健康諮詢小助手: 最近流行的AI 機器學

query

In [44]:
results = collection.query(
    query_texts = [ "你們有什麼服務" ], 
    n_results = 2, # how many results to return
    where = { "category": "qa" },
    # include=[ "documents" ]
)
results

{'ids': [['qa_1', 'qa_33']],
 'embeddings': None,
 'documents': [['使用者: 你們有哪些線上服務\n心肝寶貝健康諮詢小助手: 主要提供的線上服務為心臟病和肝病的風險預測',
   '使用者: 你們的線上服務有什麼?\n心肝寶貝健康諮詢小助手: 我們是「心肝寶貝疾病預測」，提供心臟病和肝病的風險預測，還有一個健康小助手可以線上問問題，幫你關注健康！']],
 'uris': None,
 'data': None,
 'metadatas': [[{'category': 'qa'}, {'category': 'qa'}]],
 'distances': [[0.6169126067353479, 0.6231232824808209]],
 'included': [<IncludeEnum.distances: 'distances'>,
  <IncludeEnum.documents: 'documents'>,
  <IncludeEnum.metadatas: 'metadatas'>]}

In [45]:
results = collection.get( limit = 1 )  # 只取前 3 筆
print( results )
chroma_client.list_collections()

{'ids': ['qa_1'], 'embeddings': None, 'documents': ['使用者: 你們有哪些線上服務\n心肝寶貝健康諮詢小助手: 主要提供的線上服務為心臟病和肝病的風險預測'], 'uris': None, 'data': None, 'metadatas': [{'category': 'qa'}], 'included': [<IncludeEnum.documents: 'documents'>, <IncludeEnum.metadatas: 'metadatas'>]}


['gad245-g1-chromadb-embedding']

In [41]:
chroma_client.delete_collection( CHROMADB_COLLECTION_NAME ) # delete the collection