## This notebook aims to fine-tune CodeT5 for generating descriptive comments for code snippets, aiding developers in maintaining readable codebases.

### PHASE 1: Setup and Environment Preparation

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install evaluate sentencepiece rouge_score
!pip install --upgrade transformers
!pip install datasets==3.0.0

Collecting evaluate
  Downloading evaluate-0.4.5-py3-none-any.whl.metadata (9.5 kB)
Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading evaluate-0.4.5-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=0f08322d3a4b3c1dd0cd9ea9cfd2a37d0c99f401bdd1314a9d16a2a7e9add44b
  Stored in directory: /root/.cache/pip/wheels/1e/19/43/8a442dc83660ca25e163e1bd1f89919284ab0d0c1475475148
Successfully built rouge_score
Installing collected packages: rouge_score, evaluate
Successfully installed evaluate-0.4.5 rouge_score-0.1.2
Collecting transformers
  Downloading transformers-4.54.0-py3-none-any.whl.metadata (41 kB)


In [3]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments
from datasets import load_dataset, DatasetDict

### PHASE 2: Dataset Acquisition & Preprocessing
Objective: Use the CodeSearchNet Python subset and format it for CodeT5.

In [4]:
# Load the python subset of CodeSeachNet

from datasets import load_dataset
ds = load_dataset("code_search_net", "python")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


code_search_net.py: 0.00B [00:00, ?B/s]

README.md: 0.00B [00:00, ?B/s]

The repository for code_search_net contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/code_search_net.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


python.zip:   0%|          | 0.00/941M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/412178 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/22176 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/23107 [00:00<?, ? examples/s]

#### Filtering Invalid Samples, applying preprocessing & Tokenization

In [5]:
ds

