In [38]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers.trainer import TrainingArguments, Trainer
import logging
from utils import metrics
import datasets
import mlflow
from utils import container, text_utils

In [39]:
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

In [41]:
pre_trained_model_name = 'bert-google-uncase-base'
logger.critical("Build pre-trained model {}".format(pre_trained_model_name))
base_pre_trained_model_path = '/home/ubuntu/likun/nlp_pretrained/{}'.format(pre_trained_model_name)
tokenizer = AutoTokenizer.from_pretrained(base_pre_trained_model_path)

Build pre-trained model bert-google-uncase-base


In [None]:
mlflow.set_tracking_uri("http://10.10.111.130:5005")
mlflow.set_experiment("bert_classify_finetune_experiment")

mlflow_tags = {
    "paper": "How to Fine-Tune BERT for Text Classification",
    "dl_frame": "huggingface-pytorch",
    "pretrain-model": pre_trained_model_name,
    "mlflow.runName": "bert_imdb_baseline_v2"
}

In [42]:
logger.critical("Build Training and validating dataset")
dataset_args = {
    "name": "imdb",
    "data_cache_dir": f"/home/ubuntu/likun/huggingface_dataset",
    "train_size": 25000,
    "val_size": 0,
    "test_size": 25000,
    "max_length": 510
}
mlflow_tags.update(dataset_args)
dataset = datasets.load_dataset(dataset_args['name'], cache_dir=dataset_args['data_cache_dir'])

# num_labels = dataset['train'].features['label-coarse'].num_classes


Build Training and validating dataset
Reusing dataset imdb (/home/ubuntu/likun/huggingface_dataset/imdb/plain_text/1.0.0/90099cb476936b753383ba2ae6ab2eae419b2e87f71cd5189cb9c8e5814d12a3)


In [43]:
print(dataset)
print(dataset['train'].features)
print('train size {}'.format(len(dataset['train'])))

print('Train dataset stat:')
text_utils.text_stat([example['text'] for example in dataset['train']])
dataset['train'][1]

print('test size {}'.format(len(dataset['test'])))
print('Test dataset stat:')
text_utils.text_stat([example['text'] for example in dataset['test']])

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})
{'text': Value(dtype='string', id=None), 'label': ClassLabel(num_classes=2, names=['neg', 'pos'], names_file=None, id=None)}
train size 25000
Train dataset stat:
Min length: 52, Max length: 13704, Avg length: 1325.06964
test size 25000
Test dataset stat:
Min length: 32, Max length: 12988, Avg length: 1293.7924


In [44]:
# dataset = dataset.map(lambda example: {'label': example['label-coarse']}, remove_columns=['label-coarse', 'label-fine'])

In [45]:
if dataset_args['train_size'] == len(dataset['train']):
    train_dataset = dataset['train']
elif dataset_args['val_size'] != 0:
    train_dataset = dataset['train'].train_test_split(train_size=dataset_args['train_size'],
                                                      test_size=dataset_args['val_size'])
    train_dataset, val_dataset = train_dataset['train'], train_dataset['test']
else:
    train_dataset = dataset['train'].train_test_split(train_size=dataset_args['train_size'],
                                                      test_size=1)
    train_dataset, _ = train_dataset['train'], train_dataset['test']

if dataset_args['test_size'] < len(dataset['test']):
    test_dataset = dataset['test'].train_test_split(train_size=dataset_args['test_size'])
    test_dataset = test_dataset['train']
else:
    test_dataset = dataset['test']

In [46]:
def head_tail_encode(examples):
    examples['text'] = list(map(lambda t: t[:128] + t[-382:] if len(t) > 510 else t,examples['text']))
    return tokenizer(examples['text'] , max_length=dataset_args['max_length'], truncation=True, padding='max_length')

def standard_encode(examples):
    return tokenizer(examples['text'] , max_length=dataset_args['max_length'], truncation=True, padding='max_length')

In [47]:
encode_func = head_tail_encode
train_dataset = train_dataset.map(encode_func, batched=True)
test_dataset = test_dataset.map(encode_func, batched=True)
if dataset_args['val_size'] != 0:
    val_dataset = val_dataset.map(encode_func, batched=True)
else:
    val_dataset = test_dataset


HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))




In [48]:
logger.critical("Setup the training environment")
model = AutoModelForSequenceClassification.from_pretrained(base_pre_trained_model_path,
                                                           num_labels=num_labels,
                                                           output_attentions=False,
                                                           output_hidden_states=False)

