In [None]:
! pip install datasets transformers[torch] evaluate nltk

Collecting datasets
  Using cached datasets-2.15.0-py3-none-any.whl (521 kB)
Collecting evaluate
  Using cached evaluate-0.4.1-py3-none-any.whl (84 kB)
Collecting pyarrow-hotfix (from datasets)
  Using cached pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Using cached dill-0.3.7-py3-none-any.whl (115 kB)
Collecting multiprocess (from datasets)
  Using cached multiprocess-0.70.15-py310-none-any.whl (134 kB)
Collecting accelerate>=0.20.3 (from transformers[torch])
  Using cached accelerate-0.25.0-py3-none-any.whl (265 kB)
Collecting responses<0.19 (from evaluate)
  Using cached responses-0.18.0-py3-none-any.whl (38 kB)
Installing collected packages: pyarrow-hotfix, dill, responses, multiprocess, accelerate, datasets, evaluate
Successfully installed accelerate-0.25.0 datasets-2.15.0 dill-0.3.7 evaluate-0.4.1 multiprocess-0.70.15 pyarrow-hotfix-0.6 responses-0.18.0


In [None]:
import nltk
from datasets import load_dataset
import evaluate
import numpy as np
from transformers import AutoTokenizer, DataCollatorForSeq2Seq
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer

# Prepare and tokenize dataset
docstrings_and_code = load_dataset(
    "code_search_net", name="python", split="train[:10000]"
)
docstrings_and_code = docstrings_and_code.train_test_split(test_size=0.2)

Downloading builder script:   0%|          | 0.00/8.44k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/18.5k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/12.9k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/941M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/412178 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/22176 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/23107 [00:00<?, ? examples/s]

In [None]:
import re

t5_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
bart_tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-xsum-12-3")


def remove_docstrings(text):
    return [re.sub(r"(['\"]{3}).*?\1", "", t, flags=re.DOTALL) for t in text]


def preprocess_function(batch, tokenizer):
    source_tokenized = tokenizer(
        remove_docstrings(batch["func_code_string"]),
        padding="max_length",
        truncation=True,
        max_length=128,
    )
    target_tokenized = tokenizer(
        batch["func_documentation_string"],
        padding="max_length",
        truncation=True,
        max_length=64,
    )
    batch = {k: v for k, v in source_tokenized.items()}
    # Ignore padding in the loss
    batch["labels"] = [
        [-100 if token == tokenizer.pad_token_id else token for token in l]
        for l in target_tokenized["input_ids"]
    ]
    return batch


def preprocess_function_bart(examples):
    return preprocess_function(examples, bart_tokenizer)


def preprocess_function_t5(examples):
    return preprocess_function(examples, t5_tokenizer)


to_remove = [
    "repository_name",
    "func_path_in_repository",
    "func_name",
    "whole_func_string",
    "language",
    "func_code_tokens",
    "func_documentation_tokens",
    "split_name",
    "func_code_url",
]

tokenized_bart = docstrings_and_code.map(
    preprocess_function_bart, batched=True, remove_columns=to_remove
)
tokenized_t5 = docstrings_and_code.map(
    preprocess_function_t5, batched=True, remove_columns=to_remove
)

tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.51k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

In [None]:
# Setup evaluation
nltk.download("punkt", quiet=True)
metric = evaluate.load("google_bleu")


def compute_metrics(eval_preds, tokenizer):
    preds, labels = eval_preds

    # decode preds and labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # rougeLSum expects newline after each sentence
    decoded_preds = [
        "\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds
    ]
    decoded_labels = [
        "\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels
    ]

    return metric.compute(predictions=decoded_preds, references=decoded_labels)


def compute_metrics_bart(eval_preds):
    return compute_metrics(eval_preds, bart_tokenizer)


def compute_metrics_t5(eval_preds):
    return compute_metrics(eval_preds, t5_tokenizer)

Downloading builder script:   0%|          | 0.00/8.64k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.34k [00:00<?, ?B/s]

In [None]:
# Load pretrained model and evaluate model after each epoch
model_bart = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-xsum-12-3")
data_collator_bart = DataCollatorForSeq2Seq(tokenizer=bart_tokenizer, model=model_bart)

training_args_bart = Seq2SeqTrainingArguments(
    output_dir="./results_bart",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=5,
    predict_with_generate=True,
)

trainer_bart = Seq2SeqTrainer(
    model=model_bart,
    args=training_args_bart,
    train_dataset=tokenized_bart["train"],
    eval_dataset=tokenized_bart["test"],
    tokenizer=bart_tokenizer,
    data_collator=data_collator_bart,
    compute_metrics=compute_metrics_bart,
)

trainer_bart.train()

