In [1]:
import torch
from transformers import T5ForConditionalGeneration, RobertaTokenizer

# Load the pre-trained CodeT5 models and tokenizer
small_model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-small')
large_model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base')
tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-small')

def speculative_decoding(input_text, num_candidates=3, max_length=50):

    input_ids = tokenizer.encode(input_text, return_tensors='pt')
    
    final_sequence = input_ids

    for _ in range(max_length):
        candidate_outputs = small_model.generate(final_sequence, num_return_sequences=num_candidates, max_length=final_sequence.size(-1) + 1, do_sample=True)
        candidate_tokens = [output[-1].unsqueeze(0) for output in candidate_outputs]
        
        decoder_input_ids = final_sequence

        with torch.no_grad():
            large_model_outputs = large_model(input_ids, decoder_input_ids=decoder_input_ids)
            large_model_logits = large_model_outputs.logits[:, -1, :]
        
        predicted_token = torch.argmax(large_model_logits, dim=-1).item()

        match_found = False
        for candidate_token in candidate_tokens:
            if candidate_token.item() == predicted_token:
                final_sequence = torch.cat((final_sequence, candidate_token.unsqueeze(0)), dim=-1)
                match_found = True
                break
        
        if not match_found:
            final_sequence = torch.cat((final_sequence, torch.tensor([[predicted_token]])), dim=-1)
        
        decoded_sequence = tokenizer.decode(final_sequence[0], skip_special_tokens=True)
        
        if decoded_sequence.endswith('</s>'):
            break
    
    return decoded_sequence

# Example usage
# input_code = "def add(a, b):"
# generated_code = speculative_decoding(input_code)
# print("Generated Code:", generated_code)


In [10]:
import torch
from transformers import T5ForConditionalGeneration, RobertaTokenizer

file_model = "sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin"
file_assis_model = "sh/saved_models/translate/cs-java/codet5_small_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin"

assistant_model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-small')
model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base')
model.load_state_dict(torch.load(file_model))
assistant_model.load_state_dict(torch.load(file_assis_model))
tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-small')
input_code = "public override void Serialize(ILittleEndianOutput out1){out1.WriteShort(field_1_vcenter);}"

inputs = tokenizer(input_code, return_tensors="pt")
outputs = model.generate(**inputs, assistant_model=assistant_model, do_sample=True, temperature=0.3)
print("Model output: ", tokenizer.batch_decode(outputs, skip_special_tokens=True))
print("Correct answer: public void serialize(LittleEndianOutput out) {out.writeShort(field_1_vcenter);}")

Model output:  ['public void serialize(LittleEndianOutput out) {out.writeShort(field_1']
Correct answer: public void serialize(LittleEndianOutput out) {out.writeShort(field_1_vcenter);}


In [1]:
from dotmap import DotMap
import torch

WORKDIR="/home/okozlova/diplom_oksana/CodeT5/CodeT5"
MODEL_TAG = 'codet5_large'
MODEL_DIR = 'sh/saved_models'
TASK = 'translate'
SUB_TASK = 'cs-java'
DATA_TAG = 'all'
LR = 5e-5
BS = 4
SRC_LEN = 320
TRG_LEN = 256
PATIENCE = 5
EPOCH = 100
FULL_MODEL_TAG=f'{MODEL_TAG}_{DATA_TAG}_lr{LR}_bs{BS}_src{SRC_LEN}_trg{TRG_LEN}_pat{PATIENCE}_e{EPOCH}'
OUTPUT_DIR=f'{MODEL_DIR}/{TASK}/{SUB_TASK}/{FULL_MODEL_TAG}'
CACHE_DIR=f'{OUTPUT_DIR}/cache_data'
RES_DIR=f'{OUTPUT_DIR}/prediction'
LOG=f'{OUTPUT_DIR}/train.log'
DATA_DIR = '/home/okozlova/diplom_oksana/CodeT5/CodeT5/data/translate/'
args_dict = {
    "task": 'translate',
    "sub_task": SUB_TASK,
    "lang": 'java', # 'c_sharp',
    "model_tag": MODEL_TAG,
    "res_dir": 'sh/results',
    "model_dir": 'sh/saved_models',
    "summary_dir": 'tensorboard',
    "data_num": -1,
    "gpu": 0,
    "do_train": True,
    "do_eval": True,
    "do_eval_bleu": True,
    "do_test": True,
    "model_type": 'codet5',
    "num_train_epochs": EPOCH, # 100
    "warmup_steps": 1000,
    "learning_rate": LR,
    "patience": PATIENCE, 
    "tokenizer_name": 'Salesforce/codet5-large',
    "model_name_or_path": 'Salesforce/codet5-large',
    'data_dir': WORKDIR + '/data',
    'cache_path': CACHE_DIR,
    'output_dir': OUTPUT_DIR,
    'summary_dir': 'tensorboard',
    'save_last_checkpoints': True,
    'always_save_model': True,
    'res_dir': RES_DIR,
    'train_batch_size': BS,
    'eval_batch_size': BS, 
    'max_source_length': SRC_LEN,
    'max_target_length': TRG_LEN,
    'seed': 1234,
    'local_rank': 0,
    'no_cuda': False,
    'n_gpu': 1,
    'device': torch.device("cuda", 1),
    'train_filename': f'{DATA_DIR}/train.java-cs.txt.cs,{DATA_DIR}/train.java-cs.txt.java',
    'dev_filename': f'{DATA_DIR}/valid.java-cs.txt.cs,{DATA_DIR}/valid.java-cs.txt.java',
    'test_filename': f'{DATA_DIR}/test.java-cs.txt.cs,{DATA_DIR}/test.java-cs.txt.java',
    'weight_decay': 0.0,
    'adam_epsilon': 1e-8,
    'gradient_accumulation_steps': 1, # Number of updates steps to accumulate before performing a backward/update pass
    'load_model_path': None,
    'cpu_cont': 16,
    'beam_size': 10
    
}

