In [1]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForMaskedLM, RobertaTokenizerFast
from transformers.trainer import TrainingArguments, Trainer
import logging
from utils import metrics
import datasets
import mlflow
from utils import container, text_utils
import tqdm

In [2]:
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

In [3]:
pre_trained_model_name = 'roberta-base'
logger.critical("Build pre-trained model {}".format(pre_trained_model_name))
base_pre_trained_model_path = '/home/ubuntu/likun/nlp_pretrained/{}'.format(pre_trained_model_name)
tokenizer = AutoTokenizer.from_pretrained(base_pre_trained_model_path)
# tokenizer = RobertaTokenizerFast.from_pretrained(base_pre_trained_model_path)

Build pre-trained model roberta-base


In [4]:
mlflow.set_tracking_uri("http://10.10.111.130:5005")
mlflow.set_experiment("few-shot-finetune")

mlflow_tags = {
    "paper": "Diverse Few-Shot Text Classification with Multiple Metrics",
    "dl_frame": "huggingface-pytorch",
    "pretrain-model": pre_trained_model_name,
    "mlflow.runName": "few-shot-prompt-arsc"
}

In [10]:
logger.critical("Build Training and validating dataset")
dataset_args = {
    "dataset_name": "ARSC",
    "data_cache_dir": "/home/ubuntu/likun/huggingface_dataset",
    "train_size": 192,
    "per_class_train_size": 32,
    "val_size": 200,
    "test_size": 500,
    "max_length": 100,
    "shuffle": True,
    "val_source": "no"  # no, trian, val, test
}
mlflow_tags.update(dataset_args)
# dataset = datasets.load_dataset(dataset_args['dataset_name'],cache_dir=dataset_args['data_cache_dir'])
dataset = datasets.load_dataset('csv', data_files={'train': ['ARSC_train.csv'], 'test': ['ARSC_test.csv']})
# dataset = datasets.load_dataset(dataset_args['dataset_name'], dataset_args['dataset_second_name'],cache_dir=dataset_args['data_cache_dir'])
# num_labels = dataset['train'].features['label'].num_classes
num_labels = 2

Build Training and validating dataset
Using custom data configuration default
Reusing dataset csv (/home/ubuntu/.cache/huggingface/datasets/csv/default-d2f27f1c83bf5a44/0.0.0/2960f95a26e85d40ca41a230ac88787f715ee3003edaacb8b1f0891e9f04dda2)


In [11]:
text_col_name = 'text'
print(dataset)
print(dataset['train'].features)
print('train size {}'.format(len(dataset['train'])))

print('Train dataset stat:')
text_utils.text_stat([example[text_col_name] for example in dataset['train']])
dataset['train'][1]

print('test size {}'.format(len(dataset['test'])))
print('Test dataset stat:')
text_utils.text_stat([example[text_col_name] for example in dataset['test']])

DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'domain', 'task_threshold', 'use'],
        num_rows: 119745
    })
    test: Dataset({
        features: ['text', 'label', 'domain', 'task_threshold', 'use'],
        num_rows: 18627
    })
})
{'text': Value(dtype='string', id=None), 'label': Value(dtype='int64', id=None), 'domain': Value(dtype='string', id=None), 'task_threshold': Value(dtype='string', id=None), 'use': Value(dtype='string', id=None)}
train size 119745
Train dataset stat:
Min length: 3, Max length: 19934, Avg length: 526.0338385736356
test size 18627
Test dataset stat:
Min length: 5, Max length: 14475, Avg length: 600.5344392548451


In [12]:
dataset['test'][0]

{'text': "old dvd 1997 looks better than new hd version . i've compared two , first dvd good color saturation , dark blacks looked right heavy red brown palette . hd version blue-green haze covering everything picture too bright . they also removed film grain made image look smooth creamy . sound quality better now , though , image course higher resolution .",
 'label': -1,
 'domain': 'dvd',
 'task_threshold': 't4',
 'use': 'dev'}

