# Training BART using the Accelerate API
Training BART in multi-GPU / multi-NODE manner with the Accelerate API

In [87]:
# imports 
import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir)

import argparse
from datetime import datetime
import json
import os
import pickle
import random
import time
import wandb
from pathlib import Path
import gc
import glob
from tqdm import tqdm
import torch
from torch.optim import AdamW
from transformers import get_scheduler
from torch.utils.data import DataLoader

import numpy as np
from transformers import TrainingArguments, Trainer, BartConfig, BartForConditionalGeneration
from transformers.file_utils import logging

# custom veci
from dataset import SpectroDataset, SpectroDataCollator
sys.path.append('data')
sys.path.append('bart_spektro')
from modeling_bart_spektro import BartSpektoForConditionalGeneration
from configuration_bart_spektro import BartSpektroConfig
from data_preprocess1 import print_args
from bart_spektro_tokenizer import BartSpektroTokenizer
from tokenizers import Tokenizer

In [83]:
# initinal variables (shouldn't be necessary to change anything after thi cell)
bs = 32 #4
gas = 16 # gradient accumulation steps
which_bart = "spektro" #"original" # "spektro"
data_type = "1K"

# TOKENIZER
tokenizer_type = "_bbpe_1M" # for my custom spektro tokenizer use ""
tokenizer = Tokenizer.from_file(f"./tokenizer/bbpe_tokenizer/bart{tokenizer_type}_tokenizer.model")

SEQ_LEN = 200
num_epochs = 10 # 10 (BARTY se trenovaly 10 epoch celkem) # int(os.environ["TOTAL_EPOCHS"]) #21 + 8 = BP
resume_training = False # bool(int(os.environ["RESUME_TRAINING"]))
resume_wandb_id = "" # "1l7305qk" #pass # ""

model = None # aby nebyl nedefinovany

# find the last checkpoint automatically
# models_pth = "/storage/projects/msml/mg_neims_branch/models/bart_trial"
# runs = glob.glob(models_pth)
# checkpoints =  glob.glob(runs[-1]+"/checkpoint-*")
# checkpoints.sort(key=lambda x: int(x.split("-")[-1]))
# load_checkpoint = checkpoints[-1]
# print(f"last checkpoint: {load_checkpoint}")

# load_checkpoint = "./models/bart_2022-06-01-04_30_20/checkpoint-4152/"

In [22]:
# num of gpus 
os.environ["PBS_NGPUS"]

'4'

