In [23]:
import os
import sys

DIR_PREFIX = "/home/user/commits/commit_messages_generation/"

sys.path.insert(0, DIR_PREFIX)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "2,3" #<- for the common server(SSH)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [24]:
import torch
import transformers
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, pipeline
import pandas as pd
import numpy as np
import datasets
from transformers import AutoModelForSeq2SeqLM
from pprint import pprint
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import evaluate
from torch.utils.data import DataLoader

# import custom scripts
from CommitChronicle_preprocessing.DatasetParser import DatasetParser
from metrics.bnorm.bleu_norm import BLEUNorm

bnorm = BLEUNorm()

In [25]:
checkpoint = DIR_PREFIX + "model/t5p_CommitChron_v2/checkpoint-225000"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, local_files_only=True).to(
    device
)

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [26]:
def model_inference(model, tokenizer, text, return_dict=False, seqs=5):
    prompt = text
    input = tokenizer(
        prompt, return_tensors="pt", truncation=True, padding="max_length"
    ).to(device)
    model.eval()
    with torch.no_grad():
        sample_outputs = model.generate(
            **input,
            max_new_tokens=25,
            top_k=50,
            num_return_sequences=seqs,
            num_beams=5,
            no_repeat_ngram_size=2,
            do_sample=True,
            early_stopping=True,
            top_p=0.95,
        )
    if not return_dict:
        for i, sample_output in enumerate(sample_outputs):
            print(
                "{}: {}".format(
                    i, tokenizer.decode(sample_output, skip_special_tokens=True)
                )
            )
            print("-" * 80)
    else:
        res = []
        for i, sample_output in enumerate(sample_outputs):
            res.append(tokenizer.decode(sample_output, skip_special_tokens=True))
        return res

In [27]:
dataset = datasets.load_dataset(DIR_PREFIX + "/CommitChronicle/")
train_data = dataset["train"]
val_data = dataset["validation"]
langs = np.unique(val_data["language"])
parser = DatasetParser(tokenizer)

Resolving data files:   0%|          | 0/63 [00:00<?, ?it/s]

In [28]:
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    # take only first 10 tokens to compute metric
    preds = preds[:, :10]
    labels = labels[:, :10]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = {}
    result = bnorm.compute(predictions=decoded_preds, references=decoded_labels)[
        "b_norm"
    ]
    result = {"BLEU_norm": result}

    prediction_lens = [
        torch.count_nonzero(pred != tokenizer.pad_token_id).item() for pred in preds
    ]
    result["prediction_len"] = np.array(prediction_lens).mean()
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [29]:
generation_conf = transformers.GenerationConfig(
    max_new_tokens=25,
    top_k=50,
    num_beams=5,
    no_repeat_ngram_size=2,
    do_sample=True,
    early_stopping=True,
    top_p=0.95,
)

In [30]:
parser = DatasetParser(tokenizer)
val_data.set_format("torch")
val_data = val_data.map(parser.parse_input, num_proc=8)
val_data = parser.remove_useless_columns(val_data)
val_data = val_data.map(parser.tokenize_data, num_proc=10)
val_data = val_data.map(parser.squeeze_dataset, num_proc=10)
val_data

