In [12]:
import torch
import torch.nn as nn
import numpy as np
import random
import os
import logging
import argparse
import math
import numpy as np
from tqdm import tqdm
import multiprocessing
import time
from dotmap import DotMap
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import T5Config, T5ForConditionalGeneration, RobertaTokenizer

In [13]:
from models import build_or_load_gen_model
from evaluator import smooth_bleu
from evaluator.CodeBLEU import calc_code_bleu
from evaluator.bleu import _bleu
from utils import get_filenames, get_elapse_time, load_and_cache_gen_data
from configs import set_seed, set_dist

In [3]:
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)


def eval_ppl_epoch(args, eval_data, eval_examples, model, tokenizer):
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size,
                                 num_workers=4, pin_memory=True)
    # Start evaluating model
    logger.info("  " + "***** Running ppl evaluation *****")
    logger.info("  Num examples = %d", len(eval_examples))
    logger.info("  Batch size = %d", args.eval_batch_size)

    model.eval()
    eval_loss, batch_num = 0, 0
    for batch in tqdm(eval_dataloader, total=len(eval_dataloader), desc="Eval ppl"):
        batch = tuple(t.to(args.device) for t in batch)
        source_ids, target_ids = batch
        source_mask = source_ids.ne(tokenizer.pad_token_id)
        target_mask = target_ids.ne(tokenizer.pad_token_id)

        with torch.no_grad():
            if args.model_type == 'roberta':
                loss, _, _ = model(source_ids=source_ids, source_mask=source_mask,
                                   target_ids=target_ids, target_mask=target_mask)
            else:
                outputs = model(input_ids=source_ids, attention_mask=source_mask,
                                labels=target_ids, decoder_attention_mask=target_mask)
                loss = outputs.loss

        eval_loss += loss.item()
        batch_num += 1
    eval_loss = eval_loss / batch_num
    eval_ppl = round(np.exp(eval_loss), 5)
    return eval_ppl

In [4]:
def eval_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, split_tag, criteria):
    logger.info("  ***** Running bleu evaluation on {} data*****".format(split_tag))
    logger.info("  Num examples = %d", len(eval_examples))
    logger.info("  Batch size = %d", args.eval_batch_size)
    eval_sampler = SequentialSampler(eval_data)
    if args.data_num == -1:
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size,
                                     num_workers=4, pin_memory=True)
    else:
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

    model.eval()
    pred_ids = []
    bleu, codebleu = 0.0, 0.0
    for batch in tqdm(eval_dataloader, total=len(eval_dataloader), desc="Eval bleu for {} set".format(split_tag)):
        source_ids = batch[0].to(args.device)
        source_mask = source_ids.ne(tokenizer.pad_token_id)
        with torch.no_grad():
            if args.model_type == 'roberta':
                preds = model(source_ids=source_ids, source_mask=source_mask)

                top_preds = [pred[0].cpu().numpy() for pred in preds]
            else:
                preds = model.generate(source_ids,
                                       attention_mask=source_mask,
                                       use_cache=True,
                                       num_beams=args.beam_size,
                                       early_stopping=args.task == 'summarize',
                                       max_length=args.max_target_length)
                top_preds = list(preds.cpu().numpy())
            pred_ids.extend(top_preds)

    pred_nls = [tokenizer.decode(id, skip_special_tokens=True, clean_up_tokenization_spaces=False) for id in pred_ids]

    output_fn = os.path.join(args.res_dir, "test_{}.output".format(criteria))
    gold_fn = os.path.join(args.res_dir, "test_{}.gold".format(criteria))
    src_fn = os.path.join(args.res_dir, "test_{}.src".format(criteria))

    if args.task in ['defect']:
        target_dict = {0: 'false', 1: 'true'}
        golds = [target_dict[ex.target] for ex in eval_examples]
        eval_acc = np.mean([int(p == g) for p, g in zip(pred_nls, golds)])
        result = {'em': eval_acc * 100, 'bleu': 0, 'codebleu': 0}

        with open(output_fn, 'w') as f, open(gold_fn, 'w') as f1, open(src_fn, 'w') as f2:
            for pred_nl, gold in zip(pred_nls, eval_examples):
                f.write(pred_nl.strip() + '\n')
                f1.write(target_dict[gold.target] + '\n')
                f2.write(gold.source.strip() + '\n')
            logger.info("Save the predictions into %s", output_fn)
    else:
        dev_accs, predictions = [], []
        with open(output_fn, 'w') as f, open(gold_fn, 'w') as f1, open(src_fn, 'w') as f2:
            for pred_nl, gold in zip(pred_nls, eval_examples):
                dev_accs.append(pred_nl.strip() == gold.target.strip())
                if args.task in ['summarize']:
                    # for smooth-bleu4 evaluation
                    predictions.append(str(gold.idx) + '\t' + pred_nl)
                    f.write(str(gold.idx) + '\t' + pred_nl.strip() + '\n')
                    f1.write(str(gold.idx) + '\t' + gold.target.strip() + '\n')
                    f2.write(str(gold.idx) + '\t' + gold.source.strip() + '\n')
                else:
                    f.write(pred_nl.strip() + '\n')
                    f1.write(gold.target.strip() + '\n')
                    f2.write(gold.source.strip() + '\n')

        if args.task == 'summarize':
            (goldMap, predictionMap) = smooth_bleu.computeMaps(predictions, gold_fn)
            bleu = round(smooth_bleu.bleuFromMaps(goldMap, predictionMap)[0], 2)
        else:
            bleu = round(_bleu(gold_fn, output_fn), 2)
            # if args.task in ['concode', 'translate', 'refine']:
            #    codebleu = calc_code_bleu.get_codebleu(gold_fn, output_fn, args.lang)

        result = {'em': np.mean(dev_accs) * 100, 'bleu': bleu}
        # result['codebleu'] = codebleu * 100

    logger.info("***** Eval results *****")
    for key in sorted(result.keys()):
        logger.info("  %s = %s", key, str(round(result[key], 4)))

    return result

