# SBERT-Jittor General Use Demo

This demo shows common SBERTJittor usage patterns: basic construction,
HF checkpoint loading, and encoding text.

## 0) (Optional) Path setup and warning control

- Run this only if you hit import errors in this notebook.
- This step also silences noisy HF cache deprecation warnings.

In [None]:
# Run this only if you get import errors
import os
import sys
from pathlib import Path

# Find the repo root and add it to sys.path
# so `model/` and `utils/` can be imported in notebooks.
def _find_repo_root(start: Path):
    for p in [start] + list(start.parents):
        if (p / 'model' / 'sbert_model.py').is_file():
            return p
    return None

repo_root = _find_repo_root(Path.cwd())
if repo_root and str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))
    print('Added repo root to sys.path:', repo_root)
else:
    print('Repo root not found or already on sys.path')


In [None]:
# Optional HF cache + warning control
import os
import warnings

# Put HF cache in a local folder for this notebook
os.environ.setdefault('HF_HOME', './.hf_cache')
# Avoid deprecated TRANSFORMERS_CACHE warning
os.environ.pop('TRANSFORMERS_CACHE', None)
# Silence the deprecation warning in output
warnings.filterwarnings(
    'ignore',
    message='Using `TRANSFORMERS_CACHE` is deprecated',
    category=FutureWarning,
)


## 1) Imports

In [None]:
import jittor as jt
from transformers import AutoTokenizer

from model.sbert_model import SBERTJittor


## 2) Basic SBERTJittor usage patterns

Construct different SBERT variants and inspect output dimensions.

In [None]:
# 1) Basic SBERT (mean pooling)
model1 = SBERTJittor('bert-base-uncased', pooling='mean', head_type='none')
print('model1 output_dim:', model1.output_dim)

# 2) RoBERTa SBERT
model2 = SBERTJittor('roberta-base', pooling='mean', head_type='none')
print('model2 output_dim:', model2.output_dim)
print('vocab size:', model2.config.vocab_size)
print('max position:', model2.config.max_position_embeddings)

# 3) Linear projection head
model3 = SBERTJittor('bert-base-uncased', pooling='mean', head_type='linear', output_dim=256)
print('model3 output_dim:', model3.output_dim)

# 4) MLP projection head
model4 = SBERTJittor('bert-base-uncased', pooling='mean', head_type='mlp', output_dim=128, num_layers=2)
print('model4 output_dim:', model4.output_dim)


## 3) Load a pretrained SBERT checkpoint from HF

Use a hosted Jittor checkpoint and run encoding.

In [None]:
repo_id = 'Kyle-han/roberta-base-nli-mean-tokens'
model, tokenizer, _ = SBERTJittor.from_pretrained(
    repo_id,
    return_tokenizer=True,
)

# Encode sample text
batch = tokenizer('hello world', return_tensors='np')
input_ids = jt.array(batch['input_ids'])
attention_mask = jt.array(batch['attention_mask'])
token_type_ids = jt.array(batch['token_type_ids']) if 'token_type_ids' in batch else None

emb = model.encode(input_ids, attention_mask, token_type_ids)
print('embedding shape:', emb.shape)
