### Distilling CodeT5 (codet5-base) for the purpose of creating a student model for test case assertion generation

First we install the needed requirements:

In [None]:
!pip install tree-sitter==0.23.0

Let's start with understanding the data format. We have the /data_generation/data/distillation_data_training.jsonl file, containing the data (both input and output) for the teacher model.

In [3]:
from data.load_dataset import load_dataset

NUM_LINES_TO_INSPECT = 5
DATA_PATH = "data_generation/data/codet5/distillation_data_training.jsonl"

inspected_data = load_dataset(DATA_PATH, NUM_LINES_TO_INSPECT)

Loading dataset: 100%|██████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 166.68it/s]






Now let's look closer at the parsed JSON entry:

In [4]:
print(inspected_data[0].keys())
print(inspected_data[0]["test_method_masked"])
print(inspected_data[0]["assertions"])
print(inspected_data[0]["predicted_assertions"])
print(inspected_data[0]["model_type"])
print(inspected_data[0]["performance_metrics"])
print(inspected_data[0]["compressed_logits"])

dict_keys(['focal_method', 'test_method_masked', 'assertions', 'predicted_assertions', 'compressed_logits', 'model_type', 'performance_metrics'])
@Test
@Category(UnitTest.class)
public void testGet1stHalfRect() throws Exception
{
  LongRectangle tiles = new LongRectangle(915, 203, 917, 204);

  KVIterator<TileIdWritable, MrGeoRaster> iter = reader.get(tiles);

  int ndx = 0;
  while (iter.hasNext())
  {
        // <ASSERTION_PLACEHOLDER>
  }

}
['Assert.assertEquals("Unexpected tileid: ", resultTiles[ndx++], iter.currentKey().get());', 'Assert.assertEquals("Wrong number of items", 6, ndx);']
['Assert.assertEquals("Unexpected tileid: ", resultTiles[ndx++], iter.currentKey().get());']
codet5
{'exact_matches': 1, 'generated_count': 1, 'reference_count': 2, 'precision': 1.0, 'recall': 0.5, 'f1': 0.6666666666666666, 'accuracy': 0.5, 'similarity_score_avg': 1.0, 'similarity_scores': [1.0]}
{'format': 'lz4', 'compression': {'bits': 4, 'original_size_bytes': 65740800, 'bit_compressed_size_byte

As we can see, the data contains the focal method that is being tested, the test method that was written (with masked assertions), the original target assertions, as well as the prediction of the teacher model and the teacher's output logits (which we will use for the loss function of the student model). The data also has the teacher model type from which the assertions were generated (in this case codet5 - indicating a codet5-base model).

Now, for every entry from the dataset, we need to construct an input for the student model that follows the same format as the input for the teacher model (as defined in data_generation/train_codet5_assertions.py). We also need to tokenize those inputs. We do this using the StudentDataset (data/student_dataset) class, that will manage and tokenize the student model's input data. Finally, evaluate_model.py and train_model.py contain methods to train and evaluate student models. We also have the student_models directory, which has configurations/architectures for different student models.

Let's start by distilling our codet5-base data into a pretrained codet5-small model and compare the results of them. For the codet5-small-pretrained-1_00_weightconfig, we use the pretrained codet5-small model, 5 epochs, and it will learn exclusively based on the teacher logits (and not based on the ground truth). Model evaluation will be done after every epoch.

In [5]:
from train_model import run_student_training

In [6]:
from student_models.codet5_small_pretrained_1_00_weight_config import codet5_small_pretrained_1_00_weight_config

run_student_training(codet5_small_pretrained_1_00_weight_config(''))

Loading training dataset from data_generation/data/codet5/distillation_data_training.jsonl...


Loading dataset: 100%|██████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 132.35it/s]

Using first 9 examples
Loading validation dataset from data_generation/data/codet5/distillation_data_validation.jsonl...



Loading dataset:  20%|████████████▌                                                  | 20/100 [00:00<00:00, 199.97it/s]



Loading dataset: 100%|██████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 171.53it/s]


