# Pretrain

In [1]:
PRETRAIN = "brt"
SLUG = "bs2"

In [2]:
model_name = dict(
    brt="bert-base-uncased",
    rbt="roberta-base"
)[PRETRAIN]
model_name

'bert-base-uncased'

In [3]:
from forgebox.imports import *
from transformers import (AutoModel,AutoModelForMaskedLM, 
                          AutoTokenizer, LineByLineTextDataset,
                          DataCollatorForLanguageModeling,
                          Trainer, TrainingArguments)

In [4]:
data = pd.concat(list(map(pd.read_csv,["train.csv", "test.csv"])))

Create a dataset based on purely text data

In [5]:
data['excerpt'] = data['excerpt'].apply(lambda x: x.replace('\n',''))

text  = '\n'.join(data.excerpt.tolist())

with open('text.txt','w') as f:
    f.write(text)

In [6]:
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
tokenizer.save_pretrained(f'./{PRETRAIN}_{SLUG}_pre');

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Pretrain use Masked Language Modeling

In [7]:
train_dataset = LineByLineTextDataset(
    tokenizer=tokenizer,
    file_path="text.txt", #mention train text file here
    block_size=256)

valid_dataset = LineByLineTextDataset(
    tokenizer=tokenizer,
    file_path="text.txt", #mention valid text file here
    block_size=256)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

training_args = TrainingArguments(
    output_dir=f"./weights/{PRETRAIN}_rbt_{SLUG}", #select model path for checkpoint
    overwrite_output_dir=True,
    num_train_epochs=5,
    per_device_train_batch_size=12,
    per_device_eval_batch_size=18,
    evaluation_strategy= 'steps',
    save_total_limit=2,
    eval_steps=250,
    metric_for_best_model='eval_loss',
    greater_is_better=False,
    load_best_model_at_end =True,
    prediction_loss_only=True,
    report_to = "none")

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset)



In [8]:
trainer.train()

***** Running training *****
  Num examples = 2841
  Num Epochs = 5
  Instantaneous batch size per device = 12
  Total train batch size (w. parallel, distributed & accumulation) = 24
  Gradient Accumulation steps = 1
  Total optimization steps = 595


Step,Training Loss,Validation Loss
250,No log,1.738972
500,1.877700,1.664761


***** Running Evaluation *****
  Num examples = 2841
  Batch size = 36
Saving model checkpoint to ./weights/brt_rbt_bs2/checkpoint-250
Configuration saved in ./weights/brt_rbt_bs2/checkpoint-250/config.json
Model weights saved in ./weights/brt_rbt_bs2/checkpoint-250/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 2841
  Batch size = 36
Saving model checkpoint to ./weights/brt_rbt_bs2/checkpoint-500
Configuration saved in ./weights/brt_rbt_bs2/checkpoint-500/config.json
Model weights saved in ./weights/brt_rbt_bs2/checkpoint-500/pytorch_model.bin


Training completed. Do not forget to share your model on huggingface.co/models =)


Loading best model from ./weights/brt_rbt_bs2/checkpoint-500 (score: 1.6647613048553467).


TrainOutput(global_step=595, training_loss=1.8641962644432772, metrics={'train_runtime': 264.9274, 'train_samples_per_second': 53.618, 'train_steps_per_second': 2.246, 'total_flos': 1861958654433900.0, 'train_loss': 1.8641962644432772, 'epoch': 5.0})

In [10]:
trainer.save_model(f"weights/pre_{PRETRAIN}{SLUG}")

Saving model checkpoint to weights/pre_brtbs2
Configuration saved in weights/pre_brtbs2/config.json
Model weights saved in weights/pre_brtbs2/pytorch_model.bin
