# Downstream Demo (MR)

This demo tests transfer ability by attaching a classification head to a pretrained SBERT encoder,
then training on the MR sentiment dataset and evaluating performance.

Flow:
0) (Optional) Path setup + warning control (only if you hit import warnings/errors)
1) Load config and device
2) Load pretrained SBERT + tokenizer from HF
3) Load datasets
4) Train classifier head (SBERT frozen in this demo)
5) Evaluate on validation + test set

## 0) (Optional) Path setup and warning control

- Use this only if you hit import errors in this notebook.
- This step also silences noisy HF cache deprecation warnings.

In [None]:
# Run this only if you get import errors
import os
import sys
from pathlib import Path

# Find the repo root and add it to sys.path
# so `model/` and `utils/` can be imported in notebooks.
def _find_repo_root(start: Path):
    for p in [start] + list(start.parents):
        if (p / 'model' / 'sbert_model.py').is_file():
            return p
    return None

repo_root = _find_repo_root(Path.cwd())
if repo_root and str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))
    print('Added repo root to sys.path:', repo_root)
else:
    print('Repo root not found or already on sys.path')

In [None]:
# Optional HF cache + warning control
import os
import warnings

# Put HF cache in a local folder for this notebook
os.environ.setdefault('HF_HOME', './.hf_cache')
# Avoid deprecated TRANSFORMERS_CACHE warning
os.environ.pop('TRANSFORMERS_CACHE', None)
# Silence the deprecation warning in output
warnings.filterwarnings(
    'ignore',
    message='Using `TRANSFORMERS_CACHE` is deprecated',
    category=FutureWarning,
)

In [None]:
import math
import numpy as np
import jittor as jt
from jittor import nn
from jittor.dataset import DataLoader
from transformers import AutoTokenizer
from tqdm import tqdm

from model.sbert_model import SBERTJittor
from utils.data_loader import prepare_text_classification_dataset, collate_text_classification
from utils.jt_utils import _to_jittor_batch_single
from utils.training_utils import setup_device

## 1) Load config and device

Set basic runtime configuration and select CPU/GPU.

In [None]:
# Basic config
# - data_dir: local dataset root
# - repo_id: HF model id to load

data_dir = './data'
batch_size = 32
max_length = 128
repo_id = 'Kyle-han/roberta-base-nli-mean-tokens'

setup_device(True)

## 2) Load pretrained SBERT + tokenizer from HF

Fetch the model and tokenizer from the Hugging Face Hub.

In [None]:
# Load pretrained SBERT + tokenizer
model, tokenizer, _ = SBERTJittor.from_pretrained(
    repo_id,
    return_tokenizer=True,
)

## 3) Load datasets

Load MR train/validation/test splits and build Jittor DataLoaders.

In [None]:
# Load MR datasets and build dataloaders
train_ds = prepare_text_classification_dataset(
    data_dir=data_dir,
    dataset_name='MR',
    split='train',
    tokenizer=tokenizer,
    max_length=max_length,
    cache_dir=None,
    overwrite_cache=False,
    tokenize_batch_size=1024,
)
val_ds = prepare_text_classification_dataset(
    data_dir=data_dir,
    dataset_name='MR',
    split='validation',
    tokenizer=tokenizer,
    max_length=max_length,
    cache_dir=None,
    overwrite_cache=False,
    tokenize_batch_size=1024,
)
test_ds = prepare_text_classification_dataset(
    data_dir=data_dir,
    dataset_name='MR',
    split='test',
    tokenizer=tokenizer,
    max_length=max_length,
    cache_dir=None,
    overwrite_cache=False,
    tokenize_batch_size=1024,
)

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    drop_last=False,
    collate_batch=collate_text_classification,
)
val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    drop_last=False,
    collate_batch=collate_text_classification,
)
test_loader = DataLoader(
    test_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    drop_last=False,
    collate_batch=collate_text_classification,
)

## 4) Train classifier head

Freeze the SBERT encoder and train a lightweight classifier head.

In [None]:
# Train a small classifier head on top of frozen SBERT
for param in model.parameters():
    param.stop_grad()

clf = nn.Linear(model.output_dim, 2)
optimizer = nn.Adam(clf.parameters(), lr=2e-5)
loss_fn = nn.CrossEntropyLoss()

steps_per_epoch = math.ceil(len(train_ds) / batch_size)
model.train()
clf.train()
for step, batch in enumerate(tqdm(train_loader, total=steps_per_epoch, desc='MR train'), 1):
    jt_batch = _to_jittor_batch_single(batch)
    reps = model.encode(jt_batch['input_ids'], jt_batch['attention_mask'], jt_batch.get('token_type_ids'))
    logits = clf(reps)
    loss = loss_fn(logits, jt_batch['labels'])
    optimizer.step(loss)
    if step >= steps_per_epoch:
        break

## 5) Evaluate on validation + test set

Measure accuracy and loss on held-out splits.

In [None]:
# Evaluation loop for validation/test

def eval_loop(loader, name, dataset_len):
    model.eval()
    clf.eval()
    total_correct = 0
    total_samples = 0
    total_loss = 0.0
    total_batches = math.ceil(dataset_len / batch_size)
    with jt.no_grad():
        for batch in tqdm(loader, total=total_batches, desc=f'{name} eval'):
            jt_batch = _to_jittor_batch_single(batch)
            reps = model.encode(jt_batch['input_ids'], jt_batch['attention_mask'], jt_batch.get('token_type_ids'))
            logits = clf(reps)
            loss = loss_fn(logits, jt_batch['labels'])
            preds = jt.argmax(logits, dim=1)[0]
            total_correct += jt.sum(preds == jt_batch['labels']).item()
            total_samples += jt_batch['labels'].shape[0]
            total_loss += loss.item() * jt_batch['labels'].shape[0]
    avg_loss = total_loss / max(total_samples, 1)
    acc = total_correct / max(total_samples, 1) * 100
    print({name: {'loss': avg_loss, 'accuracy': acc}})

# Validate and test

eval_loop(val_loader, 'MR validation', len(val_ds))
eval_loop(test_loader, 'MR test', len(test_ds))