In [1]:
from dataclasses import dataclass, field

import torch


# ===== CORE MODEL CONFIG ======
@dataclass
class Config:
    device: str = "cuda" if torch.cuda.is_available() else "cpu"


@dataclass
class EmbeddingConfig(Config):
    model_name: str = "keepitreal/vietnamese-sbert"


@dataclass
class DatasetConfig:
    data_root: str = "./db"
    corpus_dir: str = f"{data_root}/corpus.csv"
    train_dir: str = f"{data_root}/train.csv"
    vector_src_dir: str = f"{data_root}/vector_db_src.csv"
    public_test_dir: str = f"{data_root}/public_test.csv"


# ===== DATA SETUP CONFIG =====
@dataclass
class MilvusDBConfig:
    data_root: str = "./db"

    db_name: str = (
        f"{data_root}/bkai_milvus.db"  # Change this to place the db to where you want
    )
    collection_name: str = "bkai_vectordb"
    limit: int = 30  # This is top_k results
    output_fields: list = field(default_factory=lambda: ["text", "cid"])
    metric_type: str = (
        "COSINE"  # Possible values are IP, L2, COSINE, JACCARD, and HAMMING
    )

    # More details at: https://milvus.io/api-reference/pymilvus/v2.4.x/MilvusClient/Vector/search.md
    params: dict = field(default_factory=lambda: {})

    # Dataset config: this part requires dataset's EDA.
    dimension: int = 768  # Length of the embedding vector (1, embed_len)
    primary_field_name: str = "id"
    id_type: str = "int"
    vector_field_name: str = "embeddings"
    auto_id: bool = False


In [2]:
import os
import ast
from dotenv import load_dotenv

import pandas as pd

# This script is highly recommended for individuals that have limit computatation resources.
def process_data_in_batches(df, batch_size=10000):
    """Processes data in batches.

    Args:
      df: The Pandas DataFrame to process.
      batch_size: The number of rows to process in each batch.

    Yields:
      A generator that yields batches of the DataFrame.
    """
    for i in range(0, len(df), batch_size):
        yield df[i : i + batch_size]

def get_bkai_result_format(input_csv_path, output_txt_path):
    "Return results data from pandas into the format: qid cid1 cid2 ..."
    # Load data
    result_df = pd.read_csv(input_csv_path, index_col=False)

    # Parse 'cid' strings to lists
    result_df["cid"] = result_df["cid"].apply(ast.literal_eval)

    # Build the output lines
    lines = [
        f"{q} " + " ".join(map(str, cid)) + "\n"
        for q, cid in zip(result_df["qid"], result_df["cid"])
    ]

    # Write all lines to file at once
    with open(output_txt_path, "w") as file:
        file.writelines(lines)

# Setup Milvus

In [3]:
"Connet to MilvusDB"

import pymilvus
from pymilvus import MilvusClient


class MilvusDBConnection:
    "Establish Connection to Milvus"

    def __init__(self, config):
        self.config = config
        self.connect()

    def connect(self):
        try:
            self.client = MilvusClient(self.config.db_name)
        except Exception as e:
            raise pymilvus.exceptions.ConnectError() from e

    def check_collection(self) -> bool:
        return self.client.has_collection(collection_name=self.config.collection_name)

    def create_collection(self):
        try:
            self.client.create_collection(
                collection_name=self.config.collection_name,
                dimension=self.config.dimension,
                primary_field_name=self.config.primary_field_name,
                id_type=self.config.id_type,
                vector_field_name=self.config.vector_field_name,
                auto_id=self.config.auto_id,
                metric_type=self.config.metric_type,
            )
        except Exception as e:
            raise e

    def drop_collection(self):
        try:
            if self.check_collection():
                self.client.drop_collection(collection_name=self.config.collection_name)
        except Exception as e:
            raise pymilvus.exceptions.CollectionNotExistException() from e


In [9]:
embedding_conf=EmbeddingConfig()
ds_conf=DatasetConfig()

In [6]:
def _convert_string_to_float_df( sample):
    string_ = sample[2:-2] 
    float_strings = string_.replace('\n ', ' ').split(' ')
    float_list = [float(s) for s in float_strings if s != '']
    return float_list

print("Convert embedding value from string to float")
vector_df = pd.read_csv(ds_conf.vector_src_dir)

# vector_df = pd.read_csv(self.ds_conf.corpus_dir, index_col=0)
vector_df["embeddings"] = vector_df["embeddings"].apply(
    lambda x: _convert_string_to_float_df(x)
)

