In [1]:
import torch
import torch.nn as nn
import numpy as np
import random
import os
import logging
import argparse
import math
import numpy as np
from tqdm import tqdm
import multiprocessing
import time
from dotmap import DotMap
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import T5Config, T5ForConditionalGeneration, RobertaTokenizer

In [2]:
from models import build_or_load_gen_model
from evaluator import smooth_bleu
from evaluator.CodeBLEU import calc_code_bleu
from evaluator.bleu import _bleu
from utils import get_filenames, get_elapse_time, load_and_cache_gen_data
from configs import set_seed, set_dist

In [3]:
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)


def eval_ppl_epoch(args, eval_data, eval_examples, model, tokenizer):
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size,
                                 num_workers=4, pin_memory=True)
    logger.info("  " + "***** Running ppl evaluation *****")
    logger.info("  Num examples = %d", len(eval_examples))
    logger.info("  Batch size = %d", args.eval_batch_size)

    model.eval()
    eval_loss, batch_num = 0, 0
    for batch in tqdm(eval_dataloader, total=len(eval_dataloader), desc="Eval ppl"):
        batch = tuple(t.to(args.device) for t in batch)
        source_ids, target_ids = batch
        source_mask = source_ids.ne(tokenizer.pad_token_id)
        target_mask = target_ids.ne(tokenizer.pad_token_id)

        with torch.no_grad():
            if args.model_type == 'roberta':
                loss, _, _ = model(source_ids=source_ids, source_mask=source_mask,
                                   target_ids=target_ids, target_mask=target_mask)
            else:
                outputs = model(input_ids=source_ids, attention_mask=source_mask,
                                labels=target_ids, decoder_attention_mask=target_mask)
                loss = outputs.loss

        eval_loss += loss.item()
        batch_num += 1
    eval_loss = eval_loss / batch_num
    eval_ppl = round(np.exp(eval_loss), 5)
    return eval_ppl

In [4]:
def eval_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, split_tag, criteria):
    logger.info("  ***** Running bleu evaluation on {} data*****".format(split_tag))
    logger.info("  Num examples = %d", len(eval_examples))
    logger.info("  Batch size = %d", args.eval_batch_size)
    eval_sampler = SequentialSampler(eval_data)
    if args.data_num == -1:
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size,
                                     num_workers=4, pin_memory=True)
    else:
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

    model.eval()
    pred_ids = []
    bleu, codebleu = 0.0, 0.0
    for batch in tqdm(eval_dataloader, total=len(eval_dataloader), desc="Eval bleu for {} set".format(split_tag)):
        source_ids = batch[0].to(args.device)
        source_mask = source_ids.ne(tokenizer.pad_token_id)
        with torch.no_grad():
            if args.model_type == 'roberta':
                preds = model(source_ids=source_ids, source_mask=source_mask)

                top_preds = [pred[0].cpu().numpy() for pred in preds]
            else:
                preds = model.generate(source_ids,
                                       attention_mask=source_mask,
                                       use_cache=True,
                                       num_beams=args.beam_size,
                                       early_stopping=args.task == 'summarize',
                                       max_length=args.max_target_length)
                top_preds = list(preds.cpu().numpy())
            pred_ids.extend(top_preds)

    pred_nls = [tokenizer.decode(id, skip_special_tokens=True, clean_up_tokenization_spaces=False) for id in pred_ids]

    output_fn = os.path.join(args.res_dir, "test_{}.output".format(criteria))
    gold_fn = os.path.join(args.res_dir, "test_{}.gold".format(criteria))
    src_fn = os.path.join(args.res_dir, "test_{}.src".format(criteria))

    if args.task in ['defect']:
        target_dict = {0: 'false', 1: 'true'}
        golds = [target_dict[ex.target] for ex in eval_examples]
        eval_acc = np.mean([int(p == g) for p, g in zip(pred_nls, golds)])
        result = {'em': eval_acc * 100, 'bleu': 0, 'codebleu': 0}

        with open(output_fn, 'w') as f, open(gold_fn, 'w') as f1, open(src_fn, 'w') as f2:
            for pred_nl, gold in zip(pred_nls, eval_examples):
                f.write(pred_nl.strip() + '\n')
                f1.write(target_dict[gold.target] + '\n')
                f2.write(gold.source.strip() + '\n')
            logger.info("Save the predictions into %s", output_fn)
    else:
        dev_accs, predictions = [], []
        with open(output_fn, 'w') as f, open(gold_fn, 'w') as f1, open(src_fn, 'w') as f2:
            for pred_nl, gold in zip(pred_nls, eval_examples):
                dev_accs.append(pred_nl.strip() == gold.target.strip())
                if args.task in ['summarize']:
                    # for smooth-bleu4 evaluation
                    predictions.append(str(gold.idx) + '\t' + pred_nl)
                    f.write(str(gold.idx) + '\t' + pred_nl.strip() + '\n')
                    f1.write(str(gold.idx) + '\t' + gold.target.strip() + '\n')
                    f2.write(str(gold.idx) + '\t' + gold.source.strip() + '\n')
                else:
                    f.write(pred_nl.strip() + '\n')
                    f1.write(gold.target.strip() + '\n')
                    f2.write(gold.source.strip() + '\n')

        if args.task == 'summarize':
            (goldMap, predictionMap) = smooth_bleu.computeMaps(predictions, gold_fn)
            bleu = round(smooth_bleu.bleuFromMaps(goldMap, predictionMap)[0], 2)
        else:
            bleu = round(_bleu(gold_fn, output_fn), 2)
            if args.task in ['concode', 'translate', 'refine']:
               codebleu = calc_code_bleu.get_codebleu(gold_fn, output_fn, args.lang)

        result = {'em': np.mean(dev_accs) * 100, 'bleu': bleu}
        result['codebleu'] = codebleu * 100

    logger.info("***** Eval results *****")
    for key in sorted(result.keys()):
        logger.info("  %s = %s", key, str(round(result[key], 4)))

    return result

