In [30]:
import json

import numpy as np
import pandas as pd

from retrain_bert import settings
from retrain_bert.preprocessor import get_labels_conf, load_labels

In [4]:
preds = pd.read_csv(settings.DATA_DIR / 'preds' / 'BERT_Missing_Category__predictions.csv')

In [23]:
preds.head()

Unnamed: 0,OcrValue,OcrValueId,level_0,level_1,level_2,level_3,level_4
0,1 KAJAL LAPIE 0308,60897720,"[0.12473943829536438, 0.5767314434051514, 0.11...","[0.17285484075546265, 0.16256852447986603, 0.0...","[0.10994251072406769, 0.02743549272418022, 0.0...","[0.5283539295196533, 0.07703693956136703, 0.05...","[0.17845898866653442, 0.13360297679901123, 0.0..."
1,GREFUTUBO BBQ125 GR,3279063,"[0.9956769347190857, 0.003556867828592658, 0.0...","[0.9858474135398865, 0.000280908978311345, 7.8...","[5.924932793277549e-06, 0.004577833227813244, ...","[0.3125729560852051, 0.5827136635780334, 0.013...","[0.30687618255615234, 0.06241997331380844, 0.0..."
2,CEXAC SUAVE 1L 4X,64656851,"[0.8442399501800537, 0.1554601490497589, 0.000...","[0.056602928787469864, 0.02390812523663044, 0....","[0.056338243186473846, 0.7022128105163574, 0.0...","[0.19311988353729248, 0.08673679083585739, 0.0...","[0.37947818636894226, 0.1729581505060196, 0.10..."
3,BASE % IVA IMPORTE TOTAL,5179620,"[0.9884185194969177, 0.0067498586140573025, 0....","[0.9979379773139954, 0.00019949785200878978, 0...","[0.02187095768749714, 0.002626574831083417, 0....","[0.963013231754303, 0.018383320420980453, 0.01...","[0.7434091567993164, 0.15277265012264252, 0.03..."
4,MANZANA GOLDEN A,583912,"[0.00801509153097868, 0.9791175127029419, 0.00...","[0.03212112933397293, 0.0021935675758868456, 0...","[0.5109317302703857, 0.22153709828853607, 0.04...","[0.0943668931722641, 0.003984348848462105, 0.0...","[0.5845780968666077, 0.028550414368510246, 0.0..."


In [26]:
def to_cats(row):
    return row.apply(lambda x: np.argmax(json.loads(x)))

cats = preds.drop(columns=["OcrValue", "OcrValueId"]).apply(to_cats, axis=0)
cats

Unnamed: 0,level_0,level_1,level_2,level_3,level_4
0,1,4,12,0,19
1,0,0,18,1,19
2,0,3,1,0,0
3,0,0,3,0,0
4,1,5,0,3,0
...,...,...,...,...,...
4341607,0,3,1,9,2
4341608,3,3,5,3,0
4341609,2,0,5,3,19
4341610,0,0,3,0,0


In [31]:
labels = load_labels()
labels_conf = get_labels_conf(labels)

for col, level_conf in zip(cats.columns, labels_conf):
    cats.loc[cats[col] == level_conf['num_classes'] - 1, col] = pd.NA
cats.isna().mean()

level_0    0.000039
level_1    0.000918
level_2    0.000911
level_3    0.012034
level_4    0.521351
dtype: float64

In [27]:
def to_probas(row):
    return row.apply(lambda x: np.max(json.loads(x)))

probas = preds.drop(columns=["OcrValue", "OcrValueId"]).apply(to_probas, axis=0)
probas

Unnamed: 0,level_0,level_1,level_2,level_3,level_4
0,0.576731,0.213502,0.145851,0.528354,0.352492
1,0.995677,0.985847,0.907150,0.582714,0.512136
2,0.844240,0.760885,0.702213,0.193120,0.379478
3,0.988419,0.997938,0.955073,0.963013,0.743409
4,0.979118,0.860438,0.510932,0.683915,0.584578
...,...,...,...,...,...
4341607,0.999600,0.995405,0.973583,0.565775,0.720419
4341608,0.999611,0.670222,0.465100,0.998489,0.965914
4341609,0.997169,0.990682,0.982648,0.775127,0.590813
4341610,0.998877,0.999959,0.589314,0.999938,0.996974


In [32]:
useful_threshold = 0.5

useful = cats.copy()
useful[~(probas > useful_threshold)] = pd.NA

useful.dropna(how="all", axis=0, inplace=True)

useful.isna().mean()

level_0    0.041563
level_1    0.220014
level_2    0.404629
level_3    0.325691
level_4    0.627277
dtype: float64

In [33]:
na_mask = np.zeros(len(useful), dtype=bool)
for col in useful.columns:
    useful.loc[na_mask, col] = np.nan
    na_mask = useful[col].isna().values

useful.isna().mean()

level_0    0.041563
level_1    0.243724
level_2    0.453287
level_3    0.552801
level_4    0.733788
dtype: float64

In [35]:
def to_category_code(row):
    row = row.dropna()
    row += 1
    return "".join(f"{int(x):02d}" for x in row)

codes = useful.apply(to_category_code, axis=1)
codes

0                  02
1            01011902
2              010402
3          0101040101
4          0206010401
              ...    
4341607    0104021003
4341608          0404
4341609      03010604
4341610    0101040101
4341611            02
Length: 4098415, dtype: object

In [37]:
codes.name = "CategoryCode"

In [38]:
category_codes = pd.merge(preds["OcrValueId"], codes, left_index=True, right_index=True, how="inner")

In [None]:
category_codes.to_csv(settings.DATA_DIR / 'preds' / 'BERT_Missing_Category__category_codes.csv', index=False)