In [1]:
import glob
import json
import os
import re
import shutil
from typing import List

import torch
from dataset import AMRDataSetFast, DataCollatorForSeq2Seq
from model_utils import (activate_embeds, assert_all_frozen, freeze_embeds,
                         freeze_params, get_ETMG2graph,
                         get_inverse_sqrt_schedule_with_warmup, get_MTEG2text,
                         get_MTMG2partial, get_MTMG2TG, get_PTPG2partial)
from run_multitask_unified_pretraining import smart_emb_init
from spring_amr.tokenization_bart import PENMANBartTokenizer
from torch.optim import AdamW
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from tqdm import tqdm, trange
from transformers import AutoConfig, AutoModelForMaskedLM

os.environ["TOKENIZERS_PARALLELISM"] = "true"

In [18]:
args = {
    'model_name_or_path': 'vinai/bartpho-syllable',
    'output_dir': '/AIHCM/KGQA/NLPCore/graph2text/AMRBART/output/pre-train/tmt',
    'train_file': '/AIHCM/KGQA/NLPCore/graph2text/data/tmt/all_v2_p1.jsonl',
    'val_file': '/AIHCM/KGQA/NLPCore/graph2text/data/tmt/all_v2_p1.jsonl',
    'test_file': '/AIHCM/KGQA/NLPCore/graph2text/data/tmt/all_v2_p2.jsonl',
    'block_size': 256,                  # Optional input sequence length after tokenization.
    'smart_init': False,                # Initializing AMR Vocab according to similar tokens
    'max_steps': 100000,                # If > 0: set total number of training steps to perform. Override num_train_epochs
    'gradient_accumulation_steps': 1,   # Number of updates steps to accumulate before performing a backward/update pass
    'weight_decay': 0.0,                # Weight decay if we apply some
    'learning_rate': 5e-5,              # Initial learning rate for AdamW
    'adam_epsilon': 1e-8,               # Epsilon for AdamW optimizer
    'warmup_steps': 2500,               # Linear warmup over warmup_steps
    'mlm_amr': True,                    # * [Empty text + Masked Graph -> Graph]
    'mlm_text': True,                   # * [Masked Text + Empty Graph -> Text]
    'mlm_text_plus_amr': True,          # * [Masked text + Graph -> Text] (apply dynamic masking rate)
    'mlm_amr_plus_text': False,         # [Text + Masked Graph -> Graph] (apply dynamic masking rate)
    'mlm_joint_to_amr': False,          # [Masked Text + Masked Graph -> Graph]
    'mlm_joint_to_text': False,         # [Masked Text + Masked Graph -> Text]
    'mlm_joint_to_joint': False,
    'joint_train_interval': 1,          # The interval of joint AMR and text training
    'max_grad_norm': 1.0,               # Max gradient norm
    'logging_steps': 1000,              # Log every X updates steps
    'evaluate_during_training': True,
    'freeze_embeds': False,
    'freeze_encoder': False,
    'freeze_decoder': False,
    'save_total_limit': None,
    # CUDA
    'no_cuda': True,
    'cuda_index': 1,
    'fp16': False,                      # Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit
    'fp16_opt_level': 'O1',             # For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']
    'per_gpu_train_batch_size': 1,      # Batch size per GPU/CPU for training
    'per_gpu_eval_batch_size': 1,       # Batch size per GPU/CPU for evaluation
}
t_total = args['max_steps']

In [19]:
if args['no_cuda']:
    device = torch.device("cuda" if torch.cuda.is_available() and not args['no_cuda'] else "cpu")
    args['n_gpu'] = 0 if args['no_cuda'] else torch.cuda.device_count()
else:
    torch.cuda.set_device(args['cuda_index'])
    device = torch.device("cuda", args['cuda_index'])
    args['n_gpu'] = torch.cuda.device_count()
args['device'] = device

In [20]:
args['device']

device(type='cpu')