In [23]:
# arguments for the training (fix ones)
now = str(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
now = now.replace(":","_").replace(" ", "-")
parser = argparse.ArgumentParser()
parser.add_argument("--lr",default=5e-5, type=float, help="learning rate")
parser.add_argument("--seed",default=42, type=int,  help="seed to replicate results")
parser.add_argument("--gradient-accumulation-steps",default=gas, type=int, help="gradient_accumulation_steps")
parser.add_argument("--batch-size",default=bs, type=int,  help="batch_size")
parser.add_argument("--warmup",default=500, type=int,  help="warmup steps for learning rate")
parser.add_argument("--weight-decay",default=0.01, type=float,  help="weight decay rate parameter")
parser.add_argument("--num-workers",default=os.environ["PBS_NCPUS"], type=int,  help="num of cpus available")
parser.add_argument("--device",default=torch.device('cuda'), help="torch.device object")
parser.add_argument("--num-train-epochs",default=num_epochs, type=int,  help="number of training epochs")
parser.add_argument("--save-dir",default='/storage/projects/msml/mg_neims_branch/MassGenie/models', type=str,  help="Path to save trained model")
parser.add_argument("--save-name", type=str, default=f'bart_{now}', help="Name of the model, used for saves")
parser.add_argument("--load-checkpoint", type=str, default='', help="Path to the checkpoint to resume training")
parser.add_argument("--fp16",default=True, type=bool, required=False, help="whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
parser.add_argument("--train-data-path",default=f'/storage/projects/msml/mg_neims_branch/MassGenie/data/trial_set/{data_type}{tokenizer_type}_bart_prepared_data_train.pkl', type=str, help="Path to jsonl train dataset")
parser.add_argument("--valid-data-path",default=f'/storage/projects/msml/mg_neims_branch/MassGenie/data/trial_set/{data_type}{tokenizer_type}_bart_prepared_data_valid.pkl', type=str, help="Path to jsonl validation dataset")
parser.add_argument("--log-steps",default=50, type=int,  help="number of steps between logs")
parser.add_argument("--eval-steps",default=7142, type=int,  help="number of steps between evaluations")
parser.add_argument("--wandb", action='store_true', default=True, help="optinal logging via Weights&Biases")
parser.add_argument("--wandb-resume", action='store_true', default=resume_training, help="resume logging via wandb, needs an valid run ID set in args.wandb-id")
parser.add_argument("--wandb-id", type=str, default=wandb.util.generate_id(), help="Process unique wandb ID used for resumin the training process")

args = parser.parse_args([])
arg_log = print_args(args)

# extended outputs
logging.set_verbosity_info()

arguments:
  lr ........................... 5e-05
  seed ......................... 42
  gradient_accumulation_steps .. 16
  batch_size ................... 32
  warmup ....................... 500
  weight_decay ................. 0.01
  num_workers .................. 16
  device ....................... cuda
  num_train_epochs ............. 10
  save_dir ..................... /storage/projects/msml/mg_neims_branch/MassGenie/models
  save_name .................... bart_2022-06-21-17_22_59
  load_checkpoint .............. 
  fp16 ......................... True
  train_data_path .............. /storage/projects/msml/mg_neims_branch/MassGenie/data/trial_set/1K_bbpe_1M_bart_prepared_data_train.pkl
  valid_data_path .............. /storage/projects/msml/mg_neims_branch/MassGenie/data/trial_set/1K_bbpe_1M_bart_prepared_data_valid.pkl
  log_steps .................... 50
  eval_steps ................... 7142
  wandb ........................ True
  wandb_resume ................. False
  wandb_id ..

In [24]:
# BART CONIGURATION
if which_bart == "spektro":
    config = BartSpektroConfig(vocab_size = len(tokenizer.get_vocab()),
                                 max_position_embeddings = SEQ_LEN,
                                 max_length = SEQ_LEN,
                                 min_len = 0,
                                 encoder_layers = 12,
                                 encoder_ffn_dim = 4096,
                                 encoder_attention_heads = 16,
                                 decoder_layers = 12,
                                 decoder_ffn_dim = 4096,
                                 decoder_attention_heads = 16,
                                 encoder_layerdrop = 0.0,
                                 decoder_layerdrop = 0.0,
                                 activation_function = 'gelu',
                                 d_model = 1024,
                                 dropout = 0.2,
                                 attention_dropout = 0.0,
                                 activation_dropout = 0.0,
                                 init_std = 0.02,
                                 classifier_dropout = 0.0,
                                 scale_embedding = False,
                                 use_cache = True,
                                 pad_token_id = 2,
                                 bos_token_id = 3,
                                 eos_token_id = 0,
                                 is_encoder_decoder = True,
                                 decoder_start_token_id = 3,
                                 forced_eos_token_id = 0,
                                 max_log_id=9)

if which_bart == "original":
    config = BartConfig(vocab_size = len(tokenizer.get_vocab()),
                                 max_position_embeddings = SEQ_LEN,
                                 max_length = SEQ_LEN,
                                 min_len = 0,
                                 encoder_layers = 12,
                                 encoder_ffn_dim = 4096,
                                 encoder_attention_heads = 16,
                                 decoder_layers = 12,
                                 decoder_ffn_dim = 4096,
                                 decoder_attention_heads = 16,
                                 encoder_layerdrop = 0.0,
                                 decoder_layerdrop = 0.0,
                                 activation_function = 'gelu',
                                 d_model = 1024,
                                 dropout = 0.2,
                                 attention_dropout = 0.0,
                                 activation_dropout = 0.0,
                                 init_std = 0.02,
                                 classifier_dropout = 0.0,
                                 scale_embedding = False,
                                 use_cache = True,
                                 pad_token_id = 2,
                                 bos_token_id = 3,
                                 eos_token_id = 0,
                                 is_encoder_decoder = True,
                                 decoder_start_token_id = 3,
                                 forced_eos_token_id = 0)


In [84]:
# DATA
train_data = SpectroDataset(args.train_data_path, original=which_bart=="original")
valid_data = SpectroDataset(args.valid_data_path, original=which_bart=="original")

# drops 
train_data.data.drop(columns=["decoder_input_ids"], inplace=True)
valid_data.data.drop(columns=["decoder_input_ids"], inplace=True)

if which_bart == "original":
    try:
        train_data.data.drop(columns=["position_ids"], inplace=True)
        valid_data.data.drop(columns=["position_ids"], inplace=True)
    except:
        pass

train_dataloader = DataLoader(train_data, batch_size=bs, shuffle=True)
valid_dataloader = DataLoader(valid_data, batch_size=bs, shuffle=True)    

# clean memory
gc.collect()
torch.cuda.empty_cache()

# MODEL
if which_bart == "original":
    model = BartForConditionalGeneration(config)
elif which_bart == "spektro":
    model = BartSpektoForConditionalGeneration(config)
else:
    raise AttributeError('Wrong \'which_bart\' attribute. Assign \'original\' or \'spektro\'.')

None

In [51]:
train_data.data.head(5)

Unnamed: 0,destereo_smiles,input_ids,encoder_attention_mask,decoder_attention_mask,labels,position_ids
629,COCCN1C(=O)C(=O)N(C1=O)CC(=O)c1c(N)n(C)c(=O)n(...,"[15, 28, 29, 30, 31, 32, 33, 39, 40, 41, 42, 4...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 224, 325, 20, 38, 260, 50, 12, 38, 260, 50...","[3, 1, 2, 5, 5, 4, 4, 5, 6, 7, 8, 6, 7, 9, 6, ..."
338,O=C(NC1CCS(=O)(=O)C1)CCC(=O)NC1CCS(=O)(=O)C1,"[17, 18, 26, 27, 28, 29, 30, 31, 32, 33, 34, 3...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 224, 50, 32, 38, 11, 272, 20, 290, 260, 50...","[1, 1, 3, 6, 6, 7, 4, 5, 2, 1, 1, 2, 4, 7, 6, ..."
620,CN(C(=O)c1c(C)nc2n1CCN(C2)C(=O)c1cc(=O)n(c(=O)...,"[30, 33, 34, 38, 39, 40, 41, 42, 43, 44, 45, 4...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 224, 266, 11, 38, 260, 50, 12, 70, 20, 70,...","[2, 4, 0, 4, 7, 7, 7, 9, 7, 8, 6, 4, 4, 6, 6, ..."
396,Cn1ncc(c1)S(=O)(=O)N1CCN(CC1)S(=O)(=O)c1cnn(c1)C,"[33, 34, 36, 37, 38, 39, 40, 41, 42, 43, 44, 4...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 224, 275, 20, 310, 11, 70, 20, 12, 54, 260...","[2, 3, 3, 1, 5, 6, 6, 6, 8, 6, 5, 3, 4, 4, 4, ..."
251,O=C(C1CNCC(C1)C(=O)N1CCOCC1)NCCc1nncn1C,"[33, 36, 39, 40, 41, 42, 43, 44, 45, 51, 52, 5...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 224, 50, 32, 38, 11, 38, 20, 266, 261, 11,...","[4, 1, 8, 7, 9, 9, 8, 9, 2, 6, 8, 8, 8, 9, 9, ..."


In [30]:
# Resume training
if resume_training:
    args.load_checkpoint = load_checkpoint
    if args.wandb_resume:
        args.wandb_id = resume_wandb_id

# Init wandb
if args.wandb:
    wandb.login()
    wandb.init(id=args.wandb_id, resume="allow", entity="hajekad", project="BART_for_gcms")
    wandb.run.name = args.save_name + "_accelerate"



VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

[34m[1mwandb[0m: wandb version 0.12.18 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


## Now the ***Accelerate*** stuff

In [90]:
from accelerate import Accelerator

accelerator = Accelerator()
optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
      "linear",
      optimizer=optimizer,
      num_warmup_steps=0,
      num_training_steps=num_training_steps
  )
progress_bar = tqdm(range(num_training_steps))
None

  0%|          | 0/150 [00:00<?, ?it/s]

In [91]:
# accelerate preparation
train_dataloader, valid_dataloader, model, optimizer = accelerator.prepare(
    train_dataloader, valid_dataloader, model, optimizer)

### Training loop (basic)
pridelat minimalne: 
- wandb logging
- evaluace po x krocich
- gas? 
- saving
- fp16

In [92]:
for epoch in range(num_epochs):
    for batch in train_dataloader:
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)

        optimizer.step()
#         lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

100%|██████████| 150/150 [01:30<00:00,  1.82it/s]