In [7]:
seed_ = 1234
random.seed(seed_)
np.random.seed(seed_)
torch.manual_seed(seed_)
torch.cuda.manual_seed_all(seed_)

In [8]:
WORKDIR="/home/okozlova/diplom_oksana/CodeT5/CodeT5"
MODEL_TAG = 'codet5_base' # 'codet5_small', 'codet5_large', 'codet5p-220m'
MODEL_DIR = 'sh/saved_models'
TASK = 'translate'
SUB_TASK = 'cs-java'
DATA_TAG = 'all'
LR = 5e-5
BS = 16
SRC_LEN = 320
TRG_LEN = 256
PATIENCE = 5
EPOCH = 100
FULL_MODEL_TAG=f'{MODEL_TAG}_{DATA_TAG}_lr{LR}_bs{BS}_src{SRC_LEN}_trg{TRG_LEN}_pat{PATIENCE}_e{EPOCH}'
OUTPUT_DIR=f'{MODEL_DIR}/{TASK}/{SUB_TASK}/{FULL_MODEL_TAG}'
CACHE_DIR=f'{OUTPUT_DIR}/cache_data'
RES_DIR=f'{OUTPUT_DIR}/prediction'
LOG=f'{OUTPUT_DIR}/train.log'
DATA_DIR = '/home/okozlova/diplom_oksana/CodeT5/CodeT5/data/translate/'
args_dict = {
    "task": 'translate',
    "sub_task": SUB_TASK,
    "lang": 'java', # 'c_sharp',
    "model_tag": MODEL_TAG,
    "res_dir": 'sh/results',
    "model_dir": 'sh/saved_models',
    "summary_dir": 'tensorboard',
    "data_num": -1,
    "gpu": 0,
    "do_train": True,
    "do_eval": True,
    "do_eval_bleu": True,
    "do_test": True,
    "model_type": 'codet5',
    "num_train_epochs": EPOCH, # 100
    "warmup_steps": 1000,
    "learning_rate": LR,
    "patience": PATIENCE, 
    "tokenizer_name": 'Salesforce/codet5-base',
    "model_name_or_path": 'Salesforce/codet5-base',
    'data_dir': WORKDIR + '/data',
    'cache_path': CACHE_DIR,
    'output_dir': OUTPUT_DIR,
    'summary_dir': 'tensorboard',
    'save_last_checkpoints': True,
    'always_save_model': True,
    'res_dir': RES_DIR,
    'train_batch_size': BS,
    'eval_batch_size': BS, 
    'max_source_length': SRC_LEN,
    'max_target_length': TRG_LEN,
    'seed': 1234,
    'local_rank': 0,
    'no_cuda': False,
    'n_gpu': 1,
    'device': torch.device("cuda", 0),
    'train_filename': f'{DATA_DIR}/train.java-cs.txt.cs,{DATA_DIR}/train.java-cs.txt.java',
    'dev_filename': f'{DATA_DIR}/valid.java-cs.txt.cs,{DATA_DIR}/valid.java-cs.txt.java',
    'test_filename': f'{DATA_DIR}/test.java-cs.txt.cs,{DATA_DIR}/test.java-cs.txt.java',
    'weight_decay': 0.0,
    'adam_epsilon': 1e-8,
    'gradient_accumulation_steps': 1, # Number of updates steps to accumulate before performing a backward/update pass
    'load_model_path': None,
    'cpu_cont': 16,
    'beam_size': 10
    
}


In [9]:
args = DotMap(args_dict)

In [10]:
logger.info(args)
t0 = time.time()

config, model, tokenizer = build_or_load_gen_model(args)
model.to(args.device)
if args.n_gpu > 1:
    model = torch.nn.DataParallel(model)
pool = multiprocessing.Pool(args.cpu_cont)
fa = open(os.path.join(args.output_dir, 'summary.log'), 'a+')