In [21]:
# tokenizer = PENMANBartTokenizer.from_pretrained(
#     args['model_name_or_path'], 
#     collapse_name_ops=False, 
#     use_pointer_tokens=True, 
#     raw_graph=False,
# )
import pickle
with open('/AIHCM/KGQA/NLPCore/graph2text/models/vinai/bartpho-syllable/tokenizer.pkl', 'rb') as f:
    tokenizer = pickle.load(f)

if args['block_size'] <= 0:
    args['block_size'] = tokenizer.model_max_length
    # Our input block size will be the max possible for the model
else:
    args['block_size'] = min(args['block_size'], tokenizer.model_max_length)

In [22]:
%%capture

config = AutoConfig.from_pretrained(args['model_name_or_path'], cache_dir=None)
model = AutoModelForMaskedLM.from_pretrained(
    args['model_name_or_path'],
    from_tf=bool(".ckpt" in args['model_name_or_path']),
    config=config,
    cache_dir=None,
)
model.resize_token_embeddings(len(tokenizer))

In [23]:
%%capture

if args['freeze_encoder']:  # 151m
    freeze_params(model.get_encoder())
    assert_all_frozen(model.get_encoder())

if args['freeze_decoder']:  # 201m
    freeze_params(model.get_decoder())
    assert_all_frozen(model.get_decoder())
    
if args['freeze_embeds']: # 40m
    freeze_embeds(model)
else:
    activate_embeds(model)

model.to(args['device'])

In [24]:
train_params = [
    n for n, p in model.named_parameters() if p.requires_grad
]
len(train_params)

515

In [25]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

396830720

In [28]:
%%capture

AMRDataset = AMRDataSetFast(
    tokenizer=tokenizer,
    train_file=args['train_file'],
    validation_file=args['val_file'],
    test_file=args['test_file'],
    pad_to_max_length=False,
    max_src_length=args['block_size'],
    max_tgt_length=256,
)
AMRDataset.setup()

In [32]:
# Dummy Test
train_dataset = AMRDataset.train_dataset
dev_dataset = AMRDataset.valid_dataset
print('train samples: ', len(train_dataset))
print('dev samples: ', len(dev_dataset))

train samples:  8178
dev samples:  8178


In [33]:
seq2seq_collate_fn = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=-100,
    pad_to_multiple_of=None
)

# Training

In [34]:
if args['smart_init']:
    smart_emb_init(tokenizer, model)

In [35]:
args['train_batch_size'] = args['per_gpu_train_batch_size'] * max(1, args['n_gpu'])

In [36]:
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(
    train_dataset,
    sampler=train_sampler,
    batch_size=args['train_batch_size'],
    collate_fn=seq2seq_collate_fn,
    num_workers=6,
)

In [37]:
num_train_epochs = args['max_steps'] // (len(train_dataloader) // args['gradient_accumulation_steps']) + 1
num_train_epochs

13

In [42]:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [
            p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)
        ],
        "weight_decay": args['weight_decay'],
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]

In [43]:
optimizer = AdamW(optimizer_grouped_parameters, lr=args['learning_rate'], eps=args['adam_epsilon'])

In [44]:
scheduler = get_inverse_sqrt_schedule_with_warmup(
    optimizer, num_warmup_steps=args['warmup_steps'], num_training_steps=t_total
)

In [45]:
args['fp16'] = False
if args['fp16']:
    from apex import amp
    model, optimizer = amp.initialize(model, optimizer, opt_level=args['fp16_opt_level'])

In [46]:
args['n_gpu']

0

In [47]:
# Multi-gpu training (should be after apex fp16 initialization)
if args['n_gpu'] > 1 and args['fp16']:
    model = torch.nn.DataParallel(model)

