In [146]:
from transformers import (
    AutoConfig, BertConfig, BertTokenizer, 
    BertForSequenceClassification)
    
import sys   
sys.path.insert(1, "/home/sarahwooders_gmail_com/transformers/")
from src.transformers.modeling_bert import BertForRetrieval
import torch
import numpy as np

from tqdm import tqdm 

In [127]:
ckpt_dir = "/home/sarahwooders_gmail_com/transformers/checkpoints/checkpoint-28000"

In [128]:
config = AutoConfig.from_pretrained(ckpt_dir)
tokenizer = BertTokenizer.from_pretrained(
    ckpt_dir,
    do_lower_case=True,
)

model = BertForRetrieval.from_pretrained(
    ckpt_dir,
    from_tf=False,
    #config=config,
    #cache_dir='/home/sarahwooders_gmail_com/transformers/checkpoints'
)

In [129]:
use_gpu = False
if use_gpu:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")
print(device)

cpu


In [130]:
model = model.to(device)

In [131]:
def embedding_distance(text_a, text_b):

    max_seq_length=128
    test_input_a = tokenizer.batch_encode_plus(
        [(text_a, None)],
        max_length=max_seq_length,
        pad_to_max_length=True,
    )

    test_input_b = tokenizer.batch_encode_plus(
        [(text_b, None)],
        max_length=max_seq_length,
        pad_to_max_length=True,
    )

    model_inp_a = dict(test_input_a)
    model_inp_b = dict(test_input_b)
    model_inp_a['labels'] = [0]
    model_inp_b['labels'] = [0]
    model_inp_a = {k: torch.tensor(v).to(device) for k, v in model_inp_a.items()}
    model_inp_b = {k: torch.tensor(v).to(device) for k, v in model_inp_b.items()}

    model_out = model(
        input_ids_a=model_inp_a['input_ids'], 
        attention_mask_a=model_inp_a['attention_mask'],
        token_type_ids_a=model_inp_a['token_type_ids'],
        input_ids_b=model_inp_b['input_ids'], 
        attention_mask_b=model_inp_b['attention_mask'],
        token_type_ids_b=model_inp_b['token_type_ids']
    )

    title_embedding = model_out[1].detach().cpu().numpy()[0]
    category_embedding = model_out[2].detach().cpu().numpy()[0]

    t = title_embedding/np.linalg.norm(title_embedding)
    c = category_embedding/np.linalg.norm(category_embedding)
    return np.dot(t, c)

# Mini test

In [154]:
test = [
"Juice Del Valle Frut Pet Peach 1 L Juice Del Valle Frut Pet Peach 1 L",
"Elseve Total 5 Extra Professional Serum 15 mL Elseve Total 5 Extra Professional Serum 15 mL",
"Corbalan Canned Mixed Skewer 500 g Corbalan Canned Mixed Skewer 500 g",
"Pedigree Dog Food Puppies Medium and Large Breeds 3 kg With Vitamins and Minerals that help you stay strong and healthy. With Natural Fibers for optimal digestion. With Calcium for growth. With protein for strong muscles. Textured grains that help reduce the formation of tartar and keep your teeth and gums healthy.",
"Kendall 7/8 BC Sock Without Tip M 1322 Kendall 7/8 BC Sock Without Tip M 1322",
"Johnsons Roma Soap 80 g Johnsons Roma Soap 80 g",
"Sensodyne Extra Fresh Rinse 250 ml Sensodyne Extra Fresh Rinse 250 ml",
"Kelloggs Granola And Honey Biscuit 120 Gr Kelloggs Granola And Honey Biscuit 120 Gr",
"Multivitaminico Vita Force AZ 120 Cáps - Voxx Multivitaminico Vita Force AZ 120 Cáps - Voxx",
"Bioext Queravit Megadose 15 mL Bioext Queravit Megadose 15 mL",
"Moo Yogurt Type Skyr Coco Without Lactose Moo Yogurt Type Skyr Coco Without Lactose",
"Beer 600 ml Choose your beer.",
"Bienn 4mg 10 Tablets Bienn 4Mg 10 Tablets",
"Dental Brush B Indicator Nº40 2 U Dental Brush Indicator Oral-B, White Teeth And Healthy Gums. It has bristles with rounded tips and with Indicator Plus system, which indicates the moment of brush change, its handle is rubberized with comfort grip, which provides more safety and comfort during brushing. Enjoy the Promotion Take 2, Pay 1. Product with colors and / or assorted prints. Shipping According to Stock Availability. (Key Words: Toothpaste, Toothpaste, Oral Health, Oral Hygiene, Anticaries, Toothbrush, Tooth Hiding)",
"Nestlé Flan Strawberry Dessert Set 200 g Nestlé Flan Strawberry Dessert Set 200 g (Key Words: Dairy / Morning)",
"Wickbold Scooby Doo Integral Tube 300 g The Wickbold Scooby Doo wholegrain biscuit was made with wholemeal flour especially for those who follow a slimming diet and cannot do without the daily bisnaguinha. Healthy, soft and tasty, ideal for any time of the day.",
"Orthodontic Success Interdental Kit x 4 Units - PLU: 33455",
"Dark Chocolate With Tiramisu Kopenhagen Filling Dark Chocolate With Tiramisu Kopenhagen Filling",
"Naturafrig Bovine Palette Bovine Naturafrig Bulk Palette",
"Palmolive Naturals Soap Secret Seductive 90 g Palmolive Naturals Soap Secret Seductive 90 g. (Key Words: Body Care, Beauty Care, Beauty, Hygiene, Deodorant, Anti Perspirant, Antiperspirant, Soap, Soap, Bar Soap)",
]

categories = [
    "food",
    "health and beauty", 
    "cleaning supplies", 
    "canned foods", 
    "pet supplies", 
    "fashion apparel"
]

In [156]:
for t in test:
    print(t)
    for c in categories:
        print("\t", embedding_distance(t, c), c)
        
    print()

SyntaxError: unexpected EOF while parsing (<ipython-input-156-d585f0080adc>, line 6)

# Rappi data test

In [157]:
import csv

labels = []
samples = []

# took out subcategory or else evaluation takes forever

with open("rappi_sub_categories.csv", newline='') as csvfile:
    reader = csv.DictReader(csvfile, delimiter=',')
    for row in reader:
        #labels.append(row['category'] + ' ' + row['subcategory'])
        labels.append(row['category'])
labels = list(set(labels))

with open("rappi_human.csv", newline='') as csvfile:
    reader = csv.DictReader(csvfile, delimiter=',')
    for row in reader:
        text = row['translated_title'] + ' ' + row['translated_description']
        label = row['category'] #+ ' ' + row['subcategory']
        if label not in labels:
            continue
            
        samples.append((text, label))
        
    

In [152]:
print(len(samples))
print(len(labels))

1125
182


In [153]:
for sample in samples: 
    print(sample)
    predictions = [(embedding_distance(sample[0], l), l) for l in tqdm(labels)]
    predictions = sorted(predictions)
    print(predictions[:5])
    break

  0%|          | 0/182 [00:00<?, ?it/s]

('Japanese Mafer Premium Roasted Peanuts with Lemon 180 g Japanese Mafer Premium Roasted Peanuts with Lemon 180 g - Prices including VAT', 'snacks and confectionery')


100%|██████████| 182/182 [00:42<00:00,  4.28it/s]

[(0.75331056, 'office supplies'), (0.7533928, 'medical equipment and supplies'), (0.75372857, 'school supplies'), (0.75402176, 'electronic supplies'), (0.75913936, 'cards')]