Dataset({
    features: ['message', 'language', 'model_input', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 1554042
})

In [None]:
SAMPLING_ITERATIONS = 10

for sampling_iteration in range(SAMPLING_ITERATIONS):
    print(f"ITERATION NUMBER {sampling_iteration}")
    # ======================================================= Sampling from val Set
    samples_to_eval = int(1e5)
    random_indeces = np.random.choice(
        len(val_data), size=samples_to_eval, replace=False
    )
    val_data = val_data.select(random_indeces)

    languages_used = np.array(val_data["language"])
    unique_langs = np.unique(languages_used)
    lang_datasets = {}
    for lang in unique_langs:
        lang_datasets[lang] = []

    def distribute_samples(example):
        global lang_datasets
        lang = example["language"]
        lang_datasets[lang].append(example)
        return example

    val_data = val_data.map(distribute_samples, num_proc=1, load_from_cache_file=False)
    for lang in lang_datasets.keys():
        lang_datasets[lang] = datasets.Dataset.from_list(lang_datasets[lang])
        lang_datasets[lang].set_format("torch")

    lang_datasets = datasets.DatasetDict(lang_datasets)
    # =======================================================
    
    by_lang_results = {}
    total_BLEU = 0
    total_MSG_LEN = 0
    total_batches = 0
    for lang in lang_datasets.keys():
        print(f"{lang:=^75}")
        eval_dataloader = DataLoader(lang_datasets[lang], batch_size=32)
        BLEU = 0
        MSG_LEN = 0
        progress = tqdm(enumerate(eval_dataloader), total=len(eval_dataloader))
        for i, batch in progress:
            batch_messages = batch["message"]
            batch_langs = batch["language"]
            batch_input_ids = batch["input_ids"].to(device)
            batch_attn = batch["attention_mask"].to(device)
            batch_labels = batch["labels"]

            batch_encodings = {
                "input_ids": batch_input_ids,
                "attention_mask": batch_attn,
                # "decoder_input_ids": batch_input_ids.clone(),
            }
            batch_preds = model.generate(
                **batch_encodings, generation_config=generation_conf
            )
            eval_preds = (batch_preds, batch_labels)
            batch_result = compute_metrics(eval_preds)
            BLEU += batch_result["BLEU_norm"]
            total_BLEU += batch_result["BLEU_norm"]
            MSG_LEN += batch_result["prediction_len"]
            total_MSG_LEN += batch_result["prediction_len"]
            progress.set_description(
                f"BLEU: {BLEU / (i+1):.3f} LEN : {MSG_LEN / (i+1):.3f}"
            )
        n = len(eval_dataloader)
        total_batches += n
        by_lang_results[lang] = {
            "BLEU_norm": BLEU / len(eval_dataloader),
            "MSG_LEN": MSG_LEN / len(eval_dataloader),
        }

    by_lang_results["Total"] = {
        "BLEU_norm": total_BLEU / total_batches,
        "MSG_LEN": total_MSG_LEN / total_batches,
    }

    metrics_df = pd.DataFrame(by_lang_results).T
    metrics_df.to_csv(f"metric_distribution/BLEU_norm_results_{sampling_iteration}.csv")

ITERATION NUMBER 0


Map:   0%|          | 0/100000 [00:00<?, ? examples/s]



  0%|          | 0/119 [00:00<?, ?it/s]



  0%|          | 0/175 [00:00<?, ?it/s]



  0%|          | 0/409 [00:00<?, ?it/s]



  0%|          | 0/35 [00:00<?, ?it/s]



  0%|          | 0/8 [00:00<?, ?it/s]



  0%|          | 0/269 [00:00<?, ?it/s]



  0%|          | 0/4 [00:00<?, ?it/s]



  0%|          | 0/399 [00:00<?, ?it/s]



  0%|          | 0/340 [00:00<?, ?it/s]



  0%|          | 0/60 [00:00<?, ?it/s]



  0%|          | 0/173 [00:00<?, ?it/s]



  0%|          | 0/3 [00:00<?, ?it/s]



  0%|          | 0/65 [00:00<?, ?it/s]



  0%|          | 0/427 [00:00<?, ?it/s]



  0%|          | 0/82 [00:00<?, ?it/s]



  0%|          | 0/123 [00:00<?, ?it/s]



  0%|          | 0/27 [00:00<?, ?it/s]



  0%|          | 0/3 [00:00<?, ?it/s]



  0%|          | 0/58 [00:00<?, ?it/s]



  0%|          | 0/357 [00:00<?, ?it/s]

ITERATION NUMBER 1


Map:   0%|          | 0/100000 [00:00<?, ? examples/s]



  0%|          | 0/119 [00:00<?, ?it/s]



  0%|          | 0/175 [00:00<?, ?it/s]



  0%|          | 0/409 [00:00<?, ?it/s]



  0%|          | 0/35 [00:00<?, ?it/s]



  0%|          | 0/8 [00:00<?, ?it/s]



  0%|          | 0/269 [00:00<?, ?it/s]



  0%|          | 0/4 [00:00<?, ?it/s]



  0%|          | 0/399 [00:00<?, ?it/s]



  0%|          | 0/340 [00:00<?, ?it/s]



  0%|          | 0/60 [00:00<?, ?it/s]



  0%|          | 0/173 [00:00<?, ?it/s]



  0%|          | 0/3 [00:00<?, ?it/s]



  0%|          | 0/65 [00:00<?, ?it/s]



  0%|          | 0/427 [00:00<?, ?it/s]



  0%|          | 0/82 [00:00<?, ?it/s]



  0%|          | 0/123 [00:00<?, ?it/s]



  0%|          | 0/27 [00:00<?, ?it/s]



  0%|          | 0/3 [00:00<?, ?it/s]



  0%|          | 0/58 [00:00<?, ?it/s]



  0%|          | 0/357 [00:00<?, ?it/s]

ITERATION NUMBER 2


Map:   0%|          | 0/100000 [00:00<?, ? examples/s]



  0%|          | 0/119 [00:00<?, ?it/s]



  0%|          | 0/175 [00:00<?, ?it/s]



  0%|          | 0/409 [00:00<?, ?it/s]



  0%|          | 0/35 [00:00<?, ?it/s]



  0%|          | 0/8 [00:00<?, ?it/s]



  0%|          | 0/269 [00:00<?, ?it/s]



  0%|          | 0/4 [00:00<?, ?it/s]



  0%|          | 0/399 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)





  0%|          | 0/60 [00:00<?, ?it/s]

In [None]:
123