# Evaluation Demo (STS-B)

This demo evaluates a pretrained SBERT model on the STS-Benchmark dataset.

Flow:  
0) (Optional) Path setup + warning control (only if you hit import warnings/errors)
1) Load config and device
2) Load pretrained SBERT + tokenizer from HF
3) Load datasets
4) Run evaluation loop
5) Metrics + NaN checks

## 0) (Optional) Path setup and warning control

- Use this only if you hit import errors in this notebook.
- This step also silences noisy HF cache deprecation warnings.

In [None]:
# Run this only if you get import errors
import os
import sys
from pathlib import Path

# Find the repo root and add it to sys.path
# so `model/` and `utils/` can be imported in notebooks.
def _find_repo_root(start: Path):
    for p in [start] + list(start.parents):
        if (p / 'model' / 'sbert_model.py').is_file():
            return p
    return None

repo_root = _find_repo_root(Path.cwd())
if repo_root and str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))
    print('Added repo root to sys.path:', repo_root)
else:
    print('Repo root not found or already on sys.path')

In [None]:
# Optional HF cache + warning control
import os
import warnings

# Put HF cache in a local folder for this notebook
os.environ.setdefault('HF_HOME', './.hf_cache')
# Avoid deprecated TRANSFORMERS_CACHE warning
os.environ.pop('TRANSFORMERS_CACHE', None)
# Silence the deprecation warning in output
warnings.filterwarnings(
    'ignore',
    message='Using `TRANSFORMERS_CACHE` is deprecated',
    category=FutureWarning,
)

In [None]:
import math
import numpy as np
import jittor as jt
from jittor.dataset import DataLoader
from transformers import AutoTokenizer
from scipy.stats import pearsonr, spearmanr
from tqdm import tqdm

from model.sbert_model import SBERTJittor
from utils.data_loader import prepare_sts_dataset, collate_sts
from utils.jt_utils import _to_jittor_batch
from utils.training_utils import setup_device

## 1) Load config and device

Set basic runtime configuration and select CPU/GPU.

In [None]:
# Basic config
# - data_dir: local dataset root
# - repo_id: HF model id to load
# - split: STS-B split to evaluate

data_dir = './data'
dataset_name = 'STS-B'
split = 'test'
batch_size = 32
max_length = 128
repo_id = 'Kyle-han/roberta-base-nli-mean-tokens'

setup_device(True)

## 2) Load pretrained SBERT + tokenizer from HF

Fetch the model and tokenizer from the Hugging Face Hub.

In [None]:
# Load pretrained SBERT + tokenizer
model, tokenizer, _ = SBERTJittor.from_pretrained(
    repo_id,
    return_tokenizer=True,
)

## 3) Load datasets

Load STS-B split and build the Jittor DataLoader.

In [None]:
# Load STS-B dataset and build dataloader
sts_dataset = prepare_sts_dataset(
    data_dir=data_dir,
    dataset_name=dataset_name,
    split=split,
    tokenizer=tokenizer,
    max_length=max_length,
    cache_dir=None,
    overwrite_cache=False,
    tokenize_batch_size=1024,
)

sts_loader = DataLoader(
    sts_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    drop_last=False,
    collate_batch=collate_sts,
)

## 4) Run evaluation loop

Compute cosine similarity for each pair with a visible progress bar.

In [None]:
# Evaluate with cosine similarity
all_preds = []
all_scores = []

model.eval()
# Use dataset length to show a stable progress bar total
total_batches = math.ceil(len(sts_dataset) / batch_size)
with jt.no_grad():
    for batch in tqdm(sts_loader, total=total_batches, desc='STS-B eval', leave=False):
        jt_batch = _to_jittor_batch(batch, for_sts=True)
        emb_a = model.encode(
            jt_batch['input_ids_a'],
            jt_batch['attention_mask_a'],
            jt_batch.get('token_type_ids_a'),
        )
        emb_b = model.encode(
            jt_batch['input_ids_b'],
            jt_batch['attention_mask_b'],
            jt_batch.get('token_type_ids_b'),
        )
        emb_a_np = emb_a.numpy()
        emb_b_np = emb_b.numpy()
        denom = np.linalg.norm(emb_a_np, axis=1) * np.linalg.norm(emb_b_np, axis=1) + 1e-9
        sim = np.sum(emb_a_np * emb_b_np, axis=1) / denom
        all_preds.extend(sim.tolist())
        all_scores.extend(jt_batch['scores'].numpy().reshape(-1).tolist())

## 5) Metrics + NaN checks

Report Pearson/Spearman and verify there are no NaNs.

In [None]:
# Metrics + sanity checks
pearson, _ = pearsonr(all_preds, all_scores)
spearman, _ = spearmanr(all_preds, all_scores)
print({'pearson': pearson * 100, 'spearman': spearman * 100})
print('scores nan:', np.isnan(sts_dataset.arrays['scores']).any())
print('preds nan:', np.isnan(all_preds).any())