In [48]:
def ids_to_clean_text(tokenizer, generated_ids: List[int]):
    generated_ids.masked_fill_(generated_ids == -100, tokenizer.pad_token_id)
    # gen_text = tokenizer.batch_decode(generated_ids, clean_up_tokenization_spaces=False)
    gen_text = tokenizer.convert_ids_to_tokens(generated_ids)
    return " ".join(gen_text)

def save_dummy_batch2(output_dir, input_ids, dec_inp_ids, labels, tokenizer, prefix="train"):
    dummy_ids, dummy_tokens = [], []
    for idx in range(len(input_ids)):
        ith_dict, ith_tok_dict = {}, {}
        ith_dict["input_ids"] = input_ids[idx].tolist()
        ith_dict["label_ids"] = labels[idx].tolist()
        ith_dict["dec_inp_ids"] = dec_inp_ids[idx].tolist()
        dummy_ids.append(ith_dict)

        ith_tok_dict["input_tokens"] = ids_to_clean_text(tokenizer, input_ids[idx])
        ith_tok_dict["label_tokens"] = ids_to_clean_text(tokenizer, labels[idx])
        ith_tok_dict["dec_inp_tokens"] = ids_to_clean_text(tokenizer, dec_inp_ids[idx])
        dummy_tokens.append(ith_tok_dict)

    with open(output_dir + f"/dummy_{prefix}_ids.json", "w", encoding="utf-8") as fout:
        json.dump(dummy_ids, fout, indent=4, ensure_ascii=False)
    with open(output_dir + f"/dummy_{prefix}_token.json", "w", encoding="utf-8") as fout:
        json.dump(dummy_tokens, fout, indent=4, ensure_ascii=False)

def _sorted_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False):
    ordering_and_checkpoint_path = []

    glob_checkpoints = glob.glob(os.path.join(args['output_dir'], "{}-*".format(checkpoint_prefix)))

    for path in glob_checkpoints:
        if use_mtime:
            ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
        else:
            regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path)
            if regex_match and regex_match.groups():
                ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

    checkpoints_sorted = sorted(ordering_and_checkpoint_path)
    checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
    return checkpoints_sorted


def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False):
    if not args['save_total_limit']:
        return
    if args['save_total_limit'] <= 0:
        return

    # Check if we should delete older checkpoint(s)
    checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)
    if len(checkpoints_sorted) <= args['save_total_limit']:
        return

    number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args['save_total_limit'])
    checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
    for checkpoint in checkpoints_to_be_deleted:
        print("Deleting older checkpoint [{}] due to args['save_total_limit']".format(checkpoint))
        shutil.rmtree(checkpoint)

