# SBERT-Jittor Fine-tuning Demo

This notebook loads data from Hugging Face datasets, fine-tunes on NLI, then runs STS regression.


In [None]:
import os
import sys
from pathlib import Path

def _find_repo_root(start: Path):
    for p in [start] + list(start.parents):
        if (p / 'model' / 'sbert_model.py').is_file():
            return p
    return None

repo_root = _find_repo_root(Path.cwd())
if repo_root is None:
    print('SBERT_JITTOR root not found. Set sys.path manually.')
else:
    sys.path.insert(0, str(repo_root))
    os.chdir(repo_root)
    print(f'Using repo root: {repo_root}')

# If you want to use code from the HF repo instead:
# from huggingface_hub import snapshot_download
# repo_dir = Path(snapshot_download('Kyle-han/roberta-base-nli-mean-tokens'))
# sys.path.insert(0, str(repo_dir))


In [None]:
import os
import jittor as jt
from argparse import Namespace
from transformers import AutoTokenizer

from utils.download_data import download_nli_datasets, download_sts_benchmark
from training.nli.train_nli import train as train_nli
from training.sts.train_sts import train as train_sts


In [None]:
# Data download (run once)
data_dir = './data'
os.makedirs(data_dir, exist_ok=True)
# download_nli_datasets(data_dir)
# download_sts_benchmark(data_dir)


In [None]:
# Tokenizer from Hugging Face
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', use_fast=True)
tokenizer('hello world', return_tensors='np')


In [None]:
# NLI fine-tuning
nli_args = Namespace(
    base_model='bert-base-uncased',
    pooling='mean',
    loss='softmax',
    ablation=0,
    encoder_checkpoint=None,
    tokenizer_dir=None,
    num_labels=3,
    data_dir=data_dir,
    datasets=['SNLI', 'MultiNLI'],
    max_length=128,
    batch_size=16,
    eval_batch_size=32,
    epochs=1,
    lr=2e-5,
    weight_decay=0.01,
    warmup_ratio=0.1,
    use_cuda=jt.has_cuda,
    log_steps=100,
    eval_steps=1000,
    save_steps=0,
    start_from_checkpoints=None,
    output_dir=None,
    num_workers=4,
    cache_dir=None,
    overwrite_cache=False,
    tokenize_batch_size=1024,
)

train_nli(nli_args)


In [None]:
# STS regression fine-tuning
sts_args = Namespace(
    base_model='bert-base-uncased',
    pooling='mean',
    encoder_checkpoint=None,
    tokenizer_dir=None,
    data_dir=data_dir,
    cache_dir=None,
    overwrite_cache=False,
    tokenize_batch_size=1024,
    train_dataset='STS-B',
    train_split='train',
    eval_dataset='STS-B',
    eval_split='validation',
    test_dataset='STS-B',
    test_split='test',
    batch_size=32,
    eval_batch_size=32,
    epochs=1,
    lr=2e-5,
    max_length=128,
    log_steps=20,
    eval_steps=180,
    save_steps=0,
    disable_checkpoint=False,
    num_workers=4,
    use_cuda=jt.has_cuda,
    normalize_scores=False,
    score_scale=5.0,
    start_from_checkpoints=None,
    output_dir=None,
)

train_sts(sts_args)
