In [None]:
from sentence_transformers import SentenceTransformer, InputExample
from sentence_transformers.cross_encoder import CrossEncoder
import gzip
import json
import random
from torch.utils.data import Dataset, IterableDataset
import logging
from torch.utils.data import DataLoader
from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator
import math
from sentence_transformers import LoggingHandler, util
import torch.nn

### Data from https://huggingface.co/datasets/sentence-transformers/embedding-training-data

In [None]:
pairs=["gooaq_pairs.jsonl.gz"]

triplets=["quora_duplicates_triplets.jsonl.gz", "AllNLI.jsonl.gz", "specter_train_triples.jsonl.gz", 
       "msmarco-triplets.jsonl.gz"]

path="/Users/g.salazar.2/Downloads/trec_data"

st_path = "/sentence_transformers_embedding_data"
esci_path = "/esci"
trec_path = "/trec"


In [None]:
st_pairs_files = [path+st_path+"/"+x for x in pairs]
st_triplets_files = [path+st_path+"/"+x for x in triplets]        

In [None]:

def get_triplet_example(raw_example):
    if isinstance(raw_example, dict):
        query = raw_example['query']
        pos = raw_example['pos']
        neg = raw_example['neg']
        examples=[]
        pos_example = InputExample(texts=[query, pos], label=3)
        neg_examples = [InputExample(texts=[query, x], label=0) for x in neg]
        neg_examples.append(pos_example)
        return neg_examples
    else:
        query = raw_example[0]
        pos = raw_example[1]
        neg = raw_example[2]
        pos_example = InputExample(texts=[query, pos], label=3)
        neg_example = InputExample(texts=[query, neg], label=0)
        return [pos_example, neg_example]
    
def get_pair_example(raw_example):
    return InputExample(texts=[raw_example[0], raw_example[1]], label=3)



In [None]:
def load_pair_dataset(filepath):
    examples=[]
    with gzip.open(filepath, 'rt') as fIn:
            for line in fIn:
                example = get_pair_example(json.loads(line))
                examples.append(example)
    return examples

def load_triplet_dataset(filepath):
    examples=[]
    with gzip.open(filepath, 'rt') as fIn:
            for line in fIn:
                example = get_triplet_example(json.loads(line))
                examples.extend(example)
    return examples

In [None]:
full_set=[]
for file in st_pairs_files:
    print("Processing "+file)
    full_set.extend(load_pair_dataset(file))
    
for file in st_triplets_files:
    print("Processing "+file)
    full_set.extend(load_triplet_dataset(file))


### Dataset size

In [None]:
len(full_set)

### Data from Amazon ESCI

In [None]:
import pandas as pd
df_examples = pd.read_parquet(path+esci_path+"/shopping_queries_dataset_examples.parquet")
df_products = pd.read_parquet(path+esci_path+"/shopping_queries_dataset_products.parquet")
df_sources = pd.read_csv(path+esci_path+"/shopping_queries_dataset_sources.csv")

In [None]:
df_examples_products = pd.merge(
    df_examples,
    df_products,
    how='left',
    left_on=['product_locale','product_id'],
    right_on=['product_locale', 'product_id']
)

In [None]:
df_task_1 = df_examples_products[df_examples_products["large_version"] == 1]
df_task_1_train = df_task_1[df_task_1["split"] == "train"]
df_task_1_test = df_task_1[df_task_1["split"] == "test"]

In [None]:
df_task_1_train

In [None]:
esci_western_products = df_task_1_train[df_task_1_train["product_locale"]!="jp"][["query", "product_title", "esci_label"]]

In [None]:
non_western_queries = ['자전거트레일러', '골프공', '가마솥', '茶叶', '肽', '睡衣 女', '眼镜框', 'земфира', 'кроссовки', 'مبخرة', 'محفظه رجاليه']
clean_esci_phase1 = esci_western_products[~esci_western_products['query'].isin(non_western_queries)]

In [None]:
esci_labels = {"E": 3, "S": 2, "C": 1, "I": 0}
esci_numeric_labels = clean_esci_phase1.replace({'esci_label': esci_labels})

esci_numeric_labels

In [None]:
def get_esci_pair_example(row):
    return InputExample(texts=[row['query'], row['product_title']], label=row['esci_label'])
    
train_esci_data = esci_numeric_labels.apply(lambda row: get_esci_pair_example(row), axis=1)

full_set.extend(train_esci_data)

### Dataset size

In [None]:
len(full_set)

In [None]:

logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

train_batch_size = 4
num_epochs = 1
num_labels = 4
max_length = 512
evaluation_steps = 100
lr = 7e-6

train_dataloader = DataLoader(train_set, shuffle=True, batch_size=train_batch_size)
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #10% of train data for warm-up

default_activation_function = torch.nn.Identity()

model = CrossEncoder('microsoft/deberta-v3-large', num_labels=num_labels, 
                     tokenizer_args={'pad_token': '[PAD]'}, 
                     default_activation_function=default_activation_function)

evaluator = CERerankingEvaluator(test_set, name='train-eval')


logger.info("Warmup-steps: {}".format(warmup_steps))
loss_fct=torch.nn.L1Loss()

model.config.pad_token_id = model.tokenizer.pad_token_id

model.fit(train_dataloader=train_dataloader,
          evaluator=evaluator,
          epochs=num_epochs,
          evaluation_steps=evaluation_steps,
          warmup_steps=warmup_steps,
          optimizer_params={'lr': lr},
          output_path="model_saved")



In [None]:
from sklearn.model_selection import train_test_split
(train_set, test_set) = train_test_split(full_set, test_size=0.33, random_state=42, shuffle=False)