In [49]:
def evaluate(args, eval_dataset, collate_fn, model, tokenizer, prefix=""):
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_output_dir = args['output_dir']

    os.makedirs(eval_output_dir, exist_ok=True)

    args['eval_batch_size'] = args['per_gpu_eval_batch_size'] * max(1, args['n_gpu'])

    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(
        eval_dataset,
        sampler=eval_sampler,
        batch_size=args['eval_batch_size'],
        collate_fn=collate_fn,
        num_workers=4,
    )

    # multi-gpu evaluate
    # if args.n_gpu > 1:
    #     model = torch.nn.DataParallel(model)

    # Eval!
    print("***** Running evaluation *****")
    eval_loss = 0.0
    nb_eval_steps = 0
    model.eval()

    pbar = tqdm(eval_dataloader, desc="Evaluating")
    for batch in pbar:

        with torch.no_grad():
            if args['mlm_amr']:
                masked_input, attention_mask, dec_input, labels = get_ETMG2graph(batch, tokenizer, mlm_prob=0.35)
                masked_input = masked_input.to(args['device'])
                labels = labels.to(args['device'])
                dec_input = dec_input.to(args['device'])
                outputs = model(
                    input_ids=masked_input,
                    attention_mask=attention_mask,
                    decoder_input_ids=dec_input,
                    labels=labels,
                )
                amr_loss = outputs[0]
            else:
                amr_loss = 0

            if args['mlm_text']:
                masked_input, attention_mask, dec_input, labels = get_MTEG2text(batch, tokenizer, mlm_prob=0.35)
                masked_input = masked_input.to(args['device'])
                labels = labels.to(args['device'])
                dec_input = dec_input.to(args['device'])
                outputs = model(
                    input_ids=masked_input,
                    attention_mask=attention_mask,
                    decoder_input_ids=dec_input,
                    labels=labels,
                )
                text_loss = outputs[0]  # model outputs are always tuple in transformers (see doc)
            else:
                text_loss = 0

            if args['mlm_text_plus_amr']:
                masked_input, attention_mask, dec_input, labels = get_PTPG2partial(batch, tokenizer, inp="text")
                masked_input = masked_input.to(args['device'])
                labels = labels.to(args['device'])
                dec_input = dec_input.to(args['device'])
                outputs = model(
                    input_ids=masked_input,
                    attention_mask=attention_mask,
                    decoder_input_ids=dec_input,
                    labels=labels,
                )
                text_joint_loss = outputs[0]
            else:
                text_joint_loss = 0

            if args['mlm_amr_plus_text']:
                masked_input, attention_mask, dec_input, labels = get_PTPG2partial(batch, tokenizer, inp="amr")
                masked_input = masked_input.to(args['device'])
                labels = labels.to(args['device'])
                dec_input = dec_input.to(args['device'])
                outputs = model(
                    input_ids=masked_input,
                    attention_mask=attention_mask,
                    decoder_input_ids=dec_input,
                    labels=labels,
                )
                amr_joint_loss = outputs[0]
            else:
                amr_joint_loss = 0

            if args['mlm_joint_to_text']:
                mlm_prob = 0.35
                masked_input, attention_mask, dec_input, labels = get_MTMG2partial(batch, tokenizer, inp="text", mlm_prob=mlm_prob)
                masked_input = masked_input.to(args['device'])
                labels = labels.to(args['device'])
                dec_input = dec_input.to(args['device'])
                outputs = model(
                    input_ids=masked_input,
                    attention_mask=attention_mask,
                    decoder_input_ids=dec_input,
                    labels=labels,
                )
                text_joint_loss2 = outputs[0]
            else:
                text_joint_loss2 = 0

            if args['mlm_joint_to_amr']:
                mlm_prob = 0.35
                masked_input, attention_mask, dec_input, labels = get_MTMG2partial(batch, tokenizer, inp="amr", mlm_prob=mlm_prob)
                masked_input = masked_input.to(args['device'])
                labels = labels.to(args['device'])
                dec_input = dec_input.to(args['device'])
                outputs = model(
                    input_ids=masked_input,
                    attention_mask=attention_mask,
                    decoder_input_ids=dec_input,
                    labels=labels,
                )
                amr_joint_loss2 = outputs[0]
            else:
                amr_joint_loss2 = 0

            if args['mlm_joint_to_joint']:
                mlm_prob = 0.35
                masked_input, attention_mask, dec_input, labels = get_MTMG2TG(batch, tokenizer, mlm_prob=mlm_prob)
                masked_input = masked_input.to(args['device'])
                labels = labels.to(args['device'])
                dec_input = dec_input.to(args['device'])
                outputs = model(
                    input_ids=masked_input,
                    attention_mask=attention_mask,
                    decoder_input_ids=dec_input,
                    labels=labels,
                )
                joint2joint_loss = outputs[0]
            else:
                joint2joint_loss = 0

            loss = amr_loss + text_loss + text_joint_loss + amr_joint_loss + text_joint_loss2 + amr_joint_loss2 + joint2joint_loss

            pbar.set_postfix(lm_loss=loss.mean().item())

            eval_loss += loss.mean().item()
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps
    perplexity = torch.exp(torch.tensor(eval_loss))

    result = {"perplexity": perplexity, "eval_loss": eval_loss}

    output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
    with open(output_eval_file, "a") as writer:
        for key in sorted(result.keys()):
            print(f"{key} = {result[key]}")
            writer.write("%s = %s\n" % (key, str(result[key])))

    return result

