In [None]:
from sentence_transformers import SentenceTransformer, InputExample
from sentence_transformers.cross_encoder import CrossEncoder
import gzip
import json
import random
from torch.utils.data import Dataset, IterableDataset
import logging
from torch.utils.data import DataLoader
from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator
import math
from sentence_transformers import LoggingHandler, util
import torch.nn

#### Data from https://huggingface.co/datasets/sentence-transformers/embedding-training-data

In [2]:
pairs=["gooaq_pairs.jsonl.gz"]

triplets=["quora_duplicates_triplets.jsonl.gz", "AllNLI.jsonl.gz", "specter_train_triples.jsonl.gz", 
       "msmarco-triplets.jsonl.gz"]

path="/Users/g.salazar.2/Downloads/trec_data"


In [None]:
pairs_files = [path+"/"+x for x in pairs]
triplets_files = [path+"/"+x for x in triplets]        

In [None]:

def get_triplet_example(raw_example):
    if isinstance(raw_example, dict):
        query = raw_example['query']
        pos = raw_example['pos']
        neg = raw_example['neg']
        examples=[]
        pos_example = InputExample(texts=[query, pos], label=3)
        neg_examples = [InputExample(texts=[query, x], label=0) for x in neg]
        neg_examples.append(pos_example)
        return neg_examples
    else:
        query = raw_example[0]
        pos = raw_example[1]
        neg = raw_example[2]
        pos_example = InputExample(texts=[query, pos], label=3)
        neg_example = InputExample(texts=[query, neg], label=0)
        return [pos_example, neg_example]
    
def get_pair_example(raw_example):
    return InputExample(texts=[raw_example[0], raw_example[1]], label=3)



In [None]:
def load_pair_dataset(filepath):
    examples=[]
    with gzip.open(filepath, 'rt') as fIn:
            for line in fIn:
                example = get_pair_example(json.loads(line))
                examples.append(example)
    return examples

def load_triplet_dataset(filepath):
    examples=[]
    with gzip.open(filepath, 'rt') as fIn:
            for line in fIn:
                example = get_triplet_example(json.loads(line))
                examples.extend(example)
    return examples

In [None]:
full_set=[]
for file in pairs_files:
    print("Processing "+file)
    full_set.extend(load_pair_dataset(file))
    
for file in triplets_files:
    print("Processing "+file)
    full_set.extend(load_triplet_dataset(file))


### Data from Amazon ESCI

In [None]:
import pandas as pd
df_examples = pd.read_parquet(path+"/amazon_shopping_queries_dataset_examples.parquet")
df_products = pd.read_parquet(path+"/amazon_shopping_queries_dataset_products.parquet")
df_sources = pd.read_csv(path+"/amazon_shopping_queries_dataset_sources.csv")

In [5]:
df_examples_products = pd.merge(
    df_examples,
    df_products,
    how='left',
    left_on=['product_locale','product_id'],
    right_on=['product_locale', 'product_id']
)

In [9]:
df_task_1 = df_examples_products[df_examples_products["large_version"] == 1]
df_task_1_train = df_task_1[df_task_1["split"] == "train"]
df_task_1_test = df_task_1[df_task_1["split"] == "test"]

In [10]:
df_task_1_train["query", "product_title", "product_description"]

