# SBERT MR Demo (Jittor)

This notebook downloads the MR (movie review) dataset, trains SBERTJittor on MR, and prints the best dev accuracy.


## 1. Setup
Imports and CUDA setup.


In [None]:
import os
import jittor as jt

jt.flags.use_cuda = 1 if jt.has_cuda else 0
print('CUDA:', bool(jt.flags.use_cuda))

## 2. Download MR dataset


In [None]:
from datasets import load_dataset

data_dir = '../data' # set as yours
os.makedirs(data_dir, exist_ok=True)

print('Downloading MR (rotten_tomatoes)...')
mr = load_dataset('rotten_tomatoes')
mr.save_to_disk(os.path.join(data_dir, 'MR'))
print('MR saved to', os.path.join(data_dir, 'MR'))
print('Train:', len(mr['train']), 'Val:', len(mr['validation']), 'Test:', len(mr['test']))

## 3. Configure paths
Set model name and optional encoder checkpoint.


In [None]:
base_model = 'roberta-base'
pooling = 'mean'
data_dir = '../data'
cache_dir = '../data/tokenized'

# Jittor checkpoint from a trained sroberta model
jittor_ckpt = './important_checkpoints/nli/sroberta-base_best.pkl' # change this to your route

output_dir = './checkpoints/mr_demo'


## 4. Train on MR
Runs the MR training script and saves the best checkpoint.


In [None]:
from tqdm import tqdm
from datasets import load_from_disk
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import numpy as np
import jittor as jt
from jittor import nn

from model.sbert_model import SBERTJittor

def _jt_array(data, dtype):
    return jt.array(np.asarray(data, dtype=dtype))

def _to_batch(batch):
    out = {
        'input_ids': _jt_array(batch['input_ids'], 'int32'),
        'attention_mask': _jt_array(batch['attention_mask'], 'float32'),
        'labels': _jt_array(batch['labels'], 'int32'),
    }
    if 'token_type_ids' in batch:
        out['token_type_ids'] = _jt_array(batch['token_type_ids'], 'int32')
    return out

def collate_mr(batch):
    out = {
        'input_ids': np.asarray([b['input_ids'] for b in batch], dtype=np.int32),
        'attention_mask': np.asarray([b['attention_mask'] for b in batch], dtype=np.float32),
        'labels': np.asarray([b['labels'] for b in batch], dtype=np.int32),
    }
    if 'token_type_ids' in batch[0]:
        out['token_type_ids'] = np.asarray([b['token_type_ids'] for b in batch], dtype=np.int32)
    return out

def prepare_mr_dataset(split, tokenizer, max_length=128):
    ds = load_from_disk(os.path.join(data_dir, 'MR'))[split]
    def tokenize_fn(batch):
        tok = tokenizer(batch['text'], padding='max_length', truncation=True, max_length=max_length)
        tok['labels'] = batch['label']
        return tok
    return ds.map(tokenize_fn, batched=True, remove_columns=ds.column_names)

tokenizer_dir = './hf/roberta-base'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=True)

train_ds = prepare_mr_dataset('train', tokenizer)
dev_ds = prepare_mr_dataset('validation', tokenizer)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, collate_fn=collate_mr)
dev_loader = DataLoader(dev_ds, batch_size=32, shuffle=False, collate_fn=collate_mr)

model = SBERTJittor(
    encoder_name=base_model,
    pooling=pooling,
    head_type='none',
    checkpoint_path=None,
)
payload = jt.load(jittor_ckpt)
if isinstance(payload, dict) and 'model_state' in payload:
    model.load_state_dict(payload['model_state'])
else:
    model.load_state_dict(payload)

classifier = nn.Linear(model.output_dim, 2)
optimizer = nn.Adam(list(model.parameters()) + list(classifier.parameters()), lr=2e-5)
loss_fct = nn.CrossEntropyLoss()

best_acc = -1.0
for epoch in range(3):
    total_loss = 0.0
    total_samples = 0
    for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}'):
        jt_batch = _to_batch(batch)
        reps = model.encode(jt_batch['input_ids'], jt_batch['attention_mask'], jt_batch.get('token_type_ids'))
        logits = classifier(reps)
        loss = loss_fct(logits, jt_batch['labels'])
        optimizer.step(loss)
        total_loss += loss.item() * jt_batch['labels'].shape[0]
        total_samples += jt_batch['labels'].shape[0]
    print(f'Epoch {epoch+1} train loss:', total_loss / max(total_samples, 1))

    # eval
    model.eval()
    classifier.eval()
    correct = 0
    total = 0
    with jt.no_grad():
        for batch in dev_loader:
            jt_batch = _to_batch(batch)
            reps = model.encode(jt_batch['input_ids'], jt_batch['attention_mask'], jt_batch.get('token_type_ids'))
            logits = classifier(reps)
            preds = jt.argmax(logits, dim=1)[0]
            correct += jt.sum(preds == jt_batch['labels']).item()
            total += jt_batch['labels'].shape[0]
    acc = correct / max(total, 1) * 100
    print(f'Epoch {epoch+1} dev acc:', acc)
    model.train()
    classifier.train()
    if acc > best_acc:
        best_acc = acc
        os.makedirs(output_dir, exist_ok=True)
        safe_model = base_model.replace('/', '_')
        ckpt_path = os.path.join(output_dir, f'{safe_model}_best.pkl')
        jt.save({'model_state': model.state_dict(), 'classifier_state': classifier.state_dict()}, ckpt_path)
        print('Saved best checkpoint to', ckpt_path)

print('Best dev acc:', best_acc)