vector_df.head()

Convert embedding value from string to float


Unnamed: 0,text,cid,embeddings
0,"Thông tư này hướng dẫn tuần tra, canh gác bảo ...",0,"[0.251511246, -0.239485502, 0.508561134, -0.15..."
1,"1. Hàng năm trước mùa mưa, lũ, Ủy ban nhân dân...",1,"[0.07124101, -0.1602871, 0.27494153, 0.1891961..."
2,Tiêu chuẩn của các thành viên thuộc lực lượng ...,2,"[0.110181659, -0.107379489, 0.125140026, 0.246..."
3,"Nhiệm vụ của lực lượng tuần tra, canh gác đê\n...",3,"[0.0841644704, -0.00109320623, 0.287990421, 0...."
4,"Phù hiệu của lực lượng tuần tra, canh gác đê\n...",4,"[0.285764992, 0.000564881775, 0.371830344, 0.1..."


In [7]:
# ===== DATA SETUP CONFIG =====
@dataclass
class MilvusDBConfig:
    data_root: str = "./db"

    db_name: str = (
        f"{data_root}/bkai_milvus.db"  # Change this to place the db to where you want
    )
    collection_name: str = "bkai_vectordb"
    limit: int = 30  # This is top_k results
    output_fields: list = field(default_factory=lambda: ["text", "cid"])
    metric_type: str = (
        "COSINE"  # Possible values are IP, L2, COSINE, JACCARD, and HAMMING
    )

    # More details at: https://milvus.io/api-reference/pymilvus/v2.4.x/MilvusClient/Vector/search.md
    params: dict = field(default_factory=lambda: {})

    # Dataset config: this part requires dataset's EDA.
    dimension: int = 768  # Length of the embedding vector (1, embed_len)
    primary_field_name: str = "cid"
    id_type: str = "int"
    vector_field_name: str = "embeddings"
    auto_id: bool = False

mlv_conf=MilvusDBConfig()

# Setup MilvusDB Connection
connection = MilvusDBConnection(config=mlv_conf)
connection.create_collection()

client = connection.client

In [8]:
print("Insert data by batch")
for batch in process_data_in_batches(vector_df, batch_size=1000):
    data = [batch.iloc[idx].to_dict() for idx in range(len(batch))]

    # Insert records
    res = client.insert(
        collection_name=mlv_conf.collection_name, data=data
    )

    print(res)

