In [None]:
%cd "full/path/to/codebase"
from pathlib import Path

import polars as pl
from qdrant_client.models import PointStruct
from tqdm import tqdm

from src.data.dataset import LawCorpus, TrainDataset
from src.db.qdrant import Qdrant
from src.nn.text_embedding import TextEmbedding

In [None]:
text_embedding_model = TextEmbedding()

In [None]:
qdrant = Qdrant()

# path to private test folder, perhaps one need to slightly change the file names and image path
# read the code in src/data/dataset.py
law_corpus = LawCorpus(Path("../data/law_db"))
train_data = TrainDataset(Path("../data/train_data"), law_corpus)

image_feature_df = pl.read_json("assets/train_image_feature.json")

In [None]:
collection_name = "train_data"
if qdrant.init_collection(collection_name):
    length = len(train_data)
    for i in tqdm(range(length)):
        if train_data[i]["__faulty__"]:
            continue
        point = train_data[i]
        image_name = Path(point["image_path"]).name
        image_feature_point = image_feature_df.filter(pl.col("image_name") == image_name).row(0, named=True)

        fmt_text = "Question: " + point["question"] + "\nOptions:\n"
        if point["question_type"] == "Multiple choice":
            for k_choice, v_choice in point["choices"].items():
                fmt_text += f"{k_choice}: {v_choice}\n"
        elif point["question_type"] == "Yes/No":
            fmt_text += "Đúng\nSai\n"

        image_object_feature_list_vector = [[0.0] * 2304]
        if image_feature_point["object_feature_list"]:
            image_object_feature_list_vector = [
                img_obj["object_feature"] for img_obj in image_feature_point["object_feature_list"]
            ]

        point_struct = PointStruct(
            id=i,
            vector={
                "text_vector": text_embedding_model.infer_single(fmt_text),
                "image_general_feature_vector": image_feature_point["general_feature"],
                "image_object_feature_list_vector": image_object_feature_list_vector,
            },
            payload={**point},
        )
        qdrant.client.upsert(collection_name=collection_name, points=[point_struct])