In [1]:
from pathlib import Path
import polars as pl
from datasets import Dataset
import numpy as np
from autofaiss import build_index
from FlagEmbedding import BGEM3FlagModel
from core_pro.ultilities import make_dir
import sys
sys.path.extend([str(Path.home() / 'PycharmProjects/item_matching')])

from notebooks.benchmark.data_load import load

In [2]:
df, col, path = load()
df.head()

Data Shape: (72110, 6)


id,q_item_id,q_level1_global_be_category,q_item_name,q_link_first_image,q_item_name_clean
u32,i64,str,str,str,str
0,21383090719,"""Men Shoes""","""Dép nam, nữ unisex 5 màu VAC đ…","""https://cf.shopee.sg/file/vn-1…","""dép nam, nữ unisex 5 màu vac đ…"
1,20951659760,"""Men Shoes""","""Giày NB crt 300 2.0 Fullbox, G…","""https://cf.shopee.sg/file/dbc7…","""giày nb crt 300 2.0 fullbox, g…"
2,19930263099,"""Men Shoes""","""GIÀY BẢO HỘ LAO ĐỘNG ĐẾ KẾP, Đ…","""https://cf.shopee.sg/file/23a6…","""giày bảo hộ lao động đế kếp, đ…"
3,24673915700,"""Men Shoes""","""Dép tông nam nữ BBR chữ đế 2 l…","""https://cf.shopee.sg/file/vn-1…","""dép tông nam nữ bbr chữ đế 2 l…"
4,15496457499,"""Men Shoes""","""(𝗖𝗵𝗶́𝗻𝗵 𝗵𝗮̃𝗻𝗴) Dép ADIDAS ADIL…","""https://cf.shopee.sg/file/vn-1…","""dép adidas adilette aqua chống…"


In [None]:
path_tmp_array = Path('tmp/array')
path_tmp_ds = Path('tmp/ds')
make_dir(path_tmp_ds)
make_dir(path_tmp_array)

file_embed = path_tmp_array / 'embed.npy'
if not file_embed.exists():
    model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=False)
    embeddings = model.encode(
        df['q_item_name_clean'].to_list(),
        batch_size=8,
        max_length=80,
        return_dense=True,
        return_sparse=False,
        return_colbert_vecs=False
    )['dense_vecs']
    np.save(file_embed, embeddings)
else:
    embeddings = np.load(file_embed)
print(embeddings.shape)

Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]

In [None]:
df = df.with_columns(pl.Series(values=embeddings, name='embed'))
dataset = Dataset.from_polars(df)
dataset.set_format(type='numpy', columns=['embed'], output_all_columns=True)

In [None]:
path_index = Path('tmp/index')
build_index(
    str(path_tmp_array),
    index_path=str(path_index / f'ip.index'),
    index_infos_path=str(path_index / f'index.json'),
    save_on_disk=True,
    metric_type='ip',
    verbose=30,
)

In [None]:
# add index
dataset.load_faiss_index('embed', path_index / f'ip.index')

In [None]:
score, result = dataset.get_nearest_examples_batch(
    'embed',
    np.asarray(dataset['embed']),
    k=5
)

dict_ = {'score': [list(i) for i in score]}
df_score = pl.DataFrame(dict_)
df_result = pl.DataFrame(result).drop(['embed'])

In [None]:
df_match = pl.concat([df, df_result, df_score], how='horizontal')
col_explode = [i for i in df_match.columns if 'db' in i] + ['score']
df_match = df_match.explode(col_explode)

In [None]:
df_match