# SBERT-Jittor Evaluation (Step-by-step)

Run cells in order. Each cell matches a single step.


In [None]:
import os
import sys
from pathlib import Path

def _find_repo_root(start: Path):
    for p in [start] + list(start.parents):
        if (p / 'model' / 'sbert_model.py').is_file():
            return p
    return None

repo_root = _find_repo_root(Path.cwd())
if repo_root is None:
    print('SBERT_JITTOR root not found. Set sys.path manually.')
else:
    sys.path.insert(0, str(repo_root))
    os.chdir(repo_root)
    print(f'Using repo root: {repo_root}')


In [None]:
import os
import warnings

os.environ.setdefault('HF_HOME', './.hf_cache')
os.environ.pop('TRANSFORMERS_CACHE', None)
warnings.filterwarnings(
    'ignore',
    message='Using `TRANSFORMERS_CACHE` is deprecated',
    category=FutureWarning,
)


In [None]:
import numpy as np
import jittor as jt
from jittor.dataset import DataLoader
from transformers import AutoTokenizer
from scipy.stats import pearsonr, spearmanr

from model.sbert_model import SBERTJittor
from utils.data_loader import prepare_sts_dataset, collate_sts
from utils.jt_utils import _to_jittor_batch, setup_device


In [None]:
data_dir = './data'
dataset_name = 'STS-B'
split = 'test'
batch_size = 32
max_length = 128
repo_id = 'Kyle-han/roberta-base-nli-mean-tokens'


In [None]:
setup_device(True)


In [None]:
model, tokenizer, _ = SBERTJittor.from_pretrained(
    repo_id,
    return_tokenizer=True,
)


In [None]:
sts_dataset = prepare_sts_dataset(
    data_dir=data_dir,
    dataset_name=dataset_name,
    split=split,
    tokenizer=tokenizer,
    max_length=max_length,
    cache_dir=None,
    overwrite_cache=False,
    tokenize_batch_size=1024,
)

sts_loader = DataLoader(
    sts_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    collate_batch=collate_sts,
)


In [None]:
all_preds = []
all_scores = []

model.eval()
with jt.no_grad():
    for batch in sts_loader:
        jt_batch = _to_jittor_batch(batch, for_sts=True)
        emb_a = model.encode(
            jt_batch['input_ids_a'],
            jt_batch['attention_mask_a'],
            jt_batch.get('token_type_ids_a'),
        )
        emb_b = model.encode(
            jt_batch['input_ids_b'],
            jt_batch['attention_mask_b'],
            jt_batch.get('token_type_ids_b'),
        )
        emb_a_np = emb_a.numpy()
        emb_b_np = emb_b.numpy()
        denom = np.linalg.norm(emb_a_np, axis=1) * np.linalg.norm(emb_b_np, axis=1) + 1e-9
        sim = np.sum(emb_a_np * emb_b_np, axis=1) / denom
        all_preds.extend(sim.tolist())
        all_scores.extend(jt_batch['scores'].numpy().reshape(-1).tolist())


In [None]:
pearson, _ = pearsonr(all_preds, all_scores)
spearman, _ = spearmanr(all_preds, all_scores)
print({'pearson': pearson * 100, 'spearman': spearman * 100})
print('scores nan:', np.isnan(sts_dataset.arrays['scores']).any())
print('preds nan:', np.isnan(all_preds).any())