In [5]:
# !pip install dotmap


In [6]:
from dotmap import DotMap

In [7]:
seed_ = 1234
random.seed(seed_)
np.random.seed(seed_)
torch.manual_seed(seed_)
torch.cuda.manual_seed_all(seed_)

In [8]:
WORKDIR="/home/okozlova/diplom_oksana/CodeT5/CodeT5"
MODEL_TAG = 'codet5p-220m'
MODEL_DIR = 'sh/saved_models'
TASK = 'translate'
SUB_TASK = 'cs-java'
DATA_TAG = 'all'
LR = 5e-5
BS = 16
SRC_LEN = 320
TRG_LEN = 256
PATIENCE = 5
EPOCH = 100
FULL_MODEL_TAG=f'{MODEL_TAG}_{DATA_TAG}_lr{LR}_bs{BS}_src{SRC_LEN}_trg{TRG_LEN}_pat{PATIENCE}_e{EPOCH}'
OUTPUT_DIR=f'{MODEL_DIR}/{TASK}/{SUB_TASK}/{FULL_MODEL_TAG}'
CACHE_DIR=f'{OUTPUT_DIR}/cache_data'
RES_DIR=f'{OUTPUT_DIR}/prediction'
LOG=f'{OUTPUT_DIR}/train.log'
DATA_DIR = '/home/okozlova/diplom_oksana/CodeT5/CodeT5/data/translate/'
args_dict = {
    "task": 'translate',
    "sub_task": SUB_TASK,
    "lang": 'java', # 'c_sharp',
    "model_tag": MODEL_TAG,
    "res_dir": 'sh/results',
    "model_dir": 'sh/saved_models',
    "summary_dir": 'tensorboard',
    "data_num": -1,
    "gpu": 0,
    "do_train": True,
    "do_eval": True,
    "do_eval_bleu": True,
    "do_test": True,
    "model_type": 'codet5',
    "num_train_epochs": EPOCH, # 100
    "warmup_steps": 1000,
    "learning_rate": LR,
    "patience": PATIENCE, 
    "tokenizer_name": 'Salesforce/codet5p-220m',
    "model_name_or_path": 'Salesforce/codet5p-220m',
    'data_dir': WORKDIR + '/data',
    'cache_path': CACHE_DIR,
    'output_dir': OUTPUT_DIR,
    'summary_dir': 'tensorboard',
    'save_last_checkpoints': True,
    'always_save_model': True,
    'res_dir': RES_DIR,
    'train_batch_size': BS,
    'eval_batch_size': BS, 
    'max_source_length': SRC_LEN,
    'max_target_length': TRG_LEN,
    'seed': 1234,
    'local_rank': 0,
    'no_cuda': False,
    'n_gpu': 1,
    'device': torch.device("cuda", 0),
    'train_filename': f'{DATA_DIR}/train.java-cs.txt.cs,{DATA_DIR}/train.java-cs.txt.java',
    'dev_filename': f'{DATA_DIR}/valid.java-cs.txt.cs,{DATA_DIR}/valid.java-cs.txt.java',
    'test_filename': f'{DATA_DIR}/test.java-cs.txt.cs,{DATA_DIR}/test.java-cs.txt.java',
    'weight_decay': 0.0,
    'adam_epsilon': 1e-8,
    'gradient_accumulation_steps': 1, # Number of updates steps to accumulate before performing a backward/update pass
    'load_model_path': None,
    'cpu_cont': 16,
    'beam_size': 10
    
}


In [9]:
args = DotMap(args_dict)

In [10]:
logger.info(args)
t0 = time.time()

config, model, tokenizer = build_or_load_gen_model(args)
file = 'sh/saved_models/translate/cs-java/codet5p-220m_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-decompose/pytorch_model.bin' #os.path.join(args.output_dir, 'checkpoint-{}/pytorch_model.bin'.format(criteria))
print(file)
logger.info("Reload model from {}".format(file))
model.load_state_dict(torch.load(file))
model.to(args.device)
if args.n_gpu > 1:
    model = torch.nn.DataParallel(model)
pool = multiprocessing.Pool(args.cpu_cont)
fa = open(os.path.join(args.output_dir, 'summary.log'), 'a+')

