In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Based on: https://gmihaila.github.io/tutorial_notebooks/pretrain_transformers_pytorch/

In [None]:
import io
import os
import math
import torch
import warnings
from tqdm.notebook import tqdm
from ml_things import plot_dict, fix_text
from transformers import (
                          CONFIG_MAPPING,
                          MODEL_FOR_MASKED_LM_MAPPING,
                          MODEL_FOR_CAUSAL_LM_MAPPING,
                          PreTrainedTokenizer,
                          TrainingArguments,
                          AutoConfig,
                          LongformerConfig,
                          AutoTokenizer,
                          AutoModelWithLMHead,
                          AutoModelForCausalLM,
                          AutoModelForMaskedLM,
                          LineByLineTextDataset,
                          TextDataset,
                          DataCollatorForLanguageModeling,
                          DataCollatorForWholeWordMask,
                          DataCollatorForPermutationLanguageModeling,
                          PretrainedConfig,
                          Trainer,
                          set_seed,
                          )

# Supress deprecation warnings
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=FutureWarning)

# Set seed for reproducibility,
set_seed(4444)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class ModelDataArguments(object):
  r"""Model and data configuration needed to perform pre-training.

    block_size (:obj:`int`, `optional`, defaults to :obj:`-1`): 
      It refers to the windows size that is moved across the text file. 
      Set to -1 to use maximum allowed length.
      This argument is optional and it has a default value.

    model_type (:obj:`str`, `optional`): 
      Type of model used: bert, roberta, gpt2. 
      More details: https://huggingface.co/transformers/pretrained_models.html
      This argument is optional and it will have a `None` value attributed 
      inside the function.

    tokenizer_name: (:obj:`str`, `optional`)
      Tokenizer used to process data for training the model. 
      It usually has same name as model_name_or_path: bert-base-cased, 
      roberta-base, gpt2 etc.
      This argument is optional and it will have a `None` value attributed 
      inside the function.

    model_name_or_path (:obj:`str`, `optional`): 
      Path to existing transformers model or name of 
      transformer model to be used: bert-base-cased, roberta-base, gpt2 etc. 
      More details: https://huggingface.co/transformers/pretrained_models.html
      This argument is optional and it will have a `None` value attributed 
      inside the function.

    model_cache_dir (:obj:`str`, `optional`): 
      Path to cache files to save time when re-running code.
      This argument is optional and it will have a `None` value attributed 
      inside the function.    
  """

  def __init__(self, train_data_file=None, eval_data_file=None, 
               line_by_line=False, mlm=False, mlm_probability=0.15, 
               whole_word_mask=False, plm_probability=float(1/6), 
               max_span_length=5, block_size=-1, overwrite_cache=False, 
               model_type=None, model_config_name=None, tokenizer_name=None, 
               model_name_or_path=None, model_cache_dir=None, ignore_mismatched_sizes=False):
    
    # Check if a new model will be loaded from scratch.
    if not any([model_config_name, model_name_or_path]):
      warnings.formatwarning = lambda message,category,*args,**kwargs: \
                               '%s: %s\n' % (category.__name__, message)
      # Display warning.
      warnings.warn('You are planning to train a model from scratch!')

    # Check if a new tokenizer wants to be loaded.
    if not any([tokenizer_name, model_name_or_path]):
      # Can't train tokenizer from scratch here! Raise error.
      raise ValueError('You want to train tokenizer from scratch! ' \
                    'That is not possible yet! You can train your own ' \
                    'tokenizer separately and use path here to load it!')
      
    # Set all data related arguments.
    self.train_data_file = train_data_file
    self.eval_data_file = eval_data_file
    self.line_by_line = line_by_line
    self.mlm = mlm
    self.whole_word_mask = whole_word_mask
    self.mlm_probability = mlm_probability
    self.plm_probability = plm_probability
    self.max_span_length = max_span_length
    self.block_size = block_size
    self.overwrite_cache = overwrite_cache

    # Set all model and tokenizer arguments.
    self.model_type = model_type
    self.model_config_name = model_config_name
    self.tokenizer_name = tokenizer_name
    self.model_name_or_path = model_name_or_path
    self.model_cache_dir = model_cache_dir
    self.ignore_mismatched_sizes = ignore_mismatched_sizes

    return


