In [1]:
import pandas as pd
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

In [2]:
# Cargar el tokenizer y el modelo previamente guardados
tokenizer = DistilBertTokenizer.from_pretrained("categorizationModel_1/tokenizer_distilbert")
model = DistilBertForSequenceClassification.from_pretrained("categorizationModel_1/model_distilbert")

In [3]:
# Evaluación del modelo
model.eval()

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
 

In [4]:
# Cargue de datos
prod = pd.read_csv("data/Amazon_data.csv")

In [5]:
def predict_category(title):
    if pd.isnull(title) or title == "Unknown":  # Manejar valores nulos y "unknown"
        return 6
    encoded_input = tokenizer(title, padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
        output = model(**encoded_input)
        predicted_label = torch.argmax(output.logits, dim=1).item()
    return predicted_label

In [6]:
categories = {0: 'Baby', 1: 'Beauty', 2: 'Food', 3: 'Health and personal care', 4: 'Pet supplies', 5: 'Toys and games', 6: "Unknown"}

In [7]:
# Aplicar la función de predicción a la columna Title del conjunto de datos de estudio
prod['Num_Label'] = prod['Title'].apply(predict_category)

In [8]:
prod['Category'] = prod['Num_Label'].map(categories)

In [9]:
prod.sample(20)

Unnamed: 0,Title,Text,Num_Label,Category
116448,Hot Wheels Shark Park,"What can I say. We love Hot Wheels cars, and t...",5,Toys and games
140769,Alpha Hydrox Foaming Face Wash -- 6 fl oz,I had high hopes for this product because of a...,1,Beauty
64471,Slim Jim Snack Sticks,I LOVE SLIM JIMS. THEY TASTE GREAT AND THEY AR...,2,Food
25892,Body Back Buddy,I purchased this product specifically to targe...,3,Health and personal care
122231,IGIA AT956 Epil Sport for Men Hair Removal System,"If you don't have a sensitive skin, I guess th...",3,Health and personal care
49663,GLOW For Women By J. LO Eau De Toilette Spray,I get several compliments a day when wearing t...,1,Beauty
40775,"Home Health Almond Glow Lotion Unscented,8 ounces",I actually started using this as lubricant at ...,3,Health and personal care
38454,Fisher-Price Sensory Selections Bouncer,our 6 month likes this item but get it jumping...,0,Baby
14390,Teenage Mutant Ninja Turtles - Turtle Playset,I hate this thing!! It NEVER stays together an...,5,Toys and games
26728,Graco Bumper Jumper,"This jumper has great, sturdy construction, is...",0,Baby


In [10]:
#Confirmación de resultados
prod.groupby('Category').describe()

Unnamed: 0_level_0,Num_Label,Num_Label,Num_Label,Num_Label,Num_Label,Num_Label,Num_Label,Num_Label
Unnamed: 0_level_1,count,mean,std,min,25%,50%,75%,max
Category,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
Baby,18494.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Beauty,23485.0,1.0,0.0,1.0,1.0,1.0,1.0,1.0
Food,13791.0,2.0,0.0,2.0,2.0,2.0,2.0,2.0
Health and personal care,38550.0,3.0,0.0,3.0,3.0,3.0,3.0,3.0
Pet supplies,18480.0,4.0,0.0,4.0,4.0,4.0,4.0,4.0
Toys and games,37155.0,5.0,0.0,5.0,5.0,5.0,5.0,5.0
Unknown,45.0,6.0,0.0,6.0,6.0,6.0,6.0,6.0


In [13]:
#Guardar el nuevo dataser con predicción de categorías
prod.to_csv('result/AmazonData_Categories_DistrilBERT.csv', index=False)