args = DotMap(args_dict)


In [2]:
import torch
import logging
import os
import random
import time
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
from transformers import T5ForConditionalGeneration, RobertaTokenizer
from tqdm import tqdm
from evaluator.CodeBLEU import calc_code_bleu
from evaluator.bleu import _bleu
from evaluator import smooth_bleu
from utils import read_examples, convert_examples_to_features, calc_stats

logger = logging.getLogger(__name__)

def load_and_cache_gen_data(args, filename, tokenizer, split_tag, only_src=False, is_sample=False):
    data_tag = '_all' if args.data_num == -1 else '_%d' % args.data_num
    cache_fn = '{}/{}.pt'.format(args.cache_path, split_tag + ('_src' if only_src else '') + data_tag)

    examples = read_examples(filename, args.data_num, args.task)

    if is_sample:
        examples = random.sample(examples, min(5000, len(examples)))
    if split_tag == 'train':
        calc_stats(examples, tokenizer, is_tokenize=True)
    else:
        calc_stats(examples)
    if os.path.exists(cache_fn) and not is_sample:
        logger.info("Load cache data from %s", cache_fn)
        data = torch.load(cache_fn)
    else:
        if is_sample:
            logger.info("Sample 5k data for computing BLEU from %s", filename)
        else:
            logger.info("Create cache data into %s", cache_fn)
        features = [convert_examples_to_features((example, idx, tokenizer, args, split_tag)) for idx, example in enumerate(tqdm(examples))]
        all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long)
        if split_tag == 'test' or only_src:
            data = TensorDataset(all_source_ids)
        else:
            all_target_ids = torch.tensor([f.target_ids for f in features], dtype=torch.long)
            data = TensorDataset(all_source_ids, all_target_ids)
        if args.local_rank in [-1, 0] and not is_sample:
            torch.save(data, cache_fn)
    return examples, data