def get_model_config(args: ModelDataArguments):
  r"""  Get model configuration. """

  # Check model configuration.
  if args.model_config_name is not None:
    model_config = AutoConfig.from_pretrained(args.model_config_name,
                                      cache_dir=args.model_cache_dir)
  elif args.model_name_or_path is not None:
    # Use model name or path if defined.
    model_config = AutoConfig.from_pretrained(args.model_name_or_path, 
                                      cache_dir=args.model_cache_dir)
  else:
    # Use config mapping if building model from scratch.
    model_config = CONFIG_MAPPING[args.model_type]()

  return model_config

def get_tokenizer(args: ModelDataArguments):
  r""" Get model tokenizer.  """

  # Check tokenizer configuration.
  if args.tokenizer_name:
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, 
                                              cache_dir=args.model_cache_dir)
  elif args.model_name_or_path:
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, 
                                              cache_dir=args.model_cache_dir)

  if args.block_size <= 0:
    args.block_size = tokenizer.model_max_length
  else:
    # Never go beyond tokenizer maximum length.
    args.block_size = min(args.block_size, tokenizer.model_max_length)

  return tokenizer
  

def get_model(args: ModelDataArguments, model_config):
  r"""  Get model. """

  # Check if using pre-trained model or train from scratch.
  if args.model_name_or_path:
    if type(model_config) in MODEL_FOR_CAUSAL_LM_MAPPING.keys():
      return AutoModelForCausalLM.from_pretrained(
                                          args.model_name_or_path, 
                                          from_tf=bool(".ckpt" in 
                                                        args.model_name_or_path),
                                          config=model_config, 
                                          cache_dir=args.model_cache_dir,
                                          ignore_mismatched_sizes=args.ignore_mismatched_sizes)
    else:
      raise ValueError(
          'Invalid `model_name_or_path`! It should be in %s!' % 
          (str(MODEL_FOR_CAUSAL_LM_MAPPING.keys())))
  else:
      print("Training new model from scratch...")
      return AutoModelWithLMHead.from_config(config)


def get_dataset(args: ModelDataArguments, tokenizer: PreTrainedTokenizer, 
                evaluate: bool=False):
  r""" Process dataset file into PyTorch Dataset. """

  # Get file path
  file_path = args.eval_data_file if evaluate else args.train_data_file

  if args.line_by_line:
    return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, 
                                 block_size=args.block_size)
  else:
    return TextDataset(tokenizer=tokenizer, file_path=file_path, 
                       block_size=args.block_size, 
                       overwrite_cache=args.overwrite_cache)


def get_collator(args: ModelDataArguments, model_config: PretrainedConfig, 
                 tokenizer: PreTrainedTokenizer):
    return DataCollatorForLanguageModeling(
                                        tokenizer=tokenizer, 
                                        mlm=args.mlm, 
                                        mlm_probability=args.mlm_probability,
                                        )

In [None]:
# Define arguments for data, tokenizer and model arguments.
model_data_args = ModelDataArguments(
                                    train_data_file='/data/forta/ethereum/text/pretraining/small_pretraining_train.txt',
                                    eval_data_file='/data/forta/ethereum/text/pretraining/small_pretraining_val.txt',
                                    line_by_line=True, 
                                    mlm=False,
                                    whole_word_mask=False,
                                    plm_probability=float(1/6),
                                    max_span_length=5,
                                    block_size=-1,
                                    overwrite_cache=False,
                                    model_type='allenai/longformer-base-4096',
                                    model_config_name='allenai/longformer-base-4096',
                                    tokenizer_name='/data/forta/ethereum/tokenizer_longformer',
                                    model_name_or_path=None,
                                    model_cache_dir='/data/forta/ethereum/cache',
                                    ignore_mismatched_sizes=True,
                                    )

print(model_data_args)

