In [1]:
import os
import sys
DIR_PREFIX = "/home/user/commits/commit_messages_generation"

sys.path.insert(0, DIR_PREFIX)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "2,3" #<- for the common server(SSH)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
import re
import torch
import torch.nn as nn
import transformers
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, pipeline
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import pandas as pd
import numpy as np
import datasets
from peft import PeftModel, PeftConfig
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
from transformers import TrainingArguments, Trainer
import wandb
from transformers import TrainerCallback
from accelerate import Accelerator
from transformers import AutoModelForSeq2SeqLM
from peft import PrefixTuningConfig
from datasets import load_dataset
from pprint import pprint
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import evaluate

#import custom scripts
from CommitChronicle_preprocessing.DatasetParser import DatasetParser

In [10]:
dataset = load_dataset(DIR_PREFIX + "/CommitChronicle")
train_data = dataset['train']
val_data = dataset['validation']

Resolving data files:   0%|          | 0/63 [00:00<?, ?it/s]

In [6]:
checkpoint = "JetBrains-Research/cmg-codet5-without-history"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [23]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, device_map='auto')

In [15]:
def parse_input(example):
    """
    function to parse code changes from CommitChronicle 
    It's adds special tokens to the sample(<code_del> , <code_add> ...)
    Example of usage with the whole dataset:
    >>> parser = DatasetParser(tokenizer)
    >>> train_data = train_data.map(parser.parse_input, num_proc=8)
    """
    diffs = []
    for i in range(len(example["mods"])):
        change_type = example["mods"][i]["change_type"]
        new_path = (
            example["mods"][i]["new_path"] if example["mods"][i]["new_path"] else ""
        )
        old_path =  (
            example["mods"][i]["old_path"] if example["mods"][i]["old_path"] else ""
        )

        code_diff = example["mods"][i]["diff"]
        code_diff_lines = code_diff.split('\n')
        code_diff = '\n'.join(code_diff_lines)
        model_input = (old_path + "\n" + new_path + "\n" + code_diff + "\n"
        )
        diffs.append(model_input)
    example["model_input"] = "\n".join(diffs)
    return example

In [28]:
parser = DatasetParser(tokenizer)
val_data.set_format('torch')
val_data = parser.remove_useless_columns(val_data)
val_data = val_data.map(parser.tokenize_data, num_proc=10)
val_data = val_data.map(parser.squeeze_dataset, num_proc=10)
val_data

Map (num_proc=10):   0%|          | 0/1554042 [00:00<?, ? examples/s]

Map (num_proc=10):   0%|          | 0/1554042 [00:00<?, ? examples/s]