DatasetDict({
    train: Dataset({
        features: ['repository_name', 'func_path_in_repository', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens', 'split_name', 'func_code_url'],
        num_rows: 412178
    })
    test: Dataset({
        features: ['repository_name', 'func_path_in_repository', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens', 'split_name', 'func_code_url'],
        num_rows: 22176
    })
    validation: Dataset({
        features: ['repository_name', 'func_path_in_repository', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens', 'split_name', 'func_code_url'],
        num_rows: 23107
    })
})

In [6]:
ds["train"].column_names

['repository_name',
 'func_path_in_repository',
 'func_name',
 'whole_func_string',
 'language',
 'func_code_string',
 'func_code_tokens',
 'func_documentation_string',
 'func_documentation_tokens',
 'split_name',
 'func_code_url']

In [7]:
ds["train"][3]["func_code_string"]

'def ensure_dir(d):\n    """\n    Check to make sure the supplied directory path does not exist, if so, create it. The\n    method catches OSError exceptions and returns a descriptive message instead of\n    re-raising the error.\n\n    :type d: str\n    :param d: It is the full path to a directory.\n\n    :return: Does not return anything, but creates a directory path if it doesn\'t exist\n             already.\n    """\n    if not os.path.exists(d):\n        try:\n            os.makedirs(d)\n        except OSError as oe:\n            # should not happen with os.makedirs\n            # ENOENT: No such file or directory\n            if os.errno == errno.ENOENT:\n                msg = twdd("""One or more directories in the path ({}) do not exist. If\n                           you are specifying a new directory for output, please ensure\n                           all other directories in the path currently exist.""")\n                return msg.format(d)\n            else:\n           

In [8]:
ds["train"][3]["func_documentation_string"]

"Check to make sure the supplied directory path does not exist, if so, create it. The\n    method catches OSError exceptions and returns a descriptive message instead of\n    re-raising the error.\n\n    :type d: str\n    :param d: It is the full path to a directory.\n\n    :return: Does not return anything, but creates a directory path if it doesn't exist\n             already."

In [9]:
"""
Filtering Function (has_valid_fields):

Ensures code and docstrings are non-empty, sufficiently long (>50 chars for code, >10 for docs), and non-placeholder (no "TODO").

Excludes overly long docstrings (>5 newlines) to maintain conciseness.

Apply Filter:

Filters the training split to create a cleaner dataset for fine-tuning, removing low-quality examples.

Purpose:
Prepares a high-quality dataset for training models (e.g., CodeT5) to generate accurate and relevant code explanations.
"""
def has_valid_fields(example):
    # More stringent filtering
    return (
        example["func_code_string"] and
        example["func_documentation_string"] and
        len(example["func_code_string"].strip()) > 50 and  # Minimum code length
        len(example["func_documentation_string"].strip()) > 10 and  # Minimum doc length
        not example["func_documentation_string"].startswith("TODO") and
        example["func_documentation_string"].count('\n') < 5  # Avoid overly long docs
    )

train_dataset = ds["train"].filter(has_valid_fields)


Filter:   0%|          | 0/412178 [00:00<?, ? examples/s]

In [10]:
"""
Docstring Cleaning:

Removes whitespace and triple-quotes to ensure clean, consistent docstrings.

Structured Input-Target Format:

Input: Code prefixed with Generate docstring: (task prompt).

Target: Cleaned docstring (expected model output).

Purpose: Prepares data for seq2seq training (code → docstring generation).

Dataset Processing:

Applies process to all examples and removes unused columns to optimize memory and compatibility.

"""

def process(example):
    # Clean the docstring
    doc = example["func_documentation_string"].strip()
    # Remove common prefixes that might confuse the model
    doc = doc.replace('"""', '').replace("'''", '').strip()

    return {
        "input_text": "Generate docstring: " + example["func_code_string"],
        "target_text": doc
    }

processed_dataset = train_dataset.map(process).remove_columns(train_dataset.column_names)

Map:   0%|          | 0/253254 [00:00<?, ? examples/s]

In [11]:
# Manual augmentation approach

"""
- Manual augmentation creates multiple prompt variations for each code snippet (e.g., `"Explain this function: "`, `"Document: "`).
- Goal: Improve model generalization by training it to respond to different phrasings of the same task.
"""
def create_augmented_dataset(dataset):
    """Create an augmented dataset by expanding each example"""
    all_examples = []

    for example in dataset:
        code = example["input_text"].replace("Generate docstring: ", "")
        target = example["target_text"]

        # Create variations
        variations = [
            {"input_text": "Generate docstring: " + code, "target_text": target},
            {"input_text": "Explain this function: " + code, "target_text": target},
            {"input_text": "What does this code do: " + code, "target_text": target},
            {"input_text": "Document: " + code, "target_text": target},
        ]

        all_examples.extend(variations)

    # Convert back to dataset
    from datasets import Dataset
    return Dataset.from_list(all_examples)

# Apply augmentation
print("Creating augmented dataset...")
augmented_dataset = create_augmented_dataset(processed_dataset)
print(f"Original dataset size: {len(processed_dataset)}")
print(f"Augmented dataset size: {len(augmented_dataset)}")

Creating augmented dataset...
Original dataset size: 253254
Augmented dataset size: 1013016


In [12]:
for i, example in enumerate(processed_dataset.shuffle(seed=42).select(range(3))):
    print(f"\n--- Example {i+1} ---")
    print("Input:\n", example["input_text"][:500])
    print("\nTarget:\n", example["target_text"])



--- Example 1 ---
Input:
 Generate docstring: def _compute_weight_std(self, C, mag):
        """
        Common part of equations 8 and 9, page 971.
        """
        if mag < 6.0:
            return C['a1']
        elif mag >= 6.0 and mag < 6.5:
            return C['a1'] + (C['a2'] - C['a1']) * ((mag - 6.0) / 0.5)
        else:
            return C['a2']

Target:
 Common part of equations 8 and 9, page 971.

--- Example 2 ---
Input:
 Generate docstring: def copy(self, *, shallow=False):
        """Return a copy of a table."""
        table = type(self)()
        for label in self.labels:
            if shallow:
                column = self[label]
            else:
                column = np.copy(self[label])
            self._add_column_and_format(table, label, column)
        return table

Target:
 Return a copy of a table.

--- Example 3 ---
Input:
 Generate docstring: def run_display_description(self):
        """Print profile name with programMain."""
        # disp

In [13]:
"""
Tokenization
"""

from transformers import AutoTokenizer

model_checkpoint = "Salesforce/codet5-small"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

def preprocess_function(example):
    # Longer sequences for complex functions
    model_input = tokenizer(
        example["input_text"],
        truncation=True,
        padding="max_length",
        max_length=768  # Increased from 512
    )

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            example["target_text"],
            truncation=True,
            padding="max_length",
            max_length=256  # Increased from 128
        )

    model_input["labels"] = labels["input_ids"]
    return model_input

# Apply the improved tokenization
tokenized_dataset = processed_dataset.map(preprocess_function, batched=True)


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

Map:   0%|          | 0/253254 [00:00<?, ? examples/s]



In [14]:
import os

# Define the path where you want to save it in your Drive
save_path = "/content/drive/My Drive/Code Comment Generator/my_tokenized_dataset_enhanced_2"

# Create the directory if it doesn't exist
os.makedirs(save_path, exist_ok=True)

# Save the dataset
tokenized_dataset.save_to_disk(save_path)
print(f"Dataset saved to: {save_path}")

Saving the dataset (0/4 shards):   0%|          | 0/253254 [00:00<?, ? examples/s]

Dataset saved to: /content/drive/My Drive/Code Comment Generator/my_tokenized_dataset_enhanced_2


### PHASE 3: Fine-tuning the Model

In [None]:
# # load the tokenize dataset from drive
# from datasets import load_from_disk

# load_path = "/content/drive/My Drive/Code Comment Generator/my_tokenized_dataset_enhanced"
# tokenized_dataset = load_from_disk(load_path)
# print(f"Dataset loaded from: {load_path}")

In [15]:
import evaluate
import numpy as np
import torch

rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    print(f"Predictions shape: {predictions.shape}, dtype: {predictions.dtype}")
    print(f"Labels shape: {labels.shape}, dtype: {labels.dtype}")

    try:
        # Convert to numpy arrays and ensure correct data type
        if isinstance(predictions, torch.Tensor):
            predictions = predictions.cpu().numpy()
        if isinstance(labels, torch.Tensor):
            labels = labels.cpu().numpy()

        # Ensure the arrays are of integer type and within valid range
        predictions = predictions.astype(np.int32)
        labels = labels.astype(np.int32)

        # For sequence-to-sequence models, predictions might be logits
        # If predictions have 3 dimensions, take the argmax to get token IDs
        if len(predictions.shape) == 3:
            predictions = np.argmax(predictions, axis=-1)

        # Clip values to valid token ID range to prevent out-of-range errors
        vocab_size = tokenizer.vocab_size
        predictions = np.clip(predictions, 0, vocab_size - 1)
        labels = np.clip(labels, 0, vocab_size - 1)

        # Replace -100 in labels with pad_token_id
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

        print(f"Predictions range: {predictions.min()} to {predictions.max()}")
        print(f"Labels range: {labels.min()} to {labels.max()}")
        print(f"Tokenizer vocab size: {vocab_size}")

        # Decode predictions and labels safely
        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Clean up any empty strings
        decoded_preds = [pred.strip() if pred.strip() else "empty" for pred in decoded_preds]
        decoded_labels = [label.strip() if label.strip() else "empty" for label in decoded_labels]

        # Debug: Print a few examples
        print(f"Sample predictions: {decoded_preds[:3]}")
        print(f"Sample labels: {decoded_labels[:3]}")

        # Compute metrics with error handling for each metric
        try:
            rouge_result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
            rouge_scores = {
                "rouge1": round(rouge_result["rouge1"], 2),
                "rouge2": round(rouge_result["rouge2"], 2),
                "rougeL": round(rouge_result["rougeL"], 2),
            }
        except Exception as rouge_error:
            print(f"ROUGE computation failed: {rouge_error}")
            rouge_scores = {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0}

        try:
            bleu_result = bleu.compute(predictions=decoded_preds, references=[[label] for label in decoded_labels])
            bleu_score = round(bleu_result["bleu"], 2)
        except Exception as bleu_error:
            print(f"BLEU computation failed: {bleu_error}")
            bleu_score = 0.0

        final_scores = {
            "bleu": bleu_score,
            **rouge_scores
        }

        print(f"Computed metrics: {final_scores}")
        return final_scores

    except Exception as e:
        print(f"Error in compute_metrics: {e}")
        print(f"Error type: {type(e)}")
        import traceback
        traceback.print_exc()

        # Return default values if computation fails
        return {
            "bleu": 0.0,
            "rouge1": 0.0,
            "rouge2": 0.0,
            "rougeL": 0.0,
        }

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading extra modules: 0.00B [00:00, ?B/s]

In [16]:
import os
os.environ["WANDB_DISABLED"] = "true"

# Load model and tokenizer
model_checkpoint = "Salesforce/codet5-small"
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

print("Loading model...")
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

# Create dataset splits with larger sizes
print("Creating dataset splits...")
train_size = 1000
eval_size = 200

# Make sure tokenized_dataset is available
shuffled_dataset = tokenized_dataset.shuffle(seed=42)
train_dataset = shuffled_dataset.select(range(train_size))
eval_dataset = shuffled_dataset.select(range(train_size, train_size + eval_size))

print(f"Training samples: {len(train_dataset)}")
print(f"Evaluation samples: {len(eval_dataset)}")

# Enhanced Training arguments
from transformers import Seq2SeqTrainingArguments

output_dir = "/content/drive/MyDrive/codet5-small-training-enhanced"
model_save_dir = "/content/drive/MyDrive/codet5-small-comment-generator-enhanced"

training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    learning_rate=3e-4,  # Higher learning rate
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=8,
    weight_decay=0.01,
    logging_dir=f'{output_dir}/logs',
    logging_steps=100,
    eval_strategy="steps",
    eval_steps=200,
    save_strategy="steps",
    save_steps=200,    # Less frequent saving
    load_best_model_at_end=True,
    metric_for_best_model="eval_bleu",  # Use ROUGE instead of loss
    greater_is_better=True,
    predict_with_generate=True,
    fp16=True,
    gradient_accumulation_steps=2,  # Simulate larger batch size
    report_to=None,
    dataloader_pin_memory=False,

    # Valid generation parameters for evaluation
    generation_max_length=256,  # Increased
    generation_num_beams=6,     # More beams
)

# Create trainer
from transformers import Seq2SeqTrainer

print("Creating trainer...")
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    processing_class=tokenizer,
    compute_metrics=compute_metrics,  # Keep your existing compute_metrics function
)