if args.do_train:
    if args.local_rank in [-1, 0] and args.data_num == -1:
        summary_fn = '{}/{}'.format(args.summary_dir, '/'.join(args.output_dir.split('/')[1:]))
        tb_writer = SummaryWriter(summary_fn)

    train_examples, train_data = load_and_cache_gen_data(args, args.train_filename, pool, tokenizer, 'train')
    train_sampler = RandomSampler(train_data) 
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size,
                                  num_workers=4, pin_memory=True)

    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    num_train_optimization_steps = args.num_train_epochs * len(train_dataloader)
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=args.warmup_steps,
                                                num_training_steps=num_train_optimization_steps)

    # Start training
    train_example_num = len(train_data)
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", train_example_num)
    logger.info("  Batch size = %d", args.train_batch_size)
    logger.info("  Batch num = %d", math.ceil(train_example_num / args.train_batch_size))
    logger.info("  Num epoch = %d", args.num_train_epochs)

    dev_dataset = {}
    global_step, best_bleu_em, best_ppl = 0, -1, 1e6
    not_loss_dec_cnt, not_bleu_em_inc_cnt = 0, 0 if args.do_eval_bleu else 1e6

    for cur_epoch in range(0, int(args.num_train_epochs)):
        bar = tqdm(train_dataloader, total=len(train_dataloader), desc="Training")
        nb_tr_examples, nb_tr_steps, tr_loss = 0, 0, 0
        model.train()
        for step, batch in enumerate(bar):
            batch = tuple(t.to(args.device) for t in batch)
            source_ids, target_ids = batch
            source_mask = source_ids.ne(tokenizer.pad_token_id)
            target_mask = target_ids.ne(tokenizer.pad_token_id)

            if args.model_type == 'roberta':
                loss, _, _ = model(source_ids=source_ids, source_mask=source_mask,
                                   target_ids=target_ids, target_mask=target_mask)
            else:
                outputs = model(input_ids=source_ids, attention_mask=source_mask,
                                labels=target_ids, decoder_attention_mask=target_mask)
                loss = outputs.loss

            if args.n_gpu > 1:
                loss = loss.mean()
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            tr_loss += loss.item()

            nb_tr_examples += source_ids.size(0)
            nb_tr_steps += 1
            loss.backward()

            if nb_tr_steps % args.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()
                global_step += 1
                train_loss = round(tr_loss * args.gradient_accumulation_steps / (nb_tr_steps + 1), 4)
                bar.set_description("[{}] Train loss {}".format(cur_epoch, round(train_loss, 3)))

        if args.do_eval:
            if 'dev_loss' in dev_dataset:
                eval_examples, eval_data = dev_dataset['dev_loss']
            else:
                eval_examples, eval_data = load_and_cache_gen_data(args, args.dev_filename, pool, tokenizer, 'dev')
                dev_dataset['dev_loss'] = eval_examples, eval_data

            eval_ppl = eval_ppl_epoch(args, eval_data, eval_examples, model, tokenizer)
            result = {'epoch': cur_epoch, 'global_step': global_step, 'eval_ppl': eval_ppl}
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
            logger.info("  " + "*" * 20)
            if args.data_num == -1:
                tb_writer.add_scalar('dev_ppl', eval_ppl, cur_epoch)

            if args.save_last_checkpoints:
                last_output_dir = os.path.join(args.output_dir, 'checkpoint-last')
                if not os.path.exists(last_output_dir):
                    os.makedirs(last_output_dir)
                model_to_save = model.module if hasattr(model, 'module') else model
                output_model_file = os.path.join(last_output_dir, "pytorch_model.bin")
                torch.save(model_to_save.state_dict(), output_model_file)
                logger.info("Save the last model into %s", output_model_file)

            if eval_ppl < best_ppl:
                not_loss_dec_cnt = 0
                logger.info("  Best ppl:%s", eval_ppl)
                logger.info("  " + "*" * 20)
                fa.write("[%d] Best ppl changed into %.4f\n" % (cur_epoch, eval_ppl))
                best_ppl = eval_ppl

                output_dir = os.path.join(args.output_dir, 'checkpoint-best-ppl')
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                if args.always_save_model:
                    model_to_save = model.module if hasattr(model, 'module') else model
                    output_model_file = os.path.join(output_dir, "pytorch_model.bin")
                    torch.save(model_to_save.state_dict(), output_model_file)
                    logger.info("Save the best ppl model into %s", output_model_file)
            else:
                not_loss_dec_cnt += 1
                logger.info("Ppl does not decrease for %d epochs", not_loss_dec_cnt)
                if all([x > args.patience for x in [not_bleu_em_inc_cnt, not_loss_dec_cnt]]):
                    early_stop_str = "[%d] Early stop as not_bleu_em_inc_cnt=%d, and not_loss_dec_cnt=%d\n" % (
                        cur_epoch, not_bleu_em_inc_cnt, not_loss_dec_cnt)
                    logger.info(early_stop_str)
                    fa.write(early_stop_str)
                    break
            logger.info("***** CUDA.empty_cache() *****")
            torch.cuda.empty_cache()
            if args.do_eval_bleu:
                eval_examples, eval_data = load_and_cache_gen_data(args, args.dev_filename, pool, tokenizer, 'dev',
                                                                   only_src=True, is_sample=True)

                result = eval_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, 'dev', 'e%d' % cur_epoch)
                dev_bleu, dev_em = result['bleu'], result['em']
                if args.task in ['summarize']:
                    dev_bleu_em = dev_bleu
                elif args.task in ['defect']:
                    dev_bleu_em = dev_em
                else:
                    dev_bleu_em = dev_bleu + dev_em
                if args.data_num == -1:
                    tb_writer.add_scalar('dev_bleu_em', dev_bleu_em, cur_epoch)
                    # tb_writer.add_scalar('dev_em', dev_em, cur_epoch)
                if dev_bleu_em > best_bleu_em:
                    not_bleu_em_inc_cnt = 0
                    logger.info("  [%d] Best bleu+em: %.2f (bleu: %.2f, em: %.2f)",
                                cur_epoch, dev_bleu_em, dev_bleu, dev_em)
                    logger.info("  " + "*" * 20)
                    best_bleu_em = dev_bleu_em
                    fa.write("[%d] Best bleu+em changed into %.2f (bleu: %.2f, em: %.2f)\n" % (
                        cur_epoch, best_bleu_em, dev_bleu, dev_em))
                    # Save best checkpoint for best bleu
                    output_dir = os.path.join(args.output_dir, 'checkpoint-best-bleu')
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    if args.data_num == -1 or args.always_save_model:
                        model_to_save = model.module if hasattr(model, 'module') else model
                        output_model_file = os.path.join(output_dir, "pytorch_model.bin")
                        torch.save(model_to_save.state_dict(), output_model_file)
                        logger.info("Save the best bleu model into %s", output_model_file)
                else:
                    not_bleu_em_inc_cnt += 1
                    logger.info("Bleu does not increase for %d epochs", not_bleu_em_inc_cnt)
                    fa.write(
                        "[%d] Best bleu+em (%.2f) does not drop changed for %d epochs, cur bleu+em: %.2f (bleu: %.2f, em: %.2f)\n" % (
                            cur_epoch, best_bleu_em, not_bleu_em_inc_cnt, dev_bleu_em, dev_bleu, dev_em))
                    if all([x > args.patience for x in [not_bleu_em_inc_cnt, not_loss_dec_cnt]]):
                        stop_early_str = "[%d] Early stop as not_bleu_em_inc_cnt=%d, and not_loss_dec_cnt=%d\n" % (
                            cur_epoch, not_bleu_em_inc_cnt, not_loss_dec_cnt)
                        logger.info(stop_early_str)
                        fa.write(stop_early_str)
                        break
        logger.info("***** CUDA.empty_cache() *****")
        torch.cuda.empty_cache()

    if args.local_rank in [-1, 0] and args.data_num == -1:
        tb_writer.close()
    logger.info("Finish training and take %s", get_elapse_time(t0))