Dataset({
    features: ['message', 'language', 'model_input', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 1554042
})

In [31]:
def model_inference(model, tokenizer, text, return_dict=False, seqs=5):
    prompt = text
    input = tokenizer(prompt, return_tensors="pt", truncation=True, padding='max_length').to(device)
    model.eval()
    with torch.no_grad():
        sample_outputs = model.generate(
            **input,
            max_new_tokens=25,
            top_k=50,
            num_return_sequences=seqs,
            num_beams=5,
            no_repeat_ngram_size=2,
            do_sample=True,
            early_stopping=True,
            top_p=0.95,
        )
    if not return_dict:
        for i, sample_output in enumerate(sample_outputs):
            print(
                "{}: {}".format(
                    i, tokenizer.decode(sample_output, skip_special_tokens=True)
                )
            )
            print("-" * 80)
    else:
        res = []
        for i, sample_output in enumerate(sample_outputs):
            res.append(tokenizer.decode(sample_output, skip_special_tokens=True))
        return res

In [32]:
def random_sample_inference():
    n = np.random.randint(0, len(val_data))
    sample = val_data[n]
    print(f"Index - {n}")
    print(f"{'CODE_DIFSS':=^75}")
    print(sample["model_input"])
    print(f"{'MESSAGE':=^75}")
    print(val_data[n]["message"], '\n')
    print(f"{'GENERATED_MESSAGE':=^75}")
    generated = model_inference(model, tokenizer, sample["model_input"], return_dict=True)
    for elem in generated:
        print(elem)
        print("="*75)

In [35]:
samples_to_eval = int(1e5)
random_indeces = np.random.choice(len(val_data), size=samples_to_eval, replace=False)
val_data = val_data.select(random_indeces)

In [36]:
languages_used = np.array(val_data['language'])
unique_langs = np.unique(languages_used)
lang_datasets = {}
for lang in unique_langs:
    lang_datasets[lang] = []

def distribute_samples(example):
    global lang_datasets
    lang = example['language']
    lang_datasets[lang].append(example)
    return example

val_data = val_data.map(distribute_samples, num_proc=1, load_from_cache_file=False)
for lang in lang_datasets.keys():
    lang_datasets[lang] = datasets.Dataset.from_list(lang_datasets[lang])
    lang_datasets[lang].set_format('torch')

lang_datasets = datasets.DatasetDict(lang_datasets)

Map:   0%|          | 0/100000 [00:00<?, ? examples/s]

In [37]:
from metrics.bnorm.bleu_norm import BLEUNorm

bnorm = BLEUNorm()

  bnorm = BLEUNorm()


In [39]:
generation_conf = transformers.GenerationConfig(
    max_new_tokens=25,
    top_k=50,
    num_beams=5,
    no_repeat_ngram_size=2,
    do_sample=True,
    early_stopping=True,
    top_p=0.95,
)

In [41]:
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    #take only first 10 tokens to compute metric
    preds = preds[:, :10]
    labels = labels[:, :10]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = {}
    result = bnorm.compute(predictions=decoded_preds, references=decoded_labels)['b_norm']
    result = {"BLEU_norm": result}

    prediction_lens = [torch.count_nonzero(pred != tokenizer.pad_token_id).item() for pred in preds]
    result['prediction_len'] = np.array(prediction_lens).mean()
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [42]:
by_lang_results = {}
total_BLEU = 0
total_MSG_LEN = 0
total_batches = 0
for lang in lang_datasets.keys():
    print(f"{lang:=^75}")
    eval_dataloader = DataLoader(lang_datasets[lang], batch_size=32)
    BLEU = 0
    MSG_LEN = 0
    progress = tqdm(enumerate(eval_dataloader), total=len(eval_dataloader))
    for i, batch in progress:
        batch_messages = batch["message"]
        batch_langs = batch["language"]
        batch_input_ids = batch["input_ids"].to(device)
        batch_attn = batch["attention_mask"].to(device)
        batch_labels = batch["labels"]

        batch_encodings = {
            "input_ids": batch_input_ids,
            "attention_mask": batch_attn,
            # "decoder_input_ids": batch_input_ids.clone(),
        }
        batch_preds = model.generate(
            **batch_encodings, generation_config=generation_conf
        )
        eval_preds = (batch_preds, batch_labels)
        batch_result = compute_metrics(eval_preds)
        BLEU += batch_result["BLEU_norm"]
        total_BLEU += batch_result["BLEU_norm"]
        MSG_LEN += batch_result["prediction_len"]
        total_MSG_LEN += batch_result["prediction_len"]
        progress.set_description(
            f"BLEU: {BLEU / (i+1):.3f} LEN : {MSG_LEN / (i+1):.3f}"
        )
    n = len(eval_dataloader)
    total_batches += n
    by_lang_results[lang] = {
        "BLEU_norm": BLEU / len(eval_dataloader),
        "MSG_LEN": MSG_LEN / len(eval_dataloader),
    }

by_lang_results['Total'] = {
        "BLEU_norm": total_BLEU / total_batches,
        "MSG_LEN": total_MSG_LEN / total_batches,
    }



  0%|          | 0/118 [00:00<?, ?it/s]



  0%|          | 0/172 [00:00<?, ?it/s]



  0%|          | 0/411 [00:00<?, ?it/s]



  0%|          | 0/35 [00:00<?, ?it/s]



  0%|          | 0/8 [00:00<?, ?it/s]



  0%|          | 0/268 [00:00<?, ?it/s]



  0%|          | 0/4 [00:00<?, ?it/s]



  0%|          | 0/400 [00:00<?, ?it/s]



  0%|          | 0/344 [00:00<?, ?it/s]



  0%|          | 0/61 [00:00<?, ?it/s]



  0%|          | 0/174 [00:00<?, ?it/s]



  0%|          | 0/3 [00:00<?, ?it/s]



  0%|          | 0/66 [00:00<?, ?it/s]



  0%|          | 0/431 [00:00<?, ?it/s]



  0%|          | 0/77 [00:00<?, ?it/s]



  0%|          | 0/122 [00:00<?, ?it/s]



  0%|          | 0/28 [00:00<?, ?it/s]



  0%|          | 0/3 [00:00<?, ?it/s]



  0%|          | 0/54 [00:00<?, ?it/s]



  0%|          | 0/356 [00:00<?, ?it/s]

In [43]:
metrics_df = pd.DataFrame(by_lang_results).T
metrics_df.to_csv('JB_BLEU_norm_results.csv')