<a href="https://colab.research.google.com/github/ferdouszislam/java-docstring-generator/blob/main/javadoc_generator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [22]:
# Transformers installation
! pip install transformers datasets evaluate accelerate rouge_score



In [23]:
from datasets import load_dataset
from transformers import (AutoTokenizer, DataCollatorForSeq2Seq,
                          AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments,
                          Seq2SeqTrainer, pipeline)
import evaluate
import numpy as np

In [24]:
javadoc_ds = load_dataset(
    'Shuu12121/java-treesitter-dedupe_doc-filtered-dataset',
    split='train[:50000]')

sample = javadoc_ds[0]
print(f"features: {sample.keys()}")

# print(f"code:\n {sample['code']}")
# print(f"docstring:\n {sample['docstring']}")

features: dict_keys(['code', 'docstring', 'func_name', 'language', 'repo', 'path', 'url', 'license'])


In [25]:
javadoc_ds = javadoc_ds.train_test_split(test_size=0.2)
print(javadoc_ds)

DatasetDict({
    train: Dataset({
        features: ['code', 'docstring', 'func_name', 'language', 'repo', 'path', 'url', 'license'],
        num_rows: 40000
    })
    test: Dataset({
        features: ['code', 'docstring', 'func_name', 'language', 'repo', 'path', 'url', 'license'],
        num_rows: 10000
    })
})


In [26]:
model_name = "google-t5/t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)

Tokenize dataset

In [27]:
INPUT_PREFIX = "Generate JavaDoc for the function: "

def preprocess_ds(ds):
  inputs = [INPUT_PREFIX + doc for doc in ds["code"]]
  model_inputs = tokenizer(inputs, max_length=512, truncation=True)
  labels = tokenizer(text_target=ds["docstring"], max_length=256,
                      truncation=True)
  model_inputs["labels"] = labels["input_ids"]
  return model_inputs

In [28]:
tokenized_javadoc_ds = javadoc_ds.map(preprocess_ds, batched=True)

Map:   0%|          | 0/40000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [29]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model_name)

Setup evaluation process

In [30]:
# Load BLEU metric
bleu = evaluate.load("bleu")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # BLEU expects list of references for each prediction
    references = [[label] for label in decoded_labels]

    # Compute BLEU score
    result = bleu.compute(predictions=decoded_preds, references=references)

    # Extract the scalar BLEU score
    bleu_score = result["bleu"]

    # Add generation length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    gen_len = np.mean(prediction_lens)

    return {"eval_bleu": round(bleu_score, 4), "eval_gen_len": round(gen_len, 4)}

Setup model

In [31]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

Train model

In [34]:
MODEL_STORE_PATH = "./demis_java_docstr_generator"

training_args = Seq2SeqTrainingArguments(
    output_dir=MODEL_STORE_PATH,
    eval_strategy="epoch",
    learning_rate=5e-5,                   # Increase (faster convergence)
    per_device_train_batch_size=16, # Reduced batch size
    per_device_eval_batch_size=16,  # Reduced batch size
    weight_decay=0.01,
    save_total_limit=1,                   # Keep fewer checkpoints
    num_train_epochs=4,                   # Reduce epochs
    predict_with_generate=True,
    fp16=True, #change to bf16=True for XPU
    report_to="none", # Disable wandb reporting
    gradient_checkpointing=True, # Enable gradient checkpointing
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_javadoc_ds["train"],
    eval_dataset=tokenized_javadoc_ds["test"],
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Epoch,Training Loss,Validation Loss,Bleu,Gen Len
1,2.8027,2.534927,0.0242,19.4508


Epoch,Training Loss,Validation Loss,Bleu,Gen Len
1,2.8027,2.534927,0.0242,19.4508
2,2.6573,2.40802,0.0264,19.2745
3,2.5739,2.354269,0.0264,19.1496
4,2.535,2.340934,0.0274,19.1251


TrainOutput(global_step=10000, training_loss=2.724187255859375, metrics={'train_runtime': 6071.1706, 'train_samples_per_second': 26.354, 'train_steps_per_second': 1.647, 'total_flos': 1.8750363141144576e+16, 'train_loss': 2.724187255859375, 'epoch': 4.0})

Print evaluation scores

In [35]:
results = trainer.evaluate()

def print_results(results, model_name):
    print(f"\n{model_name} Results:")
    print("-" * 30)
    print(f"BLEU Score: {results['eval_bleu']:.4f}")
    print(f"Gen Length: {results['eval_gen_len']:.1f}")

print_results(results, "T5-Small Java Docstring")


T5-Small Java Docstring Results:
------------------------------
BLEU Score: 0.0274
Gen Length: 19.1


Save model

In [36]:
trainer.save_model(MODEL_STORE_PATH)

Inference

In [37]:
code_sample = """
public boolean isPalindrome(String str) {
    // Remove spaces and convert to lowercase for case-insensitive comparison
    String cleanedStr = str.replaceAll("\\s+", "").toLowerCase();

    int left = 0;
    int right = cleanedStr.length() - 1;

    while (left < right) {
        if (cleanedStr.charAt(left) != cleanedStr.charAt(right)) {
            return false;
        }
        left++;
        right--;
    }

    return true;
}
"""
code_sample = INPUT_PREFIX + code_sample

Inference via pipeline()

In [38]:
demis_java_docstr_generator = pipeline("text2text-generation",
                                       model=MODEL_STORE_PATH,
                                       tokenizer=MODEL_STORE_PATH,
                                       max_length=256) # Increased max_length for longer docstrings
demis_java_docstr_generator(code_sample)

Device set to use cuda:0


[{'generated_text': 'Returns true if the given string is a palindrome. @param str The string to be a palindrome. @return true if the string is a palindrome.'}]

Manual Inference

In [41]:
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_STORE_PATH)
tokenizer = AutoTokenizer.from_pretrained(MODEL_STORE_PATH)
inputs = tokenizer(code_sample, return_tensors="pt").input_ids
outputs = model.generate(inputs, max_new_tokens=100, do_sample=False)
tokenizer.decode(outputs[0], skip_special_tokens=True)

'Returns whether the given string is a polymorphic string. @param str The string to be a polymorphic string. @return true if the string is a polymorphic string.'