In [7]:
def label_select(dataset, per_label_num=32, label_name='label', label_is_index=True):
    lds = []
    for i, name in enumerate(dataset.features[label_name].names):
        label = i if label_is_index else name
        ld = dataset.filter(lambda example: example[label_name]==label).shuffle().select(range(per_label_num))
        lds.append(ld)
    ls = datasets.concatenate_datasets(lds)
    return ls
select_train_dataset = label_select(dataset['train'], dataset_args['per_class_train_size'], label_name='label-coarse')
dataset['train'] = select_train_dataset

Loading cached processed dataset at /home/ubuntu/likun/huggingface_dataset/trec/default/1.1.0/1902c380fe66cc215f989888b1b35e8da7e79a3a97520f00dce753fd1f8f5c48/cache-b246349c0734d46b.arrow
Loading cached shuffled indices for dataset at /home/ubuntu/likun/huggingface_dataset/trec/default/1.1.0/1902c380fe66cc215f989888b1b35e8da7e79a3a97520f00dce753fd1f8f5c48/cache-97391c367abf3261.arrow
Loading cached processed dataset at /home/ubuntu/likun/huggingface_dataset/trec/default/1.1.0/1902c380fe66cc215f989888b1b35e8da7e79a3a97520f00dce753fd1f8f5c48/cache-0378c94e6cde7404.arrow
Loading cached processed dataset at /home/ubuntu/likun/huggingface_dataset/trec/default/1.1.0/1902c380fe66cc215f989888b1b35e8da7e79a3a97520f00dce753fd1f8f5c48/cache-033bf1ce24c5d8a3.arrow
Loading cached processed dataset at /home/ubuntu/likun/huggingface_dataset/trec/default/1.1.0/1902c380fe66cc215f989888b1b35e8da7e79a3a97520f00dce753fd1f8f5c48/cache-e2958d0a3396289c.arrow
Loading cached processed dataset at /home/ubuntu/

In [8]:
if dataset_args['shuffle']:
    dataset = dataset.shuffle()

### 将文本变换成prompt-based

In [9]:
template = '{}:{}'
mask_token = '<mask>'
label_names = dataset['train'].features['label-coarse'].names
dataset = dataset.map(lambda example: {'labels': template.format(label_names[example['label-coarse']], example['text'])})
dataset = dataset.map(lambda example: {'text': template.format(mask_token, example['text'])})
dataset = dataset.map(lambda example: example, remove_columns=['label-coarse', 'label-fine'])

HBox(children=(FloatProgress(value=0.0, max=192.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=500.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=192.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=500.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=192.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=500.0), HTML(value='')))




In [10]:
# dataset = dataset.map(lambda example: {'label': example['label-coarse']}, remove_columns=['label-coarse', 'label-fine'])
dataset['train']['text'][0]

'<mask>:What presidential administration challenged Americans to explore The New Frontier ?'

In [11]:
if dataset_args['val_source'] == 'no':
    train_dataset = dataset['train'].select(range(dataset_args['train_size']))
    test_dataset = dataset['test'].select(range(dataset_args['test_size']))
    val_dataset = None
elif dataset_args['val_source'] == 'train':
    train_dataset = dataset['train'].train_test_split(train_size=dataset_args['train_size'],
                                                      test_size=dataset_args['val_size'])
    train_dataset, val_dataset = train_dataset['train'], train_dataset['test']
    test_dataset = dataset['test'].select(range(dataset_args['test_size']))
elif dataset_args['val_source'] == 'test':
    train_dataset = dataset['train'].select(range(dataset_args['train_size']))
    val_dataset = dataset['test'].select(range(dataset_args['val_size']))
    test_dataset = dataset['test'].select(range(dataset_args['test_size']))
elif dataset_args['val_source'] == 'val':
    train_dataset = dataset['train'].select(range(dataset_args['train_size']))
    val_dataset = dataset['validation'].select(range(dataset_args['val_size']))
    test_dataset = dataset['test'].select(range(dataset_args['test_size']))