if args.do_test:
    logger.info("  " + "***** Testing *****")
    logger.info("  Batch size = %d", args.eval_batch_size)

    for criteria in ['best-bleu']:
        file = os.path.join(args.output_dir, 'checkpoint-{}/pytorch_model.bin'.format(criteria))
        logger.info("Reload model from {}".format(file))
        model.load_state_dict(torch.load(file))
        eval_examples, eval_data = load_and_cache_gen_data(args, args.test_filename, pool, tokenizer, 'test',
                                                           only_src=True, is_sample=False)
        result = eval_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, 'test', criteria)
        test_bleu, test_em = result['bleu'], result['em']
        test_codebleu = result['codebleu'] if 'codebleu' in result else 0
        result_str = "[%s] bleu-4: %.2f, em: %.4f, codebleu: %.4f\n" % (criteria, test_bleu, test_em, test_codebleu)
        logger.info(result_str)
        fa.write(result_str)
        if args.res_fn:
            with open(args.res_fn, 'a+') as f:
                f.write('[Time: {}] {}\n'.format(get_elapse_time(t0), file))
                f.write(result_str)
logger.info("Finish and take {}".format(get_elapse_time(t0)))
fa.write("Finish and take {}".format(get_elapse_time(t0)))
fa.close()


05/13/2024 12:46:29 - INFO - __main__ -   DotMap(task='translate', sub_task='cs-java', lang='java', model_tag='codet5_base', res_dir='sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/prediction', model_dir='sh/saved_models', summary_dir='tensorboard', data_num=-1, gpu=0, do_train=True, do_eval=True, do_eval_bleu=True, do_test=True, model_type='codet5', num_train_epochs=100, warmup_steps=1000, learning_rate=5e-05, patience=5, tokenizer_name='Salesforce/codet5-base', model_name_or_path='Salesforce/codet5-base', data_dir='/home/mzhelezin/diplom_oksana/CodeT5/CodeT5/data', cache_path='sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/cache_data', output_dir='sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100', save_last_checkpoints=True, always_save_model=True, train_batch_size=16, eval_batch_size=16, max_source_length=320, max_target_length=256, seed=1234, local_rank=0, no_cuda=F

args.model_type:  codet5


05/13/2024 12:46:31 - INFO - models -   Finish loading model [223M] from Salesforce/codet5-base
05/13/2024 12:46:37 - INFO - utils -   Read 10295 examples, avg src len: 15, avg trg len: 13, max src len: 118, max trg len: 136
05/13/2024 12:46:37 - INFO - utils -   [TOKENIZE] avg src len: 56, avg trg len: 45, max src len: 404, max trg len: 391
05/13/2024 12:46:37 - INFO - utils -   Load cache data from sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/cache_data/train_all.pt
05/13/2024 12:46:37 - INFO - __main__ -   ***** Running training *****
05/13/2024 12:46:37 - INFO - __main__ -     Num examples = 10295
05/13/2024 12:46:37 - INFO - __main__ -     Batch size = 16
05/13/2024 12:46:37 - INFO - __main__ -     Batch num = 644
05/13/2024 12:46:37 - INFO - __main__ -     Num epoch = 100
[0] Train loss 0.17: 100%|████████████████████████████████████████████████████████████████████████████████████████| 644/644 [05:25<00:00,  1.98it/s]
05/13/2024 12:52:02 

ngram match: 0.6637586094527332, weighted ngram match: 0.6683766329126941, syntax_match: 0.8302033751622674, dataflow_match: 0.798823016564952


05/13/2024 12:55:11 - INFO - __main__ -   Save the best bleu model into sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin
05/13/2024 12:55:11 - INFO - __main__ -   ***** CUDA.empty_cache() *****
[1] Train loss 0.055: 100%|███████████████████████████████████████████████████████████████████████████████████████| 644/644 [05:27<00:00,  1.96it/s]
05/13/2024 13:00:39 - INFO - __main__ -     ***** Running ppl evaluation *****
05/13/2024 13:00:39 - INFO - __main__ -     Num examples = 499
05/13/2024 13:00:39 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.04it/s]
05/13/2024 13:00:45 - INFO - __main__ -     epoch = 1
05/13/2024 13:00:45 - INFO - __main__ -     eval_ppl = 1.04429
05/13/2024 13:00:45 - INFO - __main__ -     global_step = 1288
05/13/2024 13:00:45 - INFO - __main__ -     *****

ngram match: 0.7304604693411234, weighted ngram match: 0.7357119208223809, syntax_match: 0.8632626568585028, dataflow_match: 0.8323888404533566