Using first 100 examples
Training on 9 examples, validating on 100 examples
Loading model: Salesforce/codet5-small
Using device: cuda


  scaler = torch.cuda.amp.GradScaler() if args["fp16"] else None


Starting training for 1 epochs...


  with torch.cuda.amp.autocast():
Epoch 1/1: 100%|████████████████████████████████████████████████████| 3/3 [00:32<00:00, 10.70s/it, loss=25.6, ex/s=0.4]



Epoch 1/1 completed in 32.10s (0.28 examples/s)
  Average training loss: 25.5509
  Evaluating epoch 1...


Evaluating:   4%|██▊                                                                    | 1/25 [00:33<13:17, 33.25s/it]

Number of assertions should be the same


Evaluating:  12%|████████▌                                                              | 3/25 [00:59<06:54, 18.83s/it]

Number of assertions should be the same
Number of assertions should be the same


Evaluating:  16%|███████████▎                                                           | 4/25 [01:20<06:57, 19.89s/it]

Unterminated character/string literal at """, line 14: Antlr node.addFilterToDomain("
<ASSERTION_PLACEHOLDER>
            Antlr node.addFilterToDomain("custom",<ASSERTION_PLACEHOLDER>
            Antlr node.addFilterToDomain("custom",<ASSERTION_PLACEHOLDER>
            Antlr node.addFilterToDomain("custom",<ASSERTION_PLACEHOLDER>
            Antlr node.addFilterToDomain("custom",<ASSERTION_PLACEHOLDER>
            Antlr node.addFilterToDomain("
Number of assertions should be the same
Number of assertions should be the same
Number of assertions should be the same


Evaluating:  20%|██████████████▏                                                        | 5/25 [01:41<06:43, 20.18s/it]

Number of assertions should be the same
Number of assertions should be the same
Number of assertions should be the same


Evaluating:  24%|█████████████████                                                      | 6/25 [01:45<04:41, 14.81s/it]

Number of assertions should be the same


Evaluating:  28%|███████████████████▉                                                   | 7/25 [02:06<04:58, 16.61s/it]

Number of assertions should be the same
Number of assertions should be the same
Number of assertions should be the same


Evaluating:  32%|██████████████████████▋                                                | 8/25 [02:15<04:03, 14.30s/it]

Number of assertions should be the same
Number of assertions should be the same


Evaluating:  36%|█████████████████████████▌                                             | 9/25 [02:37<04:27, 16.71s/it]

Number of assertions should be the same
Number of assertions should be the same
Number of assertions should be the same


Evaluating:  40%|████████████████████████████                                          | 10/25 [02:59<04:33, 18.26s/it]

Number of assertions should be the same
Number of assertions should be the same
Number of assertions should be the same
Number of assertions should be the same


Evaluating:  44%|██████████████████████████████▊                                       | 11/25 [03:20<04:29, 19.23s/it]

Number of assertions should be the same
Number of assertions should be the same


Evaluating:  48%|█████████████████████████████████▌                                    | 12/25 [03:43<04:21, 20.14s/it]

Number of assertions should be the same
Number of assertions should be the same


Evaluating:  52%|████████████████████████████████████▍                                 | 13/25 [04:03<04:01, 20.11s/it]

Number of assertions should be the same
Number of assertions should be the same


Evaluating:  56%|███████████████████████████████████████▏                              | 14/25 [04:25<03:48, 20.75s/it]

Number of assertions should be the same
Number of assertions should be the same


Evaluating:  60%|██████████████████████████████████████████                            | 15/25 [04:44<03:21, 20.16s/it]

Number of assertions should be the same
Number of assertions should be the same


Evaluating:  64%|████████████████████████████████████████████▊                         | 16/25 [05:06<03:07, 20.81s/it]

Number of assertions should be the same
Number of assertions should be the same
Number of assertions should be the same


Evaluating:  68%|███████████████████████████████████████████████▌                      | 17/25 [05:12<02:11, 16.44s/it]

Number of assertions should be the same
Number of assertions should be the same


Evaluating:  72%|██████████████████████████████████████████████████▍                   | 18/25 [05:16<01:28, 12.64s/it]

Number of assertions should be the same


Evaluating:  76%|█████████████████████████████████████████████████████▏                | 19/25 [05:37<01:30, 15.04s/it]

Could not process token at "\", line 11: \\\\{
<ASSERTION_PLACEHOLDER>{
            // <ASSERTION_PLACEHOLDER>
            \\\\{
            // <ASSERTION_PLACEHOLDER>
            \\\\{
            // <ASSERTION_PLACEHOLDER>
            \\\\{
            // <ASSERTION_PLACEHOLDER>
            \\\\{
            // <ASSERTION_PLACEHOLDER>
            \\\\// <ASSERTION_PLACEHOLDER>
            \\\\ }
Number of assertions should be the same
Number of assertions should be the same


Evaluating:  80%|████████████████████████████████████████████████████████              | 20/25 [05:59<01:25, 17.16s/it]

Number of assertions should be the same
Number of assertions should be the same
Number of assertions should be the same
Number of assertions should be the same


Evaluating:  84%|██████████████████████████████████████████████████████████▊           | 21/25 [06:18<01:11, 17.89s/it]

Number of assertions should be the same
Number of assertions should be the same
Number of assertions should be the same
Number of assertions should be the same


Evaluating:  88%|█████████████████████████████████████████████████████████████▌        | 22/25 [06:38<00:55, 18.52s/it]

Number of assertions should be the same
Number of assertions should be the same
Number of assertions should be the same


Evaluating:  92%|████████████████████████████████████████████████████████████████▍     | 23/25 [06:57<00:37, 18.67s/it]

Number of assertions should be the same
Number of assertions should be the same
Number of assertions should be the same


Evaluating:  96%|███████████████████████████████████████████████████████████████████▏  | 24/25 [07:16<00:18, 18.73s/it]

Number of assertions should be the same


Evaluating:  96%|███████████████████████████████████████████████████████████████████▏  | 24/25 [07:32<00:18, 18.85s/it]

KeyboardInterrupt



In [3]:
from transformers import T5ForConditionalGeneration, RobertaTokenizer # Or AutoTokenizer
import torch

# --- Configuration for loading ---
MODEL_PATH = "./output_models/student_model_output_codet5_small_pretrained_1_00_weight/final_model" # Or path to "final_model" or a checkpoint
# TOKENIZER_PATH = MODEL_PATH # Usually tokenizer is saved in the same directory by .save_pretrained
                            # Or, if you used a fixed tokenizer name for the student:
TOKENIZER_NAME = "Salesforce/codet5-small" # Or the tokenizer you used for the student
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Load Model and Tokenizer ---
print(f"Loading model from: {MODEL_PATH}")
loaded_model = T5ForConditionalGeneration.from_pretrained(MODEL_PATH)
loaded_model.to(DEVICE)
loaded_model.eval() # Set to evaluation mode

print(f"Loading tokenizer: {TOKENIZER_NAME}") # Or from MODEL_PATH if saved with custom student
tokenizer = RobertaTokenizer.from_pretrained(TOKENIZER_NAME) # Or AutoTokenizer.from_pretrained(MODEL_PATH)


# --- Prepare Sample Input (similar to how StudentDataset does it) ---
# This should match the input format your model was trained on.
# Example: get an item from your validation data or craft one.

# Let's assume you have a sample data item:
sample_data_item = {
    "focal_method": "public int calculate(int a, int b) { return a + b; }",
    "test_method_masked": "@Test public void testCalc() { int result = new MyClass().calculate(2,3); /* MASK */ }",
    # "original_target": "assertEquals(5, result);" # Not needed for inference, but for comparison
}

# Construct input text (mirroring StudentDataset input prep)
# You'd ideally have a helper function for this input prep if you do it often.
focal_method_str = sample_data_item.get('focal_method', "")
test_method_masked_str = sample_data_item.get('test_method_masked', "")
# Apply clean_assertion_placeholders if you used it during training data prep
# cleaned_focal_method = clean_assertion_placeholders(focal_method_str) ...
# cleaned_test_method = clean_assertion_placeholders(test_method_masked_str) ...

# Simplified input construction for this example:
input_text = f"FOCAL METHOD:\n{focal_method_str}\n\nTEST METHOD:\n{test_method_masked_str}"
print(f"\nInput text for model:\n{input_text}")

# --- Tokenize Input ---
# Use the same max_src_length your model was trained with for padding/truncation.
# For a single inference, padding might not be strictly necessary if your model handles variable length,
# but for consistency with training, it's good to include.
MAX_SRC_LENGTH_INFERENCE = 1024 # Should match training
inputs = tokenizer(
    input_text,
    max_length=MAX_SRC_LENGTH_INFERENCE,
    padding="max_length", # Or False for single inference if model handles it
    truncation=True,
    return_tensors="pt"
).to(DEVICE)

# --- Generate Output ---
MAX_TGT_LENGTH_INFERENCE = 512 # Should match training target length or be reasonable for assertions
print("\nGenerating assertions...")
with torch.no_grad(): # Ensure no gradients are calculated
    generated_ids = loaded_model.generate(
        inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_length=MAX_TGT_LENGTH_INFERENCE,
        num_beams=4,       # Or whatever beam size you used in evaluation
        early_stopping=True
    )

generated_assertions_str = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

print(f"\nGenerated Assertions:\n{generated_assertions_str}")

# --- (Optional) Evaluate the output ---
# If you have the ground truth for this sample_data_item:
# reference_assertions = "assertEquals(5, result);"
# metrics = evaluate_assertions(generated_assertions_str, reference_assertions)
# print(f"\nMetrics for this sample: {metrics}")

# parsable = check_java_parsability(generated_assertions_str)
# print(f"Is parsable by javalang: {parsable}")

Loading model from: ./output_models/student_model_output_codet5_small_pretrained_1_00_weight/final_model
Loading tokenizer: Salesforce/codet5-small

Input text for model:
FOCAL METHOD:
public int calculate(int a, int b) { return a + b; }

TEST METHOD:
@Test public void testCalc() { int result = new MyClass().calculate(2,3); /* MASK */ }

Generating assertions...

Generated Assertions:
assertEquals(2, result);


In [4]:
from evaluation.evaluate_teacher import evaluate_teacher
from student_models.codet5_small_pretrained_1_00_weight_config import codet5_small_pretrained_1_00_weight_config

evaluate_teacher(codet5_small_pretrained_1_00_weight_config(''))

Loading dataset:  20%|████████████▌                                                  | 20/100 [00:00<00:00, 190.46it/s]



Loading dataset: 100%|██████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 180.83it/s]


Number of assertions should be the same
Number of assertions should be the same
Number of assertions should be the same
Number of assertions should be the same
Number of assertions should be the same
Number of assertions should be the same




Number of assertions should be the same
Number of assertions should be the same
Number of assertions should be the same
Number of assertions should be the same
Number of assertions should be the same
Number of assertions should be the same
Number of assertions should be the same




Number of assertions should be the same
Number of assertions should be the same
Number of assertions should be the same
Number of assertions should be the same
  Similarity score: 0.8048
  Accuracy: 0.4277
  F1 score: 0.4277
  CodeBLEU score: 0.6029
    n-gram match score: 0.5148
    Weighted n-gram match score: 0.5331
    Syntax match score: 0.5385
    Dataflow match score: 0.0450
  Parsability rate: 0.9900


{'precision': 0.4503311258278146,
 'recall': 0.40718562874251496,
 'f1': 0.42767295597484284,
 'accuracy': 0.4276729559748428,
 'avg_per_sample_accuracy': 0.44116666666666665,
 'similarity_score_avg': 0.8047781100549055,
 'avg_per_sample_f1': 0.44755555555555554,
 'total_exact_matches': 68,
 'total_generated': 151,
 'total_reference': 167,
 'avg_codebleu_score': 0.6028612059568116,
 'avg_ngram_score': 0.5147804409752849,
 'avg_weighted_ngram_score': 0.5331358280734061,
 'avg_syntax_match_score': 0.5385285547785548,
 'avg_dataflow_match_score': 0.045,
 'total_assertion_blocks': 100,
 'parsable_assertion_blocks': 99,
 'parsability_rate': 0.99}