else:
    raise Exception("Val source should be 'no', 'train', 'test', 'val' ")

### 文本编码

In [12]:
def head_tail_encode(examples):
    examples['text'] = list(map(lambda t: t[:128] + t[-382:] if len(t) > 510 else t,examples['text']))
    return tokenizer(examples['text'] , max_length=dataset_args['max_length'], truncation=True, padding='max_length')

def standard_encode(examples):
    return tokenizer(examples['text'] , max_length=dataset_args['max_length'], truncation=True, padding='max_length')

def mask_lm_encode(examples):
    res = tokenizer(examples['text'] , max_length=dataset_args['max_length'], truncation=True, padding='max_length')
    res['labels'] = tokenizer(examples['labels'] , max_length=dataset_args['max_length'], truncation=True, padding='max_length')['input_ids']
    return res

In [13]:
encode_func = mask_lm_encode
train_dataset = train_dataset.map(encode_func, batched=True)

if dataset_args['val_source'] != 'no':
    val_dataset = val_dataset.map(encode_func, batched=True)

test_dataset = test_dataset.map(encode_func, batched=True)

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




In [14]:
logger.critical("Setup the training environment")
model = AutoModelForMaskedLM.from_pretrained(base_pre_trained_model_path)
# model = AutoModelForSequenceClassification.from_pretrained(base_pre_trained_model_path,
#                                                            num_labels=num_labels,
#                                                            output_attentions=False,
#                                                            output_hidden_states=False)
model.config.return_dict = True