05/13/2024 13:03:47 - INFO - __main__ -   Save the best bleu model into sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin
05/13/2024 13:03:47 - INFO - __main__ -   ***** CUDA.empty_cache() *****
[2] Train loss 0.04: 100%|████████████████████████████████████████████████████████████████████████████████████████| 644/644 [05:27<00:00,  1.96it/s]
05/13/2024 13:09:15 - INFO - __main__ -     ***** Running ppl evaluation *****
05/13/2024 13:09:15 - INFO - __main__ -     Num examples = 499
05/13/2024 13:09:15 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.03it/s]
05/13/2024 13:09:20 - INFO - __main__ -     epoch = 2
05/13/2024 13:09:20 - INFO - __main__ -     eval_ppl = 1.03917
05/13/2024 13:09:20 - INFO - __main__ -     global_step = 1932
05/13/2024 13:09:20 - INFO - __main__ -     *****

ngram match: 0.7512912518720135, weighted ngram match: 0.756249297194844, syntax_match: 0.8820424058848984, dataflow_match: 0.8700959023539668


05/13/2024 13:12:19 - INFO - __main__ -   Save the best bleu model into sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin
05/13/2024 13:12:19 - INFO - __main__ -   ***** CUDA.empty_cache() *****
[3] Train loss 0.03: 100%|████████████████████████████████████████████████████████████████████████████████████████| 644/644 [05:27<00:00,  1.96it/s]
05/13/2024 13:17:47 - INFO - __main__ -     ***** Running ppl evaluation *****
05/13/2024 13:17:47 - INFO - __main__ -     Num examples = 499
05/13/2024 13:17:47 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.03it/s]
05/13/2024 13:17:52 - INFO - __main__ -     epoch = 3
05/13/2024 13:17:52 - INFO - __main__ -     eval_ppl = 1.03704
05/13/2024 13:17:52 - INFO - __main__ -     global_step = 2576
05/13/2024 13:17:52 - INFO - __main__ -     *****

ngram match: 0.7659669577119461, weighted ngram match: 0.7708417258217134, syntax_match: 0.8925140631761143, dataflow_match: 0.8698779424585876


05/13/2024 13:20:52 - INFO - __main__ -   Save the best bleu model into sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin
05/13/2024 13:20:52 - INFO - __main__ -   ***** CUDA.empty_cache() *****
[4] Train loss 0.023: 100%|███████████████████████████████████████████████████████████████████████████████████████| 644/644 [05:27<00:00,  1.97it/s]
05/13/2024 13:26:20 - INFO - __main__ -     ***** Running ppl evaluation *****
05/13/2024 13:26:20 - INFO - __main__ -     Num examples = 499
05/13/2024 13:26:20 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.03it/s]
05/13/2024 13:26:25 - INFO - __main__ -     epoch = 4
05/13/2024 13:26:25 - INFO - __main__ -     eval_ppl = 1.03543
05/13/2024 13:26:25 - INFO - __main__ -     global_step = 3220
05/13/2024 13:26:25 - INFO - __main__ -     *****

ngram match: 0.7829057855080891, weighted ngram match: 0.7880062920103581, syntax_match: 0.8967546516659455, dataflow_match: 0.8949433304272014


05/13/2024 13:29:20 - INFO - __main__ -   Save the best bleu model into sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin
05/13/2024 13:29:20 - INFO - __main__ -   ***** CUDA.empty_cache() *****
[5] Train loss 0.019: 100%|███████████████████████████████████████████████████████████████████████████████████████| 644/644 [05:27<00:00,  1.97it/s]
05/13/2024 13:34:48 - INFO - __main__ -     ***** Running ppl evaluation *****
05/13/2024 13:34:48 - INFO - __main__ -     Num examples = 499
05/13/2024 13:34:48 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.03it/s]
05/13/2024 13:34:53 - INFO - __main__ -     epoch = 5
05/13/2024 13:34:53 - INFO - __main__ -     eval_ppl = 1.03538
05/13/2024 13:34:53 - INFO - __main__ -     global_step = 3864
05/13/2024 13:34:53 - INFO - __main__ -     *****

ngram match: 0.7858288561206691, weighted ngram match: 0.790895052193847, syntax_match: 0.9016875811337084, dataflow_match: 0.9003923278116827


05/13/2024 13:37:51 - INFO - __main__ -   Save the best bleu model into sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin
05/13/2024 13:37:51 - INFO - __main__ -   ***** CUDA.empty_cache() *****
[6] Train loss 0.015: 100%|███████████████████████████████████████████████████████████████████████████████████████| 644/644 [05:27<00:00,  1.97it/s]
05/13/2024 13:43:18 - INFO - __main__ -     ***** Running ppl evaluation *****
05/13/2024 13:43:18 - INFO - __main__ -     Num examples = 499
05/13/2024 13:43:18 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.03it/s]
05/13/2024 13:43:24 - INFO - __main__ -     epoch = 6
05/13/2024 13:43:24 - INFO - __main__ -     eval_ppl = 1.03521
05/13/2024 13:43:24 - INFO - __main__ -     global_step = 4508
05/13/2024 13:43:24 - INFO - __main__ -     *****

ngram match: 0.7798880788481343, weighted ngram match: 0.7845152807024449, syntax_match: 0.8991778450887062, dataflow_match: 0.8916739319965127


