In [1]:
from tqdm import tqdm
import numpy as np
import pandas as pd
import torch
import random
from typing import List, Union, Dict
from tqdm.autonotebook import trange
from optimum.onnxruntime import ORTModelForFeatureExtraction
from transformers import AutoTokenizer
from datasets import load_dataset

  from tqdm.autonotebook import trange


In [2]:
def _text_length(text: Union[List[int], List[List[int]]]):
    if isinstance(text, dict):  # {key: value} case
        return len(next(iter(text.values())))
    elif not hasattr(text, "__len__"):  # Object has no len() method
        return 1
    elif len(text) == 0 or isinstance(text[0], int):  # Empty string or list of ints
        return len(text)
    else:
        return sum([len(t) for t in text])  # Sum of length of individual strings
    
def inference(tokenizer: AutoTokenizer, model: ORTModelForFeatureExtraction, 
             sentences: List[str], batch_size:int = 16, verbose: bool =False):
    
    length_sorted_idx = np.argsort([-_text_length(sen) for sen in sentences])
    sentences_sorted = [sentences[idx] for idx in length_sorted_idx]

    embeddings = []

    with tqdm(total=len(sentences), desc="Batches", disable = not verbose) as pbar:
        for i in trange(0, len(sentences), batch_size, desc="Batches", disable = True):
            batch = sentences_sorted[i:i+batch_size]
            encoded_inputs = tokenizer(batch, padding=True, truncation=True, max_length=128, return_tensors='pt').to(torch.device('cpu'))
            with torch.no_grad():
                output = model(**encoded_inputs)['last_hidden_state'].detach()
                batch_prototypes = torch.mean(output, dim=1)
                batch_prototypes = torch.nn.functional.normalize(batch_prototypes, p=2, dim=1).to(torch.device('cpu'))
                embeddings.extend(batch_prototypes)
            pbar.update(len(batch))

    embeddings = [embeddings[idx] for idx in np.argsort(length_sorted_idx)]
    embeddings = np.asarray([emb.numpy() for emb in embeddings])

    return embeddings

In [3]:
category_dataset = "clw8998/Shopee-Categories"
p_name_dataset = 'clw8998/Coupang-Product-Set-1M'
model_id = 'clw8998/Product-Classification-Model-Distilled'
batch_size = 16
model = ORTModelForFeatureExtraction.from_pretrained(model_id, export=False)
tokenizer = AutoTokenizer.from_pretrained(model_id)

categories = load_dataset(category_dataset, split="train")['category']
p_names = load_dataset(p_name_dataset, split="train")['product_name']
p_names = random.sample(p_names, 30)

In [4]:
print(f"Num of category: {len(categories)}, Num of p_name: {len(p_names)}")

Num of category: 1425, Num of p_name: 30


In [5]:
def get_category(tokenizer: AutoTokenizer, model: ORTModelForFeatureExtraction, 
         p_names: List[str], categories: List[str], top_k: int = 3) -> Dict[str, List[str]]:
    

    p_names_embeddings = inference(tokenizer, model, p_names, batch_size, verbose=True) # (n. 768)
    categories_embeddings = inference(tokenizer, model, categories, batch_size, verbose=True) # (n, 768)

    scores_matrix = np.dot(p_names_embeddings, categories_embeddings.T)
    
    result = {}
    for i in range(len(p_names)):
        top_k_indices = np.argsort(-scores_matrix[i])[:top_k] # Default ascending order
        top_k_path = [categories[index] for index in top_k_indices]
        result[p_names[i]] = top_k_path

    return result


In [6]:
top_k = 1
result = get_category(tokenizer, model, p_names, categories, top_k=top_k)

Batches: 100%|██████████| 30/30 [00:00<00:00, 738.21it/s]
Batches: 100%|██████████| 1425/1425 [00:00<00:00, 1613.79it/s]


In [7]:
df = pd.DataFrame.from_dict(result, orient='index')
df.columns = [f'category_{i+1}' for i in range(top_k)]
df = df.reset_index().rename(columns={'index': 'p_name'})
df

Unnamed: 0,p_name,category_1
0,"withshyan 60秒美甲, M144 櫻桃糖漿, 9ml, 6個",美食、伴手禮_飲料、沖泡品_果汁、果醋
1,"蓋曼光澤珍珠指甲貼, NO358, 3個",美妝保健_手足保養、美甲_指甲貼紙
2,KKAKKUNGNORITER 中性款打底褲 2款組,女生衣著_長褲_緊身褲 / 內搭褲
3,"Jooyon 2024年好作品D1酷睿i5英特爾第12代, Free DOS, 256GB,...",3C與筆電_電腦零組件_CPU
4,"Welluv 優質右肩姿勢帶, 2個",運動/健身_健身運動器材_健身及運動護具
5,"Deolilly 攜帶式奶瓶瀝水架, 綠色",文創商品_設計雜物_水杯瓶、餐廚配件
6,"ABM 提手紙購物袋冰淇淋 10p, 混色",居家生活_日用品_垃圾袋
7,"wooriga 夜關門茶葉, 600g, 2包",美食、伴手禮_飲料、沖泡品_茶葉、茶包
8,"DIAMOND 鑽石牌 鋁箔紙 小, 3盒",女生配件/黃金_其他_紋身貼紙
9,"NUART 造型手持風扇, 貼片-S5000B, 泰迪熊（棕色）",手機平板與周邊_手機周邊配件_USB風扇/ 手持風扇