# Start training
print("Starting training...")
try:
    training_history = trainer.train()

    # Save the model
    print("Saving model...")
    trainer.save_model(model_save_dir)
    tokenizer.save_pretrained(model_save_dir)

    print("✅ Training completed successfully!")
    print(f"Model saved to: {model_save_dir}")

except Exception as e:
    print(f"❌ Training failed: {e}")
    import traceback
    traceback.print_exc()

    # Save whatever progress was made
    try:
        trainer.save_model(f"{model_save_dir}_partial")
        tokenizer.save_pretrained(f"{model_save_dir}_partial")
        print(f"Partial model saved to: {model_save_dir}_partial")
    except Exception as save_error:
        print(f"Failed to save partial model: {save_error}")

print("Training process completed!")

Loading tokenizer...
Loading model...


config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/242M [00:00<?, ?B/s]

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Creating dataset splits...
Training samples: 1000
Evaluation samples: 200
Creating trainer...


model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

Starting training...


Step,Training Loss,Validation Loss,Bleu,Rouge1,Rouge2,Rougel
200,0.0113,0.003124,0.88,0.98,0.98,0.98
400,0.0045,0.002056,0.98,0.99,0.98,0.99
600,0.004,0.001433,0.96,0.99,0.98,0.99
800,0.001,0.001017,0.99,0.99,0.99,0.99
1000,0.0006,0.000938,0.99,0.99,0.99,0.99


