In [None]:
import torch

torch.cuda.is_available()

In [None]:
jsts_url = "https://raw.githubusercontent.com/yahoojapan/JGLUE/main/datasets/jsts-v1.1/valid-v1.1.json"
jsick_url = "https://github.com/verypluming/JSICK/raw/main/jsick/test.tsv"
miracle_n_hard_negs = 300
miracle_n_recall = 30
sts_prefix = ""
retrieve_query_prefix = ""
retrieve_passage_prefix = ""

# Model

In [None]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer(model_id)
model.max_seq_length = 512

# JSTS

In [None]:
import json
import pandas as pd
from urllib.request import urlopen

df = pd.DataFrame([json.loads(line) for line in urlopen(jsts_url).readlines()])
df.head(1)

In [None]:
df.shape

## Encode

In [None]:
sentence1_embs = model.encode(sts_prefix + df["sentence1"])
sentence2_embs = model.encode(sts_prefix + df["sentence2"])
sentence1_embs.shape, sentence2_embs.shape

## Correlation Score

In [None]:
from scipy.spatial.distance import cosine, euclidean
from scipy.stats import spearmanr

df["similarity"] = [
    1 - cosine(s1, s2) for s1, s2 in zip(sentence1_embs, sentence2_embs)
]
jsts_score = spearmanr(df["similarity"], df["label"])[0]
jsts_score

# JSICK

In [None]:
df = pd.read_csv(jsick_url, sep="\t")
df.head(1)

In [None]:
df.shape

## Encode

In [None]:
sentence1_embs = model.encode(sts_prefix + df["sentence_A_Ja"])
sentence2_embs = model.encode(sts_prefix + df["sentence_B_Ja"])
sentence1_embs.shape, sentence2_embs.shape

## Correlation Score

In [None]:
from scipy.spatial.distance import cosine
from scipy.stats import spearmanr

df["similarity"] = [
    1 - cosine(s1, s2) for s1, s2 in zip(sentence1_embs, sentence2_embs)
]
jsick_score = spearmanr(df["similarity"], df["relatedness_score_Ja"])[0]
jsick_score

# Miracle
* Need access token for huggingface

In [None]:
import os
import dotenv

dotenv.load_dotenv("huggingface_access_token", override=True)

In [None]:
import datasets

# query and positives
ds = datasets.load_dataset(
    "miracl/miracl", "ja", use_auth_token=os.environ["HF_ACCESS_TOKEN"], split="dev"
)
ds

In [None]:
# all corpus texts
corpus = datasets.load_dataset("miracl/miracl-corpus", "ja")
corpus

In [None]:
# hard negatives
with open("./miracl_hard_negs_1000.json") as f:
    hn = json.loads(f.read())
len(hn), list(hn.keys())[:5], hn["0"].keys(), hn["0"]["docids"][:2], hn["0"]["indices"][
    :2
]

In [None]:
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist


def get_text(corpus_item):
    return corpus_item["title"] + " " + corpus_item["text"]


corpus_dict = {item["docid"]: get_text(item) for item in corpus["train"]}

n_total_pos = 0
n_total_tp = 0

for item in ds:
    # query
    query_emb = model.encode([retrieve_query_prefix + item["query"]])

    # passages are set(300 hard negatives + positives)
    positive_docids = [pp["docid"] for pp in item["positive_passages"]]
    positive_texts = [get_text(pp) for pp in item["positive_passages"]]
    hn_docids = hn[item["query_id"]]["docids"][:miracle_n_hard_negs]

    # drop hard negatives in positives
    hn_docids = [docid for docid in hn_docids if docid not in positive_docids]

    # search target
    target_docids = positive_docids + hn_docids
    target_texts = positive_texts + [corpus_dict[docid] for docid in hn_docids]

    # embedding
    target_embs = model.encode(
        [retrieve_passage_prefix + text for text in target_texts]
    )

    # topK
    topk_indices = np.argsort(cdist(query_emb, target_embs, metric="cosine"))[0][
        :miracle_n_recall
    ]

    n_pos = len(positive_docids)
    n_tp = len(
        set(topk_indices) & set(range(len(positive_docids)))
    )  # positives are first indices

    n_total_pos += n_pos
    n_total_tp += n_tp

    # if n_pos > n_tp:
    # print(f"{item['query_id']}:{n_tp}/{n_pos}", end=", ")

miracl_recall = n_total_tp / n_total_pos

n_total_pos, n_total_tp, miracl_recall

# Output

In [None]:
model_id, jsts_score, jsick_score, miracl_recall

In [None]:
import json

with open(f'./scores/{model_id.replace("/", "_")}.txt', "w") as f:
    f.write(
        json.dumps(
            {
                "model_id": model_id,
                "jsts": jsts_score,
                "jsick": jsick_score,
                "miracl": miracl_recall,
            }
        )
    )