In [1]:
import pandas as pd
from collections import defaultdict
from tqdm import tqdm

In [2]:
df = pd.read_csv('sdt_limited_test.tsv', sep='\t', names=['query', 'cat_wt_tcat_twt'], header=None)

In [3]:
len(df)

4986542

In [4]:
df.head().to_dict('records')

[{'query': 'big gift boxes for presents',
  'cat_wt_tcat_twt': "[['2499', '2520', '2501'], [0.034763362258672, 0.692795872688293, 0.209890782833099], ['2479'], [0.692795872688293]]"},
 {'query': 'women s plus size coats',
  'cat_wt_tcat_twt': "[['6079', '6047', '6044'], [0.017222201451659, 0.9199814200401301, 0.033585455268621], ['6028', '6026'], [0.017222201451659, 0.9199814200401301]]"},
 {'query': 'foundation maybelline',
  'cat_wt_tcat_twt': "[['998', '1009'], [0.056512609124183, 0.8071712851524351], ['998'], [0.8071712851524351]]"},
 {'query': 'hasbul',
  'cat_wt_tcat_twt': "[['5901', '5904', '5913'], [0.057797316461801, 0.083189412951469, 0.34847527742385803], ['5901'], [0.34847527742385803]]"},
 {'query': 'h d kristal',
  'cat_wt_tcat_twt': "[['2685', '4083', '4089'], [0.49266675114631603, 0.065896727144718, 0.21316447854042], ['2649', '4081'], [0.49266675114631603, 0.21316447854042]]"}]

In [5]:
df_tax = pd.read_json('../taxonomy/wish_newtax_02222023.json', lines=True)

In [6]:
id2tax = {}
tax2id = {}
for i in df_tax.to_dict('records'):
    if len(i['category_path']) > 0:
        id2tax[str(i['id'])] = i['category_path']
        tax2id[i['category_path']] = str(i['id'])

In [7]:
TRUNCATE_DEPTH = 2

In [8]:
def validate_truncation(x):
    cats, wts, tcats, twts = x
    tcat_dicts_given = {i: j for i, j in zip(tcats, twts)}
    tcats_dict = defaultdict(int)
    for i, j in zip(cats, wts):
        tcat = ' > '.join(id2tax[i].split(' > ')[:TRUNCATE_DEPTH])
        tcats_dict[tax2id[tcat]] = max(tcats_dict[tax2id[tcat]], j)
    return set(tcats_dict.items()) == set(tcat_dicts_given.items()), (tcats_dict, tcat_dicts_given)

In [9]:
validate_truncation(eval("[[u'6079', u'6047', u'6044'], [0.017222201451659, 0.9199814200401301, 0.033585455268621], [u'6028', u'6026'], [0.017222201451659, 0.9199814200401301]]"))

(True,
 (defaultdict(int, {'6028': 0.017222201451659, '6026': 0.9199814200401301}),
  {'6028': 0.017222201451659, '6026': 0.9199814200401301}))

In [10]:
validate_truncation(eval('[[], [], [], []]'))

(True, (defaultdict(int, {}), {}))

In [11]:
errors = []
for i in tqdm(df.to_dict('records')):
    is_correct, res = validate_truncation(eval(i['cat_wt_tcat_twt']))
    if not is_correct:
        errors.append((i, res))

100%|██████████| 4986542/4986542 [03:47<00:00, 21888.78it/s]


In [12]:
len(errors) / len(df)

0.0

In [14]:
df.sample(5).to_dict('records')

[{'query': 'wand licht schlafzimmer',
  'cat_wt_tcat_twt': "[['3915'], [0.22929243743419603], ['3897'], [0.22929243743419603]]"},
 {'query': 'tooth fairy gifts for girls',
  'cat_wt_tcat_twt': "[['5850', '1089', '5904'], [0.11144375056028301, 0.06003725528717001, 0.071317106485366], ['1085', '5901', '5848'], [0.06003725528717001, 0.071317106485366, 0.11144375056028301]]"},
 {'query': 'onn party speaker',
  'cat_wt_tcat_twt': "[['1755', '1754', '1761'], [0.469256103038787, 0.34622895717620805, 0.146192356944084], ['1754'], [0.469256103038787]]"},
 {'query': 'camas bebes',
  'cat_wt_tcat_twt': "[['4273', '4236', '4233'], [0.48028725385665805, 0.06607480347156501, 0.11679942160844801], ['4272', '4231'], [0.48028725385665805, 0.11679942160844801]]"},
 {'query': 'leggings de drainage',
  'cat_wt_tcat_twt': "[['5959', '6084'], [0.28475460410118103, 0.21063022315502103], ['6028', '6142'], [0.21063022315502103, 0.28475460410118103]]"}]