[7] Train loss 0.012: 100%|███████████████████████████████████████████████████████████████████████████████████████| 644/644 [05:27<00:00,  1.97it/s]
05/13/2024 13:51:52 - INFO - __main__ -     ***** Running ppl evaluation *****
05/13/2024 13:51:52 - INFO - __main__ -     Num examples = 499
05/13/2024 13:51:52 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.04it/s]
05/13/2024 13:51:58 - INFO - __main__ -     epoch = 7
05/13/2024 13:51:58 - INFO - __main__ -     eval_ppl = 1.03473
05/13/2024 13:51:58 - INFO - __main__ -     global_step = 5152
05/13/2024 13:51:58 - INFO - __main__ -     ********************
05/13/2024 13:51:59 - INFO - __main__ -   Save the last model into sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-last/pytorch_model.bin
05/13/2024 13:51:59 - INFO - __main__ -     Best ppl:1.03473
05/13/

ngram match: 0.7848160427066382, weighted ngram match: 0.7898810565084641, syntax_match: 0.8922544353093899, dataflow_match: 0.8705318221447254


05/13/2024 13:55:03 - INFO - __main__ -   Save the best bleu model into sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin
05/13/2024 13:55:03 - INFO - __main__ -   ***** CUDA.empty_cache() *****
[8] Train loss 0.01: 100%|████████████████████████████████████████████████████████████████████████████████████████| 644/644 [05:27<00:00,  1.97it/s]
05/13/2024 14:00:31 - INFO - __main__ -     ***** Running ppl evaluation *****
05/13/2024 14:00:31 - INFO - __main__ -     Num examples = 499
05/13/2024 14:00:31 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.04it/s]
05/13/2024 14:00:36 - INFO - __main__ -     epoch = 8
05/13/2024 14:00:36 - INFO - __main__ -     eval_ppl = 1.03612
05/13/2024 14:00:36 - INFO - __main__ -     global_step = 5796
05/13/2024 14:00:36 - INFO - __main__ -     *****

ngram match: 0.790452285534264, weighted ngram match: 0.7953546494024162, syntax_match: 0.8946776287321506, dataflow_match: 0.8982127288578902


[9] Train loss 0.009: 100%|███████████████████████████████████████████████████████████████████████████████████████| 644/644 [05:27<00:00,  1.97it/s]
05/13/2024 14:09:09 - INFO - __main__ -     ***** Running ppl evaluation *****
05/13/2024 14:09:09 - INFO - __main__ -     Num examples = 499
05/13/2024 14:09:09 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.03it/s]
05/13/2024 14:09:14 - INFO - __main__ -     epoch = 9
05/13/2024 14:09:14 - INFO - __main__ -     eval_ppl = 1.0366
05/13/2024 14:09:14 - INFO - __main__ -     global_step = 6440
05/13/2024 14:09:14 - INFO - __main__ -     ********************
05/13/2024 14:09:15 - INFO - __main__ -   Save the last model into sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-last/pytorch_model.bin
05/13/2024 14:09:15 - INFO - __main__ -   Ppl does not decrease for 

ngram match: 0.7816054146299463, weighted ngram match: 0.7861170532459663, syntax_match: 0.8997836434443963, dataflow_match: 0.8916739319965127


[10] Train loss 0.007: 100%|██████████████████████████████████████████████████████████████████████████████████████| 644/644 [05:27<00:00,  1.97it/s]
05/13/2024 14:17:52 - INFO - __main__ -     ***** Running ppl evaluation *****
05/13/2024 14:17:52 - INFO - __main__ -     Num examples = 499
05/13/2024 14:17:52 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.04it/s]
05/13/2024 14:17:57 - INFO - __main__ -     epoch = 10
05/13/2024 14:17:57 - INFO - __main__ -     eval_ppl = 1.03884
05/13/2024 14:17:57 - INFO - __main__ -     global_step = 7084
05/13/2024 14:17:57 - INFO - __main__ -     ********************
05/13/2024 14:17:58 - INFO - __main__ -   Save the last model into sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-last/pytorch_model.bin
05/13/2024 14:17:58 - INFO - __main__ -   Ppl does not decrease fo

ngram match: 0.791960713520326, weighted ngram match: 0.7978910853341352, syntax_match: 0.903331890956296, dataflow_match: 0.8984306887532694


05/13/2024 14:21:05 - INFO - __main__ -   Save the best bleu model into sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin
05/13/2024 14:21:05 - INFO - __main__ -   ***** CUDA.empty_cache() *****
[11] Train loss 0.006: 100%|██████████████████████████████████████████████████████████████████████████████████████| 644/644 [05:27<00:00,  1.97it/s]
05/13/2024 14:26:32 - INFO - __main__ -     ***** Running ppl evaluation *****
05/13/2024 14:26:32 - INFO - __main__ -     Num examples = 499
05/13/2024 14:26:32 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.04it/s]
05/13/2024 14:26:38 - INFO - __main__ -     epoch = 11
05/13/2024 14:26:38 - INFO - __main__ -     eval_ppl = 1.03865
05/13/2024 14:26:38 - INFO - __main__ -     global_step = 7728
05/13/2024 14:26:38 - INFO - __main__ -     ****

ngram match: 0.7901473952241362, weighted ngram match: 0.7984051181416916, syntax_match: 0.8999567286888793, dataflow_match: 0.9040976460331299


[12] Train loss 0.005: 100%|██████████████████████████████████████████████████████████████████████████████████████| 644/644 [05:27<00:00,  1.97it/s]
05/13/2024 14:35:07 - INFO - __main__ -     ***** Running ppl evaluation *****
05/13/2024 14:35:07 - INFO - __main__ -     Num examples = 499
05/13/2024 14:35:07 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.03it/s]
05/13/2024 14:35:12 - INFO - __main__ -     epoch = 12
05/13/2024 14:35:12 - INFO - __main__ -     eval_ppl = 1.03892
05/13/2024 14:35:12 - INFO - __main__ -     global_step = 8372
05/13/2024 14:35:12 - INFO - __main__ -     ********************
05/13/2024 14:35:13 - INFO - __main__ -   Save the last model into sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-last/pytorch_model.bin
05/13/2024 14:35:13 - INFO - __main__ -   Ppl does not decrease fo

