# Predicción de categoria de productos

### Links

- [Text Classification with BERT in PyTorch](https://towardsdatascience.com/text-classification-with-bert-in-pytorch-887965e5820f)
- [Best models](https://huggingface.co/models)

In [3]:
%load_ext autoreload
%autoreload 2

In [159]:
import sys
sys.path.append('../../src')

from matplotlib import pyplot as plt
import seaborn as sns

import numpy  as np
import pandas as pd

import torch
from   torch import nn
from   torch.optim import Adam
    
import logging

import random

import data  as dt
import model as ml
import util  as ut
import pytorch_common.util as pu

## Setup

Se configura el default logger para que la consola sea el output y loguee mensajes a partir del nivel INFO.

In [160]:
pu.LoggerBuilder().on_console().build()

Por defecto usamos GPU. De no existir este hardware, el fallback es CPU:

In [161]:
torch.cuda.is_available()

In [162]:
torch.__version__

In [163]:
pu.set_device_name('gpu')

logging.info(pu.get_device())

2022-09-19 12:34:41,852 - INFO - cuda:0
2022-09-19 12:34:41,852 - INFO - cuda:0


In [164]:
torch.cuda.get_arch_list()

## Funciones Helper

In [165]:
def set_seed(value):
    random.seed(value)
    np.random.seed(value)
    torch.manual_seed(value)

## Parametros

Fijamos la semilla para todas las librerias:

In [166]:
set_seed(42)

Definimos el modelo BERT preentrenado a utilizar. Este es una capa/module mas de nuestro modelo.

In [167]:
BERT_MODEL ='bert-base-cased'
# BERT_MODEL ='distilbert-base-cased'
# BERT_MODEL ='distilbert-base-uncased'
# BERT_MODEL ='albert-base-v2'

Definimos los paths del dataset y pesos del modelo:

In [168]:
DATASET_PATH       = '../../datasets/fashion-outfits'
IMAGES_PATH        = '../../datasets/fashion-outfits/images'
WEIGHTS_PATH       = '../../weights'
MODEL_WEIGHTS_PATH = f'{WEIGHTS_PATH}/model_weights.h5'

In [169]:
TRAIN = False

In [170]:
!mkdir -p {WEIGHTS_PATH}

## Cargar dataset

In [171]:
all_set   = pd.read_csv(f'{DATASET_PATH}/all.csv')
train_set = pd.read_csv(f'{DATASET_PATH}/train.csv')
val_set   = pd.read_csv(f'{DATASET_PATH}/val.csv')
test_set  = pd.read_csv(f'{DATASET_PATH}/test.csv')

Nota: Por alguna randon tenemos algun valor lon en la description. Revisar?

In [172]:
all_set.head()

Unnamed: 0,id,family,category,sub_category,description,highlights,brand,gender,materials,branch,tokens_count,image_uri
0,17073270,clothing,knitwear,cardi-coats,x The Beatles intarsia-knit coat,"[black/white, organic cotton, mix print, intar...",stella mccartney,unisex,['Cotton'],clothing~knitwear~cardi-coats,5,17/07/32/70/17073270.jpg
1,17674562,clothing,knitwear,cardi-coats,cashmere-blend long belted cardigan,"[plum purple, cashmere blend, wrap design, sle...",extreme cashmere,unisex,['Spandex/Elastane' 'Nylon' 'Cashmere'],clothing~knitwear~cardi-coats,4,17/67/45/62/17674562.jpg
2,17678603,clothing,knitwear,cardi-coats,cashmere-blend long belted cardigan,"[blue, cashmere blend, wrap design, sleeveless...",extreme cashmere,unisex,['Spandex/Elastane' 'Cashmere' 'Nylon'],clothing~knitwear~cardi-coats,4,17/67/86/03/17678603.jpg
3,17179699,clothing,knitwear,cardi-coats,long cashmere cardigan,"[light pink, stretch-cashmere blend, fine knit...",extreme cashmere,unisex,['Nylon' 'Spandex/Elastane' 'Cashmere'],clothing~knitwear~cardi-coats,3,17/17/96/99/17179699.jpg
4,15907453,clothing,sweaters & knitwear,cardigans,tie-dye print cashmere cardigan,"[multicolour, cashmere, tie-dye print, knitted...",the elder statesman,men,['Cashmere'],clothing~sweaters & knitwear~cardigans,4,15/90/74/53/15907453.jpg


In [173]:
train_set['description'] = train_set['description'].apply(str)
val_set  ['description'] = val_set  ['description'].apply(str)
test_set ['description'] = test_set ['description'].apply(str)

Tamaño máximo de la secuencia de entrada:

In [174]:
max_length = 2 + train_set['tokens_count'].max()
max_length

Definimos el tokenizer y los dataset para tran, validation y test:

In [175]:
tokenizer     = ml.Tokenizer(BERT_MODEL, padding = 'max_length', max_length = max_length)

In [None]:
train_dataset = dt.BertDataset(train_set, feature_col = 'description', target_col = 'branch_seq', tokenizer = tokenizer)
val_dataset   = dt.BertDataset(val_set,   feature_col = 'description', target_col = 'branch_seq', tokenizer = tokenizer)
test_dataset  = dt.BertDataset(test_set,  feature_col = 'description', target_col = 'branch_seq', tokenizer = tokenizer)

## Preparando el modelo

Cantidad de clases a predecir:

In [None]:
len(train_set['branch_seq'].unique())

In [None]:
n_classes = train_set['branch_seq'].max() + 1
n_classes

In [179]:
classifier = ml.BertClassifier(output_dim = n_classes, model = BERT_MODEL, dropout = 0)

## Entrenamiento

In [180]:
model = ml.BertModel(classifier, batch_size = 70, criterion = nn.CrossEntropyLoss())

In [181]:
LR     = 0.00001
EPOCHS = 5

if TRAIN:
    model.fit(
        train_dataset, 
        val_dataset = val_dataset,
        optimizer   = Adam(classifier.parameters(), lr = LR),
        epochs      = EPOCHS
    )

In [182]:
if TRAIN:
    classifier.save(MODEL_WEIGHTS_PATH)

## Evaluación

In [183]:
classifier.load(MODEL_WEIGHTS_PATH)

In [184]:
summary = model.validate(test_dataset)

In [185]:
summary.show()

Accuracy: 75.24%, Loss: 0.012566


In [None]:
# summary.show_sample_metrics(0)
# summary.show_sample_metrics(1)
# summary.show_metrics()

In [194]:
report_generator = ml.FailReportGenerator(tokenizer, all_set, test_set, test_dataset, summary.targets, summary.predictions, IMAGES_PATH)

report = report_generator()
report.to_csv(f'{DATASET_PATH}/bet-model-fail-report.csv', index=False)

Total Fails: 0.25%


Unnamed: 0,id,description,true_class,true_image,pred_class,pred_image
0,17783943,Portofino lace - up sneakers,shoes~trainers~n/d,,shoes~trainers~low-tops,
1,17116415,balloon - sleeve knitted jumper,clothing~knitwear~jumpers,,clothing~sweaters & knitwear~jumpers,
2,16509543,Emoji - print track shorts,clothing~shorts~short shorts,,clothing~shorts~track & running shorts,
3,18206712,Est. 2009 logo T - shirt,clothing~t-shirts & vests~t-shirts,,pre-owned~tops~n/d,
4,16942906,box - pleat wide - leg trousers,clothing~trousers~wide-leg trousers,,clothing~trousers~high-waisted trousers,
5,16509543,side logo - print shorts,clothing~shorts~short shorts,,clothing~shorts~bermuda shorts,
6,16509543,all - over star - print shorts,clothing~shorts~short shorts,,clothing~shorts~bermuda shorts,
7,16653345,reversible lightweight windbreaker,clothing~jackets~lightweight jackets,,clothing~jackets~sport jackets & windbreakers,
8,16724249,crystal - embellished maxi dress,clothing~dresses~evening dresses,,clothing~dresses~cocktail & party dresses,
9,16425361,plissé - effect open - front jacket,clothing~jackets~fitted jackets,,clothing~jackets~lightweight jackets,