if args.do_train:
    if args.local_rank in [-1, 0] and args.data_num == -1:
        summary_fn = '{}/{}'.format(args.summary_dir, '/'.join(args.output_dir.split('/')[1:]))
        tb_writer = SummaryWriter(summary_fn)

    # Prepare training data loader
    train_examples, train_data = load_and_cache_gen_data(args, args.train_filename, pool, tokenizer, 'train')
    train_sampler = RandomSampler(train_data) 
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size,
                                  num_workers=4, pin_memory=True)

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    num_train_optimization_steps = args.num_train_epochs * len(train_dataloader)
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=args.warmup_steps,
                                                num_training_steps=num_train_optimization_steps)

    # Start training
    train_example_num = len(train_data)
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", train_example_num)
    logger.info("  Batch size = %d", args.train_batch_size)
    logger.info("  Batch num = %d", math.ceil(train_example_num / args.train_batch_size))
    logger.info("  Num epoch = %d", args.num_train_epochs)

    dev_dataset = {}
    global_step, best_bleu_em, best_ppl = 0, -1, 1e6
    not_loss_dec_cnt, not_bleu_em_inc_cnt = 0, 0 if args.do_eval_bleu else 1e6

    for cur_epoch in range(0, int(args.num_train_epochs)):
        bar = tqdm(train_dataloader, total=len(train_dataloader), desc="Training")
        nb_tr_examples, nb_tr_steps, tr_loss = 0, 0, 0
        model.train()
        for step, batch in enumerate(bar):
            batch = tuple(t.to(args.device) for t in batch)
            source_ids, target_ids = batch
            source_mask = source_ids.ne(tokenizer.pad_token_id)
            target_mask = target_ids.ne(tokenizer.pad_token_id)

            if args.model_type == 'roberta':
                loss, _, _ = model(source_ids=source_ids, source_mask=source_mask,
                                   target_ids=target_ids, target_mask=target_mask)
            else:
                outputs = model(input_ids=source_ids, attention_mask=source_mask,
                                labels=target_ids, decoder_attention_mask=target_mask)
                loss = outputs.loss

            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            tr_loss += loss.item()

            nb_tr_examples += source_ids.size(0)
            nb_tr_steps += 1
            loss.backward()

            if nb_tr_steps % args.gradient_accumulation_steps == 0:
                # Update parameters
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()
                global_step += 1
                train_loss = round(tr_loss * args.gradient_accumulation_steps / (nb_tr_steps + 1), 4)
                bar.set_description("[{}] Train loss {}".format(cur_epoch, round(train_loss, 3)))

        if args.do_eval:
            # Eval model with dev dataset
            if 'dev_loss' in dev_dataset:
                eval_examples, eval_data = dev_dataset['dev_loss']
            else:
                eval_examples, eval_data = load_and_cache_gen_data(args, args.dev_filename, pool, tokenizer, 'dev')
                dev_dataset['dev_loss'] = eval_examples, eval_data

            eval_ppl = eval_ppl_epoch(args, eval_data, eval_examples, model, tokenizer)
            result = {'epoch': cur_epoch, 'global_step': global_step, 'eval_ppl': eval_ppl}
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
            logger.info("  " + "*" * 20)
            if args.data_num == -1:
                tb_writer.add_scalar('dev_ppl', eval_ppl, cur_epoch)

            # save last checkpoint
            if args.save_last_checkpoints:
                last_output_dir = os.path.join(args.output_dir, 'checkpoint-last')
                if not os.path.exists(last_output_dir):
                    os.makedirs(last_output_dir)
                model_to_save = model.module if hasattr(model, 'module') else model
                output_model_file = os.path.join(last_output_dir, "pytorch_model.bin")
                torch.save(model_to_save.state_dict(), output_model_file)
                logger.info("Save the last model into %s", output_model_file)

            if eval_ppl < best_ppl:
                not_loss_dec_cnt = 0
                logger.info("  Best ppl:%s", eval_ppl)
                logger.info("  " + "*" * 20)
                fa.write("[%d] Best ppl changed into %.4f\n" % (cur_epoch, eval_ppl))
                best_ppl = eval_ppl

                # Save best checkpoint for best ppl
                output_dir = os.path.join(args.output_dir, 'checkpoint-best-ppl')
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                if args.always_save_model:
                    model_to_save = model.module if hasattr(model, 'module') else model
                    output_model_file = os.path.join(output_dir, "pytorch_model.bin")
                    torch.save(model_to_save.state_dict(), output_model_file)
                    logger.info("Save the best ppl model into %s", output_model_file)
            else:
                not_loss_dec_cnt += 1
                logger.info("Ppl does not decrease for %d epochs", not_loss_dec_cnt)
                if all([x > args.patience for x in [not_bleu_em_inc_cnt, not_loss_dec_cnt]]):
                    early_stop_str = "[%d] Early stop as not_bleu_em_inc_cnt=%d, and not_loss_dec_cnt=%d\n" % (
                        cur_epoch, not_bleu_em_inc_cnt, not_loss_dec_cnt)
                    logger.info(early_stop_str)
                    fa.write(early_stop_str)
                    break
            logger.info("***** CUDA.empty_cache() *****")
            torch.cuda.empty_cache()
            if args.do_eval_bleu:
                eval_examples, eval_data = load_and_cache_gen_data(args, args.dev_filename, pool, tokenizer, 'dev',
                                                                   only_src=True, is_sample=True)

                result = eval_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, 'dev', 'e%d' % cur_epoch)
                dev_bleu, dev_em = result['bleu'], result['em']
                if args.task in ['summarize']:
                    dev_bleu_em = dev_bleu
                elif args.task in ['defect']:
                    dev_bleu_em = dev_em
                else:
                    dev_bleu_em = dev_bleu + dev_em
                if args.data_num == -1:
                    tb_writer.add_scalar('dev_bleu_em', dev_bleu_em, cur_epoch)
                    # tb_writer.add_scalar('dev_em', dev_em, cur_epoch)
                if dev_bleu_em > best_bleu_em:
                    not_bleu_em_inc_cnt = 0
                    logger.info("  [%d] Best bleu+em: %.2f (bleu: %.2f, em: %.2f)",
                                cur_epoch, dev_bleu_em, dev_bleu, dev_em)
                    logger.info("  " + "*" * 20)
                    best_bleu_em = dev_bleu_em
                    fa.write("[%d] Best bleu+em changed into %.2f (bleu: %.2f, em: %.2f)\n" % (
                        cur_epoch, best_bleu_em, dev_bleu, dev_em))
                    # Save best checkpoint for best bleu
                    output_dir = os.path.join(args.output_dir, 'checkpoint-best-bleu')
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    if args.data_num == -1 or args.always_save_model:
                        model_to_save = model.module if hasattr(model, 'module') else model
                        output_model_file = os.path.join(output_dir, "pytorch_model.bin")
                        torch.save(model_to_save.state_dict(), output_model_file)
                        logger.info("Save the best bleu model into %s", output_model_file)
                else:
                    not_bleu_em_inc_cnt += 1
                    logger.info("Bleu does not increase for %d epochs", not_bleu_em_inc_cnt)
                    fa.write(
                        "[%d] Best bleu+em (%.2f) does not drop changed for %d epochs, cur bleu+em: %.2f (bleu: %.2f, em: %.2f)\n" % (
                            cur_epoch, best_bleu_em, not_bleu_em_inc_cnt, dev_bleu_em, dev_bleu, dev_em))
                    if all([x > args.patience for x in [not_bleu_em_inc_cnt, not_loss_dec_cnt]]):
                        stop_early_str = "[%d] Early stop as not_bleu_em_inc_cnt=%d, and not_loss_dec_cnt=%d\n" % (
                            cur_epoch, not_bleu_em_inc_cnt, not_loss_dec_cnt)
                        logger.info(stop_early_str)
                        fa.write(stop_early_str)
                        break
        logger.info("***** CUDA.empty_cache() *****")
        torch.cuda.empty_cache()

    if args.local_rank in [-1, 0] and args.data_num == -1:
        tb_writer.close()
    logger.info("Finish training and take %s", get_elapse_time(t0))