Predictions shape: (200, 256), dtype: int64
Labels shape: (200, 256), dtype: int64
Predictions range: 0 to 31655
Labels range: 0 to 31655
Tokenizer vocab size: 32100
Sample predictions: ['Increment volume by 0.1 (or delta) unless it is already maxed.\n        Returns the new volume.', 'backwards compatibility function\n    :return:', 'select distinct values for a given field for a given a query']
Sample labels: ['Increment volume by 0.1 (or delta) unless it is already maxed.\n        Returns the new volume.', 'backwards compatibility function\n    :return:', 'select distinct values for a given field for a given a query']
Computed metrics: {'bleu': 0.88, 'rouge1': np.float64(0.98), 'rouge2': np.float64(0.98), 'rougeL': np.float64(0.98)}
Predictions shape: (200, 105), dtype: int64
Labels shape: (200, 256), dtype: int64
Predictions range: 0 to 31655
Labels range: 0 to 31655
Tokenizer vocab size: 32100
Sample predictions: ['Increment volume by 0.1 (or delta) unless it is already maxed.\n  

There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].


Saving model...
✅ Training completed successfully!
Model saved to: /content/drive/MyDrive/codet5-small-comment-generator-enhanced
Training process completed!


### Testing the fine-tuned model on

In [19]:
# Complete test suite with all 20 test cases
all_test_inputs = [
    {
        "input_text": """Generate docstring: def factorial(n):
    if n == 0:
        return 1
    else:
        return n * factorial(n - 1)"""
    },
    {
        "input_text": """Generate docstring: def is_prime(num):
    if num < 2:
        return False
    for i in range(2, int(num ** 0.5) + 1):
        if num % i == 0:
            return False
    return True"""
    },
    {
        "input_text": """Generate docstring: def reverse_string(s):
    return s[::-1]"""
    },
    {
        "input_text": """Generate docstring: def count_words(text):
    words = text.split()
    return len(words)"""
    },
    {
        "input_text": """Generate docstring: def find_max(lst):
    if not lst:
        return None
    max_val = lst[0]
    for val in lst[1:]:
        if val > max_val:
            max_val = val
    return max_val"""
    },
    {
        "input_text": """Generate docstring: def binary_search(arr, target):
    left, right = 0, len(arr) - 1
    while left <= right:
        mid = (left + right) // 2
        if arr[mid] == target:
            return mid
        elif arr[mid] < target:
            left = mid + 1
        else:
            right = mid - 1
    return -1"""
    },
    {
        "input_text": """Generate docstring: def merge_sort(arr):
    if len(arr) <= 1:
        return arr
    mid = len(arr) // 2
    left = merge_sort(arr[:mid])
    right = merge_sort(arr[mid:])
    return merge(left, right)"""
    },
    {
        "input_text": """Generate docstring: def fibonacci(n):
    if n <= 1:
        return n
    a, b = 0, 1
    for _ in range(2, n + 1):
        a, b = b, a + b
    return b"""
    },
    {
        "input_text": """Generate docstring: def is_palindrome(s):
    s = ''.join(char.lower() for char in s if char.isalnum())
    return s == s[::-1]"""
    },
    {
        "input_text": """Generate docstring: def flatten_list(nested_list):
    result = []
    for item in nested_list:
        if isinstance(item, list):
            result.extend(flatten_list(item))
        else:
            result.append(item)
    return result"""
    },
    {
        "input_text": """Generate docstring: def calculate_gcd(a, b):
    while b:
        a, b = b, a % b
    return a"""
    },
    {
        "input_text": """Generate docstring: def quick_sort(arr):
    if len(arr) <= 1:
        return arr
    pivot = arr[len(arr) // 2]
    left = [x for x in arr if x < pivot]
    middle = [x for x in arr if x == pivot]
    right = [x for x in arr if x > pivot]
    return quick_sort(left) + middle + quick_sort(right)"""
    },
    {
        "input_text": """Generate docstring: def validate_email(email):
    import re
    pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
    return re.match(pattern, email) is not None"""
    },
    {
        "input_text": """Generate docstring: def remove_duplicates(lst):
    seen = set()
    result = []
    for item in lst:
        if item not in seen:
            seen.add(item)
            result.append(item)
    return result"""
    },
    {
        "input_text": """Generate docstring: def matrix_multiply(A, B):
    rows_A, cols_A = len(A), len(A[0])
    rows_B, cols_B = len(B), len(B[0])
    if cols_A != rows_B:
        raise ValueError("Cannot multiply matrices")
    result = [[0 for _ in range(cols_B)] for _ in range(rows_A)]
    for i in range(rows_A):
        for j in range(cols_B):
            for k in range(cols_A):
                result[i][j] += A[i][k] * B[k][j]
    return result"""
    },
    {
        "input_text": """Generate docstring: def decode_base64(encoded_string):
    import base64
    try:
        decoded_bytes = base64.b64decode(encoded_string)
        return decoded_bytes.decode('utf-8')
    except Exception:
        return None"""
    },
    {
        "input_text": """Generate docstring: def find_common_elements(list1, list2):
    return list(set(list1) & set(list2))"""
    },
    {
        "input_text": """Generate docstring: def calculate_compound_interest(principal, rate, time, n):
    amount = principal * (1 + rate / n) ** (n * time)
    return round(amount - principal, 2)"""
    },
    {
        "input_text": """Generate docstring: def parse_json_safe(json_string):
    import json
    try:
        return json.loads(json_string)
    except json.JSONDecodeError:
        return {}"""
    },
    {
        "input_text": """Generate docstring: def levenshtein_distance(s1, s2):
    if len(s1) < len(s2):
        return levenshtein_distance(s2, s1)
    if len(s2) == 0:
        return len(s1)
    previous_row = list(range(len(s2) + 1))
    for i, c1 in enumerate(s1):
        current_row = [i + 1]
        for j, c2 in enumerate(s2):
            insertions = previous_row[j + 1] + 1
            deletions = current_row[j] + 1
            substitutions = previous_row[j] + (c1 != c2)
            current_row.append(min(insertions, deletions, substitutions))
        previous_row = current_row
    return previous_row[-1]"""
    }
]