def eval_bleu_epoch(args, eval_data, eval_examples, model, assistant_model, tokenizer, split_tag, criteria, speculative):
    logger.info("***** Running BLEU evaluation on {} data *****".format(split_tag))
    logger.info("Num examples = %d", len(eval_examples))
    logger.info("Batch size = %d", args.eval_batch_size)

    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=1, num_workers=4, pin_memory=True)

    model.eval()
    assistant_model.eval()

    pred_ids = []
    start_time = time.time()

    for batch in tqdm(eval_dataloader, total=len(eval_dataloader), desc="Eval BLEU for {} set".format(split_tag)):
        source_ids = batch[0].to(args.device)
        source_mask = source_ids.ne(tokenizer.pad_token_id).to(args.device)
        with torch.no_grad():
            if speculative:
                preds = model.generate(
                    source_ids,
                    attention_mask=source_mask,
                    do_sample=True,
                    assistant_model=assistant_model,
                    # temperature=0.3,
                    num_beams=1,
                    max_length=args.max_target_length
                )
            else:
                preds = model.generate(
                    source_ids,
                    attention_mask=source_mask,
                    use_cache=True,
                    num_beams=1,
                    max_length=args.max_target_length,
                )
            top_preds = list(preds.cpu().numpy())
            pred_ids.extend(top_preds)

    end_time = time.time()
    inference_time = end_time - start_time

    pred_nls = [tokenizer.decode(id, skip_special_tokens=True, clean_up_tokenization_spaces=False) for id in pred_ids]

    output_fn = os.path.join(args.res_dir, "test_{}.output".format(criteria))
    gold_fn = os.path.join(args.res_dir, "test_{}.gold".format(criteria))

    with open(gold_fn, 'w') as f1:
        for gold in eval_examples:
            f1.write(gold.target.strip() + '\n')

    with open(output_fn, 'w') as f:
        for pred_nl in pred_nls:
            f.write(pred_nl.strip() + '\n')

    bleu = round(_bleu(gold_fn, output_fn), 2)
    codebleu = calc_code_bleu.get_codebleu(gold_fn, output_fn, args.lang)

    result = {
        'inference_time': inference_time,
        'bleu': bleu,
        'codebleu': codebleu * 100
    }

    logger.info("***** Eval results *****")
    for key in sorted(result.keys()):
        logger.info("  %s = %s", key, str(round(result[key], 4)))

    return result


# Пути к моделям
file_model = "sh/saved_models/translate/cs-java/codet5_base_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin"
file_assis_model = "sh/saved_models/translate/cs-java/codet5_small_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin"

# Загрузка моделей и токенизатора
model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base')
assistant_model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-small')
model.load_state_dict(torch.load(file_model, map_location=args.device))
assistant_model.load_state_dict(torch.load(file_assis_model, map_location=args.device))
model.to(args.device)
assistant_model.to(args.device)
tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-small')

eval_examples, eval_data = load_and_cache_gen_data(args, args.test_filename, tokenizer, "test")

result_with_speculative = eval_bleu_epoch(args, eval_data, eval_examples, model, assistant_model, tokenizer, "test", "speculative", True)

args.use_speculative_decoding = False
result_without_speculative = eval_bleu_epoch(args, eval_data, eval_examples, model, assistant_model, tokenizer, "test", "no_speculative", False)

# Вывод результатов
print("With Speculative Decoding:", result_with_speculative)
print("Without Speculative Decoding:", result_without_speculative)


Eval BLEU for test set: 100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [05:03<00:00,  3.29it/s]


ngram match: 0.7691698609766356, weighted ngram match: 0.7766263290183562, syntax_match: 0.8943953357156291, dataflow_match: 0.8863068465767117


Eval BLEU for test set: 100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [05:37<00:00,  2.96it/s]


ngram match: 0.7886146091898436, weighted ngram match: 0.7961580300663933, syntax_match: 0.8999905962008652, dataflow_match: 0.894552723638181
With Speculative Decoding: {'inference_time': 303.8070125579834, 'bleu': 76.92, 'codebleu': 83.16245930718333}
Without Speculative Decoding: {'inference_time': 337.8633964061737, 'bleu': 78.87, 'codebleu': 84.48289897738206}


In [3]:
# Пути к моделям
file_model = "sh/saved_models/translate/cs-java/codet5_large_all_lr5e-05_bs4_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin"
file_assis_model = "sh/saved_models/translate/cs-java/codet5_small_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin"

# Загрузка моделей и токенизатора
model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-large')
assistant_model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-small')
model.load_state_dict(torch.load(file_model, map_location=args.device))
assistant_model.load_state_dict(torch.load(file_assis_model, map_location=args.device))
model.to(args.device)
assistant_model.to(args.device)
tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-small')

# Загрузка и кэширование данных
eval_examples, eval_data = load_and_cache_gen_data(args, args.test_filename, tokenizer, "test")

# Выполнение оценки BLEU и CodeBLEU
result_with_speculative = eval_bleu_epoch(args, eval_data, eval_examples, model, assistant_model, tokenizer, "test", "speculative", True)

# Отключаем speculative decoding для повторной оценки
args.use_speculative_decoding = False
result_without_speculative = eval_bleu_epoch(args, eval_data, eval_examples, model, assistant_model, tokenizer, "test", "no_speculative", False)

# Вывод результатов
print("With Speculative Decoding:", result_with_speculative)
print("Without Speculative Decoding:", result_without_speculative)


Eval BLEU for test set: 100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [05:36<00:00,  2.97it/s]


