In [1]:
import os

os.chdir('../')

import numpy as np
import networkx as nx
import torch
import pandas as pd
import seaborn as sns

sns.set()

In [5]:
def flatten(x):
    return [i for j in x for i in j]

In [2]:
data = pd.read_csv('full_data_v2.csv', usecols=['image_key', 'category_path', 'category_id'])
data.head()

Unnamed: 0,image_key,category_path,category_id
0,bqtrg37rl2o1m1t69om0,"[10012, 10165, 11239, 11738]",11738
1,bq3fgcath9dctj4eavig,"[10017, 10837, 13117]",13117
2,bq3fgfupskc4erbm4hi0,"[10017, 10837, 13117]",13117
3,bq3fgdufrsa0ancfuk7g,"[10017, 10837, 13117]",13117
4,bq3fgbepskc4erbm4hhg,"[10017, 10837, 13117]",13117


In [10]:
data_cats = pd.read_csv('product_categories_v3.csv').rename(columns={'id': 'product_id', 'path': 'categories'})
data_cats

Unnamed: 0,product_id,categories
0,1,1.10010.10717.12927
1,2,1.10014.10052.11405.12553
2,3,1.10004.10025.10200.12236
3,4,1.10014.10116.10597.14242
4,5,1.10020.10044.12690
...,...,...
435647,442784,1.10010.10541.14173
435648,442785,1.10013.10155.11133.13142
435649,442786,1.10007.10162.14187
435650,442787,1.10018.10219.12731


In [11]:
data_cats.categories = data_cats.categories.apply(
    lambda x: [] if isinstance(x, float) else list(map(int, x.split('.'))))
data_cats['num_classes'] = data_cats.categories.apply(len)

data_cats = data_cats.loc[data_cats.num_classes > 1]
data_cats.categories = data_cats.categories.apply(lambda x: x[1:])
data_cats.num_classes = data_cats.num_classes - 1
data_cats

Unnamed: 0,product_id,categories,num_classes
0,1,"[10010, 10717, 12927]",3
1,2,"[10014, 10052, 11405, 12553]",4
2,3,"[10004, 10025, 10200, 12236]",4
3,4,"[10014, 10116, 10597, 14242]",4
4,5,"[10020, 10044, 12690]",3
...,...,...,...
435647,442784,"[10010, 10541, 14173]",3
435648,442785,"[10013, 10155, 11133, 13142]",4
435649,442786,"[10007, 10162, 14187]",3
435650,442787,"[10018, 10219, 12731]",3


In [13]:
unique_labels = pd.Series(np.unique(data_cats.categories.apply(list)))

labels = set(flatten(unique_labels)).union(set(data.category_id))
outliers = set(data[~data.category_id.isin(labels)].category_id)
branches = set(flatten(unique_labels.apply(lambda x: x[:-1])))
leaves = labels.difference(branches)
num_leaves = len(leaves)
print('Num leaves:', num_leaves)

id_leaf = dict(enumerate(leaves))
leaf_id = {v: k for k, v in id_leaf.items()}

id_branch = dict(enumerate(sorted(branches), start=num_leaves))
branch_id = {v: k for k, v in id_branch.items()}

id_class = {**id_leaf, **id_branch}
class_id = {**leaf_id, **branch_id}
data_cats['labels'] = data_cats.categories.apply(lambda x: [class_id[i] for i in x])
data_cats['last_label'] = data_cats.labels.apply(lambda x: x[-1])
data['category_label'] = data.category_id.map(class_id)

leaf_graph = nx.DiGraph()
leaf_graph.add_nodes_from([class_id[i] for i in labels])
for row in unique_labels:
    for i in range(len(row) - 1):
        for j in range(i + 1, len(row)):
            leaf_graph.add_edge(class_id[row[i]], class_id[row[j]])
torch.save(leaf_graph, 'graph.pth')  # object to pass as a class_graph in HierarchicalCrossEntropyLoss

Num leaves: 2537


In [14]:
data_cats

Unnamed: 0,product_id,categories,num_classes,labels,last_label
0,1,"[10010, 10717, 12927]",3,"[2582, 2865, 1428]",1428
1,2,"[10014, 10052, 11405, 12553]",4,"[2586, 2619, 3014, 1136]",1136
2,3,"[10004, 10025, 10200, 12236]",4,"[2576, 2597, 2736, 872]",872
3,4,"[10014, 10116, 10597, 14242]",4,"[2586, 2670, 2839, 2500]",2500
4,5,"[10020, 10044, 12690]",3,"[2592, 2612, 1242]",1242
...,...,...,...,...,...
435647,442784,"[10010, 10541, 14173]",3,"[2582, 2827, 2448]",2448
435648,442785,"[10013, 10155, 11133, 13142]",4,"[2585, 2703, 2957, 1597]",1597
435649,442786,"[10007, 10162, 14187]",3,"[2579, 2710, 2460]",2460
435650,442787,"[10018, 10219, 12731]",3,"[2590, 2751, 1276]",1276


In [15]:
data

Unnamed: 0,image_key,category_path,category_id,category_label
0,bqtrg37rl2o1m1t69om0,"[10012, 10165, 11239, 11738]",11738,476
1,bq3fgcath9dctj4eavig,"[10017, 10837, 13117]",13117,1577
2,bq3fgfupskc4erbm4hi0,"[10017, 10837, 13117]",13117,1577
3,bq3fgdufrsa0ancfuk7g,"[10017, 10837, 13117]",13117,1577
4,bq3fgbepskc4erbm4hhg,"[10017, 10837, 13117]",13117,1577
...,...,...,...,...
1399499,bu8ucnqth9d4psjf05sg,"[10008, 10204, 10680, 11924]",11924,620
1399500,bu8ucojbkrvvp5btssj0,"[10008, 10204, 10680, 11924]",11924,620
1399501,c005dh9aof0ojr14abvg,"[10008, 10204, 10680, 11924]",11924,620
1399502,bu8ucoefrsad4ed9o1e0,"[10008, 10204, 10680, 11924]",11924,620