Unnamed: 0,example_id,query,query_id,product_id,product_locale,esci_label,small_version,large_version,split,product_title,product_description,product_bullet_point,product_brand,product_color
0,0,revent 80 cfm,0,B000MOO21W,us,I,0,1,train,Panasonic FV-20VQ3 WhisperCeiling 190 CFM Ceil...,,WhisperCeiling fans feature a totally enclosed...,Panasonic,White
1,1,revent 80 cfm,0,B07X3Y6B1V,us,E,0,1,train,Homewerks 7141-80 Bathroom Fan Integrated LED ...,,OUTSTANDING PERFORMANCE: This Homewerk's bath ...,Homewerks,80 CFM
2,2,revent 80 cfm,0,B07WDM7MQQ,us,E,0,1,train,Homewerks 7140-80 Bathroom Fan Ceiling Mount E...,,OUTSTANDING PERFORMANCE: This Homewerk's bath ...,Homewerks,White
3,3,revent 80 cfm,0,B07RH6Z8KW,us,E,0,1,train,Delta Electronics RAD80L BreezRadiance 80 CFM ...,This pre-owned or refurbished product has been...,Quiet operation at 1.5 sones\nBuilt-in thermos...,DELTA ELECTRONICS (AMERICAS) LTD.,White
4,4,revent 80 cfm,0,B07QJ7WYFQ,us,E,0,1,train,Panasonic FV-08VRE2 Ventilation Fan with Reces...,,The design solution for Fan/light combinations...,Panasonic,White
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2621267,2621267,ﾚﾃﾞｨｰｽ水着,130650,B07GR94N75,jp,E,0,1,train,レディース 水着 オーバーウェア ビキニ セパレーツ 無地 二点セット 海水浴 水泳 温泉 ...,,★メイン素材：90%Polyester、10%spandex\n★人気の女性用セクシーワイヤ...,kayiyasu,ブラック
2621268,2621268,ﾚﾃﾞｨｰｽ水着,130650,B0769J1VB8,jp,E,0,1,train,AOIF 競泳水着 高級 レディース フィットネス水着 めくれ防止 2091,AOIF 競泳水着 高級 レディース フィットネス水着 めくれ防止 2091,カラフルな水着から落ち着いた色の水着まで選べる楽しさ。抜群の耐久性とフィット感・長く使える安...,AOIF LLMY,ブラック
2621269,2621269,ﾚﾃﾞｨｰｽ水着,130650,B06XD8J9MG,jp,E,0,1,train,Tahoe レディース ノースリーブの水着 めくれ防止 フィットネス水着 キャップ付 フロン...,,体を優しく包みこむようにフィットし、長時間でも心地よくリラックスして着用できます。\n丈夫で...,Tahoe,グレー
2621270,2621270,ﾚﾃﾞｨｰｽ水着,130650,B01KZB82WQ,jp,E,0,1,train,LE MODE de toi(ル モード デ トア) レディース フィットネス水着 セパレー...,【▼ サイズ ※平置きの測定です。生地は伸縮性があるため多少前後します。 】<BR> 7S：...,フラットシーマ使用\n紫外線からお肌をしっかりガードUPF50＋\nキャップ付 パット付 U...,LE MODE de toi(ル モード デ トア),【めくれ防止】ブラック＆ピンク


In [None]:
from sklearn.model_selection import train_test_split
(train_set, test_set) = train_test_split(full_set, test_size=0.33, random_state=42, shuffle=False)

In [None]:
len(train_set), len(test_set)

In [None]:

logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

train_batch_size = 4
num_epochs = 1
num_labels = 4
max_length = 512
evaluation_steps = 100
lr = 7e-6

train_dataloader = DataLoader(train_set, shuffle=True, batch_size=train_batch_size)
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #10% of train data for warm-up

default_activation_function = torch.nn.Identity()

model = CrossEncoder('microsoft/deberta-v3-large', num_labels=num_labels, 
                     tokenizer_args={'pad_token': '[PAD]'}, 
                     default_activation_function=default_activation_function)

evaluator = CERerankingEvaluator(test_set, name='train-eval')


logger.info("Warmup-steps: {}".format(warmup_steps))
loss_fct=torch.nn.L1Loss()

model.config.pad_token_id = model.tokenizer.pad_token_id

model.fit(train_dataloader=train_dataloader,
          evaluator=evaluator,
          epochs=num_epochs,
          evaluation_steps=evaluation_steps,
          warmup_steps=warmup_steps,
          optimizer_params={'lr': lr},
          output_path="model_saved")

