In [1]:
!pip install sentence-transformers -q

In [23]:
BASE_DATASETS_PATH = '../DATASETS/'
EMBEDDINGS_PATH = './EMBEDDINGS/'

baseline_st_models = ["sentence-transformers/all-MiniLM-L6-v2", "allenai/scibert_scivocab_uncased"]

datasets = [
    {
        "dataset_desc": "CSAbstruct",
        "dataset_path": f"{BASE_DATASETS_PATH}CSAbstruct/",
        "file_prefix": "csabstruct",
        "st_models": ["gubartz/st_minilm_abstruct", "gubartz/st_scibert_abstruct"]
    },
    {
        "dataset_desc": "PubMed-RCT",
        "dataset_path": f"{BASE_DATASETS_PATH}PubMed-RCT/",
        "file_prefix": "pubmed_rct",
        "st_models": ["gubartz/st_minilm_pubmed_rct", "gubartz/st_scibert_pubmed_rct"]
    },
    {
        "dataset_desc": "PMC-Sents-FULL",
        "dataset_path": f"{BASE_DATASETS_PATH}PMC-Sents-FULL/",
        "file_prefix": "pmc_sents_full",
        "st_models": ["gubartz/st_minilm_pmc_sents_full", "gubartz/st_scibert_pmc_sents_full"]
    }
]

In [22]:
import os

if os.path.exists(f'{EMBEDDINGS_PATH}embeddings_meta.jsonl'):
    os.remove(f'{EMBEDDINGS_PATH}embeddings_meta.jsonl')

In [26]:
import pandas as pd
from sentence_transformers import SentenceTransformer
from tqdm.auto import tqdm
import pickle
import json

batch_size = 512

pbar_datasets = tqdm(datasets)

for dataset in pbar_datasets:
    pbar_datasets.set_description(dataset['dataset_desc'])
    
    train_dataset = f"{dataset['dataset_path']}{dataset['file_prefix']}_train.parquet"
    test_dataset = f"{dataset['dataset_path']}{dataset['file_prefix']}_test.parquet"   
    
    df_train = pd.read_parquet(train_dataset)
    df_test = pd.read_parquet(test_dataset)
    df_train['sentence'] = df_train['sentence']
    df_test['sentence'] = df_test['sentence']
    
    # sort for less padding
    df_train = df_train.sort_values(by="sentence", key=lambda x: x.str.len())
    df_test = df_test.sort_values(by="sentence", key=lambda x: x.str.len())

    train_sentences = list(df_train['sentence'].values)
    test_sentences = list(df_test['sentence'].values)
    
    y_train_true = list(df_train['label_id'].values)
    y_test_true = list(df_test['label_id'].values)    
    
    st_models = dataset['st_models'] + baseline_st_models
    pbar_st_models = tqdm(st_models, leave=False)
    
    for st_model_id in pbar_st_models:
        pbar_st_models.set_description(st_model_id)
        
        model_name = st_model_id.split("/")[-1]
        
        st_model = SentenceTransformer(st_model_id, device='cuda')
        
        embeddings_train = st_model.encode(train_sentences,
                                           show_progress_bar=True,
                                           batch_size=batch_size)

        embeddings_test = st_model.encode(test_sentences,
                                          show_progress_bar=True,
                                          batch_size=batch_size)
        # dump embeddings
        file_name = f"{dataset['dataset_desc']}_{model_name}.pickle"

        output_dict = {'embeddings_train': embeddings_train,
                       'embeddings_test': embeddings_test,
                       'y_train_true': y_train_true,
                       'y_test_true': y_test_true}
        with open(f"{EMBEDDINGS_PATH}{file_name}", 'wb') as pfile:
            pickle.dump(output_dict, pfile, protocol=pickle.HIGHEST_PROTOCOL)
            
        embedding_meta = {'dataset_desc': dataset['dataset_desc'],
                          'model_name': model_name,
                          'file_name': file_name}
        
        with open(f'{EMBEDDINGS_PATH}embeddings_meta.jsonl', 'a', encoding="utf-8") as outfile:
            outfile.write(json.dumps(embedding_meta))
            outfile.write("\n")

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]