### Distilling DeepSeek Coder 1.3B for the purpose of creating a student model for test case assertion generation

First we install and import the needed requirements:

In [23]:
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

Looking in indexes: https://download.pytorch.org/whl/cu118


In [67]:
!pip3 install transformers==4.26.1

Collecting transformers==4.26.1
  Obtaining dependency information for transformers==4.26.1 from https://files.pythonhosted.org/packages/1e/e2/60c3f4691b16d126ee9cfe28f598b13c424b60350ab339aba81aef054b8f/transformers-4.26.1-py3-none-any.whl.metadata
  Downloading transformers-4.26.1-py3-none-any.whl.metadata (100 kB)
     ---------------------------------------- 0.0/100.3 kB ? eta -:--:--
     -------------------------------------- 100.3/100.3 kB 5.6 MB/s eta 0:00:00
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers==4.26.1)
  Obtaining dependency information for tokenizers!=0.11.3,<0.14,>=0.11.1 from https://files.pythonhosted.org/packages/62/41/93d3135ec30f596a71490ce11a73572190fe80e85a2aea18f116a520cc41/tokenizers-0.13.3-cp311-cp311-win_amd64.whl.metadata
  Downloading tokenizers-0.13.3-cp311-cp311-win_amd64.whl.metadata (6.9 kB)
Downloading transformers-4.26.1-py3-none-any.whl (6.3 MB)
   ---------------------------------------- 0.0/6.3 MB ? eta -:--:--
   -- --------

In [23]:
import json
import torch
from torch.utils.data import DataLoader
from transformers import (
    T5ForConditionalGeneration,
    RobertaTokenizer,
    get_linear_schedule_with_warmup
)
from torch.optim import AdamW
import os
import time
from tqdm import tqdm
import re
from difflib import SequenceMatcher
import matplotlib.pyplot as plt
import numpy as np

from Data.student_dataset import StudentDataset

Let's start with understanding the data format. We have the /Data/distillation_data_training.jsonl file, containing the data (both input and output) for the teacher model.

In [41]:
NUM_LINES_TO_INSPECT = 5
DATA_PATH = "Data/distillation_data_training.jsonl"

inspected_data = []

with open(DATA_PATH, 'r') as data_file:
    for i, line_content in enumerate(data_file):
        if i >= NUM_LINES_TO_INSPECT:
            break
        data = json.loads(line_content.strip())
        inspected_data.append(data)

Now let's look closer at the parsed JSON entry:

In [42]:
print(inspected_data[1].keys())
print(inspected_data[1]["test_method_masked"])
print(inspected_data[1]["original_target"])
print(inspected_data[1]["predicted_assertions"])
print(inspected_data[1]["model_type"])
print(inspected_data[1]["compressed_logits"])

dict_keys(['focal_file', 'test_method_masked', 'original_target', 'predicted_assertions', 'model_type', 'compressed_logits'])
@Test
  public void testGetWebSocketContainerReturnsDefaultContainerFactory() {
    System.setProperty(WebSocketServiceImpl.WEB_SOCKET_CONTAINER_FACTORY_PROPERTY, "");

    WebSocketContainer webSocketContainer;

    webSocketContainer = WebSocketServiceImpl.getWebSocketContainer();
        // ASSERTION PLACEHOLDER

    System.setProperty(
        WebSocketServiceImpl.WEB_SOCKET_CONTAINER_FACTORY_PROPERTY,
        CustomWebSocketContainerFactory.class.getName());
    webSocketContainer = WebSocketServiceImpl.getWebSocketContainer();
        // ASSERTION PLACEHOLDER
  }
