# SBERT-Jittor General Use Demo

This demo shows common SBERTJittor usage patterns: basic construction,
HF checkpoint loading, and encoding text.

## 0) (Optional) Path setup and warning control

- Run this only if you hit import errors in this notebook.
- This step also silences noisy HF cache deprecation warnings.

In [1]:
# Run this only if you get import errors
import os
import sys
from pathlib import Path

# Find the repo root and add it to sys.path
# so `model/` and `utils/` can be imported in notebooks.
def _find_repo_root(start: Path):
    for p in [start] + list(start.parents):
        if (p / 'model' / 'sbert_model.py').is_file():
            return p
    return None

repo_root = _find_repo_root(Path.cwd())
if repo_root and str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))
    print('Added repo root to sys.path:', repo_root)
else:
    print('Repo root not found or already on sys.path')


Added repo root to sys.path: /root/SBERT_JITTOR


In [2]:
# Optional HF cache + warning control
import os
import warnings

# Put HF cache in a local folder for this notebook
os.environ.setdefault('HF_HOME', './.hf_cache')
# Avoid deprecated TRANSFORMERS_CACHE warning
os.environ.pop('TRANSFORMERS_CACHE', None)
# Silence the deprecation warning in output
warnings.filterwarnings(
    'ignore',
    message='Using `TRANSFORMERS_CACHE` is deprecated',
    category=FutureWarning,
)


## 1) Imports

In [3]:
import jittor as jt
from transformers import AutoTokenizer

from model.sbert_model import SBERTJittor
from utils.jt_utils import setup_device

[38;5;2m[i 0106 13:34:11.678819 48 log.cc:351] Load log_sync: 1[m
[38;5;2m[i 0106 13:34:11.881909 48 compiler.py:956] Jittor(1.3.10.0) src: /root/shared-nvme/venvs/ann_pip/lib/python3.12/site-packages/jittor[m
[38;5;2m[i 0106 13:34:11.885009 48 compiler.py:957] g++ at /usr/bin/g++(12.4.0)[m
[38;5;2m[i 0106 13:34:11.885337 48 compiler.py:958] cache_path: /root/.cache/jittor/jt1.3.10/g++12.4.0/py3.12.3/Linux-5.15.0-1xc3/INTELRXEONRGOLx51/8891/default[m
[38;5;2m[i 0106 13:34:11.924667 48 install_cuda.py:96] cuda_driver_version: [12, 4][m
[38;5;2m[i 0106 13:34:11.928680 48 __init__.py:412] Found /root/.cache/jittor/jtcuda/cuda12.2_cudnn8_linux/bin/nvcc(12.2.140) at /root/.cache/jittor/jtcuda/cuda12.2_cudnn8_linux/bin/nvcc.[m
[38;5;2m[i 0106 13:34:11.932087 48 __init__.py:412] Found addr2line(2.42) at /usr/bin/addr2line.[m
[38;5;2m[i 0106 13:34:12.085337 48 compiler.py:1013] cuda key:cu12.2.140_sm_89[m
[38;5;2m[i 0106 13:34:13.059772 48 __init__.py:227] Total mem: 1007.52GB

In [4]:
# Device setup (use_cuda=True to enable GPU)
use_cuda = True
setup_device(use_cuda)


[38;5;2m[i 0106 13:34:20.121113 48 cuda_flags.cc:55] CUDA enabled.[m


## 2) Basic SBERTJittor usage patterns

Construct different SBERT variants and inspect output dimensions.

In [5]:
# 1) Basic SBERT (mean pooling)
model1 = SBERTJittor('bert-base-uncased', pooling='mean', head_type='none')
print('model1 output_dim:', model1.output_dim)

# 2) RoBERTa SBERT
model2 = SBERTJittor('roberta-base', pooling='mean', head_type='none')
print('model2 output_dim:', model2.output_dim)
print('vocab size:', model2.config.vocab_size)
print('max position:', model2.config.max_position_embeddings)

# 3) Linear projection head
model3 = SBERTJittor('bert-base-uncased', pooling='mean', head_type='linear', output_dim=256)
print('model3 output_dim:', model3.output_dim)

# 4) MLP projection head
model4 = SBERTJittor('bert-base-uncased', pooling='mean', head_type='mlp', output_dim=128, num_layers=2)
print('model4 output_dim:', model4.output_dim)


SBERTJittor initialized:
  Encoder: bert-base-uncased
  Pooling: mean
  Head: none
  Output dim: 768
model1 output_dim: 768
SBERTJittor initialized:
  Encoder: roberta-base
  Pooling: mean
  Head: none
  Output dim: 768
model2 output_dim: 768
vocab size: 50265
max position: 514
SBERTJittor initialized:
  Encoder: bert-base-uncased
  Pooling: mean
  Head: linear
  Output dim: 256
model3 output_dim: 256
SBERTJittor initialized:
  Encoder: bert-base-uncased
  Pooling: mean
  Head: mlp
  Output dim: 128
model4 output_dim: 128


## 3) Load a pretrained SBERT checkpoint from HF

Use a hosted Jittor checkpoint and run encoding.

In [6]:
repo_id = 'Kyle-han/roberta-base-nli-mean-tokens'
model, tokenizer, _ = SBERTJittor.from_pretrained(
    repo_id,
    return_tokenizer=True,
)

# Encode sample text
batch = tokenizer('hello world', return_tensors='np')
input_ids = jt.array(batch['input_ids'])
attention_mask = jt.array(batch['attention_mask'])
token_type_ids = jt.array(batch['token_type_ids']) if 'token_type_ids' in batch else None

emb = model.encode(input_ids, attention_mask, token_type_ids)
print('embedding shape:', emb.shape)


Fetching 40 files: 100%|█████████████████████████████| 40/40 [00:00<00:00, 1570.64it/s]


SBERTJittor initialized:
  Encoder: roberta-base
  Pooling: mean
  Head: none
  Output dim: 768
embedding shape: [1,768,]