if args.do_test:
    logger.info("  " + "***** Testing *****")
    logger.info("  Batch size = %d", args.eval_batch_size)

    for criteria in ['best-bleu']:
        file = os.path.join(args.output_dir, 'checkpoint-{}/pytorch_model.bin'.format(criteria))
        logger.info("Reload model from {}".format(file))
        model.load_state_dict(torch.load(file))
        eval_examples, eval_data = load_and_cache_gen_data(args, args.test_filename, pool, tokenizer, 'test',
                                                           only_src=True, is_sample=False)
        result = eval_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, 'test', criteria)
        test_bleu, test_em = result['bleu'], result['em']
        test_codebleu = result['codebleu'] if 'codebleu' in result else 0
        result_str = "[%s] bleu-4: %.2f, em: %.4f, codebleu: %.4f\n" % (criteria, test_bleu, test_em, test_codebleu)
        logger.info(result_str)
        fa.write(result_str)
        if args.res_fn:
            with open(args.res_fn, 'a+') as f:
                f.write('[Time: {}] {}\n'.format(get_elapse_time(t0), file))
                f.write(result_str)
logger.info("Finish and take {}".format(get_elapse_time(t0)))
fa.write("Finish and take {}".format(get_elapse_time(t0)))
fa.close()


05/24/2024 12:59:16 - INFO - __main__ -   DotMap(task='translate', sub_task='cs-java', lang='java', model_tag='codet5p-220m', res_dir='sh/saved_models/translate/cs-java/codet5p-220m_all_lr5e-05_bs16_src320_trg256_pat5_e100/prediction', model_dir='sh/saved_models', summary_dir='tensorboard', data_num=-1, gpu=0, do_train=True, do_eval=True, do_eval_bleu=True, do_test=True, model_type='codet5', num_train_epochs=100, warmup_steps=1000, learning_rate=5e-05, patience=5, tokenizer_name='Salesforce/codet5p-220m', model_name_or_path='Salesforce/codet5p-220m', data_dir='/home/mzhelezin/diplom_oksana/CodeT5/CodeT5/data', cache_path='sh/saved_models/translate/cs-java/codet5p-220m_all_lr5e-05_bs16_src320_trg256_pat5_e100/cache_data', output_dir='sh/saved_models/translate/cs-java/codet5p-220m_all_lr5e-05_bs16_src320_trg256_pat5_e100', save_last_checkpoints=True, always_save_model=True, train_batch_size=16, eval_batch_size=16, max_source_length=320, max_target_length=256, seed=1234, local_rank=0, no_

args.model_type:  codet5


05/24/2024 12:59:18 - INFO - models -   Finish loading model [223M] from Salesforce/codet5p-220m
05/24/2024 12:59:18 - INFO - __main__ -   Reload model from sh/saved_models/translate/cs-java/codet5p-220m_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-decompose/pytorch_model.bin


sh/saved_models/translate/cs-java/codet5p-220m_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-decompose/pytorch_model.bin


05/24/2024 12:59:24 - INFO - utils -   Read 10295 examples, avg src len: 15, avg trg len: 13, max src len: 118, max trg len: 136
05/24/2024 12:59:24 - INFO - utils -   [TOKENIZE] avg src len: 56, avg trg len: 45, max src len: 404, max trg len: 391
05/24/2024 12:59:24 - INFO - utils -   Load cache data from sh/saved_models/translate/cs-java/codet5p-220m_all_lr5e-05_bs16_src320_trg256_pat5_e100/cache_data/train_all.pt
05/24/2024 12:59:24 - INFO - __main__ -   ***** Running training *****
05/24/2024 12:59:24 - INFO - __main__ -     Num examples = 10295
05/24/2024 12:59:24 - INFO - __main__ -     Batch size = 16
05/24/2024 12:59:24 - INFO - __main__ -     Batch num = 644
05/24/2024 12:59:24 - INFO - __main__ -     Num epoch = 100
[0] Train loss 0.021: 100%|██████████████████████████████████████████████████████████████████████████████████| 644/644 [05:24<00:00,  1.98it/s]
05/24/2024 13:04:49 - INFO - utils -   Read 499 examples, avg src len: 16, avg trg len: 15, max src len: 99, max trg len

ngram match: 0.7678423301186774, weighted ngram match: 0.7746447792035402, syntax_match: 0.8942449156209433, dataflow_match: 0.9019180470793374


05/24/2024 13:07:55 - INFO - __main__ -   Save the best bleu model into sh/saved_models/translate/cs-java/codet5p-220m_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin
05/24/2024 13:07:55 - INFO - __main__ -   ***** CUDA.empty_cache() *****
[1] Train loss 0.006: 100%|██████████████████████████████████████████████████████████████████████████████████| 644/644 [05:26<00:00,  1.97it/s]
05/24/2024 13:13:23 - INFO - __main__ -     ***** Running ppl evaluation *****
05/24/2024 13:13:23 - INFO - __main__ -     Num examples = 499
05/24/2024 13:13:23 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.04it/s]
05/24/2024 13:13:28 - INFO - __main__ -     epoch = 1
05/24/2024 13:13:28 - INFO - __main__ -     eval_ppl = 1.03558
05/24/2024 13:13:28 - INFO - __main__ -     global_step = 1288
05/24/2024 13:13:28 - INFO - __main__ -     **************