Insert data by batch
{'insert_count': 1000, 'ids': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 44, 45, 46, 47, 48, 49, 50, 51, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 2

# Inference

In [12]:
from doculens.embedding import EmbeddingModel

embedding_model = EmbeddingModel(config=embedding_conf)



In [10]:
test_df = pd.read_csv(ds_conf.public_test_dir)
test_df

Unnamed: 0,question,qid
0,Hiệp hội Công nghiệp ghi âm Việt Nam hoạt động...,98440
1,Báo cáo nghiên cứu khả thi đầu tư xây dựng là ...,105737
2,Lịch khai giảng năm học 2022 - 2023 đối với họ...,106239
3,Số định danh cá nhân có được dùng thay thế các...,79491
4,Trợ cấp đối với Chủ tịch Hội cựu chiến binh cấ...,130557
...,...,...
9995,Đón trả hành khách trên đường cao tốc có bị gi...,42798
9996,"Các đơn vị được giao là đầu mối trao đổi, cung...",10533
9997,Ban Thường vụ Hội Hỗ trợ khắc phục hậu quả bom...,46794
9998,"Tài liệu thông tin, giáo dục, truyền thông về ...",112007


In [19]:
sample = test_df.iloc[:10]
sample

Unnamed: 0,question,qid
0,Hiệp hội Công nghiệp ghi âm Việt Nam hoạt động...,98440
1,Báo cáo nghiên cứu khả thi đầu tư xây dựng là ...,105737
2,Lịch khai giảng năm học 2022 - 2023 đối với họ...,106239
3,Số định danh cá nhân có được dùng thay thế các...,79491
4,Trợ cấp đối với Chủ tịch Hội cựu chiến binh cấ...,130557
5,Mẫu đơn ủy quyền nhận trợ cấp thất nghiệp theo...,109476
6,Sĩ quan quân đội có được nghỉ phép hàng tuần h...,64585
7,Cây xăng có được phép tuyển lao động nữ mang t...,5785
8,Nguồn kinh phí chi trả mức phụ cấp ưu đãi nghề...,159576
9,Khi thi nâng ngạch lên kế toán viên trung cấp ...,36563


In [32]:
search_params = {
"metric_type": mlv_conf.metric_type,
"params": mlv_conf.params,
}
result = ""
cids = []

for idx in range(len(test_df)): 
    question = test_df.iloc[idx].question
    qid = test_df.iloc[idx].qid
    result += f"{str(qid)} "
    sentence_embedding = embedding_model.invoke(question)

    search_result = client.search(
        collection_name=mlv_conf.collection_name,
        data=sentence_embedding,
        limit=mlv_conf.limit,
        output_fields=mlv_conf.output_fields,
        search_params=search_params,
    )

    for res in search_result[0]:
        res_entity = res["entity"]
        result += f"{res_entity["cid"]} "

    with open('predict.txt', 'a+') as writer: 
        print(f"{idx} question: {result}")
        result += '\n'
        writer.writelines(result)
    result = ""


0 question: 98440 158065 111380 96492 142150 246842 137977 132626 94594 249042 32358 165713 36319 93624 143990 157437 218723 115476 91287 159713 246317 163975 71663 234893 97812 489387 98619 138459 129441 73621 612231 
1 question: 105737 67238 567623 23804 136344 229317 76748 179897 87732 183328 608073 8584 182445 262192 23754 515144 160765 567627 567626 567548 574356 117853 608078 87731 467016 13253 618535 567624 631656 496569 108925 
2 question: 106239 214239 109771 184177 88399 518479 109770 471329 644272 445859 157808 637967 224056 472505 624396 171087 128456 598804 101075 85945 140766 127925 546599 143852 36142 87737 608468 614455 555159 119099 4553 
3 question: 79491 215763 452777 457021 508799 637475 575274 498758 113311 207895 231805 189984 132290 15161 64285 451567 18316 75564 22913 89007 451568 499881 498771 32618 94500 50701 572007 22912 457007 93077 183771 
4 question: 130557 199586 484399 626546 160754 626545 557044 593759 528875 3788 184968 605495 227054 241167 160752 534

In [13]:
sentence_embedding = embedding_model.invoke(sample.question)
sentence_embedding

array([[ 7.31994910e-03,  1.40865996e-01, -9.59775934e-04,
        -2.10210428e-01,  4.22635943e-01,  6.65427983e-01,
        -1.98055968e-01,  1.92148402e-01, -1.72215283e-01,
         2.17930466e-01,  1.59431070e-01,  1.43577099e-01,
         6.04931414e-01, -2.22751265e-03, -2.78937947e-02,
         2.19161227e-01,  3.12417269e-01, -8.97035934e-03,
         5.70143759e-02, -1.71511531e-01,  2.14563519e-01,
         4.49207008e-01,  2.71661192e-01,  1.73510775e-01,
        -2.17540666e-01,  2.85483450e-01,  3.42878044e-01,
         1.03309810e-01, -1.75041854e-01,  4.61441189e-01,
        -3.35438967e-01,  2.56961882e-01,  1.69644877e-01,
        -7.23919272e-01,  2.82167271e-03,  6.33383915e-02,
        -1.51719272e-01,  5.78777254e-01,  1.31913945e-01,
         2.10182101e-01,  1.37827754e-01, -2.39302337e-01,
         7.96478748e-01,  5.08418024e-01, -3.90665054e-01,
         2.97375828e-01,  4.39052135e-02, -5.29473543e-01,
         5.33862233e-01,  1.80093974e-01, -4.48313355e-0

In [24]:
search_params = {
    "metric_type": mlv_conf.metric_type,
    "params": mlv_conf.params,
}

print("Semantic search")
result = client.search(
    collection_name=mlv_conf.collection_name,
    data=sentence_embedding,
    limit=mlv_conf.limit,
    output_fields=mlv_conf.output_fields,
    search_params=search_params,
)

Semantic search


In [25]:
contexts = []
cids = []

for res in result[0]:
    res_entity = res["entity"]
    contexts.append(res_entity["text"])
    cids.append(res_entity["cid"])

In [26]:
cids

[158065,
 111380,
 96492,
 142150,
 246842,
 137977,
 132626,
 94594,
 249042,
 32358,
 165713,
 36319,
 93624,
 143990,
 157437,
 218723,
 115476,
 91287,
 159713,
 246317,
 163975,
 71663,
 234893,
 97812,
 489387,
 98619,
 138459,
 129441,
 73621,
 612231]