assertTrue(webSocketContainer instanceof ClientManager);
assertNull(webSocketContainer);
assertNotNull(webSocketContainer);
assertNotNull(webSocketContainer);
codet5
{'format': 'lz4', 'compression': {'bits': 4, 'original_size_bytes': 65740800, 'bit_compressed_size_bytes': 8217600, 'final_size_byte

As we can see, the data contains the file that contains the class that is being tested, the test method that was written (with masked assertions), the original target assertions, as well as the prediction of the teacher model and the teacher's output logits (which we will use for the loss function of the student model). The data also has the teacher model type from which the assertions were generated (in this case codet5 - indicating a codet5-small model).

Now, for every entry from the dataset, we need to construct an input for the student model that follows the same format as the input for the teacher model (as defined in DataGeneration/train_teacher_model.py). We also need to tokenize those inputs. We do this using the StudentDataset class, that will manage and tokenize the student model's input data:

In [43]:
# class StudentDataset(Dataset):
# This class was moved to Data/student_dataset for multithreading purposes

Now that we have the dataset class itself, we will also need a method to load the data that we have:

In [55]:
def load_dataset(jsonl_path, max_samples=None, teacher_model="codet5"):
    """Load data from JSONL file with optional sample limit"""
    data = []
    total_lines = 0
    valid_lines = 0

    # First count lines
    with open(jsonl_path, 'r') as f:
        for _ in f:
            total_lines += 1
            if total_lines >= max_samples: break

    # Then load with progress bar
    with open(jsonl_path, 'r') as f:
        for line in tqdm(f, total=total_lines, desc="Loading dataset"):
            if line.strip():
                try:
                    entry = json.loads(line)
                    # Validate required fields are present
                    if 'focal_file' in entry and 'test_method_masked' in entry and 'original_target' in entry and 'model_type' in entry and entry['model_type'] == teacher_model:
                        data.append(entry)
                        valid_lines += 1
                        if max_samples and valid_lines >= max_samples:
                            break
                    else:
                        print("Warning: Skipping entry with missing fields")
                except json.JSONDecodeError:
                    print("Warning: Skipping invalid JSON line")

    return data

A method to evaluate the performance of the student model on a validation set is also needed, in order to judge what the best student model is that we have had so far (without using the training data, as it may overfit). In order to judge this performance, we evaluate the model against the ground truth (the assertions), instead of comparing it to the teacher model's logits. This is because our end goal is the correct assertions being generated and this is how we evaluate the model. Therefore, while our training step uses the teacher's logits (to get the behaviour of the student mimicing the behaviour of the teacher), the end output that we evaluate is against the ground truth.

In [46]:
def calculate_similarity(reference, candidate):
    """Calculate string similarity using SequenceMatcher"""
    return SequenceMatcher(None, reference, candidate).ratio()


def normalize_assertion(assertion):
    """Normalize assertion text for more reliable comparison"""
    # Remove whitespace
    assertion = re.sub(r'\s+', ' ', assertion).strip()

    # Remove variable names in certain cases
    assertion = re.sub(r'assertEquals\(\s*[^,]+,\s*([^)]+)\)', r'assertEquals(VALUE, \1)', assertion)

    # Normalize assertion method names
    assertion = re.sub(r'assert(Equals|That|True|False)', r'assert\1', assertion, flags=re.IGNORECASE)

    return assertion


def evaluate_assertions(generated_assertions, reference_assertions):
    """Evaluate the quality of generated assertions against reference assertions"""

    # Parse individual assertions if provided as multiline string
    if isinstance(generated_assertions, str):
        # Split by semicolons or newlines
        generated_list = re.split(r';|\n', generated_assertions)
        generated_list = [a.strip() + ';' for a in generated_list if a.strip()]
    else:
        generated_list = generated_assertions

    if isinstance(reference_assertions, str):
        reference_list = re.split(r';|\n', reference_assertions)
        reference_list = [a.strip() + ';' for a in reference_list if a.strip()]
    else:
        reference_list = reference_assertions

    # Normalize assertions
    normalized_generated = [normalize_assertion(a) for a in generated_list]
    normalized_reference = [normalize_assertion(a) for a in reference_list]

    # Calculate exact matches (accuracy)
    exact_matches = 0
    for gen in normalized_generated:
        if gen in normalized_reference:
            exact_matches += 1

    # Calculate similarity scores
    similarity_scores = []
    for gen in normalized_generated:
        best_sim = 0
        for ref in normalized_reference:
            sim = calculate_similarity(gen, ref)
            best_sim = max(best_sim, sim)
        similarity_scores.append(best_sim)

    # Calculate metrics
    precision = exact_matches / len(normalized_generated) if normalized_generated else 0
    recall = exact_matches / len(normalized_reference) if normalized_reference else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    accuracy = exact_matches / max(len(normalized_generated), len(normalized_reference)) if max(
        len(normalized_generated), len(normalized_reference)) > 0 else 0

    return {
        "exact_matches": exact_matches,
        "generated_count": len(normalized_generated),
        "reference_count": len(normalized_reference),
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "accuracy": accuracy,
        "similarity_score_avg": sum(similarity_scores) / len(similarity_scores) if similarity_scores else 0,
        "similarity_scores": similarity_scores
    }


def evaluate_model(model, tokenizer, dataloader, device):
    """Evaluate model on dataloader"""
    model.eval()
    eval_loss = 0.0
    all_metrics = {
        "exact_matches": 0,
        "generated_count": 0,
        "reference_count": 0,
        "similarity_scores": [],
        "accuracy_scores": [],
        "f1_scores": []
    }

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            # Move batch to device
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            # Get loss
            loss = outputs.loss
            eval_loss += loss.item()

            # Generate predictions for evaluation
            generated_ids = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=512,
                num_beams=4,
                early_stopping=True
            )

            # Decode and evaluate
            for i in range(len(input_ids)):
                generated_text = tokenizer.decode(generated_ids[i], skip_special_tokens=True)
                reference_text = batch["original_target"][i]

                try:
                    metrics = evaluate_assertions(generated_text, reference_text)

                    # Update metrics
                    all_metrics["exact_matches"] += metrics["exact_matches"]
                    all_metrics["generated_count"] += metrics["generated_count"]
                    all_metrics["reference_count"] += metrics["reference_count"]
                    all_metrics["similarity_scores"].extend(metrics["similarity_scores"])
                    all_metrics["accuracy_scores"].append(metrics["accuracy"])
                    all_metrics["f1_scores"].append(metrics["f1"])
                except Exception as e:
                    print(f"Error evaluating assertion: {e}")

    # Calculate overall metrics
    avg_loss = eval_loss / len(dataloader)

    if all_metrics["generated_count"] > 0 and all_metrics["reference_count"] > 0:
        overall_precision = all_metrics["exact_matches"] / all_metrics["generated_count"]
        overall_recall = all_metrics["exact_matches"] / all_metrics["reference_count"]
        overall_f1 = 2 * overall_precision * overall_recall / (overall_precision + overall_recall) if (
                                                                                                                  overall_precision + overall_recall) > 0 else 0
        total_count = all_metrics["generated_count"] + all_metrics["reference_count"]
        overall_accuracy = all_metrics["exact_matches"] * 2 / total_count if total_count > 0 else 0
    else:
        overall_precision = 0
        overall_recall = 0
        overall_f1 = 0
        overall_accuracy = 0

    avg_similarity = sum(all_metrics["similarity_scores"]) / len(all_metrics["similarity_scores"]) if all_metrics[
        "similarity_scores"] else 0
    avg_per_sample_accuracy = sum(all_metrics["accuracy_scores"]) / len(all_metrics["accuracy_scores"]) if all_metrics[
        "accuracy_scores"] else 0
    avg_f1 = sum(all_metrics["f1_scores"]) / len(all_metrics["f1_scores"]) if all_metrics["f1_scores"] else 0

    eval_results = {
        "precision": overall_precision,
        "recall": overall_recall,
        "f1": overall_f1,
        "accuracy": overall_accuracy,
        "avg_per_sample_accuracy": avg_per_sample_accuracy,
        "similarity_score_avg": avg_similarity,
        "avg_per_sample_f1": avg_f1,
        "total_exact_matches": all_metrics["exact_matches"],
        "total_generated": all_metrics["generated_count"],
        "total_reference": all_metrics["reference_count"]
    }

    return avg_loss, eval_results

We also need a method to train the student model:

In [47]:
def train_student_model(model, tokenizer, train_dataloader, val_dataloader, args):
    """Train the student model to match the teacher's output on the assertion generation task"""

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Setup tensorboard if available
    try:
        from torch.utils.tensorboard import SummaryWriter
        tensorboard_writer = SummaryWriter(log_dir=os.path.join(args["output_dir"], "tensorboard"))
        use_tensorboard = True
    except ImportError:
        use_tensorboard = False

    # Prepare optimizer and scheduler
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args["weight_decay"],
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args["learning_rate"])

    # Calculate total training steps
    if args["max_steps"] > 0:
        t_total = args["max_steps"]
        num_epochs = args["max_steps"] // (len(train_dataloader) // args["gradient_accumulation_steps"]) + 1
    else:
        t_total = len(train_dataloader) // args["gradient_accumulation_steps"] * args["epochs"]
        num_epochs = args["epochs"]

    # Create scheduler
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args["warmup_steps"],
        num_training_steps=t_total
    )

    # Mixed precision training if requested
    scaler = torch.cuda.amp.GradScaler() if args["fp16"] else None

    # Track metrics
    best_val_loss = float('inf')
    global_step = 0
    epochs_without_improvement = 0

    # For per-batch metrics
    batch_loss_window_size = args["batch_metrics_window"]
    batch_losses = []
    batch_accuracies = []
    batch_similarities = []
    eval_pool_size = args["batch_eval_pool"]  # Number of examples to evaluate in each batch

    # Create output directory
    os.makedirs(args["output_dir"], exist_ok=True)

    # Save training arguments
    with open(os.path.join(args["output_dir"], "training_args.json"), "w") as f:
        json.dump(args, f, indent=4)

    # Create metrics file
    metrics_file = os.path.join(args["output_dir"], "training_metrics.csv")
    with open(metrics_file, "w") as f:
        f.write("epoch,batch,global_step,loss,accuracy,similarity,lr,examples_per_second\n")

    # Main training loop
    print(f"Starting training for {num_epochs} epochs...")
    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0
        epoch_start_time = time.time()
        examples_processed = 0

        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")

        for batch_idx, batch in enumerate(progress_bar):
            batch_start_time = time.time()
            examples_in_batch = len(batch["input_ids"])
            examples_processed += examples_in_batch

            # Move batch to device
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            teacher_logits = batch["teacher_logits"].to(device)

            # Forward pass with optional mixed precision
            if args["fp16"]:
                with torch.cuda.amp.autocast():
                    outputs = model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        labels=labels
                    )
                    loss_ce = outputs.loss / args["gradient_accumulation_steps"]
                    student_logits = outputs.logits
                    
                    if student_logits.shape[1] != teacher_logits.shape[1]:
                        print("Student and teacher logits sizes do not match")
                        pass
                    
                    active_loss_mask = labels.view(-1) != -100 # Flattened mask

                    active_student_log_probs = torch.nn.functional.log_softmax(
                        student_logits.view(-1, student_logits.size(-1))[active_loss_mask] / args["distillation_temp"],
                        dim=-1
                    )
                    active_teacher_probs = torch.nn.functional.softmax(
                        teacher_logits.view(-1, teacher_logits.size(-1))[active_loss_mask] / args["distillation_temp"],
                        dim=-1
                    )

                    if active_student_log_probs.numel() > 0:
                        loss_fct_kl = torch.nn.KLDivLoss(reduction="batchmean")
                        loss_distill = loss_fct_kl(
                            active_student_log_probs,
                            active_teacher_probs
                        ) * (args["distillation_temp"] ** 2) # Scale by T^2
                    else:
                        loss_distill = torch.tensor(0.0, device=loss_ce.device, dtype=loss_ce.dtype)
                        
                    loss = args["alpha_ce"] * loss_ce + args["alpha_distil"] * loss_distill

                # Backward pass with gradient scaling
                scaler.scale(loss).backward()

                if (batch_idx + 1) % args["gradient_accumulation_steps"] == 0:
                    # Unscales the gradients
                    scaler.unscale_(optimizer)

                    # Clip gradients
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args["max_grad_norm"])

                    # Update weights
                    scaler.step(optimizer)
                    scheduler.step()
                    scaler.update()
                    optimizer.zero_grad()
                    global_step += 1
            else:
                # Standard forward and backward pass
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                loss_ce = outputs.loss / args["gradient_accumulation_steps"]
                student_logits = outputs.logits
                    
                if student_logits.shape[1] != teacher_logits.shape[1]:
                    print("Student and teacher logits sizes do not match")
                    pass

                active_loss_mask = labels.view(-1) != -100 # Flattened mask

                active_student_log_probs = torch.nn.functional.log_softmax(
                    student_logits.view(-1, student_logits.size(-1))[active_loss_mask] / args["distillation_temp"],
                    dim=-1
                )
                active_teacher_probs = torch.nn.functional.softmax(
                    teacher_logits.view(-1, teacher_logits.size(-1))[active_loss_mask] / args["distillation_temp"],
                    dim=-1
                )

                if active_student_log_probs.numel() > 0:
                    loss_fct_kl = torch.nn.KLDivLoss(reduction="batchmean")
                    loss_distill = loss_fct_kl(
                        active_student_log_probs,
                        active_teacher_probs
                    ) * (args["distillation_temp"] ** 2) # Scale by T^2
                else:
                    loss_distill = torch.tensor(0.0, device=loss_ce.device, dtype=loss_ce.dtype)
                    
                loss = args["alpha_ce"] * loss_ce + args["alpha_distil"] * loss_distill

                # Backward pass
                loss.backward()

                if (batch_idx + 1) % args["gradient_accumulation_steps"] == 0:
                    # Clip gradients
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args["max_grad_norm"])

                    # Update weights
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                    global_step += 1

            # Track the loss
            loss_value = loss.item() * args["gradient_accumulation_steps"]
            epoch_loss += loss_value
            batch_losses.append(loss_value)
            if len(batch_losses) > batch_loss_window_size:
                batch_losses.pop(0)

            # Calculate time per example
            batch_time = time.time() - batch_start_time
            examples_per_second = examples_in_batch / batch_time if batch_time > 0 else 0

            # Per-batch metrics: Generate predictions for a few examples to calculate accuracy
            if args["track_batch_metrics"] and batch_idx % args["batch_metrics_every"] == 0:
                # Sample some examples from batch to evaluate
                eval_indices = np.random.choice(
                    range(len(input_ids)),
                    size=min(eval_pool_size, len(input_ids)),
                    replace=False
                )

                # Switch to eval mode temporarily
                model.eval()
                with torch.no_grad():
                    # Generate predictions for sampled examples
                    sampled_input_ids = input_ids[eval_indices]
                    sampled_attention_mask = attention_mask[eval_indices]

                    generated_ids = model.generate(
                        input_ids=sampled_input_ids,
                        attention_mask=sampled_attention_mask,
                        max_length=args["max_tgt_length"],
                        num_beams=4,
                        early_stopping=True
                    )

                    # Calculate accuracy and similarity
                    batch_accuracy = 0
                    batch_similarity = 0

                    for i, idx in enumerate(eval_indices):
                        generated_text = tokenizer.decode(generated_ids[i], skip_special_tokens=True)
                        reference_text = batch["original_target"][idx]

                        try:
                            metrics = evaluate_assertions(generated_text, reference_text)
                            batch_accuracy += metrics["accuracy"]
                            batch_similarity += metrics["similarity_score_avg"]
                        except Exception as e:
                            print(f"Error evaluating assertion: {e}")

                    # Average the metrics
                    batch_accuracy /= len(eval_indices) if eval_indices else 1
                    batch_similarity /= len(eval_indices) if eval_indices else 1

                # Switch back to train mode
                model.train()

                # Track metrics
                batch_accuracies.append(batch_accuracy)
                batch_similarities.append(batch_similarity)
                if len(batch_accuracies) > batch_loss_window_size:
                    batch_accuracies.pop(0)
                if len(batch_similarities) > batch_loss_window_size:
                    batch_similarities.pop(0)

                # Calculate moving averages
                avg_loss = sum(batch_losses) / len(batch_losses)
                avg_accuracy = sum(batch_accuracies) / len(batch_accuracies) if batch_accuracies else 0
                avg_similarity = sum(batch_similarities) / len(batch_similarities) if batch_similarities else 0

                # Update progress bar
                progress_bar.set_postfix({
                    "loss": avg_loss,
                    "accuracy": avg_accuracy,
                    "similarity": avg_similarity,
                    "ex/s": f"{examples_per_second:.1f}"
                })

                # Log to tensorboard
                if use_tensorboard:
                    tensorboard_writer.add_scalar("batch_loss", avg_loss, global_step)
                    tensorboard_writer.add_scalar("batch_accuracy", avg_accuracy, global_step)
                    tensorboard_writer.add_scalar("batch_similarity", avg_similarity, global_step)
                    tensorboard_writer.add_scalar("lr", scheduler.get_last_lr()[0], global_step)
                    tensorboard_writer.add_scalar("examples_per_second", examples_per_second, global_step)

                # Log to CSV
                with open(metrics_file, "a") as f:
                    f.write(
                        f"{epoch + 1},{batch_idx + 1},{global_step},{avg_loss:.6f},{avg_accuracy:.6f},{avg_similarity:.6f},{scheduler.get_last_lr()[0]:.8f},{examples_per_second:.2f}\n")
            else:
                # Just update with loss
                avg_loss = sum(batch_losses) / len(batch_losses)
                progress_bar.set_postfix({
                    "loss": avg_loss,
                    "ex/s": f"{examples_per_second:.1f}"
                })

            # Evaluate periodically
            if args["eval_steps"] > 0 and global_step % args["eval_steps"] == 0:
                val_loss, eval_results = evaluate_model(model, tokenizer, val_dataloader, device)

                # Log to tensorboard
                if use_tensorboard:
                    tensorboard_writer.add_scalar("eval_loss", val_loss, global_step)
                    for metric, value in eval_results.items():
                        if isinstance(value, (int, float)):
                            tensorboard_writer.add_scalar(f"eval_{metric}", value, global_step)

                # Print evaluation results
                print(f"\nEvaluation at step {global_step}:")
                print(f"  Validation loss: {val_loss:.4f}")
                print(f"  Similarity score: {eval_results['similarity_score_avg']:.4f}")
                print(f"  Accuracy: {eval_results['accuracy']:.4f}")
                print(f"  F1 score: {eval_results['f1']:.4f}")

                # Save best model
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    print(f"  New best validation loss: {val_loss:.4f}")

                    # Save model checkpoint
                    model_dir = os.path.join(args["output_dir"], "best_model")
                    os.makedirs(model_dir, exist_ok=True)
                    model.save_pretrained(model_dir)
                    tokenizer.save_pretrained(model_dir)

                    # Save optimizer and scheduler
                    torch.save(optimizer.state_dict(), os.path.join(model_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(), os.path.join(model_dir, "scheduler.pt"))

                    # Reset patience counter
                    epochs_without_improvement = 0
                else:
                    # Increment patience counter
                    epochs_without_improvement += 1

                # Early stopping
                if 0 < args["early_stopping_patience"] <= epochs_without_improvement:
                    print(f"Early stopping after {epochs_without_improvement} evaluations without improvement")
                    break

                # Back to training mode
                model.train()

            # Save checkpoint
            if args["save_steps"] > 0 and global_step % args["save_steps"] == 0:
                checkpoint_dir = os.path.join(args["output_dir"], f"checkpoint-{global_step}")
                os.makedirs(checkpoint_dir, exist_ok=True)
                model.save_pretrained(checkpoint_dir)
                tokenizer.save_pretrained(checkpoint_dir)

                # Save optimizer and scheduler
                torch.save(optimizer.state_dict(), os.path.join(checkpoint_dir, "optimizer.pt"))
                torch.save(scheduler.state_dict(), os.path.join(checkpoint_dir, "scheduler.pt"))

            # Break if max steps reached
            if args["max_steps"] > 0 and global_step >= args["max_steps"]:
                break

        # Calculate average loss for the epoch
        epoch_avg_loss = epoch_loss / len(train_dataloader)
        train_losses.append(epoch_avg_loss)

        # Calculate epoch time and speed
        epoch_time = time.time() - epoch_start_time
        examples_per_second = examples_processed / epoch_time if epoch_time > 0 else 0

        print(f"\nEpoch {epoch + 1}/{num_epochs} completed in {epoch_time:.2f}s ({examples_per_second:.2f} examples/s)")
        print(f"  Average training loss: {epoch_avg_loss:.4f}")

        # Evaluate at the end of each epoch
        print(f"  Evaluating epoch {epoch + 1}...")
        val_loss, eval_results = evaluate_model(model, tokenizer, val_dataloader, device)
        val_losses.append(val_loss)

        # Log to tensorboard
        if use_tensorboard:
            tensorboard_writer.add_scalar("epoch_train_loss", epoch_avg_loss, epoch + 1)
            tensorboard_writer.add_scalar("epoch_val_loss", val_loss, epoch + 1)
            for metric, value in eval_results.items():
                if isinstance(value, (int, float)):
                    tensorboard_writer.add_scalar(f"epoch_eval_{metric}", value, epoch + 1)

        # Print evaluation results
        print(f"  Validation loss: {val_loss:.4f}")
        print(f"  Similarity score: {eval_results['similarity_score_avg']:.4f}")
        print(f"  Accuracy: {eval_results['accuracy']:.4f}")
        print(f"  F1 score: {eval_results['f1']:.4f}")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            print(f"  New best validation loss: {val_loss:.4f}")

            # Save model checkpoint
            model_dir = os.path.join(args["output_dir"], "best_model")
            os.makedirs(model_dir, exist_ok=True)
            model.save_pretrained(model_dir)
            tokenizer.save_pretrained(model_dir)

            # Reset patience counter
            epochs_without_improvement = 0
        else:
            # Increment patience counter
            epochs_without_improvement += 1

        # Early stopping
        if 0 < args["early_stopping_patience"] <= epochs_without_improvement:
            print(f"Early stopping after {epochs_without_improvement} epochs without improvement")
            break

    # Save final model
    final_model_dir = os.path.join(args["output_dir"], "final_model")
    os.makedirs(final_model_dir, exist_ok=True)
    model.save_pretrained(final_model_dir)
    tokenizer.save_pretrained(final_model_dir)

    # Close tensorboard writer
    if use_tensorboard:
        tensorboard_writer.close()

    # Plot loss curves
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss')
    plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.savefig(os.path.join(args["output_dir"], "loss_curves.png"))
    plt.close()

    return model, tokenizer, best_val_loss

Now we will also need to run the training of our student model. For this, we will need to define the training configuration and a method to run the training with that configuration. We will then train a sample student model (with the same architecture as the teacher), just to make sure that everything works. For that training, we will use a small amount of data (1000 data points with 0.1 validation split) and we will only do evaluations at the end of the epoch (as this is just a test run). The training will also be based on logits learning and no ground truth loss (thus having alpha_ce=0 and alpha_distil=1).

In [48]:
def run_student_training(args):
    # Set random seed
    torch.manual_seed(args["seed"])
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args["seed"])

    # Load training dataset
    print(f"Loading training dataset from {args['data_path_training']}...")
    if args["max_samples_training"] is not None:
        train_data = load_dataset(args["data_path_training"], args["max_samples_training"])
        print(f"Using first {len(train_data)} examples")
    else:
        train_data = load_dataset(args["data_path_training"])
        print(f"Loaded {len(train_data)} examples")
        
    # Load validation dataset
    print(f"Loading validation dataset from {args['data_path_validation']}...")
    if args["max_samples_validation"] is not None:
        val_data = load_dataset(args["data_path_validation"], args["max_samples_validation"])
        print(f"Using first {len(val_data)} examples")
    else:
        val_data = load_dataset(args["data_path_validation"])
        print(f"Loaded {len(val_data)} examples")

    print(f"Training on {len(train_data)} examples, validating on {len(val_data)} examples")

    # Load tokenizer and model
    if args["model_name"]:
        print(f"Loading model: {args['model_name']}")
        model = T5ForConditionalGeneration.from_pretrained(args["model_name"])
    else:
        print("Using custom model")
        model = args["model"]
        
    tokenizer = RobertaTokenizer.from_pretrained(args["teacher_model_name"])

    # Create datasets
    train_dataset = StudentDataset(
        train_data,
        tokenizer,
        max_src_length=args["max_src_length"],
        max_tgt_length=args["max_tgt_length"]
    )
    val_dataset = StudentDataset(
        val_data,
        tokenizer,
        max_src_length=args["max_src_length"],
        max_tgt_length=args["max_tgt_length"]
    )

    # Create dataloaders
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=args["batch_size"],
        shuffle=True,
        num_workers=args["num_workers"],
        pin_memory=True
    )
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=args["eval_batch_size"],
        shuffle=False,
        num_workers=args["num_workers"],
        pin_memory=True
    )

    # Check for CUDA
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Train model
    model, tokenizer, best_val_loss = train_student_model(
        model,
        tokenizer,
        train_dataloader,
        val_dataloader,
        args
    )

    print(f"Training completed! Best validation loss: {best_val_loss:.4f}")
    print(f"Trained model and checkpoints saved to {args['output_dir']}")