ngram match: 0.7730145884303201, weighted ngram match: 0.7956025378623895, syntax_match: 0.8984855041107745, dataflow_match: 0.8836094158674804


05/24/2024 13:16:46 - INFO - __main__ -   Save the best bleu model into sh/saved_models/translate/cs-java/codet5p-220m_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin
05/24/2024 13:16:46 - INFO - __main__ -   ***** CUDA.empty_cache() *****
[2] Train loss 0.004: 100%|██████████████████████████████████████████████████████████████████████████████████| 644/644 [05:26<00:00,  1.97it/s]
05/24/2024 13:22:13 - INFO - __main__ -     ***** Running ppl evaluation *****
05/24/2024 13:22:13 - INFO - __main__ -     Num examples = 499
05/24/2024 13:22:13 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.04it/s]
05/24/2024 13:22:18 - INFO - __main__ -     epoch = 2
05/24/2024 13:22:18 - INFO - __main__ -     eval_ppl = 1.03604
05/24/2024 13:22:18 - INFO - __main__ -     global_step = 1932
05/24/2024 13:22:18 - INFO - __main__ -     **************

ngram match: 0.7943773747408851, weighted ngram match: 0.7992650595224503, syntax_match: 0.9016010385114669, dataflow_match: 0.8990845684394071


05/24/2024 13:25:31 - INFO - __main__ -   Save the best bleu model into sh/saved_models/translate/cs-java/codet5p-220m_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin
05/24/2024 13:25:31 - INFO - __main__ -   ***** CUDA.empty_cache() *****
[3] Train loss 0.004: 100%|██████████████████████████████████████████████████████████████████████████████████| 644/644 [05:26<00:00,  1.97it/s]
05/24/2024 13:30:58 - INFO - __main__ -     ***** Running ppl evaluation *****
05/24/2024 13:30:58 - INFO - __main__ -     Num examples = 499
05/24/2024 13:30:58 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.04it/s]
05/24/2024 13:31:03 - INFO - __main__ -     epoch = 3
05/24/2024 13:31:03 - INFO - __main__ -     eval_ppl = 1.03743
05/24/2024 13:31:03 - INFO - __main__ -     global_step = 2576
05/24/2024 13:31:03 - INFO - __main__ -     **************

ngram match: 0.7734674821393224, weighted ngram match: 0.784006890371543, syntax_match: 0.8888792730419731, dataflow_match: 0.8761987794245859


