In [1]:
import torch

torch.cuda.is_available()

True

In [2]:
jsts_url = "https://raw.githubusercontent.com/yahoojapan/JGLUE/main/datasets/jsts-v1.1/valid-v1.1.json"
jsick_url = "https://github.com/verypluming/JSICK/raw/main/jsick/test.tsv"
miracle_n_hard_negs = 300
miracle_n_recall = 30

In [3]:
# Parameters
model_id = "oshizo/japanese-e5-mistral-1.9b"

sts_task_description = "Retrieve semantically similar text. "
search_task_description = (
    "Given a question, retrieve Wikipedia passages that answer the question. "
)

# Model

In [4]:
import torch
import torch.nn.functional as F

from torch import Tensor


def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[
            torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths
        ]


def get_detailed_instruct(task_description: str, query: str) -> str:
    return f"Instruct: {task_description}\nQuery: {query}"

In [5]:
from transformers import AutoTokenizer, AutoModel


tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModel.from_pretrained(
    model_id,
    device_map="auto",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.float16,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
# https://huggingface.co/intfloat/e5-mistral-7b-instruct#usage
def get_embedding(text):
    with torch.no_grad():
        batch_dict = tokenizer(
            [text],
            max_length=512,
            return_attention_mask=False,
            padding=False,
            truncation=True,
        )

        # append eos_token_id to every input_ids
        batch_dict["input_ids"] = [
            input_ids + [tokenizer.eos_token_id]
            for input_ids in batch_dict["input_ids"]
        ]
        batch_dict = tokenizer.pad(
            batch_dict, padding=True, return_attention_mask=True, return_tensors="pt"
        )
        outputs = model(**batch_dict)
        embeddings = last_token_pool(
            outputs.last_hidden_state, batch_dict["attention_mask"]
        )

        # normalize embeddings
        embeddings = F.normalize(embeddings, p=2, dim=1)
        return embeddings

# JSTS

In [7]:
import json
import pandas as pd
from urllib.request import urlopen

df = pd.DataFrame([json.loads(line) for line in urlopen(jsts_url).readlines()])
df.head(1)

Unnamed: 0,sentence_pair_id,yjcaptions_id,sentence1,sentence2,label
0,0,100312_421853-104611-31624,レンガの建物の前を、乳母車を押した女性が歩いています。,厩舎で馬と女性とが寄り添っています。,0.0


In [8]:
df.shape

(1457, 5)

## Encode

In [9]:
# doc側はtask_descriptionを使わないのが正しいかもしれないが、評価結果は両方にtask_descriptionを付けたほうがちょっと良かった
query_texts = [
    get_detailed_instruct(sts_task_description, text) for text in df["sentence1"]
]
doc_texts = [
    get_detailed_instruct(sts_task_description, text) for text in df["sentence2"]
]

query_texts[:2], doc_texts[:2]

(['Instruct: Retrieve semantically similar text. \nQuery: レンガの建物の前を、乳母車を押した女性が歩いています。',
  'Instruct: Retrieve semantically similar text. \nQuery: 山の上に顔の白い牛が2頭います。'],
 ['Instruct: Retrieve semantically similar text. \nQuery: 厩舎で馬と女性とが寄り添っています。',
  'Instruct: Retrieve semantically similar text. \nQuery: 曇り空の山肌で、牛が２匹草を食んでいます。'])

In [10]:
from tqdm.auto import tqdm

all_embs = torch.cat([get_embedding(text) for text in tqdm(query_texts + doc_texts)])
all_embs.shape

  0%|          | 0/2914 [00:00<?, ?it/s]

You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


torch.Size([2914, 4096])

In [11]:
all_embs.shape

torch.Size([2914, 4096])

## Correlation Score

In [12]:
from scipy.spatial.distance import cosine, euclidean
from scipy.stats import spearmanr

df["similarity"] = [
    1 - cosine(s1, s2)
    for s1, s2 in zip(all_embs[: len(query_texts)], all_embs[len(query_texts) :])
]
jsts_score = spearmanr(df["similarity"], df["label"])[0]
jsts_score

0.825669979656227

# JSICK

In [16]:
df = pd.read_csv(jsick_url, sep="\t")
df.head(1)

Unnamed: 0,pair_ID,data,sentence_A_En,sentence_B_En,entailment_label_En,relatedness_score_En,corr_entailment_labelAB_En,corr_entailment_labelBA_En,sentence_A_Ja,sentence_B_Ja,entailment_label_Ja,relatedness_score_Ja,image_ID,original_caption,semtag_short,semtag_long
0,6,test,There is no boy playing outdoors and there is ...,A group of kids is playing in a yard and an ol...,neutral,3.3,,,戸外で遊んでいる男の子は一人もおらず、微笑んでいる男性は一人もいない,子供たちのグループが庭で遊んでいて、後ろの方には年を取った男性が立っている,contradiction,2.3,3155657768_b83a7831e5.jpg,"The children are playing outdoors , while a ma...",Negation#Numerical,"Numerical;人;名詞,接尾,助数詞,*#Negation;ない;助動詞,*,*,*#..."


In [17]:
df.shape

(4927, 16)

## Encode

In [18]:
query_texts = [
    get_detailed_instruct(sts_task_description, text) for text in df["sentence_A_Ja"]
]
doc_texts = [
    get_detailed_instruct(sts_task_description, text) for text in df["sentence_B_Ja"]
]

query_texts[:2], doc_texts[:2]

(['Instruct: Retrieve semantically similar text. \nQuery: 戸外で遊んでいる男の子は一人もおらず、微笑んでいる男性は一人もいない',
  'Instruct: Retrieve semantically similar text. \nQuery: 庭にいる男の子たちのグループが遊んでいて、男性が後ろの方に立っている'],
 ['Instruct: Retrieve semantically similar text. \nQuery: 子供たちのグループが庭で遊んでいて、後ろの方には年を取った男性が立っている',
  'Instruct: Retrieve semantically similar text. \nQuery: 幼い男の子たちが戸外で遊んでいて、その男性が近くで微笑んでいる'])

In [19]:
from tqdm.auto import tqdm

all_embs = torch.cat([get_embedding(text) for text in tqdm(query_texts + doc_texts)])
all_embs.shape

  0%|          | 0/9854 [00:00<?, ?it/s]

torch.Size([9854, 4096])

## Correlation Score

In [20]:
from scipy.spatial.distance import cosine
from scipy.stats import spearmanr

df["similarity"] = [
    1 - cosine(s1, s2)
    for s1, s2 in zip(all_embs[: len(query_texts)], all_embs[len(query_texts) :])
]
jsick_score = spearmanr(df["similarity"], df["relatedness_score_Ja"])[0]
jsick_score

0.8334555636238806

# Miracle
* Need access token for huggingface

In [21]:
import os
import dotenv

dotenv.load_dotenv("huggingface_access_token", override=True)

True

In [22]:
import datasets

# query and positives
ds = datasets.load_dataset(
    "miracl/miracl", "ja", use_auth_token=os.environ["HF_ACCESS_TOKEN"], split="dev"
)
ds



Dataset({
    features: ['query_id', 'query', 'positive_passages', 'negative_passages'],
    num_rows: 860
})

In [23]:
# all corpus texts
corpus = datasets.load_dataset("miracl/miracl-corpus", "ja")
corpus

DatasetDict({
    train: Dataset({
        features: ['docid', 'title', 'text'],
        num_rows: 6953614
    })
})

In [24]:
# hard negatives
with open("./miracl_hard_negs_1000.json") as f:
    hn = json.loads(f.read())
len(hn), list(hn.keys())[:5], hn["0"].keys(), hn["0"]["docids"][:2], hn["0"]["indices"][
    :2
]

(860,
 ['0', '3', '4', '5', '7'],
 dict_keys(['docids', 'indices']),
 ['2681119#0', '2681119#1'],
 [1393435, 1393436])

In [25]:
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist


def get_text(corpus_item):
    return corpus_item["title"] + " " + corpus_item["text"]


corpus_dict = {item["docid"]: get_text(item) for item in corpus["train"]}

In [26]:
n_total_pos = 0
n_total_tp = 0

for item in tqdm(ds):
    # query
    query_emb = get_embedding(
        get_detailed_instruct(search_task_description, item["query"])
    )

    # passages are set(300 hard negatives + positives)
    positive_docids = [pp["docid"] for pp in item["positive_passages"]]
    positive_texts = [get_text(pp) for pp in item["positive_passages"]]
    hn_docids = hn[item["query_id"]]["docids"][:miracle_n_hard_negs]

    # drop hard negatives in positives
    hn_docids = [docid for docid in hn_docids if docid not in positive_docids]

    # search target
    target_docids = positive_docids + hn_docids
    target_texts = positive_texts + [corpus_dict[docid] for docid in hn_docids]

    # embedding
    target_embs = torch.cat([get_embedding(text) for text in target_texts])

    # topK
    topk_indices = np.argsort(cdist(query_emb, target_embs, metric="cosine"))[0][
        :miracle_n_recall
    ]

    n_pos = len(positive_docids)
    n_tp = len(
        set(topk_indices) & set(range(len(positive_docids)))
    )  # positives are first indices

    n_total_pos += n_pos
    n_total_tp += n_tp

    # if n_pos > n_tp:
    # print(f"{item['query_id']}:{n_tp}/{n_pos}", end=", ")
miracl_recall = n_total_tp / n_total_pos

n_total_pos, n_total_tp, miracl_recall

  0%|          | 0/860 [00:00<?, ?it/s]

(1790, 1426, 0.7966480446927374)

# Output

In [27]:
model_id, jsts_score, jsick_score, miracl_recall

('../../modls/japanese-e5-mistral-1.9b-02/checkpoint-800000',
 0.825669979656227,
 0.8334555636238806,
 0.7966480446927374)

In [30]:
import json

with open(f'./scores/{model_id.replace("/", "_")}.txt', "w") as f:
    f.write(
        json.dumps(
            {
                "model_id": model_id,
                "jsts": jsts_score,
                "jsick": jsick_score,
                "miracl": miracl_recall,
            }
        )
    )