In [1]:
from pathlib import Path
import polars as pl
import duckdb
from datasets import Dataset, concatenate_datasets, load_from_disk
import numpy as np
from autofaiss import build_index
import sys
sys.path.extend(['/home/kevin/PycharmProjects/item_matching'])

from src.item_matching.build_index.func_img import PipelineImage
from src.item_matching.build_index.model import Model
from func import draw_images

## 1) Data 

Prepare 2 datasets: Database and Query

In [2]:
path = Path('/home/kevin/Downloads/cb')
path_db = path / 'cb_2024-03-07.parquet'

# db
query = f"""
select *
,concat('http://f.shopee.vn/file/', UNNEST(array_slice(string_split(images, ','), 1, 1))) image_url
from parquet_scan('{str(path_db)}')
order by item_id, images
"""
df_db = (
    duckdb.sql(query).pl()
    .select(pl.all().name.prefix(f'db_'))
    .head(10_000)
)
pipe = PipelineImage(path, col_image='image_url')
df_img_db = pipe.load_images('db')
df_db = (
    df_db.drop(['images'])
    .join(df_img_db, on='db_image_url', how='left')
    .filter(pl.col('db_exists'))
)


# q
df_q = df_db.clone()
df_q.columns = [f'q_{i.split('db_')[1]}' for i in df_db.columns]
df_q.head()

## 2) Embeddings

Use datasets and clip to transform images to vectors

In [3]:
img_model, img_processor = Model().get_img_model(model_id='openai/clip-vit-base-patch32')

In [4]:
dataset = Dataset.from_pandas(df_db.to_pandas())
fn_kwargs = {'col': f'db_file_path', 'processor': img_processor, 'model': img_model}
dataset = dataset.map(Model().pp_img, batched=True, batch_size=128, fn_kwargs=fn_kwargs)
dataset.set_format(type='numpy', columns=['img_embed'], output_all_columns=True)

# save to disk
path_tmp_array = Path('tmp/array')
path_tmp_ds = Path('tmp/ds')
np.save(path_tmp_array / 'array.npy', dataset['img_embed'])
dataset.save_to_disk(path_tmp_ds / 'ds')

## 3) Indexing

Build index to search items

In [5]:
path_index = Path('tmp/index')
build_index(
    str(path_tmp_array),
    index_path=str(path_index / f'ip.index'),
    index_infos_path=str(path_index / f'index.json'),
    save_on_disk=True,
    metric_type='ip',
    verbose=30,
)

Load index into datasets

In [6]:
dataset_db = concatenate_datasets([
    load_from_disk(str(f)) for f in sorted(path_tmp_ds.glob('*'))
])

# add index
dataset_db.load_faiss_index('img_embed', path_index / f'ip.index')

## 4) Retrieve

Batch search top-k from datasets

In [7]:
score, result = dataset_db.get_nearest_examples_batch(
    'img_embed',
    np.asarray(dataset_db['img_embed']),
    k=5
)

dict_ = {'score_img': [list(i) for i in score]}
df_score = pl.DataFrame(dict_)
df_result = pl.DataFrame(result).drop(['img_embed'])

## 5) Post process

In [8]:
df_match = pl.concat([df_q, df_result, df_score], how='horizontal')
col_explode = [i for i in df_match.columns if 'db' in i] + ['score_img']
df_match = df_match.explode(col_explode)

In [9]:
df_match

In [10]:
draw_images(df_match, 2999787165)

In [11]:
draw_images(df_match, 3099789245)

In [12]:
draw_images(df_match, 2999838844)

In [13]:
draw_images(df_match, 3099458499)

In [None]:
# df_match.write_csv(path / 'match.csv')