[4] Train loss 0.003: 100%|██████████████████████████████████████████████████████████████████████████████████| 644/644 [05:26<00:00,  1.97it/s]
05/24/2024 13:39:47 - INFO - __main__ -     ***** Running ppl evaluation *****
05/24/2024 13:39:47 - INFO - __main__ -     Num examples = 499
05/24/2024 13:39:47 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.04it/s]
05/24/2024 13:39:52 - INFO - __main__ -     epoch = 4
05/24/2024 13:39:52 - INFO - __main__ -     eval_ppl = 1.0376
05/24/2024 13:39:52 - INFO - __main__ -     global_step = 3220
05/24/2024 13:39:52 - INFO - __main__ -     ********************
05/24/2024 13:39:53 - INFO - __main__ -   Save the last model into sh/saved_models/translate/cs-java/codet5p-220m_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-last/pytorch_model.bin
05/24/2024 13:39:53 - INFO - __main__ -   Ppl does not decrease for 3 epochs


ngram match: 0.7729337580169928, weighted ngram match: 0.7832488181173539, syntax_match: 0.9034184335785375, dataflow_match: 0.9106364428945074


[5] Train loss 0.003: 100%|██████████████████████████████████████████████████████████████████████████████████| 644/644 [05:27<00:00,  1.97it/s]
05/24/2024 13:48:28 - INFO - __main__ -     ***** Running ppl evaluation *****
05/24/2024 13:48:28 - INFO - __main__ -     Num examples = 499
05/24/2024 13:48:28 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.04it/s]
05/24/2024 13:48:34 - INFO - __main__ -     epoch = 5
05/24/2024 13:48:34 - INFO - __main__ -     eval_ppl = 1.03724
05/24/2024 13:48:34 - INFO - __main__ -     global_step = 3864
05/24/2024 13:48:34 - INFO - __main__ -     ********************
05/24/2024 13:48:34 - INFO - __main__ -   Save the last model into sh/saved_models/translate/cs-java/codet5p-220m_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-last/pytorch_model.bin
05/24/2024 13:48:34 - INFO - __main__ -   Ppl does not decrease for 4 epochs

ngram match: 0.7846679724055893, weighted ngram match: 0.8036756559589567, syntax_match: 0.9012548680225011, dataflow_match: 0.9071490845684395


[6] Train loss 0.003: 100%|██████████████████████████████████████████████████████████████████████████████████| 644/644 [05:27<00:00,  1.97it/s]
05/24/2024 13:57:16 - INFO - __main__ -     ***** Running ppl evaluation *****
05/24/2024 13:57:16 - INFO - __main__ -     Num examples = 499
05/24/2024 13:57:16 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.04it/s]
05/24/2024 13:57:22 - INFO - __main__ -     epoch = 6
05/24/2024 13:57:22 - INFO - __main__ -     eval_ppl = 1.03795
05/24/2024 13:57:22 - INFO - __main__ -     global_step = 4508
05/24/2024 13:57:22 - INFO - __main__ -     ********************
05/24/2024 13:57:22 - INFO - __main__ -   Save the last model into sh/saved_models/translate/cs-java/codet5p-220m_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-last/pytorch_model.bin
05/24/2024 13:57:22 - INFO - __main__ -   Ppl does not decrease for 5 epochs

ngram match: 0.7841568447350579, weighted ngram match: 0.7990756290428533, syntax_match: 0.9018606663781913, dataflow_match: 0.9003923278116827


[7] Train loss 0.003: 100%|██████████████████████████████████████████████████████████████████████████████████| 644/644 [05:27<00:00,  1.97it/s]
05/24/2024 14:05:57 - INFO - __main__ -     ***** Running ppl evaluation *****
05/24/2024 14:05:57 - INFO - __main__ -     Num examples = 499
05/24/2024 14:05:57 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.04it/s]
05/24/2024 14:06:02 - INFO - __main__ -     epoch = 7
05/24/2024 14:06:02 - INFO - __main__ -     eval_ppl = 1.03715
05/24/2024 14:06:02 - INFO - __main__ -     global_step = 5152
05/24/2024 14:06:02 - INFO - __main__ -     ********************
05/24/2024 14:06:03 - INFO - __main__ -   Save the last model into sh/saved_models/translate/cs-java/codet5p-220m_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-last/pytorch_model.bin
05/24/2024 14:06:03 - INFO - __main__ -   Ppl does not decrease for 6 epochs

ngram match: 0.7928441182172812, weighted ngram match: 0.7992552257674952, syntax_match: 0.8999567286888793, dataflow_match: 0.9067131647776809


[8] Train loss 0.003: 100%|██████████████████████████████████████████████████████████████████████████████████| 644/644 [05:26<00:00,  1.97it/s]
05/24/2024 14:14:31 - INFO - __main__ -     ***** Running ppl evaluation *****
05/24/2024 14:14:31 - INFO - __main__ -     Num examples = 499
05/24/2024 14:14:31 - INFO - __main__ -     Batch size = 16
Eval ppl: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:05<00:00,  6.03it/s]
05/24/2024 14:14:36 - INFO - __main__ -     epoch = 8
05/24/2024 14:14:36 - INFO - __main__ -     eval_ppl = 1.03815
05/24/2024 14:14:36 - INFO - __main__ -     global_step = 5796
05/24/2024 14:14:36 - INFO - __main__ -     ********************
05/24/2024 14:14:37 - INFO - __main__ -   Save the last model into sh/saved_models/translate/cs-java/codet5p-220m_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-last/pytorch_model.bin
05/24/2024 14:14:37 - INFO - __main__ -   Ppl does not decrease for 7 epochs