In [50]:
global_step = 0
epochs_trained = 0
epoch_step = 0
steps_trained_in_current_epoch = 0
best_score = float("inf")

tr_loss, logging_loss, epoch_loss = 0.0, 0.0, 0.0
model.zero_grad()

In [52]:
train_iterator = trange(
    epochs_trained,
    int(num_train_epochs),
    desc="Epoch",
    disable=False
)

for epoch in train_iterator:
    epoch_iterator = tqdm(
        train_dataloader, 
        desc="Iteration", 
        disable=True
    )

    for step, batch in enumerate(epoch_iterator):
        if steps_trained_in_current_epoch > 0:
            steps_trained_in_current_epoch -= 1
            continue
        
        model.train()

        if args['mlm_amr']:     # [Empty text + Masked Graph -> graph]
            masked_input, attention_mask, dec_input, labels = get_ETMG2graph(batch, tokenizer, mlm_prob=0.35) 

            masked_input = masked_input.to(args['device'])
            attention_mask = attention_mask.to(args['device'])
            labels = labels.to(args['device'])
            dec_input = dec_input.to(args['device'])

            if step == 0 and epoch == 0:
                save_dummy_batch2(args['output_dir'], masked_input, dec_input, labels, tokenizer, prefix="Etextamr2amr")
            outputs = model(
                input_ids=masked_input,
                attention_mask=attention_mask,
                decoder_input_ids=dec_input,
                labels=labels,
            )
            amr_loss = outputs[0]  # model outputs are always tuple in transformers
        else:
            amr_loss = 0
        
        if args['mlm_text']:    # [Masked Text + Empty Graph -> text]
            masked_input, attention_mask, dec_input, labels = get_MTEG2text(batch, tokenizer, mlm_prob=0.35)

            masked_input = masked_input.to(args['device'])
            attention_mask = attention_mask.to(args['device'])
            labels = labels.to(args['device'])
            dec_input = dec_input.to(args['device'])
            if step == 0 and epoch == 0:
                save_dummy_batch2(args['output_dir'], masked_input, dec_input, labels, tokenizer, prefix="textEamr2text")
            outputs = model(
                input_ids=masked_input,
                attention_mask=attention_mask,
                decoder_input_ids=dec_input,
                labels=labels,
            )
            text_loss = outputs[0]  # model outputs are always tuple in transformers (see doc)
        else:
            text_loss = 0
        
        if args['mlm_text_plus_amr']:   # [Masked text + Graph -> Text] 
            if step % args['joint_train_interval'] == 0:
                mlm_prob = 0.1 + global_step / args['max_steps'] * 0.75
                masked_input, attention_mask, dec_input, labels = get_PTPG2partial(batch, tokenizer, inp="text", mlm_prob=mlm_prob)
                masked_input = masked_input.to(args['device'])
                attention_mask = attention_mask.to(args['device'])
                labels = labels.to(args['device'])
                dec_input = dec_input.to(args['device'])

                if step == 0 and epoch == 0:
                    save_dummy_batch2(args['output_dir'], masked_input, dec_input, labels, tokenizer, prefix="val_MtextAmr2text")
                outputs = model(
                    input_ids=masked_input,
                    attention_mask=attention_mask,
                    decoder_input_ids=dec_input,
                    labels=labels,
                )
                text_joint_loss = outputs[0]
            else:
                text_joint_loss = 0
        else:
            text_joint_loss = 0

        if args['mlm_amr_plus_text']:   # [Text + Masked Graph -> Graph]
            if step % args['joint_train_interval'] == 0:
                mlm_prob = 0.1 + global_step / args['max_steps'] * 0.75
                masked_input, attention_mask, dec_input, labels = get_PTPG2partial(batch, tokenizer, inp="amr", mlm_prob=mlm_prob)
                masked_input = masked_input.to(args['device'])
                attention_mask = attention_mask.to(args['device'])
                labels = labels.to(args['device'])
                dec_input = dec_input.to(args['device'])
                if step == 0 and epoch == 0:
                    save_dummy_batch2(args['output_dir'],masked_input,dec_input,labels,tokenizer,prefix="val_TextMamr2amr",)

                outputs = model(
                    input_ids=masked_input,
                    attention_mask=attention_mask,
                    decoder_input_ids=dec_input,
                    labels=labels,
                )
                amr_joint_loss = outputs[0]
            else:
                amr_joint_loss = 0
        else:
            amr_joint_loss = 0

        if args['mlm_joint_to_text']:   # [Masked Text + Masked Graph -> text]
            mlm_prob = 0.35
            masked_input, attention_mask, dec_input, labels = get_MTMG2partial(batch, tokenizer, inp="text", mlm_prob=mlm_prob)
            masked_input = masked_input.to(args['device'])
            attention_mask = attention_mask.to(args['device'])
            labels = labels.to(args['device'])
            dec_input = dec_input.to(args['device'])
            if step == 0 and epoch == 0:
                save_dummy_batch2(args['output_dir'],masked_input,dec_input,labels,tokenizer,prefix="val_MtextMamr2text",)
            outputs = model(
                input_ids=masked_input,
                attention_mask=attention_mask,
                decoder_input_ids=dec_input,
                labels=labels,
            )
            text_joint_loss2 = outputs[0]
        else:
            text_joint_loss2 = 0

        if args['mlm_joint_to_amr']:    # [Masked Text + Masked Graph -> graph]
            mlm_prob = 0.35
            masked_input, attention_mask, dec_input, labels = get_MTMG2partial(batch, tokenizer, inp="amr", mlm_prob=mlm_prob)
            masked_input = masked_input.to(args['device'])
            attention_mask = attention_mask.to(args['device'])
            labels = labels.to(args['device'])
            dec_input = dec_input.to(args['device'])
            if step == 0 and epoch == 0:
                save_dummy_batch2(args['output_dir'], masked_input, dec_input, labels, tokenizer, prefix="val_MtextMamr2amr")

            outputs = model(
                input_ids=masked_input,
                attention_mask=attention_mask,
                decoder_input_ids=dec_input,
                labels=labels,
            )
            amr_joint_loss2 = outputs[0]
        else:
            amr_joint_loss2 = 0

        if args['mlm_joint_to_joint']:
            mlm_prob = 0.35
            masked_input, attention_mask, dec_input, labels = get_MTMG2TG(batch, tokenizer, mlm_prob=mlm_prob)
            masked_input = masked_input.to(args['device'])
            attention_mask = attention_mask.to(args['device'])
            attention_mask = attention_mask.to(args['device'])
            labels = labels.to(args['device'])
            dec_input = dec_input.to(args['device'])
            if step == 0 and epoch == 0:
                save_dummy_batch2(args['output_dir'],masked_input,dec_input,labels,tokenizer,prefix="val_MtextMamr2textamr",)

            outputs = model(
                input_ids=masked_input,
                attention_mask=attention_mask,
                decoder_input_ids=dec_input,
                labels=labels,
            )
            joint2joint_loss = outputs[0]
        else:
            joint2joint_loss = 0
        
        loss = amr_loss + text_loss+ text_joint_loss+ amr_joint_loss+ text_joint_loss2+ amr_joint_loss2+ joint2joint_loss
        loss = loss / args['gradient_accumulation_steps']

        # epoch_iterator.set_postfix(lm_loss=loss.item(), lr=scheduler.get_lr()[0])
        epoch_iterator.set_postfix(lm_loss=loss.item(), lr=scheduler.get_last_lr()[0])

        if args['fp16']:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        epoch_step += 1
        tr_loss += loss.item()
        epoch_loss += loss.item()

        if (step + 1) % args['gradient_accumulation_steps'] == 0:
            if args['fp16']:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args['max_grad_norm'])
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args['max_grad_norm'])

            optimizer.step()
            scheduler.step()  # Update learning rate schedule
            model.zero_grad()
            global_step += 1

            if (args['logging_steps'] > 0 and global_step % args['logging_steps'] == 0):

                if args['evaluate_during_training']:
                    results = evaluate(args, dev_dataset, seq2seq_collate_fn, model, tokenizer)
                    cur_score = results['perplexity'].item()

                    if cur_score < best_score:
                        best_score = cur_score
                        checkpoint_prefix = "checkpoint"
                        # Save model checkpoint
                        output_dir = os.path.join(
                            args['output_dir'],
                            "{}-{}-{:.3f}".format(checkpoint_prefix, global_step, best_score),
                        )
                        os.makedirs(output_dir, exist_ok=True)
                        model_to_save = (model.module if hasattr(model, "module") else model)  # Take care of distributed/parallel training
                        model_to_save.save_pretrained(output_dir)
                        # tokenizer.save_pretrained(output_dir)
                        with open(f'{epoch_output_dir}/tokenizer.pickle', 'wb') as handle:
                            pickle.dump(tokenizer, handle, protocol=pickle.HIGHEST_PROTOCOL)

                        torch.save(args, os.path.join(output_dir, "training_args.bin"))
                        print(f"Saving model checkpoint to {output_dir}")

                        _rotate_checkpoints(args, checkpoint_prefix)

                        torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                        torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                        print(f"Saving optimizer and scheduler states to {output_dir}")

                logging_loss = tr_loss

        if args['max_steps'] > 0 and global_step > args['max_steps']:
            epoch_iterator.close()
            break
    
    if args['max_steps'] > 0 and global_step > args['max_steps']:
        results = evaluate(args, dev_dataset, seq2seq_collate_fn, model, tokenizer)
        cur_score = results["perplexity"].item()
        checkpoint_prefix = "checkpoint"
        # Save model checkpoint
        ckpt_output_dir = os.path.join(args['output_dir'], "{}-last-{:.3f}".format(checkpoint_prefix, cur_score))
        os.makedirs(ckpt_output_dir, exist_ok=True)
        model_to_save = (model.module if hasattr(model, "module") else model)  # Take care of distributed/parallel training
        model_to_save.save_pretrained(ckpt_output_dir)
        # tokenizer.save_pretrained(ckpt_output_dir)
        with open(f'{epoch_output_dir}/tokenizer.pickle', 'wb') as handle:
            pickle.dump(tokenizer, handle, protocol=pickle.HIGHEST_PROTOCOL)
        print(f"Saving model checkpoint to {ckpt_output_dir}")
        train_iterator.close()
        break

    checkpoint_prefix = "checkpoint"
    epoch_output_dir = os.path.join(args['output_dir'], "{}-last-epoch".format(checkpoint_prefix),)
    os.makedirs(epoch_output_dir, exist_ok=True)
    model_to_save = (model.module if hasattr(model, "module") else model)  # Take care of distributed/parallel training
    model_to_save.save_pretrained(epoch_output_dir)
    # tokenizer.save_pretrained(epoch_output_dir)
    with open(f'{epoch_output_dir}/tokenizer.pickle', 'wb') as handle:
        pickle.dump(tokenizer, handle, protocol=pickle.HIGHEST_PROTOCOL)
    print(f"Saving model checkpoint to {epoch_output_dir}")
    avg_epoch_loss = epoch_loss / epoch_step
    print(f'avg_train_loss = {avg_epoch_loss}')


Epoch:   0%|          | 0/13 [13:06<?, ?it/s]


KeyboardInterrupt: 