In [1]:
! pip install datasets transformers[torch] evaluate nltk



In [2]:
import nltk
from datasets import load_dataset
import evaluate
import numpy as np
from transformers import AutoTokenizer, DataCollatorForSeq2Seq
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer

# Prepare and tokenize dataset
docstrings_and_code = load_dataset(
    "code_search_net", name="python", split="train[:5000]"
)
docstrings_and_code = docstrings_and_code.train_test_split(test_size=0.2)

In [3]:
import re

t5_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
bart_tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-xsum-12-3")


def remove_docstrings(text: list[str]) -> list[str]:
    return [re.sub(r"(['\"]{3}).*?\1", "", t, flags=re.DOTALL) for t in text]


def preprocess_function(batch, tokenizer):
    source_tokenized = tokenizer(
        remove_docstrings(batch["func_code_string"]),
        padding="max_length",
        truncation=True,
        max_length=128,
    )
    target_tokenized = tokenizer(
        batch["func_documentation_string"],
        padding="max_length",
        truncation=True,
        max_length=64,
    )
    batch = {k: v for k, v in source_tokenized.items()}
    # Ignore padding in the loss
    batch["labels"] = [
        [-100 if token == tokenizer.pad_token_id else token for token in l]
        for l in target_tokenized["input_ids"]
    ]
    return batch


def preprocess_function_bart(examples):
    return preprocess_function(examples, bart_tokenizer)


def preprocess_function_t5(examples):
    return preprocess_function(examples, t5_tokenizer)


to_remove = [
    "repository_name",
    "func_path_in_repository",
    "func_name",
    "whole_func_string",
    "language",
    "func_code_tokens",
    "func_documentation_tokens",
    "split_name",
    "func_code_url",
]

tokenized_bart = docstrings_and_code.map(
    preprocess_function_bart, batched=True, remove_columns=to_remove
)
tokenized_t5 = docstrings_and_code.map(
    preprocess_function_t5, batched=True, remove_columns=to_remove
)

Map:   0%|          | 0/4000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/4000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [4]:
# Setup evaluation
nltk.download("punkt", quiet=True)
metric = evaluate.load("google_bleu")


def compute_metrics(eval_preds, tokenizer):
    preds, labels = eval_preds

    # decode preds and labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # rougeLSum expects newline after each sentence
    decoded_preds = [
        "\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds
    ]
    decoded_labels = [
        "\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels
    ]

    return metric.compute(predictions=decoded_preds, references=decoded_labels)


def compute_metrics_bart(eval_preds):
    return compute_metrics(eval_preds, bart_tokenizer)


def compute_metrics_t5(eval_preds):
    return compute_metrics(eval_preds, t5_tokenizer)

In [5]:
# Load pretrained model and evaluate model after each epoch
model_bart = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-xsum-12-3")
data_collator_bart = DataCollatorForSeq2Seq(tokenizer=bart_tokenizer, model=model_bart)

training_args_bart = Seq2SeqTrainingArguments(
    output_dir="./results_bart",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=2,
    predict_with_generate=True,
)

trainer_bart = Seq2SeqTrainer(
    model=model_bart,
    args=training_args_bart,
    train_dataset=tokenized_bart["train"],
    eval_dataset=tokenized_bart["test"],
    tokenizer=bart_tokenizer,
    data_collator=data_collator_bart,
    compute_metrics=compute_metrics_bart,
)

trainer_bart.train()

You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Google Bleu
1,No log,3.052546,0.083112
2,3.379600,2.861705,0.100391


TrainOutput(global_step=500, training_loss=3.379584228515625, metrics={'train_runtime': 774.9918, 'train_samples_per_second': 10.323, 'train_steps_per_second': 0.645, 'total_flos': 1238315827200000.0, 'train_loss': 3.379584228515625, 'epoch': 2.0})

In [6]:
# (Same thing for T5) Also train and evaluate the model
model_t5 = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
data_collator_t5 = DataCollatorForSeq2Seq(tokenizer=t5_tokenizer, model=model_t5)

training_args_t5 = Seq2SeqTrainingArguments(
    output_dir="./results_t5",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=2,
    predict_with_generate=True,
)

trainer_t5 = Seq2SeqTrainer(
    model=model_t5,
    args=training_args_t5,
    train_dataset=tokenized_t5["train"],
    eval_dataset=tokenized_t5["test"],
    tokenizer=t5_tokenizer,
    data_collator=data_collator_t5,
    compute_metrics=compute_metrics_t5,
)

trainer_t5.train()

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Google Bleu
1,No log,3.456028,0.038232
2,3.859500,3.371251,0.041452




TrainOutput(global_step=500, training_loss=3.85949755859375, metrics={'train_runtime': 276.9874, 'train_samples_per_second': 28.882, 'train_steps_per_second': 1.805, 'total_flos': 371781009408000.0, 'train_loss': 3.85949755859375, 'epoch': 2.0})

In [7]:
def compare(test_samples, model, tokenizer):
    inputs = tokenizer(
        test_samples["func_code_string"],
        padding="max_length",
        truncation=True,
        max_length=128,
        return_tensors="pt",
    )
    input_ids = inputs.input_ids.to(model.device)
    attention_mask = inputs.attention_mask.to(model.device)
    outputs = model.generate(input_ids, attention_mask=attention_mask)
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return outputs, output_str


bart_before_tuning = AutoModelForSeq2SeqLM.from_pretrained(
    "sshleifer/distilbart-xsum-12-3"
)

test_samples = docstrings_and_code["test"].select(range(16))

summaries_before_tuning = compare(test_samples, bart_before_tuning, bart_tokenizer)[1]
summaries_after_tuning = compare(test_samples, model_bart, bart_tokenizer)[1]

In [8]:
from tabulate import tabulate

print(
    tabulate(
        zip(
            range(len(summaries_after_tuning)),
            summaries_after_tuning,
            summaries_before_tuning,
        ),
        headers=["Id", "Docstring after", "Docstring before"],
    )
)
print("\nTarget docstrings:\n")
print(
    tabulate(
        list(enumerate(test_samples["func_documentation_string"])),
        headers=["Id", "Target docstring"],
    )
)
print("\nSource code:\n")
print(
    tabulate(list(enumerate(test_samples["func_code_string"])), headers=["Id", "Code"])
)

  Id  Docstring after                                                                    Docstring before
----  ---------------------------------------------------------------------------------  ---------------------------------------------------------------------------------------------------------
   0  Converts a multiPlanesRupture node into a single rupture.                          The multi-PlanesRupture system is based on a list of objects that can cause it to rupture.
   1  Create an origin mapping for a CDNCDN.                                             The SoftLayerCDNManager is a great place to start.

          :param account_id: ID of the content_id of the account.
   2  Write a csv file to a file.                                                        Wrapping a simple csv file can be exported to an export, according to the data centre.
   3  Build the nodal plane distribution as a Node instance.                             The Openquake.pmf.PMF is based on the results o