ngram match: 0.7828271110822578, weighted ngram match: 0.798854407122174, syntax_match: 0.8906966681090437, dataflow_match: 0.8938535309503052


05/24/2024 14:17:57 - INFO - utils -   Read 1000 examples, avg src len: 14, avg trg len: 13, max src len: 94, max trg len: 98
05/24/2024 14:17:57 - INFO - utils -   Load cache data from sh/saved_models/translate/cs-java/codet5p-220m_all_lr5e-05_bs16_src320_trg256_pat5_e100/cache_data/test_src_all.pt
05/24/2024 14:17:57 - INFO - __main__ -     ***** Running bleu evaluation on test data*****
05/24/2024 14:17:57 - INFO - __main__ -     Num examples = 1000
05/24/2024 14:17:57 - INFO - __main__ -     Batch size = 16
Eval bleu for test set: 100%|██████████████████████████████████████████████████████████████████████████████████| 63/63 [05:09<00:00,  4.91s/it]
05/24/2024 14:23:16 - INFO - __main__ -   ***** Eval results *****
05/24/2024 14:23:16 - INFO - __main__ -     bleu = 79.7
05/24/2024 14:23:16 - INFO - __main__ -     codebleu = 85.1677
05/24/2024 14:23:16 - INFO - __main__ -     em = 65.3
05/24/2024 14:23:16 - INFO - __main__ -   [best-bleu] bleu-4: 79.70, em: 65.3000, codebleu: 85.1677

ngram match: 0.7969467608100157, weighted ngram match: 0.8056190591647376, syntax_match: 0.9067143125822832, dataflow_match: 0.8974262868565717


In [10]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [12]:
count_parameters(model)

222882048

In [10]:
# inference

t0 = time.time()

logger.info("  " + "***** Testing *****")
logger.info("  Batch size = %d", args.eval_batch_size)
# args.model_name_or_path = Salesforce/codet5-base'
pool = multiprocessing.Pool(args.cpu_cont)
config, model, tokenizer = build_or_load_gen_model(args)
model.to(args.device)

for criteria in ['best-bleu']:
    file = 'sh/saved_models/translate/cs-java/codet5p-220m_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin' #os.path.join(args.output_dir, 'checkpoint-{}/pytorch_model.bin'.format(criteria))
    print(file)
    logger.info("Reload model from {}".format(file))
    model.load_state_dict(torch.load(file))
    eval_examples, eval_data = load_and_cache_gen_data(args, args.test_filename, pool, tokenizer, 'test',
                                                       only_src=True, is_sample=False)
    result = eval_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, 'test', criteria)
    test_bleu, test_em = result['bleu'], result['em']
    test_codebleu = result['codebleu'] if 'codebleu' in result else 0
    result_str = "[%s] bleu-4: %.2f, em: %.4f, codebleu: %.4f\n" % (criteria, test_bleu, test_em, test_codebleu)
    logger.info(result_str)
    # fa.write(result_str)
    # if args.res_fn:
    #     with open(args.res_fn, 'a+') as f:
    #         f.write('[Time: {}] {}\n'.format(get_elapse_time(t0), file))
    #         f.write(result_str)
logger.info("Finish and take {}".format(get_elapse_time(t0)))
# fa.write("Finish and take {}".format(get_elapse_time(t0)))
# fa.close()

05/14/2024 14:36:45 - INFO - __main__ -     ***** Testing *****
05/14/2024 14:36:45 - INFO - __main__ -     Batch size = 4


args.model_type:  codet5


05/14/2024 14:36:47 - INFO - models -   Finish loading model [738M] from Salesforce/codet5-large
05/14/2024 14:36:49 - INFO - __main__ -   Reload model from sh/saved_models/translate/cs-java/codet5_large_all_lr5e-05_bs4_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin


sh/saved_models/translate/cs-java/codet5_large_all_lr5e-05_bs4_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin


05/14/2024 14:36:50 - INFO - utils -   Read 1000 examples, avg src len: 14, avg trg len: 13, max src len: 94, max trg len: 98
05/14/2024 14:36:50 - INFO - utils -   Load cache data from sh/saved_models/translate/cs-java/codet5_large_all_lr5e-05_bs4_src320_trg256_pat5_e100/cache_data/test_src_all.pt
05/14/2024 14:36:50 - INFO - __main__ -     ***** Running bleu evaluation on test data*****
05/14/2024 14:36:50 - INFO - __main__ -     Num examples = 1000
05/14/2024 14:36:50 - INFO - __main__ -     Batch size = 4
Eval bleu for test set: 100%|█████████████████████████████████████████████████████████████████████████████████████| 250/250 [09:47<00:00,  2.35s/it]
05/14/2024 14:46:44 - INFO - __main__ -   ***** Eval results *****
05/14/2024 14:46:44 - INFO - __main__ -     bleu = 78.06
05/14/2024 14:46:44 - INFO - __main__ -     codebleu = 83.9113
05/14/2024 14:46:44 - INFO - __main__ -     em = 66.2
05/14/2024 14:46:44 - INFO - __main__ -   [best-bleu] bleu-4: 78.06, em: 66.2000, codebleu: 83.

ngram match: 0.780570768115452, weighted ngram match: 0.7873614145521678, syntax_match: 0.9012130900883957, dataflow_match: 0.8873063468265867


