# Downstream Demo (MR)

Step-by-step MR training + evaluation without helper functions.


In [None]:
# 1) Repository path setup (run this only if you get import errors)
import os
import sys
from pathlib import Path

def _find_repo_root(start: Path):
    for p in [start] + list(start.parents):
        if (p / 'model' / 'sbert_model.py').is_file():
            return p
    return None

repo_root = _find_repo_root(Path.cwd())
if repo_root is None:
    print('SBERT_JITTOR root not found. Set sys.path manually.')
else:
    sys.path.insert(0, str(repo_root))
    os.chdir(repo_root)
    print(f'Using repo root: {repo_root}')


In [None]:
# 2) HF cache + warning control
import os
import warnings

os.environ.setdefault('HF_HOME', './.hf_cache')
os.environ.pop('TRANSFORMERS_CACHE', None)
warnings.filterwarnings(
    'ignore',
    message='Using `TRANSFORMERS_CACHE` is deprecated',
    category=FutureWarning,
)


## 3) Imports


In [None]:
import math
import numpy as np
import jittor as jt
from jittor import nn
from jittor.dataset import DataLoader
from transformers import AutoTokenizer
from tqdm import tqdm

from model.sbert_model import SBERTJittor
from utils.data_loader import prepare_text_classification_dataset
from utils.jt_utils import _to_jittor_batch_single, setup_device


## 4) Config


In [None]:
data_dir = './data'
batch_size = 32
max_length = 128
repo_id = 'Kyle-han/roberta-base-nli-mean-tokens'


## 5) Device setup


In [None]:
setup_device(True)


## 6) Load model + tokenizer from HF


In [None]:
model, tokenizer, _ = SBERTJittor.from_pretrained(
    repo_id,
    return_tokenizer=True,
)


## 7) Load MR datasets + DataLoaders


In [None]:
train_ds = prepare_text_classification_dataset(
    data_dir=data_dir,
    dataset_name='MR',
    split='train',
    tokenizer=tokenizer,
    max_length=max_length,
    cache_dir=None,
    overwrite_cache=False,
    tokenize_batch_size=1024,
)
val_ds = prepare_text_classification_dataset(
    data_dir=data_dir,
    dataset_name='MR',
    split='validation',
    tokenizer=tokenizer,
    max_length=max_length,
    cache_dir=None,
    overwrite_cache=False,
    tokenize_batch_size=1024,
)
test_ds = prepare_text_classification_dataset(
    data_dir=data_dir,
    dataset_name='MR',
    split='test',
    tokenizer=tokenizer,
    max_length=max_length,
    cache_dir=None,
    overwrite_cache=False,
    tokenize_batch_size=1024,
)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=4)


## 8) Train (one epoch demo)


In [None]:
clf = nn.Linear(model.output_dim, 2)
optimizer = nn.Adam(list(model.parameters()) + list(clf.parameters()), lr=2e-5)
loss_fn = nn.CrossEntropyLoss()

steps_per_epoch = math.ceil(len(train_ds) / batch_size)
model.train()
clf.train()
for step, batch in enumerate(tqdm(train_loader, total=steps_per_epoch, desc='MR train'), 1):
    jt_batch = _to_jittor_batch_single(batch)
    reps = model.encode(jt_batch['input_ids'], jt_batch['attention_mask'], jt_batch.get('token_type_ids'))
    logits = clf(reps)
    loss = loss_fn(logits, jt_batch['labels'])
    optimizer.step(loss)
    if step >= steps_per_epoch:
        break


## 9) Evaluation (validation + test)


In [None]:
def eval_loop(loader, name):
    model.eval()
    clf.eval()
    total_correct = 0
    total_samples = 0
    total_loss = 0.0
    with jt.no_grad():
        for batch in tqdm(loader, total=math.ceil(len(loader.dataset) / batch_size), desc=f'{name} eval'):
            jt_batch = _to_jittor_batch_single(batch)
            reps = model.encode(jt_batch['input_ids'], jt_batch['attention_mask'], jt_batch.get('token_type_ids'))
            logits = clf(reps)
            loss = loss_fn(logits, jt_batch['labels'])
            preds = jt.argmax(logits, dim=1)[0]
            total_correct += jt.sum(preds == jt_batch['labels']).item()
            total_samples += jt_batch['labels'].shape[0]
            total_loss += loss.item() * jt_batch['labels'].shape[0]
    avg_loss = total_loss / max(total_samples, 1)
    acc = total_correct / max(total_samples, 1) * 100
    print({name: {'loss': avg_loss, 'accuracy': acc}})

eval_loop(val_loader, 'MR validation')
eval_loop(test_loader, 'MR test')
