# SBERT-Jittor Fine-tuning Demo

Step-by-step NLI fine-tuning and STS regression without calling training helpers.


In [None]:
import os
import sys
from pathlib import Path

def _find_repo_root(start: Path):
    for p in [start] + list(start.parents):
        if (p / 'model' / 'sbert_model.py').is_file():
            return p
    return None

repo_root = _find_repo_root(Path.cwd())
if repo_root is None:
    print('SBERT_JITTOR root not found. Set sys.path manually.')
else:
    sys.path.insert(0, str(repo_root))
    os.chdir(repo_root)
    print(f'Using repo root: {repo_root}')


In [None]:
import os
import warnings

os.environ.setdefault('HF_HOME', './.hf_cache')
os.environ.pop('TRANSFORMERS_CACHE', None)
warnings.filterwarnings(
    'ignore',
    message='Using `TRANSFORMERS_CACHE` is deprecated',
    category=FutureWarning,
)


In [None]:
import math
import numpy as np
import jittor as jt
from jittor.dataset import DataLoader
from transformers import AutoTokenizer
from scipy.stats import pearsonr, spearmanr

from model.sbert_model import SBERTJittor
from losses.softmax_loss import SoftmaxLoss
from losses.regression_loss import RegressionLoss
from utils.data_loader import prepare_nli_dataset, prepare_sts_dataset, collate_nli, collate_sts
from utils.jt_utils import _to_jittor_batch, setup_device

setup_device(True)


In [None]:
# Data directory (download via utils/download_data.py first)
data_dir = './data'
cache_dir = None  # default -> data/_cache


In [None]:
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', use_fast=True)


In [None]:
# NLI dataset + loader
train_dataset = prepare_nli_dataset(
    data_dir=data_dir,
    datasets=['SNLI', 'MultiNLI'],
    split='train',
    tokenizer=tokenizer,
    max_length=128,
    cache_dir=cache_dir,
    overwrite_cache=False,
    tokenize_batch_size=1024,
)

train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=4,
    collate_batch=collate_nli,
)


In [None]:
# NLI model + loss + optimizer
model = SBERTJittor('bert-base-uncased', pooling='mean', head_type='none')
train_loss = SoftmaxLoss(model=model, num_labels=3, ablation=0)
optimizer = jt.optim.AdamW(train_loss.parameters(), lr=2e-5, weight_decay=0.01)


In [None]:
# NLI training loop (1 epoch demo)
steps_per_epoch = math.ceil(len(train_dataset) / 16)
warmup_steps = max(int(steps_per_epoch * 0.1), 1)
global_step = 0

for step, batch in enumerate(train_loader, 1):
    jt_batch = _to_jittor_batch(batch, for_sts=False)
    labels = jt_batch['labels']
    loss, logits = train_loss(jt_batch, labels)
    optimizer.step(loss)

    global_step += 1
    if global_step <= warmup_steps:
        optimizer.lr = 2e-5 * (global_step / warmup_steps)

    if step % 100 == 0:
        preds = jt.argmax(logits, dim=1)[0]
        acc = (jt.sum(preds == labels).item() / labels.shape[0]) * 100
        print(f'step {step} loss={loss.item():.4f} acc={acc:.2f}%')
    if step >= steps_per_epoch:
        break


In [None]:
# STS dataset + loader
sts_dataset = prepare_sts_dataset(
    data_dir=data_dir,
    dataset_name='STS-B',
    split='validation',
    tokenizer=tokenizer,
    max_length=128,
    cache_dir=cache_dir,
    overwrite_cache=False,
    tokenize_batch_size=1024,
)

sts_loader = DataLoader(
    sts_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    collate_batch=collate_sts,
)


In [None]:
# STS evaluation loop (explicit)
all_preds = []
all_scores = []
model.eval()
with jt.no_grad():
    for batch in sts_loader:
        jt_batch = _to_jittor_batch(batch, for_sts=True)
        emb_a = model.encode(jt_batch['input_ids_a'], jt_batch['attention_mask_a'], jt_batch.get('token_type_ids_a'))
        emb_b = model.encode(jt_batch['input_ids_b'], jt_batch['attention_mask_b'], jt_batch.get('token_type_ids_b'))
        emb_a_np = emb_a.numpy()
        emb_b_np = emb_b.numpy()
        denom = np.linalg.norm(emb_a_np, axis=1) * np.linalg.norm(emb_b_np, axis=1) + 1e-9
        sim = np.sum(emb_a_np * emb_b_np, axis=1) / denom
        all_preds.extend(sim.tolist())
        all_scores.extend(jt_batch['scores'].numpy().reshape(-1).tolist())

pearson, _ = pearsonr(all_preds, all_scores)
spearman, _ = spearmanr(all_preds, all_scores)
print({'pearson': pearson * 100, 'spearman': spearman * 100})


In [None]:
# STS regression fine-tuning (1 epoch demo)
reg_loss = RegressionLoss()
optimizer = jt.optim.Adam(model.parameters(), lr=2e-5)

steps_per_epoch = math.ceil(len(sts_dataset) / 32)
for step, batch in enumerate(sts_loader, 1):
    jt_batch = _to_jittor_batch(batch, for_sts=True)
    emb_a = model.encode(jt_batch['input_ids_a'], jt_batch['attention_mask_a'], jt_batch.get('token_type_ids_a'))
    emb_b = model.encode(jt_batch['input_ids_b'], jt_batch['attention_mask_b'], jt_batch.get('token_type_ids_b'))
    loss = reg_loss(emb_a, emb_b, jt_batch['scores'])
    optimizer.step(loss)
    if step % 100 == 0:
        print(f'step {step} loss={loss.item():.4f}')
    if step >= steps_per_epoch:
        break
