In [9]:
# load dataset
from dataclasses import dataclass, field
from typing import List
import os

task_names = ["ATEC", "BQ", "LCQMC", "PAWSX", "STS-B"]


@dataclass
class Samples:
    sentence_a: List[str] = field(default_factory=list)
    sentence_b: List[str] = field(default_factory=list)
    labels: List[int] = field(default_factory=list)


def load_samples(task_name, load_limit: int = 0):
    path = os.path.join('./data/', task_name)
    train_path = os.path.join(path, f"{task_name}.train.data")
    test_path = os.path.join(path, f"{task_name}.test.data")
    valid_path = os.path.join(path, f"{task_name}.valid.data")
    def load_data(load_path):
        samples = Samples()
        with open(load_path) as f:
            lines = f.readlines()
            if load_limit > 0:
                lines = lines[:load_limit]
            for line in lines:
                s1, s2, l = line.split("\t")
                l = int(l)
                samples.sentence_a.append(s1)
                samples.sentence_b.append(s2)
                samples.labels.append(l)
        return samples
    train_samples = load_data(train_path)
    test_samples = load_data(test_path)
    valid_samples = load_data(valid_path)
    return train_samples, test_samples, valid_samples


task_samples = {}
for task_name in task_names:
    train_samples, test_samples, valid_samples = load_samples(task_name, load_limit=10 ** 7)
    task_samples[task_name] = {"train": train_samples, "test": test_samples, "valid": valid_samples}

In [12]:
# load & test model
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator

model_names = ["stsb-xlm-r-multilingual"]

for model_name in model_names:
    model = SentenceTransformer(model_name)

    for task_name in task_names:
        dtype = "valid"
        samples = task_samples.get(task_name)[dtype]
        evaluator = EmbeddingSimilarityEvaluator(samples.sentence_a, samples.sentence_b, samples.labels, name=f"{model_name}_{task_name}")
        evaluator(model, output_path="./output")



In [25]:
filename

'similarity_evaluation_stsb-xlm-r-multilingual_LCQMC_results'