<a href="https://colab.research.google.com/github/ferdouszislam/java-docstring-generator/blob/main/javadoc_generator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Transformers installation
! pip install transformers datasets evaluate accelerate rouge_score

In [None]:
from datasets import load_dataset
from transformers import (AutoTokenizer, DataCollatorForSeq2Seq,
                          AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments,
                          Seq2SeqTrainer, pipeline)
import evaluate
import numpy as np

In [None]:
javadoc_ds = load_dataset(
    'Shuu12121/java-treesitter-dedupe_doc-filtered-dataset',
    split='train[:100]')

sample = javadoc_ds[0]
print(f"features: {sample.keys()}")

# print(f"code:\n {sample['code']}")
# print(f"docstring:\n {sample['docstring']}")

In [None]:
javadoc_ds = javadoc_ds.train_test_split(test_size=0.2)
print(javadoc_ds)

In [None]:
model_name = "google-t5/t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)

Tokenize dataset

In [None]:
INPUT_PREFIX = "Generate JavaDoc for the function: "

def preprocess_ds(ds):
  inputs = [INPUT_PREFIX + doc for doc in ds["code"]]
  model_inputs = tokenizer(inputs, max_length=512, truncation=True)
  labels = tokenizer(text_target=ds["docstring"], max_length=256,
                      truncation=True)
  model_inputs["labels"] = labels["input_ids"]
  return model_inputs

In [None]:
tokenized_javadoc_ds = javadoc_ds.map(preprocess_ds, batched=True)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model_name)

Setup evaluation process

In [None]:
# Load BLEU metric
bleu = evaluate.load("bleu")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # BLEU expects list of references for each prediction
    references = [[label] for label in decoded_labels]

    # Compute BLEU score
    result = bleu.compute(predictions=decoded_preds, references=references)

    # Extract the scalar BLEU score
    bleu_score = result["bleu"]

    # Add generation length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    gen_len = np.mean(prediction_lens)

    return {"eval_bleu": round(bleu_score, 4), "eval_gen_len": round(gen_len, 4)}

Setup model

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

Train model

In [None]:
MODEL_STORE_PATH = "./demis_java_docstr_generator"

training_args = Seq2SeqTrainingArguments(
    output_dir=MODEL_STORE_PATH,
    eval_strategy="epoch",
    learning_rate=5e-5,                   # Increase (faster convergence)
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    weight_decay=0.01,
    save_total_limit=1,                   # Keep fewer checkpoints
    num_train_epochs=4,                   # Reduce epochs
    predict_with_generate=True,
    fp16=True, #change to bf16=True for XPU
    report_to="none" # Disable wandb reporting
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_javadoc_ds["train"],
    eval_dataset=tokenized_javadoc_ds["test"],
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

Print evaluation scores

In [None]:
results = trainer.evaluate()

def print_results(results, model_name):
    print(f"\n{model_name} Results:")
    print("-" * 30)
    print(f"BLEU Score: {results['eval_bleu']:.4f}")
    print(f"Gen Length: {results['eval_gen_len']:.1f}")

print_results(results, "T5-Small Java Docstring")

Save model

In [None]:
trainer.save_model(MODEL_STORE_PATH)

Inference

In [None]:
code_sample = """
public boolean isPalindrome(String str) {
    // Remove spaces and convert to lowercase for case-insensitive comparison
    String cleanedStr = str.replaceAll("\\s+", "").toLowerCase();

    int left = 0;
    int right = cleanedStr.length() - 1;

    while (left < right) {
        if (cleanedStr.charAt(left) != cleanedStr.charAt(right)) {
            return false;
        }
        left++;
        right--;
    }

    return true;
}
"""
code_sample = INPUT_PREFIX + code_sample

Inference via pipeline()

In [None]:
demis_java_docstr_generator = pipeline("summarization",
                                       model=MODEL_STORE_PATH,
                                       tokenizer=MODEL_STORE_PATH)
demis_java_docstr_generator(code_sample)

Manual Inference

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_STORE_PATH)
tokenizer = AutoTokenizer.from_pretrained(MODEL_STORE_PATH)
inputs = tokenizer(code_sample, return_tensors="pt").input_ids
outputs = model.generate(inputs, max_new_tokens=100, do_sample=False)
tokenizer.decode(outputs[0], skip_special_tokens=True)