# Enhanced generation with better parameters
from transformers import pipeline

# Load the model and tokenizer
generator = pipeline("text2text-generation",
                     model=model,
                     tokenizer=tokenizer)

# Generate summaries with enhanced parameters for all 20 test cases
print("Generating enhanced docstrings for all test cases...")
print(f"Total test cases: {len(all_test_inputs)}")
print("=" * 100)

for i, item in enumerate(all_test_inputs, 1):
    print(f"\n=== Test Case {i} ===")

    # Method 1: Using pipeline with better parameters
    output = generator(
        item["input_text"],
        max_new_tokens=150,     # Use max_new_tokens instead of max_length
        num_beams=6,            # More beams for better search
        early_stopping=True,
        do_sample=True,         # Add some controlled randomness
        temperature=1,          # Control randomness
        top_p=0.5,              # Nucleus sampling
        repetition_penalty=1.2, # Avoid repetition
        length_penalty=1.0      # Encourage longer outputs
    )

    print("Code:", item["input_text"])
    print("Enhanced Docstring:", output[0]['generated_text'])
    print("-" * 80)

print("\n" + "=" * 100)
print("✅ All 20 test cases completed!")
print("=" * 100)

Device set to use cuda:0


Generating enhanced docstrings for all test cases...
Total test cases: 20

=== Test Case 1 ===
Code: Generate docstring: def factorial(n):
    if n == 0:
        return 1
    else:
        return n * factorial(n - 1)
Enhanced Docstring: 1
    else:
--------------------------------------------------------------------------------

=== Test Case 2 ===
Code: Generate docstring: def is_prime(num):
    if num < 2:
        return False
    for i in range(2, int(num ** 0.5) + 1):
        if num % i == 0:
            return False
    return True
Enhanced Docstring: num < 2:
--------------------------------------------------------------------------------

=== Test Case 3 ===
Code: Generate docstring: def reverse_string(s):
    return s[::-1]
Enhanced Docstring: return s[::-1]
--------------------------------------------------------------------------------

=== Test Case 4 ===
Code: Generate docstring: def count_words(text):
    words = text.split()
    return len(words)
Enhanced Docstring: Gener

I tried using few-shot prompting to get a better output without succeeding
