In [1]:
import sys
import os
import torch
import itertools
import numpy as np
import pandas as pd
from collections import OrderedDict

sys.path.append("../../")
from utils_nlp.models.bert.common import Language, Tokenizer
from utils_nlp.models.bert.sequence_encoding import BERTSentenceEncoder, PoolingStrategy
from utils_nlp.eval.senteval2 import SentEvalConfig, ExperimentRunner

In [2]:
# device config
NUM_GPUS = 0

# model config
LANGUAGE = Language.ENGLISH
TO_LOWER = True
MAX_SEQ_LENGTH = 128
# LAYER_INDEX = -2
# POOLING_STRATEGY = PoolingStrategy.MEAN

# path config
CACHE_DIR = "./temp"
PATH_TO_SENTEVAL = "../../../SentEval"

In [3]:
if not os.path.exists(CACHE_DIR):
    os.makedirs(CACHE_DIR, exist_ok=True)

In [4]:
se = BERTSentenceEncoder(
    language=LANGUAGE,
    num_gpus=NUM_GPUS,
    cache_dir=CACHE_DIR,
    to_lower=TO_LOWER,
    max_len=MAX_SEQ_LENGTH,
)

In [5]:
def prepare(params, samples):
    sentences = [" ".join(s).lower() for s in samples]
    params["embeddings"] = params["model"].encode(
        sentences,
        batch_size=params["batch_size"],
        as_numpy=False,
    )
    params["sentence2idx"] = collections.OrderedDict(
        list(zip(sentences, range(len(sentences))))
    )
    return

def batcher(params, batch):
    sentences = [" ".join(s).lower() for s in batch]
    sentence_indices = [params["sentence2idx"][s] for s in sentences]

    df = params["embeddings"]
    embeddings = []
    for i in sentence_indices:
        values = np.squeeze(
            df.loc[
                (df["text_index"] == i) & (df["layer_index"] == params["layer_index"])
            ]["values"].values
        ).tolist()
        embeddings.append(values)
    embeddings = np.array(embeddings)
    return embeddings

In [6]:
senteval_config = SentEvalConfig(
    path_to_senteval=PATH_TO_SENTEVAL,
    model=se,
    prepare_func=prepare,
    batcher_func=batcher,
    transfer_tasks=["STSBenchmark"],
)

In [7]:
experiment_parameters = {"layer_index": [-1, -2], "pooling_strategy": [PoolingStrategy.MEAN, PoolingStrategy.MAX]}

In [8]:
er = ExperimentRunner(senteval_config=senteval_config, experiment_parameters=experiment_parameters)