Setup the training environment
Some weights of the model checkpoint at /home/ubuntu/likun/nlp_pretrained/bert-google-uncase-base were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForS

In [52]:
train_batch_size = 8
warmup_steps = int(len(train_dataset) // train_batch_size * 0.1)
training_args = TrainingArguments(
    output_dir='/home/ubuntu/likun/nlp_save_kernels/{}'.format(mlflow_tags['mlflow.runName']),  # output directory
    num_train_epochs=4,  # total number of training epochs
    per_device_train_batch_size=train_batch_size,  # batch size per device during training
    per_device_eval_batch_size=32,  # batch size for evaluation
    warmup_steps=warmup_steps,  # number of warmup steps for learning rate scheduler
    weight_decay=0.95,  # strength of weight decay
    logging_dir='/home/ubuntu/likun/nlp_training_logs/{}'.format(mlflow_tags['mlflow.runName']),  # directory for storing logs
    logging_steps=300,
    learning_rate=2e-5,
    seed=44,
    no_cuda=False,
    evaluate_during_training=True
)
train_params = {k: v for k, v in training_args.__dict__.items() if (isinstance(v, int) or isinstance(v, float)) and not isinstance(v, bool)}
trainer = Trainer(
    model=model,  # the instantiated 🤗 Transformers model to be trained
    args=training_args,  # training arguments, defined above
    train_dataset=train_dataset,  # training dataset
    eval_dataset=val_dataset,  # evaluation dataset
    compute_metrics=metrics.classify_metrics,
)

In [53]:
with mlflow.start_run():
    logger.critical("Start to train")
    train_res = trainer.train()
    
    if dataset_args['val_size'] != 0:
        logger.critical("Start to evaluate")
        eval_res = trainer.evaluate()
    
    logger.critical("Start to test")
    test_res = trainer.predict(test_dataset)
    
    # 记录实验参数
    mlflow.set_tags(mlflow_tags)
    mlflow.log_params(train_params)

    # 记录测试的评估指标
    mlflow.log_metric("train_loss", train_res.training_loss)
    if dataset_args['val_size'] != 0:
        mlflow.log_metrics(eval_res)
    mlflow.log_metrics({k.replace('eval', 'test'): v for k, v in test_res.metrics.items()})
    

Start to train


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1563.0, style=ProgressStyle(description_w…

{'loss': 1.8880080342292787, 'learning_rate': 6.41025641025641e-07, 'epoch': 0.006397952655150352, 'step': 10}
{'loss': 1.8170116543769836, 'learning_rate': 1.282051282051282e-06, 'epoch': 0.012795905310300703, 'step': 20}
{'loss': 1.6331066012382507, 'learning_rate': 1.9230769230769234e-06, 'epoch': 0.019193857965451054, 'step': 30}
{'loss': 1.4644789934158324, 'learning_rate': 2.564102564102564e-06, 'epoch': 0.025591810620601407, 'step': 40}
{'loss': 1.3486561179161072, 'learning_rate': 3.205128205128206e-06, 'epoch': 0.03198976327575176, 'step': 50}
{'loss': 1.2790204882621765, 'learning_rate': 3.846153846153847e-06, 'epoch': 0.03838771593090211, 'step': 60}
{'loss': 1.1935676336288452, 'learning_rate': 4.487179487179488e-06, 'epoch': 0.044785668586052464, 'step': 70}
{'loss': 1.1019884467124939, 'learning_rate': 5.128205128205128e-06, 'epoch': 0.05118362124120281, 'step': 80}
{'loss': 1.0263014137744904, 'learning_rate': 5.769230769230769e-06, 'epoch': 0.05758157389635317, 'step': 

{'loss': 0.26144900023937223, 'learning_rate': 1.855892255892256e-05, 'epoch': 0.473448496481126, 'step': 740}
{'loss': 0.22176869660615922, 'learning_rate': 1.8525252525252526e-05, 'epoch': 0.4798464491362764, 'step': 750}
{'loss': 0.3019829779863358, 'learning_rate': 1.8491582491582495e-05, 'epoch': 0.48624440179142675, 'step': 760}
{'loss': 0.29297216087579725, 'learning_rate': 1.845791245791246e-05, 'epoch': 0.4926423544465771, 'step': 770}
{'loss': 0.1666583776473999, 'learning_rate': 1.8424242424242425e-05, 'epoch': 0.4990403071017274, 'step': 780}
{'loss': 0.335477802157402, 'learning_rate': 1.839057239057239e-05, 'epoch': 0.5054382597568778, 'step': 790}
{'loss': 0.34468733668327334, 'learning_rate': 1.8356902356902356e-05, 'epoch': 0.5118362124120281, 'step': 800}
{'loss': 0.28547153025865557, 'learning_rate': 1.8323232323232325e-05, 'epoch': 0.5182341650671785, 'step': 810}
{'loss': 0.26273577064275744, 'learning_rate': 1.828956228956229e-05, 'epoch': 0.5246321177223289, 'ste

{'loss': 0.24675323069095612, 'learning_rate': 1.6101010101010103e-05, 'epoch': 0.9404990403071017, 'step': 1470}
{'loss': 0.301469224691391, 'learning_rate': 1.606734006734007e-05, 'epoch': 0.946896992962252, 'step': 1480}
{'loss': 0.2819507151842117, 'learning_rate': 1.6033670033670034e-05, 'epoch': 0.9532949456174025, 'step': 1490}
{'loss': 0.1753104269504547, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.9596928982725528, 'step': 1500}
{'loss': 0.267755015194416, 'learning_rate': 1.5966329966329968e-05, 'epoch': 0.9660908509277031, 'step': 1510}
{'loss': 0.26680568009614947, 'learning_rate': 1.5932659932659933e-05, 'epoch': 0.9724888035828535, 'step': 1520}
{'loss': 0.3581096053123474, 'learning_rate': 1.5898989898989902e-05, 'epoch': 0.9788867562380038, 'step': 1530}
{'loss': 0.35513498783111574, 'learning_rate': 1.5865319865319868e-05, 'epoch': 0.9852847088931542, 'step': 1540}
{'loss': 0.2280543178319931, 'learning_rate': 1.5831649831649833e-05, 'epoch': 0.9916826615483045

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1563.0, style=ProgressStyle(description_w…

{'loss': 0.22317509055137635, 'learning_rate': 1.5764309764309767e-05, 'epoch': 1.0044785668586051, 'step': 1570}
{'loss': 0.20434503853321076, 'learning_rate': 1.5730639730639732e-05, 'epoch': 1.0108765195137557, 'step': 1580}
{'loss': 0.15082778483629228, 'learning_rate': 1.5696969696969698e-05, 'epoch': 1.017274472168906, 'step': 1590}
{'loss': 0.1930029347538948, 'learning_rate': 1.5663299663299666e-05, 'epoch': 1.0236724248240563, 'step': 1600}
{'loss': 0.20182546526193618, 'learning_rate': 1.5629629629629632e-05, 'epoch': 1.0300703774792066, 'step': 1610}
{'loss': 0.153373883664608, 'learning_rate': 1.5595959595959597e-05, 'epoch': 1.036468330134357, 'step': 1620}
{'loss': 0.0989776760339737, 'learning_rate': 1.5562289562289563e-05, 'epoch': 1.0428662827895074, 'step': 1630}
{'loss': 0.13197975605726242, 'learning_rate': 1.552861952861953e-05, 'epoch': 1.0492642354446577, 'step': 1640}
{'loss': 0.19552517384290696, 'learning_rate': 1.5494949494949497e-05, 'epoch': 1.0556621880998

{'loss': 0.1970289707183838, 'learning_rate': 1.3306397306397308e-05, 'epoch': 1.471529110684581, 'step': 2300}
{'loss': 0.13032741695642472, 'learning_rate': 1.3272727272727275e-05, 'epoch': 1.4779270633397312, 'step': 2310}
{'loss': 0.23026261925697328, 'learning_rate': 1.323905723905724e-05, 'epoch': 1.4843250159948815, 'step': 2320}
{'loss': 0.14627012312412263, 'learning_rate': 1.3205387205387206e-05, 'epoch': 1.490722968650032, 'step': 2330}
{'loss': 0.17240799963474274, 'learning_rate': 1.3171717171717173e-05, 'epoch': 1.4971209213051824, 'step': 2340}
{'loss': 0.1505471259355545, 'learning_rate': 1.3138047138047138e-05, 'epoch': 1.5035188739603327, 'step': 2350}
{'loss': 0.1913965255022049, 'learning_rate': 1.3104377104377107e-05, 'epoch': 1.5099168266154832, 'step': 2360}
{'loss': 0.1654383882880211, 'learning_rate': 1.3070707070707072e-05, 'epoch': 1.5163147792706333, 'step': 2370}
{'loss': 0.18454862534999847, 'learning_rate': 1.303703703703704e-05, 'epoch': 1.52271273192578

{'loss': 0.13295522779226304, 'learning_rate': 1.084848484848485e-05, 'epoch': 1.9385796545105567, 'step': 3030}
{'loss': 0.254815936088562, 'learning_rate': 1.0814814814814816e-05, 'epoch': 1.944977607165707, 'step': 3040}
{'loss': 0.17257849872112274, 'learning_rate': 1.0781144781144781e-05, 'epoch': 1.9513755598208573, 'step': 3050}
{'loss': 0.25620238184928895, 'learning_rate': 1.0747474747474748e-05, 'epoch': 1.9577735124760078, 'step': 3060}
{'loss': 0.13200246393680573, 'learning_rate': 1.0713804713804714e-05, 'epoch': 1.964171465131158, 'step': 3070}
{'loss': 0.1859428808093071, 'learning_rate': 1.0680134680134683e-05, 'epoch': 1.9705694177863085, 'step': 3080}
{'loss': 0.122044138610363, 'learning_rate': 1.0646464646464648e-05, 'epoch': 1.9769673704414588, 'step': 3090}
{'loss': 0.16623858362436295, 'learning_rate': 1.0612794612794615e-05, 'epoch': 1.983365323096609, 'step': 3100}
{'loss': 0.14927822947502137, 'learning_rate': 1.057912457912458e-05, 'epoch': 1.9897632757517596

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1563.0, style=ProgressStyle(description_w…

{'loss': 0.22688088119029998, 'learning_rate': 1.0511784511784513e-05, 'epoch': 2.00255918106206, 'step': 3130}
{'loss': 0.08465697318315506, 'learning_rate': 1.0478114478114478e-05, 'epoch': 2.0089571337172103, 'step': 3140}
{'loss': 0.04573826193809509, 'learning_rate': 1.0444444444444445e-05, 'epoch': 2.015355086372361, 'step': 3150}
{'loss': 0.11550106704235077, 'learning_rate': 1.041077441077441e-05, 'epoch': 2.0217530390275114, 'step': 3160}
{'loss': 0.07424029856920242, 'learning_rate': 1.037710437710438e-05, 'epoch': 2.0281509916826614, 'step': 3170}
{'loss': 0.04768117666244507, 'learning_rate': 1.0343434343434345e-05, 'epoch': 2.034548944337812, 'step': 3180}
{'loss': 0.04327450692653656, 'learning_rate': 1.030976430976431e-05, 'epoch': 2.040946896992962, 'step': 3190}
{'loss': 0.056641629338264464, 'learning_rate': 1.0276094276094277e-05, 'epoch': 2.0473448496481126, 'step': 3200}
{'loss': 0.1299564242362976, 'learning_rate': 1.0242424242424242e-05, 'epoch': 2.05374280230326

{'loss': 0.09448942542076111, 'learning_rate': 8.053872053872055e-06, 'epoch': 2.469609724888036, 'step': 3860}
{'loss': 0.035338136553764346, 'learning_rate': 8.02020202020202e-06, 'epoch': 2.476007677543186, 'step': 3870}
{'loss': 0.09118714481592179, 'learning_rate': 7.986531986531986e-06, 'epoch': 2.4824056301983366, 'step': 3880}
{'loss': 0.15406369119882585, 'learning_rate': 7.952861952861953e-06, 'epoch': 2.488803582853487, 'step': 3890}
{'loss': 0.023316648602485657, 'learning_rate': 7.91919191919192e-06, 'epoch': 2.495201535508637, 'step': 3900}
{'loss': 0.12533117681741715, 'learning_rate': 7.885521885521886e-06, 'epoch': 2.5015994881637877, 'step': 3910}
{'loss': 0.06203100979328156, 'learning_rate': 7.851851851851853e-06, 'epoch': 2.507997440818938, 'step': 3920}
{'loss': 0.0631633684039116, 'learning_rate': 7.81818181818182e-06, 'epoch': 2.5143953934740884, 'step': 3930}
{'loss': 0.024477842450141906, 'learning_rate': 7.784511784511785e-06, 'epoch': 2.5207933461292384, 'st

{'loss': 0.04652988165616989, 'learning_rate': 5.562289562289563e-06, 'epoch': 2.943058221369162, 'step': 4600}
{'loss': 0.08826788663864135, 'learning_rate': 5.528619528619529e-06, 'epoch': 2.9494561740243124, 'step': 4610}
{'loss': 0.07413851320743561, 'learning_rate': 5.494949494949495e-06, 'epoch': 2.9558541266794625, 'step': 4620}
{'loss': 0.02076859474182129, 'learning_rate': 5.461279461279462e-06, 'epoch': 2.962252079334613, 'step': 4630}
{'loss': 0.10044800639152526, 'learning_rate': 5.427609427609428e-06, 'epoch': 2.968650031989763, 'step': 4640}
{'loss': 0.06044094562530518, 'learning_rate': 5.3939393939393945e-06, 'epoch': 2.9750479846449136, 'step': 4650}
{'loss': 0.04366032183170319, 'learning_rate': 5.3602693602693615e-06, 'epoch': 2.981445937300064, 'step': 4660}
{'loss': 0.06885612905025482, 'learning_rate': 5.326599326599327e-06, 'epoch': 2.987843889955214, 'step': 4670}
{'loss': 0.10312979072332382, 'learning_rate': 5.292929292929293e-06, 'epoch': 2.9942418426103647, 

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1563.0, style=ProgressStyle(description_w…

{'loss': 0.09409982562065125, 'learning_rate': 5.259259259259259e-06, 'epoch': 3.000639795265515, 'step': 4690}
{'loss': 0.03151584565639496, 'learning_rate': 5.225589225589226e-06, 'epoch': 3.0070377479206654, 'step': 4700}
{'loss': 0.0032574862241744997, 'learning_rate': 5.191919191919193e-06, 'epoch': 3.013435700575816, 'step': 4710}
{'loss': 0.02173630744218826, 'learning_rate': 5.158249158249159e-06, 'epoch': 3.019833653230966, 'step': 4720}
{'loss': 0.1328670486807823, 'learning_rate': 5.124579124579125e-06, 'epoch': 3.0262316058861165, 'step': 4730}
{'loss': 0.03928768932819367, 'learning_rate': 5.090909090909091e-06, 'epoch': 3.0326295585412666, 'step': 4740}
{'loss': 0.05909337103366852, 'learning_rate': 5.0572390572390574e-06, 'epoch': 3.039027511196417, 'step': 4750}
{'loss': 0.0026448220014572144, 'learning_rate': 5.023569023569024e-06, 'epoch': 3.0454254638515676, 'step': 4760}
{'loss': 0.031428244709968564, 'learning_rate': 4.98989898989899e-06, 'epoch': 3.051823416506717

{'loss': 0.1120539665222168, 'learning_rate': 2.8013468013468016e-06, 'epoch': 3.4676903390914906, 'step': 5420}
{'loss': 0.03635291159152985, 'learning_rate': 2.7676767676767678e-06, 'epoch': 3.474088291746641, 'step': 5430}
{'loss': 0.009180432558059693, 'learning_rate': 2.7340067340067344e-06, 'epoch': 3.480486244401791, 'step': 5440}
{'loss': 0.004460978507995606, 'learning_rate': 2.7003367003367e-06, 'epoch': 3.4868841970569417, 'step': 5450}
{'loss': 0.024199698865413666, 'learning_rate': 2.666666666666667e-06, 'epoch': 3.4932821497120923, 'step': 5460}
{'loss': 0.0017611771821975709, 'learning_rate': 2.3973063973063978e-06, 'epoch': 3.5444657709532947, 'step': 5540}
{'loss': 0.14871213734149932, 'learning_rate': 2.363636363636364e-06, 'epoch': 3.5508637236084453, 'step': 5550}
{'loss': 0.02596128135919571, 'learning_rate': 2.32996632996633e-06, 'epoch': 3.557261676263596, 'step': 5560}
{'loss': 0.004273319244384765, 'learning_rate': 2.2962962962962964e-06, 'epoch': 3.56365962891

{'loss': 0.046607944369316104, 'learning_rate': 1.0774410774410776e-07, 'epoch': 3.9795265515035187, 'step': 6220}
{'loss': 0.03854084014892578, 'learning_rate': 7.407407407407409e-08, 'epoch': 3.9859245041586693, 'step': 6230}
{'loss': 0.06943604648113251, 'learning_rate': 4.040404040404041e-08, 'epoch': 3.9923224568138194, 'step': 6240}
{'loss': 0.08676650822162628, 'learning_rate': 6.734006734006735e-09, 'epoch': 3.99872040946897, 'step': 6250}


Start to test






HBox(children=(FloatProgress(value=0.0, description='Prediction', max=391.0, style=ProgressStyle(description_w…




In [54]:
logger.critical("Save the model and config")
trainer.save_model()
tokenizer.save_vocabulary(training_args.output_dir)
logger.critical("Experiment Finnished")

Save the model and config
Experiment Finnished
