In [1]:
from datasets import load_dataset
import random
from transformers import pipeline
import torch
from sentence_transformers import SentenceTransformer, util
import pandas as pd
from tqdm import tqdm
import evaluate
import gc

# Loading Dataset

The dataset can be found at https://huggingface.co/datasets/code-search-net/code_search_net

In [2]:
dataset = load_dataset('code-search-net/code_search_net')

Here are the columns in the dataset. We will be using `func_documentation_string` for documentation and `func_code_string` for code.

In [3]:
dataset['train'][0].keys()

dict_keys(['repository_name', 'func_path_in_repository', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens', 'split_name', 'func_code_url'])

Getting a random sample from the test split to test our baseline model. We will increase the size once we finalize the experimental plan

In [4]:
dataset_sample = random.choices(dataset['test'], k=50)

# Testing the Alternative model

Here we define the system prompt for the Llama 4 model.

In [5]:
prompt = \
'''You are a helpful agent designed to simplify code documentation for beginner programmers.
You will be provided with a block of code and the existing doucmentation that accompanies it.
Simplify the given documentation, using the provided code as context, so that it is understandable
to beginner programmers. Output absolutely nothing else besides the simplified documentation.
Make sure to keep any documentation formatting codes present in the simplified documentation.
If you feel that the existing documentation is simple enough and meaning would be lost by simplifying
it further, feel free to keep the documentation as is. Here is the original documentation and code:'''

Creating the pipeline for the Llama 2 model using the HuggingFace transformers library. Modified from the example here: https://huggingface.co/docs/transformers/en/model_doc/llama2

In [6]:
pipe = pipeline(
    task="text-generation",
    model="meta-llama/Llama-4-Scout-17B-16E-Instruct",
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

Fetching 50 files:   0%|          | 0/50 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/50 [00:00<?, ?it/s]

Device set to use cuda:0


Testing the pipeline. Modified from examples given here: https://huggingface.co/docs/transformers/main/en/chat_templating

In [7]:
message = [
    {"role": "system", "content": prompt},
    {"role": "user", "content": f"Documentation:\n{dataset['test'][0]['func_documentation_string']}\n\nCode:\n{dataset['test'][0]['func_code_string']}"}
]
print(f"Original Documentation:\n{dataset['test'][0]['func_documentation_string']}\n")
print(f"Code:\n{dataset['test'][0]['func_code_string']}\n")
output = pipe(message, pad_token_id=pipe.tokenizer.eos_token_id, max_new_tokens=2000)
print("Simplified Documentation:\n" + output[0]['generated_text'][-1]['content'])

Original Documentation:
Extracts video ID from URL.

Code:
def get_vid_from_url(url):
        """Extracts video ID from URL.
        """
        return match1(url, r'youtu\.be/([^?/]+)') or \
          match1(url, r'youtube\.com/embed/([^/?]+)') or \
          match1(url, r'youtube\.com/v/([^/?]+)') or \
          match1(url, r'youtube\.com/watch/([^/?]+)') or \
          parse_query_param(url, 'v') or \
          parse_query_param(parse_query_param(url, 'u'), 'v')

Simplified Documentation:
``` 
def get_vid_from_url(url):
    """
    Extracts the video ID from a YouTube URL.

    This function takes a YouTube URL as input and returns the video ID.
    It supports various YouTube URL formats, including:
    - youtu.be
    - youtube.com/embed
    - youtube.com/v
    - youtube.com/watch
    - URLs with a query parameter 'v'
    """
    return match1(url, r'youtu\.be/([^?/]+)') or \
          match1(url, r'youtube\.com/embed/([^/?]+)') or \
          match1(url, r'youtube\.com/v/([^/?]+)'

Loading the evaluation model used for computing semantic similarity. Taken from example here: https://huggingface.co/tasks/sentence-similarity

In [8]:
eval_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

Running inference on the dataset sample

In [9]:
semantic_similarities = []
metrics = evaluate.combine(['rouge', 'meteor'])

for instance in tqdm(dataset_sample):
    message = [
        {"role": "system", "content": prompt},
        {"role": "user", "content": f"Documentation:\n{instance['func_documentation_string']}\n\nCode:\n{instance['func_code_string']}"}
    ]
    
    gc.collect()
    torch.cuda.empty_cache()

    result = pipe(message, pad_token_id=pipe.tokenizer.eos_token_id, max_new_tokens=2000)[0]['generated_text'][-1]['content']

    embedding_original = eval_model.encode(instance['func_documentation_string'], convert_to_tensor=True)
    embedding_predicted = eval_model.encode(result, convert_to_tensor=True)

    semantic_similarities.append(util.pytorch_cos_sim(embedding_original, embedding_predicted).item())
    metrics.add(predictions=result, references=instance['func_documentation_string'])

[nltk_data] Downloading package wordnet to
[nltk_data]     /home/j/jwoods03/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     /home/j/jwoods03/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     /home/j/jwoods03/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
 18%|█▊        | 9/50 [08:31<26:08, 38.25s/it]   You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
100%|██████████| 50/50 [44:46<00:00, 53.73s/it] 


Summary statistics for semantic similarity results (alternative model)

In [10]:
pd.DataFrame(semantic_similarities).describe()

Unnamed: 0,0
count,50.0
mean,0.700589
std,0.208311
min,-0.032901
25%,0.665341
50%,0.742608
75%,0.848514
max,0.933072


ROUGE and METEOR results

In [11]:
metrics.compute()

{'rouge1': 0.40630260978545596,
 'rouge2': 0.2399953505621984,
 'rougeL': 0.3604821523347549,
 'rougeLsum': 0.39214090448061795,
 'meteor': 0.43048831851121055}