ngram match: 0.7924721221980376, weighted ngram match: 0.7981688962035566, syntax_match: 0.8960623106880139, dataflow_match: 0.8894943330427202


05/13/2024 14:38:26 - INFO - __main__ -   Save the best bleu model into sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin
05/13/2024 14:38:26 - INFO - __main__ -   ***** CUDA.empty_cache() *****
[13] Train loss 0.005: 100%|██████████████████████████████████████████████████████████████████████████████████████| 644/644 [05:27<00:00,  1.97it/s]
05/13/2024 14:43:54 - INFO - __main__ -     ***** Running ppl evaluation *****
05/13/2024 14:43:54 - INFO - __main__ -     Num examples = 499
05/13/2024 14:43:54 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.03it/s]
05/13/2024 14:43:59 - INFO - __main__ -     epoch = 13
05/13/2024 14:43:59 - INFO - __main__ -     eval_ppl = 1.03949
05/13/2024 14:43:59 - INFO - __main__ -     global_step = 9016
05/13/2024 14:43:59 - INFO - __main__ -     ****

ngram match: 0.7935825838606049, weighted ngram match: 0.7995637868413059, syntax_match: 0.8950237992211164, dataflow_match: 0.8908020924149956


[14] Train loss 0.004: 100%|██████████████████████████████████████████████████████████████████████████████████████| 644/644 [05:27<00:00,  1.97it/s]
05/13/2024 14:52:29 - INFO - __main__ -     ***** Running ppl evaluation *****
05/13/2024 14:52:29 - INFO - __main__ -     Num examples = 499
05/13/2024 14:52:29 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.04it/s]
05/13/2024 14:52:35 - INFO - __main__ -     epoch = 14
05/13/2024 14:52:35 - INFO - __main__ -     eval_ppl = 1.04101
05/13/2024 14:52:35 - INFO - __main__ -     global_step = 9660
05/13/2024 14:52:35 - INFO - __main__ -     ********************
05/13/2024 14:52:35 - INFO - __main__ -   Save the last model into sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-last/pytorch_model.bin
05/13/2024 14:52:35 - INFO - __main__ -   Ppl does not decrease fo

ngram match: 0.7978330652346486, weighted ngram match: 0.8056532019343771, syntax_match: 0.9010817827780182, dataflow_match: 0.8962510897994769


[15] Train loss 0.004: 100%|██████████████████████████████████████████████████████████████████████████████████████| 644/644 [05:27<00:00,  1.97it/s]
05/13/2024 15:01:10 - INFO - __main__ -     ***** Running ppl evaluation *****
05/13/2024 15:01:10 - INFO - __main__ -     Num examples = 499
05/13/2024 15:01:10 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.03it/s]
05/13/2024 15:01:15 - INFO - __main__ -     epoch = 15
05/13/2024 15:01:15 - INFO - __main__ -     eval_ppl = 1.04158
05/13/2024 15:01:15 - INFO - __main__ -     global_step = 10304
05/13/2024 15:01:15 - INFO - __main__ -     ********************
05/13/2024 15:01:16 - INFO - __main__ -   Save the last model into sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-last/pytorch_model.bin
05/13/2024 15:01:16 - INFO - __main__ -   Ppl does not decrease f

ngram match: 0.7849812824163174, weighted ngram match: 0.7894414247304992, syntax_match: 0.898745131977499, dataflow_match: 0.8827375762859634


[16] Train loss 0.004: 100%|██████████████████████████████████████████████████████████████████████████████████████| 644/644 [05:27<00:00,  1.97it/s]
05/13/2024 15:09:39 - INFO - __main__ -     ***** Running ppl evaluation *****
05/13/2024 15:09:39 - INFO - __main__ -     Num examples = 499
05/13/2024 15:09:39 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.03it/s]
05/13/2024 15:09:44 - INFO - __main__ -     epoch = 16
05/13/2024 15:09:44 - INFO - __main__ -     eval_ppl = 1.04026
05/13/2024 15:09:44 - INFO - __main__ -     global_step = 10948
05/13/2024 15:09:44 - INFO - __main__ -     ********************
05/13/2024 15:09:45 - INFO - __main__ -   Save the last model into sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-last/pytorch_model.bin
05/13/2024 15:09:45 - INFO - __main__ -   Ppl does not decrease f

ngram match: 0.7878155188817185, weighted ngram match: 0.7921267247250896, syntax_match: 0.8937256598874946, dataflow_match: 0.8864428945074107


[17] Train loss 0.003: 100%|██████████████████████████████████████████████████████████████████████████████████████| 644/644 [05:27<00:00,  1.97it/s]
05/13/2024 15:18:18 - INFO - __main__ -     ***** Running ppl evaluation *****
05/13/2024 15:18:18 - INFO - __main__ -     Num examples = 499
05/13/2024 15:18:18 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.04it/s]
05/13/2024 15:18:23 - INFO - __main__ -     epoch = 17
05/13/2024 15:18:23 - INFO - __main__ -     eval_ppl = 1.04261
05/13/2024 15:18:23 - INFO - __main__ -     global_step = 11592
05/13/2024 15:18:23 - INFO - __main__ -     ********************
05/13/2024 15:18:24 - INFO - __main__ -   Save the last model into sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-last/pytorch_model.bin
05/13/2024 15:18:24 - INFO - __main__ -   Ppl does not decrease f