In [49]:
config_codet5_small_test = {
    # --- Data Args ---
    "data_path_training": "Data/distillation_data_training.jsonl",
    "data_path_validation": "Data/distillation_data_validation.jsonl",
    "output_dir": "./output_models/student_model_output_codet5_small_test",
    "teacher_model_name": "Salesforce/codet5-small",
    "model_name": "Salesforce/codet5-small",
    "model": None, # Used for custom models
    "max_src_length": 1024,
    "max_tgt_length": 512,

    # --- Distillation Args ---
    "distillation_temp": 2.0,
    "alpha_ce": 0.0,            # Weight for student's own Cross-Entropy loss
    "alpha_distil": 1.0,        # Weight for distillation loss (e.g., KL divergence)

    # --- Training Args ---
    "epochs": 5,
    "batch_size": 8,
    "eval_batch_size": 8,
    "gradient_accumulation_steps": 1,
    "learning_rate": 5e-5,
    "weight_decay": 0.01,
    "warmup_steps": 0,
    "max_grad_norm": 1.0,
    "max_steps": -1,
    "fp16": True,
    "seed": 42,

    # --- Batch Metrics Args (Optional, can simplify by removing) ---
    "track_batch_metrics": False,
    "batch_metrics_every": 50,
    "batch_metrics_window": 50,
    "batch_eval_pool": 4,

    # --- Logging and Saving Args ---
    "logging_steps": 100,
    "eval_strategy": "epoch",
    "eval_steps": 0,
    "save_strategy": "epoch",
    "save_steps": 0,
    "save_total_limit": 2,
    "early_stopping_patience": 3,
    "num_workers": 2,           
    "max_samples_training": 900,
    "max_samples_validation": 100,
}