Setup the training environment
Some weights of RobertaForMaskedLM were not initialized from the model checkpoint at /home/ubuntu/likun/nlp_pretrained/roberta-base and are newly initialized: ['lm_head.decoder.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
def mask_predict_metrics(predict_res):
    mask_id = tokenizer.get_vocab()[mask_token]
    mask_index = [ids.index(mask_id) for ids in test_dataset['input_ids']]
    predict_ids = predict_res.predictions.argmax(2)
    y_predict = predict_ids[range(len(mask_index)), mask_index]
    y_true = predict_res.label_ids[range(len(mask_index)), mask_index]
    return metrics.base_classify_metrics(y_true, y_predict)

def iter_test_eval(test_dataset, chunk_size=32):
    start_index = 0
    end_index = 0
    mask_id = tokenizer.get_vocab()[mask_token]
    y_predict = []
    y_true = []
    while end_index < len(test_dataset):
        start_index = end_index
        end_index = min(end_index + chunk_size, len(test_dataset))
        select_dataset = test_dataset.select(range(start_index, end_index))
        predict_res = trainer.predict(select_dataset)
        mask_index = [ids.index(mask_id) for ids in select_dataset['input_ids']]
        predict_ids = predict_res.predictions.argmax(2)
        y_predict.extend(predict_ids[range(len(mask_index)), mask_index])
        y_true.extend(predict_res.label_ids[range(len(mask_index)), mask_index])
        print("Process end index: {}".format(end_index))
    mres = {'eval_{}'.format(k): v for k, v in metrics.base_classify_metrics(y_true, y_predict).items()}
    return mres

In [16]:
num_train_epochs = 100
train_batch_size = 16
if dataset_args['val_source'] == 'no':
    eval_steps = 1000000
else:
    eval_steps = 30
warmup_steps = int(len(train_dataset) * num_train_epochs // train_batch_size * 0.1)
training_args = TrainingArguments(
    output_dir='/home/ubuntu/likun/nlp_save_kernels/{}'.format(mlflow_tags['mlflow.runName']),  # output directory
    num_train_epochs=num_train_epochs,  # total number of training epochs
    per_device_train_batch_size=train_batch_size,  # batch size per device during training
    per_device_eval_batch_size=16,  # batch size for evaluation
    warmup_steps=warmup_steps,  # number of warmup steps for learning rate scheduler
    weight_decay=0.99,  # strength of weight decay
    logging_dir='/home/ubuntu/likun/nlp_training_logs/{}'.format(mlflow_tags['mlflow.runName']),  # directory for storing logs
    logging_steps=10,
    learning_rate=5e-5,
    seed=44,
    no_cuda=False,
    eval_steps=eval_steps,
    evaluate_during_training=True
)
train_params = {k: v for k, v in training_args.__dict__.items() if (isinstance(v, int) or isinstance(v, float)) and not isinstance(v, bool)}
trainer = Trainer(
    model=model,  # the instantiated 🤗 Transformers model to be trained
    args=training_args,  # training arguments, defined above
    train_dataset=train_dataset,  # training dataset
    eval_dataset=val_dataset,  # evaluation dataset
    compute_metrics=None,
)

In [17]:
with mlflow.start_run():
    logger.critical("Start to train")
    train_res = trainer.train()
    
    if dataset_args['val_source'] != 'no':
        logger.critical("Start to evaluate")
        eval_res = trainer.evaluate()
    
    logger.critical("Start to test")
#     test_res = trainer.predict(test_dataset).metrics
    test_res = iter_test_eval(test_dataset)
    
    # 记录实验参数
    mlflow.set_tags(mlflow_tags)
    mlflow.log_params(train_params)

    # 记录测试的评估指标
    mlflow.log_metric("train_loss", train_res.training_loss)
    if dataset_args['val_source'] != 'no':
        mlflow.log_metrics(eval_res)
    mlflow.log_metrics({k.replace('eval', 'test'): v for k, v in test_res.items()})
    

Start to train


HBox(children=(FloatProgress(value=0.0, description='Epoch', style=ProgressStyle(description_width='initial'))…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…






HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 22.11819076538086, 'learning_rate': 4.166666666666667e-06, 'epoch': 1.6666666666666665, 'step': 10}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 14.96749668121338, 'learning_rate': 8.333333333333334e-06, 'epoch': 3.3333333333333335, 'step': 20}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 12.597815895080567, 'learning_rate': 1.25e-05, 'epoch': 5.0, 'step': 30}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 11.955266666412353, 'learning_rate': 1.6666666666666667e-05, 'epoch': 6.666666666666667, 'step': 40}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 11.306912994384765, 'learning_rate': 2.0833333333333336e-05, 'epoch': 8.333333333333334, 'step': 50}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 10.384649562835694, 'learning_rate': 2.5e-05, 'epoch': 10.0, 'step': 60}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 9.055826950073243, 'learning_rate': 2.916666666666667e-05, 'epoch': 11.666666666666666, 'step': 70}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 8.209088134765626, 'learning_rate': 3.3333333333333335e-05, 'epoch': 13.333333333333334, 'step': 80}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 7.542895698547364, 'learning_rate': 3.7500000000000003e-05, 'epoch': 15.0, 'step': 90}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 6.875672674179077, 'learning_rate': 4.166666666666667e-05, 'epoch': 16.666666666666668, 'step': 100}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 6.190883588790894, 'learning_rate': 4.5833333333333334e-05, 'epoch': 18.333333333333332, 'step': 110}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 5.488639068603516, 'learning_rate': 5e-05, 'epoch': 20.0, 'step': 120}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 4.746929550170899, 'learning_rate': 4.8958333333333335e-05, 'epoch': 21.666666666666668, 'step': 130}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 4.016482186317444, 'learning_rate': 4.791666666666667e-05, 'epoch': 23.333333333333332, 'step': 140}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 3.3163510084152223, 'learning_rate': 4.6875e-05, 'epoch': 25.0, 'step': 150}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 2.637743520736694, 'learning_rate': 4.5833333333333334e-05, 'epoch': 26.666666666666668, 'step': 160}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 2.008706068992615, 'learning_rate': 4.4791666666666673e-05, 'epoch': 28.333333333333332, 'step': 170}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 1.4475872039794921, 'learning_rate': 4.375e-05, 'epoch': 30.0, 'step': 180}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.9736446261405944, 'learning_rate': 4.270833333333333e-05, 'epoch': 31.666666666666668, 'step': 190}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.6205320417881012, 'learning_rate': 4.166666666666667e-05, 'epoch': 33.333333333333336, 'step': 200}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.37580873966217043, 'learning_rate': 4.0625000000000005e-05, 'epoch': 35.0, 'step': 210}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.231853711605072, 'learning_rate': 3.958333333333333e-05, 'epoch': 36.666666666666664, 'step': 220}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.14928898885846137, 'learning_rate': 3.854166666666667e-05, 'epoch': 38.333333333333336, 'step': 230}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.10910076647996902, 'learning_rate': 3.7500000000000003e-05, 'epoch': 40.0, 'step': 240}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.08364971205592156, 'learning_rate': 3.6458333333333336e-05, 'epoch': 41.666666666666664, 'step': 250}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.074151211977005, 'learning_rate': 3.541666666666667e-05, 'epoch': 43.333333333333336, 'step': 260}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.06287065446376801, 'learning_rate': 3.4375e-05, 'epoch': 45.0, 'step': 270}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.05747799202799797, 'learning_rate': 3.3333333333333335e-05, 'epoch': 46.666666666666664, 'step': 280}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.05354461558163166, 'learning_rate': 3.229166666666667e-05, 'epoch': 48.333333333333336, 'step': 290}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.04754711128771305, 'learning_rate': 3.125e-05, 'epoch': 50.0, 'step': 300}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.046015160903334615, 'learning_rate': 3.0208333333333334e-05, 'epoch': 51.666666666666664, 'step': 310}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.04523119889199734, 'learning_rate': 2.916666666666667e-05, 'epoch': 53.333333333333336, 'step': 320}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.04421329572796821, 'learning_rate': 2.8125000000000003e-05, 'epoch': 55.0, 'step': 330}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.04427529312670231, 'learning_rate': 2.7083333333333332e-05, 'epoch': 56.666666666666664, 'step': 340}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.04125856384634972, 'learning_rate': 2.604166666666667e-05, 'epoch': 58.333333333333336, 'step': 350}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.04272979088127613, 'learning_rate': 2.5e-05, 'epoch': 60.0, 'step': 360}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.04213319011032581, 'learning_rate': 2.3958333333333334e-05, 'epoch': 61.666666666666664, 'step': 370}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.04052279107272625, 'learning_rate': 2.2916666666666667e-05, 'epoch': 63.333333333333336, 'step': 380}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.04173435792326927, 'learning_rate': 2.1875e-05, 'epoch': 65.0, 'step': 390}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.04115041643381119, 'learning_rate': 2.0833333333333336e-05, 'epoch': 66.66666666666667, 'step': 400}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.04033973775804043, 'learning_rate': 1.9791666666666665e-05, 'epoch': 68.33333333333333, 'step': 410}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.04165144860744476, 'learning_rate': 1.8750000000000002e-05, 'epoch': 70.0, 'step': 420}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.039677802473306656, 'learning_rate': 1.7708333333333335e-05, 'epoch': 71.66666666666667, 'step': 430}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.041812087222933766, 'learning_rate': 1.6666666666666667e-05, 'epoch': 73.33333333333333, 'step': 440}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.039652717486023906, 'learning_rate': 1.5625e-05, 'epoch': 75.0, 'step': 450}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.03991901203989982, 'learning_rate': 1.4583333333333335e-05, 'epoch': 76.66666666666667, 'step': 460}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.04038015641272068, 'learning_rate': 1.3541666666666666e-05, 'epoch': 78.33333333333333, 'step': 470}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.0398643184453249, 'learning_rate': 1.25e-05, 'epoch': 80.0, 'step': 480}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.03887099735438824, 'learning_rate': 1.1458333333333333e-05, 'epoch': 81.66666666666667, 'step': 490}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.040584278479218486, 'learning_rate': 1.0416666666666668e-05, 'epoch': 83.33333333333333, 'step': 500}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.039634499326348306, 'learning_rate': 9.375000000000001e-06, 'epoch': 85.0, 'step': 510}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.03984738327562809, 'learning_rate': 8.333333333333334e-06, 'epoch': 86.66666666666667, 'step': 520}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.03949070945382118, 'learning_rate': 7.2916666666666674e-06, 'epoch': 88.33333333333333, 'step': 530}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.03881240412592888, 'learning_rate': 6.25e-06, 'epoch': 90.0, 'step': 540}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.03928893655538559, 'learning_rate': 5.208333333333334e-06, 'epoch': 91.66666666666667, 'step': 550}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.0404557004570961, 'learning_rate': 4.166666666666667e-06, 'epoch': 93.33333333333333, 'step': 560}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.0385601606220007, 'learning_rate': 3.125e-06, 'epoch': 95.0, 'step': 570}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.039016172662377356, 'learning_rate': 2.0833333333333334e-06, 'epoch': 96.66666666666667, 'step': 580}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

