# SBERT-Jittor Downstream Demo

This notebook loads a pretrained Jittor checkpoint and runs MR/SST downstream tasks.


In [None]:
from argparse import Namespace
from pathlib import Path
import jittor as jt

from model.sbert_model import SBERTJittor
from utils.download_data import download_mr, download_glue
from training.mr.train_mr import train as train_mr
from training.sst.train_sst import train as train_sst


In [None]:
data_dir = './data'
# download_mr(data_dir)
# download_glue(data_dir)  # SST-2


In [None]:
data_dir = './data'
# download_mr(data_dir)   # run once
# download_sst(data_dir)  # run once


In [None]:
model, tokenizer, repo_dir = SBERTJittor.from_pretrained(
    'Kyle-han/roberta-base-nli-mean-tokens',
    return_tokenizer=True,
)
repo_dir = Path(repo_dir)
ckpt = next(repo_dir.glob('*.pkl'))

if jt.has_cuda:
    jt.flags.use_cuda = 1


In [None]:
# MR downstream
mr_args = Namespace(
    base_model=str(repo_dir),
    tokenizer_dir=str(repo_dir),
    encoder_checkpoint=None,
    jittor_checkpoint=str(ckpt),
    pooling='mean',
    data_dir=data_dir,
    cache_dir=None,
    overwrite_cache=False,
    tokenize_batch_size=1024,
    batch_size=32,
    eval_batch_size=32,
    epochs=1,
    lr=2e-5,
    warmup_ratio=0.1,
    max_length=128,
    num_workers=4,
    use_cuda=jt.has_cuda,
    num_labels=2,
    train_encoder=False,
    log_steps=100,
    eval_steps=500,
    output_dir='./checkpoints/mr_debug',
    run_name=None,
)

train_mr(mr_args)


In [None]:
# SST downstream
sst_args = Namespace(
    base_model=str(repo_dir),
    tokenizer_dir=str(repo_dir),
    encoder_checkpoint=None,
    jittor_checkpoint=str(ckpt),
    pooling='mean',
    data_dir=data_dir,
    cache_dir=None,
    overwrite_cache=False,
    tokenize_batch_size=1024,
    batch_size=32,
    eval_batch_size=32,
    epochs=1,
    lr=2e-5,
    warmup_ratio=0.1,
    max_length=128,
    num_workers=4,
    use_cuda=jt.has_cuda,
    num_labels=2,
    train_encoder=False,
    log_steps=100,
    eval_steps=500,
    output_dir='./checkpoints/sst_debug',
    run_name=None,
)

train_sst(sst_args)