Process ForkPoolWorker-16:


In [None]:
import torch
import torch.nn as nn
import tensorly as tl
from tensorly.tenalg import svd_interface
from nltk.translate.bleu_score import corpus_bleu
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import json

decomposition_info = []

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def decompose_weights(weight_matrix, rank):
    U, S, V = svd_interface(weight_matrix.cpu().detach().numpy(), n_eigenvecs=rank)
    return torch.tensor(U).to(weight_matrix.device), torch.tensor(S).to(weight_matrix.device), torch.tensor(V).to(weight_matrix.device)

def reconstruct_weights(U, S, V):
    return torch.matmul(U, torch.matmul(torch.diag(S), V))

def decompose_and_replace_linear_layer(model, layer_name, rank):
    original_layer = dict(model.named_modules())[layer_name]
    original_weight = original_layer.weight.detach().clone()
    U, S, V = decompose_weights(original_layer.weight, rank)

    layer1 = nn.Linear(U.size(0), U.size(1), bias=False)
    layer2 = nn.Linear(S.size(0), S.size(0), bias=False)
    layer3 = nn.Linear(V.size(0), V.size(1), bias=True)

    layer1.weight.data = U
    layer2.weight.data = torch.diag(S)
    layer3.weight.data = V

    setattr(model, layer_name, nn.Sequential(layer1, layer2, layer3))
    
    return original_weight, U, S, V

def check_layer_condition(layer, condition):
    return layer.weight.size(0) > condition

def measure_bleu_score(model, tokenizer, data, device):
    model.to(device)
    references = []
    hypotheses = []
    for src, ref in data:
        input_ids = tokenizer.encode(src, return_tensors='pt').to(device)
        output = model.generate(input_ids)
        hypothesis = tokenizer.decode(output[0], skip_special_tokens=True)
        references.append([ref.split()])
        hypotheses.append(hypothesis.split())
    return corpus_bleu(references, hypotheses)

def save_decomposition_info(layer_name, rank, loss_of_bleu, compression, initial_params, final_params):
    info = {
        'layer_name': layer_name,
        'rank': rank,
        'loss_of_bleu': loss_of_bleu,
        'compression': compression,
        'initial_params': initial_params,
        'final_params': final_params
    }
    decomposition_info.append(info)

def write_decomposition_info_to_file(filename):
    with open(filename, 'w') as f:
        json.dump(decomposition_info, f, indent=4)

def decompose_transformer(transformer_model, tokenizer, rank, tolerance, condition, eval_examples, eval_data, device):
    result = eval_bleu_epoch(args, eval_data, eval_examples, transformer_model, tokenizer, 'test', 'best-bleu') # подсчитывает метрику
    initial_bleu = result['bleu']
    
    for name, module in transformer_model.named_modules():
        if isinstance(module, nn.Linear) and check_layer_condition(module, condition):
            initial_params = sum(p.numel() for p in module.parameters())
            original_weight, U, S, V = decompose_and_replace_linear_layer(transformer_model, name, rank)
            final_params = U.numel() + S.numel() + V.numel() # ???+ S.numel()  # S дважды учитывается из-за диагональной матрицы
            
            # Замеряем компрессию и метрику BLEU
            compression = initial_params / final_params
            print('compression: ', compression)
            result = eval_bleu_epoch(args, eval_data, eval_examples, transformer_model, tokenizer, 'test', 'best-bleu')
            bleu_score, test_em = result['bleu'], result['em']
            # bleu_score = measure_bleu_score(transformer_model, tokenizer, validation_data, device)
            
            if bleu_score >= initial_bleu * (1 - tolerance):
                print(f"Decomposed {name} with rank {rank}")
                save_decomposition_info(name, rank, initial_bleu - bleu_score, compression, initial_params, final_params)
                model_to_save = transformer_model.module if hasattr(transformer_model, 'module') else transformer_model
                output_model_file = os.path.join('sh/saved_models/translate/cs-java/codet5p-220m_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-decompose/', "pytorch_model.bin")
                torch.save(model_to_save.state_dict(), output_model_file)
                print('Amount of parameters: ', count_parameters(model_to_save))
            else:
                print(f"Restoring {name} due to unacceptable BLEU score drop")
                # Восстанавливаем исходные параметры слоя
                setattr(transformer_model, name, original_weight)
                
    # Дообучение модели
    # finetune_model(transformer_model)

device = torch.device("cuda", 0)
pool = multiprocessing.Pool(args.cpu_cont)
config, model, tokenizer = build_or_load_gen_model(args)
model.to(args.device)
file = 'sh/saved_models/translate/cs-java/codet5p-220m_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin' #os.path.join(args.output_dir, 'checkpoint-{}/pytorch_model.bin'.format(criteria))
model.load_state_dict(torch.load(file))
eval_examples, eval_data = load_and_cache_gen_data(args, args.test_filename, pool, tokenizer, 'test',
                                                   only_src=True, is_sample=False)
rank = 10  # Ранг для декомпозиции
tolerance = 0.3  # Порог допустимой потери BLEU
condition = 1000  # Условие для выбора слоёв (например, количество параметров больше 1000)

decompose_transformer(model, tokenizer, rank, tolerance, condition, eval_examples, eval_data, device)

write_decomposition_info_to_file('decomposition_info.json')