# Define arguments for training
training_args = TrainingArguments(
                          # Disable wandb
                          report_to="none",
                          output_dir='/data/forta/ethereum/model_longformer',
                          overwrite_output_dir=True,
                          do_train=True, 
                          do_eval=True,
                          per_device_train_batch_size=4,
                          per_device_eval_batch_size=4,
                          evaluation_strategy='epoch',
                          logging_steps=4,
                          eval_steps = None,
                          prediction_loss_only=True,
                          learning_rate = 5e-5,
                          weight_decay=0.01,
                          adam_epsilon = 1e-8,
                          max_grad_norm = 1.0,
                          num_train_epochs = 2,
                          save_steps = -1,
                          )

print(training_args)

In [None]:
# Load model configuration.
print('Loading model configuration...')
config = get_model_config(model_data_args)

# Load model tokenizer.
print('Loading model`s tokenizer...')
tokenizer = get_tokenizer(model_data_args)
tokenizer.pad_token = tokenizer.eos_token

# Loading model.
print('Loading actual model...')
model = get_model(model_data_args, config)

# Resize model to fit all tokens in tokenizer.
model.resize_token_embeddings(len(tokenizer))

# fix model padding token id
model.config.pad_token_id = model.config.eos_token_id

In [None]:
# Setup train dataset if `do_train` is set.
print('Creating train dataset...')
train_dataset = get_dataset(model_data_args, tokenizer=tokenizer, evaluate=False) if training_args.do_train else None

# Setup evaluation dataset if `do_eval` is set.
print('Creating evaluate dataset...')
eval_dataset = get_dataset(model_data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None

# Get data collator to modify data format depending on type of model used.
data_collator = get_collator(model_data_args, config, tokenizer)

print(len(train_dataset))
print(len(eval_dataset))

In [None]:
# Initialize Trainer.
print('Loading `trainer`...')
trainer = Trainer(model=model,
                  args=training_args,
                  data_collator=data_collator,
                  train_dataset=train_dataset,
                  eval_dataset=eval_dataset,
                  )

if training_args.do_train:
  print('Start training...')
  # Setup model path if the model to train loaded from a local path.
  model_path = (model_data_args.model_name_or_path 
                if model_data_args.model_name_or_path is not None and 
                os.path.isdir(model_data_args.model_name_or_path) 
                else None
                )
  # Run training.
  trainer.train(model_path=model_path)
  # Save model.
  trainer.save_model()
  trainer.save_state()
  # if trainer.is_world_process_zero():
  #  tokenizer.save_pretrained(training_args.output_dir)

In [None]:
# Keep track of train and evaluate loss.
loss_history = {'train_loss':[], 'eval_loss':[]}
perplexity_history = {'train_perplexity':[], 'eval_perplexity':[]}

for log_history in trainer.state.log_history:
  if 'loss' in log_history.keys():
    loss_history['train_loss'].append(log_history['loss'])
    perplexity_history['train_perplexity'].append(math.exp(log_history['loss']))
  elif 'eval_loss' in log_history.keys():
    loss_history['eval_loss'].append(log_history['eval_loss'])
    perplexity_history['eval_perplexity'].append(math.exp(log_history['eval_loss']))

# Plot Losses.
print(loss_history)
plot_dict(loss_history, start_step=training_args.logging_steps, 
          step_size=training_args.logging_steps, use_title='Loss', 
          use_xlabel='Train Steps', use_ylabel='Values', magnify=2)

# Plot Perplexities.
print(perplexity_history)
plot_dict(perplexity_history, start_step=training_args.logging_steps, 
          step_size=training_args.logging_steps, use_title='Perplexity', 
          use_xlabel='Train Steps', use_ylabel='Values', magnify=2)

In [None]:
if training_args.do_eval:
  eval_output = trainer.evaluate()
  perplexity = math.exp(eval_output["eval_loss"])
  print('\nEvaluate Perplexity: {:10,.2f}'.format(perplexity))
else:
  print('No evaluation needed. No evaluation data provided, `do_eval=False`!')