Epoch,Training Loss,Validation Loss,Google Bleu
1,3.5169,2.937556,0.08391
2,2.8233,2.719209,0.092533
3,2.5668,2.616951,0.094124
4,2.4231,2.563851,0.09923
5,2.341,2.544365,0.099601


TrainOutput(global_step=2500, training_loss=2.734224658203125, metrics={'train_runtime': 4137.6885, 'train_samples_per_second': 9.667, 'train_steps_per_second': 0.604, 'total_flos': 6191579136000000.0, 'train_loss': 2.734224658203125, 'epoch': 5.0})

In [None]:
# (Same thing for T5) Also train and evaluate the model
model_t5 = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
data_collator_t5 = DataCollatorForSeq2Seq(tokenizer=t5_tokenizer, model=model_t5)

training_args_t5 = Seq2SeqTrainingArguments(
    output_dir="./results_t5",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=5,
    predict_with_generate=True,
)

trainer_t5 = Seq2SeqTrainer(
    model=model_t5,
    args=training_args_t5,
    train_dataset=tokenized_t5["train"],
    eval_dataset=tokenized_t5["test"],
    tokenizer=t5_tokenizer,
    data_collator=data_collator_t5,
    compute_metrics=compute_metrics_t5,
)

trainer_t5.train()

Epoch,Training Loss,Validation Loss,Google Bleu
1,3.363,2.896448,0.053291
2,3.0299,2.763633,0.058023
3,2.9116,2.701982,0.059132
4,2.8525,2.672171,0.059052
5,2.8186,2.663088,0.059125




TrainOutput(global_step=2500, training_loss=2.99513330078125, metrics={'train_runtime': 3296.7028, 'train_samples_per_second': 12.133, 'train_steps_per_second': 0.758, 'total_flos': 6847573524480000.0, 'train_loss': 2.99513330078125, 'epoch': 5.0})

In [None]:
def compare(test_samples, model, tokenizer):
    inputs = tokenizer(
        test_samples["func_code_string"],
        padding="max_length",
        truncation=True,
        max_length=128,
        return_tensors="pt",
    )
    input_ids = inputs.input_ids.to(model.device)
    attention_mask = inputs.attention_mask.to(model.device)
    outputs = model.generate(input_ids, attention_mask=attention_mask)
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return outputs, output_str


bart_before_tuning = AutoModelForSeq2SeqLM.from_pretrained(
    "sshleifer/distilbart-xsum-12-3"
)

test_samples = docstrings_and_code["test"].select(range(16))

summaries_before_tuning = compare(test_samples, bart_before_tuning, bart_tokenizer)[1]
summaries_after_tuning = compare(test_samples, model_bart, bart_tokenizer)[1]

In [None]:
from tabulate import tabulate

print(
    tabulate(
        zip(
            range(len(summaries_after_tuning)),
            summaries_after_tuning,
            summaries_before_tuning,
        ),
        headers=["Id", "Docstring after", "Docstring before"],
    )
)
print("\nTarget docstrings:\n")
print(
    tabulate(
        list(enumerate(test_samples["func_documentation_string"])),
        headers=["Id", "Target docstring"],
    )
)
print("\nSource code:\n")
print(
    tabulate(list(enumerate(test_samples["func_code_string"])), headers=["Id", "Code"])
)

  Id  Docstring after                                                                                                                         Docstring before
----  --------------------------------------------------------------------------------------------------------------------------------------  -----------------------------------------------------------------------------------------------------------------------------------------
   0  Retrieve the authentication token from Blink.                                                                                           A new version of the Blink authentication system has been launched in the US.

          :param is_retry: True if is retry: if True, then raise an exception.
   1  Generate a class based view based view for the update.                                                                                  A new type of view for UpdateView has been generated by the BBC.
   2  Check if param_list is a list with the size of t

In [None]:
# Unfortunately, we got rate limited by Google's GPU quota, so we couldn't run this part
t5_before_tuning = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")

t5_summaries_before_tuning = compare(test_samples, t5_before_tuning, t5_tokenizer)[1]
t5_summaries_after_tuning = compare(test_samples, model_t5, t5_tokenizer)[1]

from tabulate import tabulate

print(
    tabulate(
        zip(
            range(len(t5_summaries_after_tuning)),
            summaries_after_tuning,
            t5_summaries_after_tuning,
        ),
        headers=["Id", "Docstring after", "Docstring before"],
    )
)
print("\nTarget docstrings:\n")
print(
    tabulate(
        list(enumerate(test_samples["func_documentation_string"])),
        headers=["Id", "Target docstring"],
    )
)
print("\nSource code:\n")
print(
    tabulate(list(enumerate(test_samples["func_code_string"])), headers=["Id", "Code"])
)