ngram match: 0.7507469060240761, weighted ngram match: 0.7570898394281561, syntax_match: 0.8884239232649991, dataflow_match: 0.8584457771114443


Eval BLEU for test set: 100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [10:30<00:00,  1.59it/s]


ngram match: 0.7743429767525913, weighted ngram match: 0.7851691343245184, syntax_match: 0.8981568553695694, dataflow_match: 0.8793103448275862
With Speculative Decoding: {'inference_time': 336.6193685531616, 'bleu': 75.08, 'codebleu': 81.36766114571688}
Without Speculative Decoding: {'inference_time': 630.3578195571899, 'bleu': 77.44, 'codebleu': 83.42448278185664}


In [13]:
import torch
import time
from transformers import T5ForConditionalGeneration, RobertaTokenizer
import sacrebleu

# Пути к моделям
file_model = "sh/saved_models/translate/cs-java/codet5_large_all_lr5e-05_bs4_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin"
file_assis_model = "sh/saved_models/translate/cs-java/codet5_small_all_lr5e-05_bs16_src320_trg256_pat5_e100/checkpoint-best-bleu/pytorch_model.bin"
file_model = file_assis_model
# Загрузка моделей и токенизатора
assistant_model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-small')
model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-small')
model.load_state_dict(torch.load(file_model))
assistant_model.load_state_dict(torch.load(file_assis_model))
tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-small')

input_code = "public virtual void AddAll(NGit.Util.BlockList<T> src){if (src.size == 0){return;}int srcDirIdx = 0;for (; srcDirIdx < src.tailDirIdx; srcDirIdx++){AddAll(src.directory[srcDirIdx], 0, BLOCK_SIZE);}if (src.tailBlkIdx != 0){AddAll(src.tailBlock, 0, src.tailBlkIdx);}}"
inputs = tokenizer(input_code, return_tensors="pt")

# Correct output for BLEU calculation
correct_answer = "public void addAll(BlockList<T> src) {if (src.size == 0)return;int srcDirIdx = 0;for (; srcDirIdx < src.tailDirIdx; srcDirIdx++)addAll(src.directory[srcDirIdx], 0, BLOCK_SIZE);if (src.tailBlkIdx != 0)addAll(src.tailBlock, 0, src.tailBlkIdx);}"
references = [correct_answer]

# Инференс с speculative decoding
start_time = time.time()
outputs_with_speculative = model.generate(**inputs, assistant_model=assistant_model, do_sample=True, temperature=0.3)
end_time = time.time()
time_with_speculative = end_time - start_time
output_with_speculative = tokenizer.batch_decode(outputs_with_speculative, skip_special_tokens=True)[0]

# Инференс без speculative decoding
start_time = time.time()
outputs_without_speculative = model.generate(**inputs, do_sample=True, temperature=0.3)
end_time = time.time()
time_without_speculative = end_time - start_time
output_without_speculative = tokenizer.batch_decode(outputs_without_speculative, skip_special_tokens=True)[0]

# Расчет BLEU-метрики
bleu_with_speculative = sacrebleu.corpus_bleu([output_with_speculative], [references])
bleu_without_speculative = sacrebleu.corpus_bleu([output_without_speculative], [references])

# Вывод результатов
print("Inference time with speculative decoding: {:.4f} seconds".format(time_with_speculative))
print("Inference time without speculative decoding: {:.4f} seconds".format(time_without_speculative))
print("Output with speculative decoding: ", output_with_speculative)
print("Output without speculative decoding: ", output_without_speculative)
print("Correct answer: ", correct_answer)
print("BLEU score with speculative decoding: ", bleu_with_speculative.score)
print("BLEU score without speculative decoding: ", bleu_without_speculative.score)

Inference time with speculative decoding: 0.2066 seconds
Inference time without speculative decoding: 0.1431 seconds
Output with speculative decoding:  public void addAll(BlockList<T> src) {if (src.size ==
Output without speculative decoding:  public void addAll(BlockList<T> src) {if (src.size ==
Correct answer:  public void addAll(BlockList<T> src) {if (src.size == 0)return;int srcDirIdx = 0;for (; srcDirIdx < src.tailDirIdx; srcDirIdx++)addAll(src.directory[srcDirIdx], 0, BLOCK_SIZE);if (src.tailBlkIdx != 0)addAll(src.tailBlock, 0, src.tailBlkIdx);}
BLEU score with speculative decoding:  3.3746151800550366
BLEU score without speculative decoding:  3.3746151800550366