In [56]:
run_student_training(config_codet5_small_test)

Loading training dataset from Data/distillation_data_training.jsonl...


Loading dataset:   1%|▊                                                               | 11/900 [00:00<00:09, 98.21it/s]



Loading dataset: 100%|██████████████████████████████████████████████████████████████| 900/900 [00:08<00:00, 101.87it/s]


Using first 900 examples
Loading validation dataset from Data/distillation_data_validation.jsonl...


Loading dataset:  18%|███████████▌                                                    | 18/100 [00:00<00:00, 90.37it/s]



Loading dataset: 100%|██████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 103.07it/s]


Using first 100 examples
Training on 900 examples, validating on 100 examples
Loading model: Salesforce/codet5-small
Using device: cuda


  scaler = torch.cuda.amp.GradScaler() if args["fp16"] else None


Starting training for 5 epochs...


Epoch 1/5:   0%|                                                                               | 0/113 [00:18<?, ?it/s]


KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "D:\Delft\Anaconda3\Lib\site-packages\torch\utils\data\_utils\worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "D:\Delft\Anaconda3\Lib\site-packages\torch\utils\data\_utils\fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Delft\Anaconda3\Lib\site-packages\torch\utils\data\_utils\fetch.py", line 52, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "D:\Delft\y3\q4\LLMDistillation\Data\student_dataset.py", line 25, in __getitem__
    target_text = "\n".join(item['assertions'])
                            ~~~~^^^^^^^^^^^^^^
KeyError: 'assertions'