{'loss': 0.0395346038043499, 'learning_rate': 1.0416666666666667e-06, 'epoch': 98.33333333333333, 'step': 590}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=6.0, style=ProgressStyle(description_widt…

Start to test


{'loss': 0.03899840898811817, 'learning_rate': 0.0, 'epoch': 100.0, 'step': 600}




HBox(children=(FloatProgress(value=0.0, description='Prediction', max=1.0, style=ProgressStyle(description_wid…


Process end index: 32


HBox(children=(FloatProgress(value=0.0, description='Prediction', max=1.0, style=ProgressStyle(description_wid…


Process end index: 64


HBox(children=(FloatProgress(value=0.0, description='Prediction', max=1.0, style=ProgressStyle(description_wid…


Process end index: 96


HBox(children=(FloatProgress(value=0.0, description='Prediction', max=1.0, style=ProgressStyle(description_wid…


Process end index: 128


HBox(children=(FloatProgress(value=0.0, description='Prediction', max=1.0, style=ProgressStyle(description_wid…


Process end index: 160


HBox(children=(FloatProgress(value=0.0, description='Prediction', max=1.0, style=ProgressStyle(description_wid…


Process end index: 192


HBox(children=(FloatProgress(value=0.0, description='Prediction', max=1.0, style=ProgressStyle(description_wid…


Process end index: 224


HBox(children=(FloatProgress(value=0.0, description='Prediction', max=1.0, style=ProgressStyle(description_wid…


Process end index: 256


HBox(children=(FloatProgress(value=0.0, description='Prediction', max=1.0, style=ProgressStyle(description_wid…


Process end index: 288


HBox(children=(FloatProgress(value=0.0, description='Prediction', max=1.0, style=ProgressStyle(description_wid…


Process end index: 320


HBox(children=(FloatProgress(value=0.0, description='Prediction', max=1.0, style=ProgressStyle(description_wid…


Process end index: 352


HBox(children=(FloatProgress(value=0.0, description='Prediction', max=1.0, style=ProgressStyle(description_wid…


Process end index: 384


HBox(children=(FloatProgress(value=0.0, description='Prediction', max=1.0, style=ProgressStyle(description_wid…


Process end index: 416


HBox(children=(FloatProgress(value=0.0, description='Prediction', max=1.0, style=ProgressStyle(description_wid…


Process end index: 448


HBox(children=(FloatProgress(value=0.0, description='Prediction', max=1.0, style=ProgressStyle(description_wid…


Process end index: 480


HBox(children=(FloatProgress(value=0.0, description='Prediction', max=1.0, style=ProgressStyle(description_wid…


Process end index: 500


In [19]:
# logger.critical("Save the model and config")
# trainer.save_model()
# tokenizer.save_vocabulary(training_args.output_dir)
# logger.critical("Experiment Finnished")
test_res

{'eval_error_rate': 0.16600000000000004,
 'eval_accuracy': 0.834,
 'eval_precision': 0.7850103190604076,
 'eval_recall': 0.8608103391071836,
 'eval_f1_score': 0.7745530999472262}