ngram match: 0.7881985513676216, weighted ngram match: 0.7927758468749043, syntax_match: 0.8959757680657724, dataflow_match: 0.9049694856146469


[18] Train loss 0.003: 100%|██████████████████████████████████████████████████████████████████████████████████████| 644/644 [05:27<00:00,  1.97it/s]
05/13/2024 15:26:46 - INFO - __main__ -     ***** Running ppl evaluation *****
05/13/2024 15:26:46 - INFO - __main__ -     Num examples = 499
05/13/2024 15:26:46 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.04it/s]
05/13/2024 15:26:52 - INFO - __main__ -     epoch = 18
05/13/2024 15:26:52 - INFO - __main__ -     eval_ppl = 1.04231
05/13/2024 15:26:52 - INFO - __main__ -     global_step = 12236
05/13/2024 15:26:52 - INFO - __main__ -     ********************
05/13/2024 15:26:53 - INFO - __main__ -   Save the last model into sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-last/pytorch_model.bin
05/13/2024 15:26:53 - INFO - __main__ -   Ppl does not decrease f

ngram match: 0.7840599602002781, weighted ngram match: 0.7894498714684238, syntax_match: 0.9013414106447425, dataflow_match: 0.8997384481255449


05/13/2024 15:29:51 - INFO - utils -   Read 1000 examples, avg src len: 14, avg trg len: 13, max src len: 94, max trg len: 98
05/13/2024 15:29:51 - INFO - utils -   Load cache data from sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/cache_data/test_src_all.pt
05/13/2024 15:29:51 - INFO - __main__ -     ***** Running bleu evaluation on test data*****
05/13/2024 15:29:51 - INFO - __main__ -     Num examples = 1000
05/13/2024 15:29:51 - INFO - __main__ -     Batch size = 16
Eval bleu for test set: 100%|███████████████████████████████████████████████████████████████████████████████████████| 63/63 [04:51<00:00,  4.62s/it]
05/13/2024 15:34:52 - INFO - __main__ -   ***** Eval results *****
05/13/2024 15:34:52 - INFO - __main__ -     bleu = 79.08
05/13/2024 15:34:52 - INFO - __main__ -     codebleu = 84.8959
05/13/2024 15:34:52 - INFO - __main__ -     em = 66.6
05/13/2024 15:34:52 - INFO - __main__ -   [best-bleu] bleu-4: 79.08, em: 66.6000, codebleu: 84

ngram match: 0.7907789938454809, weighted ngram match: 0.799421932584435, syntax_match: 0.9048335527553132, dataflow_match: 0.9007996001999


Process ForkPoolWorker-5:
Process ForkPoolWorker-13:
Process ForkPoolWorker-12:


In [14]:
# inference

logger.info("  " + "***** Testing *****")
logger.info("  Batch size = %d", args.eval_batch_size)
config, model, tokenizer = build_or_load_gen_model(args)
model.to(args.device)

for criteria in ['best-bleu']:
    file = 'sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin' # os.path.join(args.output_dir, 'checkpoint-{}/pytorch_model.bin'.format(criteria))
    print(file)
    logger.info("Reload model from {}".format(file))
    model.load_state_dict(torch.load(file))
    eval_examples, eval_data = load_and_cache_gen_data(args, args.test_filename, pool, tokenizer, 'test',
                                                       only_src=True, is_sample=False)
    result = eval_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, 'test', criteria)
    test_bleu, test_em = result['bleu'], result['em']
    test_codebleu = result['codebleu'] if 'codebleu' in result else 0
    result_str = "[%s] bleu-4: %.2f, em: %.4f, codebleu: %.4f\n" % (criteria, test_bleu, test_em, test_codebleu)
    logger.info(result_str)
    
logger.info("Finish and take {}".format(get_elapse_time(t0)))

05/13/2024 16:55:29 - INFO - __main__ -     ***** Testing *****
05/13/2024 16:55:29 - INFO - __main__ -     Batch size = 16


args.model_type:  codet5


05/13/2024 16:55:30 - INFO - models -   Finish loading model [223M] from Salesforce/codet5-base
05/13/2024 16:55:30 - INFO - __main__ -   Reload model from sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin
05/13/2024 16:55:31 - INFO - utils -   Read 1000 examples, avg src len: 14, avg trg len: 13, max src len: 94, max trg len: 98
05/13/2024 16:55:31 - INFO - utils -   Load cache data from sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/cache_data/test_src_all.pt
05/13/2024 16:55:31 - INFO - __main__ -     ***** Running bleu evaluation on test data*****
05/13/2024 16:55:31 - INFO - __main__ -     Num examples = 1000
05/13/2024 16:55:31 - INFO - __main__ -     Batch size = 16


sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin


Eval bleu for test set: 100%|███████████████████████████████████████████████████████████████████████████████████████| 63/63 [04:51<00:00,  4.63s/it]
05/13/2024 17:00:32 - INFO - __main__ -   ***** Eval results *****
05/13/2024 17:00:32 - INFO - __main__ -     bleu = 79.08
05/13/2024 17:00:32 - INFO - __main__ -     codebleu = 84.8959
05/13/2024 17:00:32 - INFO - __main__ -     em = 66.6
05/13/2024 17:00:32 - INFO - __main__ -   [best-bleu] bleu-4: 79.08, em: 66.6000, codebleu: 84.8959

05/13/2024 17:00:32 - INFO - __main__ -   Finish and take 4h14m


ngram match: 0.7907789938454809, weighted ngram match: 0.799421932584435, syntax_match: 0.9048335527553132, dataflow_match: 0.9007996001999
