In [2]:
# -----------------------------------------------------------------------------
# Code-Aware RAG Implementation for Python Docstring Generation (Refactored)
# -----------------------------------------------------------------------------
# This script implements the "Code-Aware Query RAG" strategy with a focus on
# clear, modular functions. It works by:
# 1. Parsing code to extract key entities.
# 2. Constructing an enriched query.
# 3. Retrieving context from Pinecone.
# 4. Generating a docstring using a two-stage LLM process.
#
# Prerequisites:
# - Python 3.7+
# - Pinecone account (API Key and Environment)
# - Ollama installed and running locally
# - Ollama models pulled:
#   - ollama pull qwen2.5-coder:0.5b
#   - ollama pull qwen2.5-coder:1.5b
# - Python libraries installed:
#   - pip install pinecone-client sentence-transformers requests beautifulsoup4 ollama
# -----------------------------------------------------------------------------

In [4]:
import requests
from bs4 import BeautifulSoup
from sentence_transformers import SentenceTransformer
from pinecone import Pinecone, ServerlessSpec, PodSpec
import ollama
import os
import uuid
import time
import pickle
import pandas as pd
import re
import pandas as pd
from rouge import Rouge
import pandas as pd
from nltk.translate.bleu_score import sentence_bleu
import nltk
import warnings
warnings.filterwarnings('ignore')
import openpyxl
from bert_score import score
import itertools
import hf_xet
import zlib
import subprocess
import tempfile
import traceback

In [6]:
PINECONE_API_KEY = "pcsk_71bnuL_HGU1YACobTvL5gJNzHsZG1NMNx3RGmz1ohyC7xMiUYoWnuZpEn5SuvWpuTxnuzm"
PINECONE_ENVIRONMENT = "us-east-1"

# --- Constants ---
INDEX_NAME = "code-aware-rag-docstring"
EMBEDDING_MODEL = 'all-MiniLM-L6-v2' # HuggingFace sentence transformer
OLLAMA_GENERATOR_MODEL = 'deepseek-coder:6.7b' # Local Ollama model name (Ensure this is pulled: `ollama pull qwen2.5-coder:1.5b`)
OLLAMA_HELPER_MODEL = 'deepseek-r1:1.5b'

In [8]:
TARGET_URL = [
    "https://peps.python.org/pep-0257/",
    "https://www.kaggle.com/code/hagzilla/what-are-docstrings",
    "https://github.com/keleshev/pep257/blob/master/pep257.py",
    "https://github.com/chadrik/doc484",
    "https://zerotomastery.io/blog/python-docstring/",
    "https://google.github.io/styleguide/pyguide.html",
    "https://www.geeksforgeeks.org/python-docstrings/",
    "https://pandas.pydata.org/docs/development/contributing_docstring.html",
    "https://www.coding-guidelines.lftechnology.com/docs/python/docstrings/",
    "https://realpython.com/python-pep8/",
    "https://pypi.org/project/AIDocStringGenerator/",
    "https://www.geeksforgeeks.org/pep-8-coding-style-guide-python/",
    "https://llego.dev/posts/write-python-docstrings-guide-documenting-functions/",
    "https://www.datacamp.com/tutorial/pep8-tutorial-python-code",
    "https://www.programiz.com/python-programming/docstrings",
    "https://marketplace.visualstudio.com/items?itemName=ShanthoshS.docstring-generator-ext",
    "https://stackoverflow.com/questions/3898572/what-are-the-most-common-python-docstring-formats",
    "https://stackoverflow.com/questions/78753860/what-is-the-proper-way-of-including-examples-in-python-docstrings",
    "https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html",
    "https://www.dataquest.io/blog/documenting-in-python-with-docstrings/",
    "https://www.tutorialspoint.com/python/python_docstrings.htm"
]
VECTOR_DIMENSION = 384 # Dimension for all-MiniLM-L6-v2
METRIC = "cosine"
CLOUD = "aws"
REGION = "us-east-1"

print("Initializing services...")
try:
    model = SentenceTransformer(EMBEDDING_MODEL)
    print("Embedding model loaded.")

    # Pinecone
    pc = Pinecone(api_key=PINECONE_API_KEY)
    print(f"Pinecone initialized.") # Environment info is handled internally

    # Ollama Client
    ollama_client = ollama.Client()
    print(f"Ollama client initialized. Attempting to use model: {OLLAMA_GENERATOR_MODEL}")
    print(f"Ensure '{OLLAMA_GENERATOR_MODEL}' is available locally in Ollama (`ollama pull {OLLAMA_GENERATOR_MODEL}`).")


except Exception as e:
    print(f"Error initializing services: {e}")
    exit()

Initializing services...
Embedding model loaded.
Pinecone initialized.
Ollama client initialized. Attempting to use model: deepseek-coder:6.7b
Ensure 'deepseek-coder:6.7b' is available locally in Ollama (`ollama pull deepseek-coder:6.7b`).


In [10]:
# --- 1. Initialize Pinecone ---
pinecone_index = None
if not PINECONE_API_KEY:
    print("ERROR: Pinecone API key not found in environment variables.")
    exit(1)
try:
    pc = Pinecone(api_key=PINECONE_API_KEY)
    existing_indexes = [index_info["name"] for index_info in pc.list_indexes()]
    print(f"Available Pinecone indexes: {existing_indexes}")

    if INDEX_NAME not in existing_indexes:
        print(f"Index '{INDEX_NAME}' not found. Creating new index...")
        pc.create_index(
            name=INDEX_NAME, dimension=VECTOR_DIMENSION, metric=METRIC,
            spec=ServerlessSpec(cloud=CLOUD, region=REGION)
        )
        while not pc.describe_index(INDEX_NAME).status["ready"]:
            print(f"Waiting for index '{INDEX_NAME}' to become ready...")
            time.sleep(5)
        print(f"Index '{INDEX_NAME}' created and ready.")
    else:
        print(f"Connecting to existing index '{INDEX_NAME}'.")
        # Optional: Clear index if you want to re-index fresh
        # print(f"WARNING: Deleting all vectors from existing index '{INDEX_NAME}'...")
        # index_to_clear = pc.Index(INDEX_NAME)
        # index_to_clear.delete(delete_all=True)
        # print(f"All vectors deleted from '{INDEX_NAME}'.")

    pinecone_index = pc.Index(INDEX_NAME)
    print(f"Successfully connected to index '{INDEX_NAME}'. Stats: {pinecone_index.describe_index_stats()}")
except Exception as e:
    print(f"ERROR: Failed to initialize or connect to Pinecone index '{INDEX_NAME}': {e}")
    exit(1)

Available Pinecone indexes: ['fusion-rag-docstring', 'self-rag-docstring', 'rag-docstring', 'corrective-rag-docstring', 'code-aware-rag-docstring']
Connecting to existing index 'code-aware-rag-docstring'.
Successfully connected to index 'code-aware-rag-docstring'. Stats: {'dimension': 384,
 'index_fullness': 0.0,
 'metric': 'cosine',
 'namespaces': {'': {'vector_count': 14}},
 'total_vector_count': 14,
 'vector_type': 'dense'}


In [11]:
# --- Load Data into Pinecone (Only if index is empty) ---
index_stats = pinecone_index.describe_index_stats()
if index_stats.total_vector_count == 0:
    total_docs_loaded = 0
    # Loop through each URL in the list
    for url in TARGET_URL:
        print(f"\nProcessing URL: {url}")
        try:
            # Fetch URL content
            response = requests.get(url, timeout=30) # Use timeout
            response.raise_for_status() # Check for HTTP errors

            # Parse HTML
            soup = BeautifulSoup(response.content, 'html.parser')
            main_content = soup.find('main') or soup.find('article') or soup.find('body')
            page_text = ""
            if main_content:
                page_text = main_content.get_text(separator='\n', strip=True)
            else:
                page_text = soup.get_text(separator='\n', strip=True) # Fallback

            if not page_text or len(page_text) < 50: # Basic check for meaningful content
                print(f" -> Warning: Could not extract sufficient text content from {url}. Skipping.")
                continue # Skip to the next URL

            print(f" -> Extracted text length: {len(page_text)} characters.")

            # Generate embedding
            # Note: Encoding large pages as a single vector might lose detail.
            # Chunking the text into smaller parts is better for real applications.
            embedding = model.encode(page_text).tolist()

            # Prepare and upsert data
            doc_id = str(uuid.uuid4())
            metadata = {"text": page_text, "source": url} # Store the specific URL as source

            pinecone_index.upsert(vectors=[(doc_id, embedding, metadata)])
            print(f" -> Data from {url} loaded into Pinecone with ID: {doc_id}")
            total_docs_loaded += 1
            time.sleep(0.5) # Small delay to be polite to the server

        except requests.exceptions.RequestException as e:
            # Handle errors fetching specific URL, continue with the next
            print(f" -> Error fetching URL {url}: {e}")
            continue
        except Exception as e:
            # Handle other errors during processing/upserting for this URL
            print(f" -> Error processing or upserting data for {url}: {e}")
            continue

    if total_docs_loaded > 0:
        print("Waiting a moment for indexing...")
        time.sleep(2)
        print(pinecone_index.describe_index_stats()) # Show final stats
    else:
        print("Warning: No documents were loaded into the index.")

else:
    print(f"\nIndex already contains {index_stats.total_vector_count} vectors. Skipping data loading.")


Index already contains 14 vectors. Skipping data loading.


# --- Get User Input ---
user_code = ("""
class Conv(Layer):
    def __init__(self, rank, filters, kernel_size, strides=1, padding='valid', data_format=None, dilation_rate=1, groups=1, activation=None, use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, trainable=True, name=None, conv_op=None, **kwargs):
        super(Conv, self).__init__(trainable=trainable, name=name, activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
        self.rank = rank
        if isinstance(filters, float):
            filters = int(filters)
        if filters is not None and filters < 0:
            raise ValueError(f'Received a negative value for `filters`.Was expecting a positive value, got {filters}.')
        self.filters = filters
        self.groups = groups or 1
        self.kernel_size = conv_utils.normalize_tuple(kernel_size, rank, 'kernel_size')
        self.strides = conv_utils.normalize_tuple(strides, rank, 'strides')
        self.padding = conv_utils.normalize_padding(padding)
        self.data_format = conv_utils.normalize_data_format(data_format)
        self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, rank, 'dilation_rate')
        self.activation = activations.get(activation)
        self.use_bias = use_bias
        self.kernel_initializer = initializers.get(kernel_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)
        self.kernel_constraint = constraints.get(kernel_constraint)
        self.bias_constraint = constraints.get(bias_constraint)
        self.input_spec = InputSpec(min_ndim=self.rank + 2)
        self._validate_init()
        self._is_causal = self.padding == 'causal'
        self._channels_first = self.data_format == 'channels_first'
        self._tf_data_format = conv_utils.convert_data_format(self.data_format, self.rank + 2)

    def _validate_init(self):
        if self.filters is not None and self.filters % self.groups != 0:
            raise ValueError('The number of filters must be evenly divisible by the number of groups. Received: groups={}, filters={}'.format(self.groups, self.filters))
        if not all(self.kernel_size):
            raise ValueError('The argument `kernel_size` cannot contain 0(s). Received: %s' % (self.kernel_size,))
        if not all(self.strides):
            raise ValueError('The argument `strides` cannot contains 0(s). Received: %s' % (self.strides,))
        if self.padding == 'causal' and (not isinstance(self, (Conv1D, SeparableConv1D))):
            raise ValueError('Causal padding is only supported for `Conv1D`and `SeparableConv1D`.')

    def build(self, input_shape):
        input_shape = tensor_shape.TensorShape(input_shape)
        input_channel = self._get_input_channel(input_shape)
        if input_channel % self.groups != 0:
            raise ValueError('The number of input channels must be evenly divisible by the number of groups. Received groups={}, but the input has {} channels (full input shape is {}).'.format(self.groups, input_channel, input_shape))
        kernel_shape = self.kernel_size + (input_channel // self.groups, self.filters)
        self.kernel = self.add_weight(name='kernel', shape=kernel_shape, initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, trainable=True, dtype=self.dtype)
        if self.use_bias:
            self.bias = self.add_weight(name='bias', shape=(self.filters,), initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, trainable=True, dtype=self.dtype)
        else:
            self.bias = None
        channel_axis = self._get_channel_axis()
        self.input_spec = InputSpec(min_ndim=self.rank + 2, axes={channel_axis: input_channel})
        if self.padding == 'causal':
            tf_padding = 'VALID'
        elif isinstance(self.padding, str):
            tf_padding = self.padding.upper()
        else:
            tf_padding = self.padding
        tf_dilations = list(self.dilation_rate)
        tf_strides = list(self.strides)
        tf_op_name = self.__class__.__name__
        if tf_op_name == 'Conv1D':
            tf_op_name = 'conv1d'
        self._convolution_op = functools.partial(nn_ops.convolution_v2, strides=tf_strides, padding=tf_padding, dilations=tf_dilations, data_format=self._tf_data_format, name=tf_op_name)
        self.built = True

    def call(self, inputs):
        input_shape = inputs.shape
        if self._is_causal:
            inputs = array_ops.pad(inputs, self._compute_causal_padding(inputs))
        outputs = self._convolution_op(inputs, self.kernel)
        if self.use_bias:
            output_rank = outputs.shape.rank
            if self.rank == 1 and self._channels_first:
                bias = array_ops.reshape(self.bias, (1, self.filters, 1))
                outputs += bias
            elif output_rank is not None and output_rank > 2 + self.rank:

                def _apply_fn(o):
                    return nn.bias_add(o, self.bias, data_format=self._tf_data_format)
                outputs = conv_utils.squeeze_batch_dims(outputs, _apply_fn, inner_rank=self.rank + 1)
            else:
                outputs = nn.bias_add(outputs, self.bias, data_format=self._tf_data_format)
        if not context.executing_eagerly():
            out_shape = self.compute_output_shape(input_shape)
            outputs.set_shape(out_shape)
        if self.activation is not None:
            return self.activation(outputs)
        return outputs

    def _spatial_output_shape(self, spatial_input_shape):
        return [conv_utils.conv_output_length(length, self.kernel_size[i], padding=self.padding, stride=self.strides[i], dilation=self.dilation_rate[i]) for i, length in enumerate(spatial_input_shape)]

    def compute_output_shape(self, input_shape):
        input_shape = tensor_shape.TensorShape(input_shape).as_list()
        batch_rank = len(input_shape) - self.rank - 1
        if self.data_format == 'channels_last':
            return tensor_shape.TensorShape(input_shape[:batch_rank] + self._spatial_output_shape(input_shape[batch_rank:-1]) + [self.filters])
        else:
            return tensor_shape.TensorShape(input_shape[:batch_rank] + [self.filters] + self._spatial_output_shape(input_shape[batch_rank + 1:]))

    def _recreate_conv_op(self, inputs):
        return False

    def get_config(self):
        config = {'filters': self.filters, 'kernel_size': self.kernel_size, 'strides': self.strides, 'padding': self.padding, 'data_format': self.data_format, 'dilation_rate': self.dilation_rate, 'groups': self.groups, 'activation': activations.serialize(self.activation), 'use_bias': self.use_bias, 'kernel_initializer': initializers.serialize(self.kernel_initializer), 'bias_initializer': initializers.serialize(self.bias_initializer), 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 'activity_regularizer': regularizers.serialize(self.activity_regularizer), 'kernel_constraint': constraints.serialize(self.kernel_constraint), 'bias_constraint': constraints.serialize(self.bias_constraint)}
        base_config = super(Conv, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def _compute_causal_padding(self, inputs):
        left_pad = self.dilation_rate[0] * (self.kernel_size[0] - 1)
        if getattr(inputs.shape, 'ndims', None) is None:
            batch_rank = 1
        else:
            batch_rank = len(inputs.shape) - 2
        if self.data_format == 'channels_last':
            causal_padding = [[0, 0]] * batch_rank + [[left_pad, 0], [0, 0]]
        else:
            causal_padding = [[0, 0]] * batch_rank + [[0, 0], [left_pad, 0]]
        return causal_padding

    def _get_channel_axis(self):
        if self.data_format == 'channels_first':
            return -1 - self.rank
        else:
            return -1

    def _get_input_channel(self, input_shape):
        channel_axis = self._get_channel_axis()
        if input_shape.dims[channel_axis].value is None:
            raise ValueError('The channel dimension of the inputs should be defined. Found `None`.')
        return int(input_shape[channel_axis])

    def _get_padding_op(self):
        if self.padding == 'causal':
            op_padding = 'valid'
        else:
            op_padding = self.padding
        if not isinstance(op_padding, (list, tuple)):
            op_padding = op_padding.upper()
        return op_padding
""")
if not user_code.strip():
    print("No code provided. Exiting.")
    exit()

In [15]:
def parse_code_for_entities(user_code):
    """Parses code to extract class name, parents, and public methods."""
    class_name_match = re.search(r"class\s+(\w+)", user_code)
    parent_match = re.search(r"class\s+\w+\((.*?)\):", user_code)
    method_matches = re.findall(r"def\s+(_?\w+)\s*\(", user_code)

    entities = {
        "class_name": class_name_match.group(1) if class_name_match else "",
        "parent_classes": [p.strip() for p in parent_match.group(1).split(',')] if parent_match else [],
        "public_methods": [m for m in method_matches if not m.startswith('__') or m == '__init__']
    }
    # Clean up empty strings that might result from splitting
    entities["parent_classes"] = [p for p in entities["parent_classes"] if p]
    #print('entities', entities)
    return entities

In [17]:
def construct_enriched_query(entities):
    """Builds a detailed query string from extracted code entities."""
    query_parts = ["python docstring conventions and examples"]
    if entities["class_name"]:
        query_parts.append(f"for a class named {entities['class_name']}")
    if entities["parent_classes"]:
        query_parts.append(f"that inherits from {' and '.join(entities['parent_classes'])}")
    if entities["public_methods"]:
        query_parts.append(f"with methods like {', '.join(entities['public_methods'][:3])}")
    return " ".join(query_parts)

In [19]:
def retrieve_context(query, pinecone_index, emb_model):
    """Embeds a query and retrieves context from Pinecone."""
    if not pinecone_index or not query:
        return "", "N/A", 0
    
    try:
        embedding = emb_model.encode(query).tolist()
        results = pinecone_index.query(vector=embedding, top_k=1, include_metadata=True)
        if results.matches:
            context = results.matches[0].metadata.get('text', '')
            #source = results.matches[0].metadata.get('source', 'Pinecone')
            return context
        else:
            return "" # Successful query, no matches
    except Exception as e:
        print(f"    -> Pinecone retrieval error: {e}")
        # On exception, return default values to prevent crash
        return ""

In [21]:
def revised_prompt(ctx, helper_model_name):
    context = ctx
    OLLAMA_REWRITER_MODEL = helper_model_name
    rewritten_request = None
    rewriter_prompt = f"""
    You are an helpful assistant that refines prompts. Given the following context from the RAG knowledge base along with python code: {context}, generate an optimized prompt for another AI whose sole task is to create a Python docstring for the code and your output.
    The optimized prompt should clearly state the task, and subtly incorporate hints from the context if relevant, without necessarily repeating the entire context.
    Focus on creating a self-contained, clear instruction for the next AI.
    
    Generate only the optimized context prompt text for the docstring generation AI.
    """
    try:
        rewriter_response = ollama_client.generate(
            model=OLLAMA_REWRITER_MODEL,
            prompt=rewriter_prompt,
            #options={'temperature': 0.3} # Lower temperature for more focused rewriting
        )
        rewritten_request = rewriter_response.get('response', '').strip()
    except Exception as e:
        rewritten_request = rewriter_prompt # Ensure it's None on error
    return rewritten_request

In [23]:
def final_content_generation(context, user_code, rewritten_req, ollama_client, OLLAMA_MODEL):
    messages = [
        {'role': 'system', 'content': 'You are an expert Python programmer tasked with generating docstrings. You will receive context (if found), and the code to document in the final message. Use the context only if it is directly relevant to explaining the provided code. Return only the docstring and dont include the given python code in the output.'}
    ]
    
    # Add the retrieved context as a separate user message, if it exists
    if context:
        # Include source information in the context message for clarity
        messages.append({'role': 'user', 'content': f"Here is potentially relevant context retrieved from python\n{user_code}\n and content :\n---\n{rewritten_req}\n---"})
    else:
        # Explicitly state if no context was found or provided
        messages.append({'role': 'user', 'content': "No specific context was retrieved or provided for this request."})
    
    # Add the final user message with the code and the explicit request
    messages.append({'role': 'user', 'content': f"Based on any relevant context provided earlier, generate the Python docstring for the following code:\n```python\n{user_code}\n```\n\nOutput *only* the complete docstring content itself, starting with triple quotes. Dont include python codes in the output. You need to generate a single docstring as whole for given python class code."})
    
    try:
        response = ollama_client.chat(
            model=OLLAMA_MODEL,
            messages=messages
        )
    
        generated_docstring = response.get('message', {}).get('content', '').strip()
        if generated_docstring.startswith("```python"):
            generated_docstring = generated_docstring[len("```python"):].strip()
        elif generated_docstring.startswith("```"):
             generated_docstring = generated_docstring[len("```"):].strip()
    
        if generated_docstring.endswith("```"):
            generated_docstring = generated_docstring[:-len("```")].strip()
    
        # Ensure it starts with triple quotes if possible, otherwise print as is
        if not (generated_docstring.startswith('"""') or generated_docstring.startswith("'''")):
             print("(Model might not have generated a perfectly formatted docstring)")
    except Exception as e:
        print(f"Error communicating with Ollama chat endpoint: {e}")
    #print(generated_docstring)
    return generated_docstring

In [25]:
def run_code_aware_rag_pipeline(user_code, pinecone_index, emb_model, ollama_cli, helper_model, generator_model):
    """Executes the full Code-Aware RAG pipeline by calling helper functions."""
    #print("--- Running Code-Aware RAG pipeline ---")
    
    # 1. Parse code
    entities = parse_code_for_entities(user_code)
    
    # 2. Construct enriched query
    enriched_query = construct_enriched_query(entities)
    enriched_query = enriched_query +  user_code
    #print(f" -> Enriched Query: {enriched_query}")
    
    # 3. Retrieve context
    retrieved_ctx = retrieve_context(enriched_query, pinecone_index, emb_model)
    retrieved_contexts_list.append(retrieved_ctx)
    #print(retrieved_ctx)
    if retrieved_ctx:
        pass
    else:
        print(" -> No relevant context was retrieved.")

    rewritten_request = revised_prompt(retrieved_ctx, helper_model)
    # 4. Generate docstring
    raw_doc = final_content_generation(retrieved_ctx, user_code, rewritten_request, ollama_cli,  generator_model)
    #print('raw_doc', raw_doc)
    #print('llm_c', llm_c)
    return raw_doc

In [27]:
class_files_df = pd.read_pickle('class_files_df.pkl')

In [29]:
ground_truth = class_files_df["Comments"].to_list()

In [31]:
generated_docstrings_list = []
retrieved_contexts_list = []
rewritten_contexts_list = []

In [33]:
for i, row in class_files_df.iterrows():
    user_code = row["Code_without_comments"]
    output = run_code_aware_rag_pipeline(user_code=user_code, pinecone_index=pinecone_index, emb_model=model, ollama_cli=ollama_client, helper_model=OLLAMA_HELPER_MODEL, generator_model=OLLAMA_GENERATOR_MODEL)
    generated_docstrings_list.append(output)
class_files_df["RAG_Docstring"] = generated_docstrings_list

    -> Pinecone retrieval error: Failed to connect; did you specify the correct index name?
 -> No relevant context was retrieved.
(Model might not have generated a perfectly formatted docstring)


In [35]:
def clean_rag_docstring(docstring_text):
    if not isinstance(docstring_text, str):
        return docstring_text

    if docstring_text.startswith("# ERROR:") or docstring_text.startswith("# SKIPPED:"):
        return docstring_text

    text = docstring_text.strip()

    if text.startswith("```python"):
        text = text[len("```python"):].strip()
    elif text.startswith("```"):
        text = text[len("```"):].strip()
    if text.endswith("```"):
        text = text[:-len("```")].strip()

    content_inside_quotes = None
    first_double_quotes = text.find('"""')
    if first_double_quotes != -1:
        last_double_quotes = text.rfind('"""')
        if last_double_quotes > first_double_quotes and (last_double_quotes + 3) <= len(text):
            content_inside_quotes = text[first_double_quotes + 3 : last_double_quotes].strip()

    if content_inside_quotes is None or not content_inside_quotes.strip():
        first_single_quotes = text.find("'''")
        if first_single_quotes != -1:
            last_single_quotes = text.rfind("'''")
            if last_single_quotes > first_single_quotes and (last_single_quotes + 3) <= len(text):
                content_inside_quotes = text[first_single_quotes + 3 : last_single_quotes].strip()
    
    if content_inside_quotes is not None and content_inside_quotes.strip():
        final_text_to_clean = content_inside_quotes
    else:
        final_text_to_clean = text
        if final_text_to_clean.startswith('"""') and final_text_to_clean.endswith('"""') and len(final_text_to_clean) >= 6:
            final_text_to_clean = final_text_to_clean[3:-3].strip()
        elif final_text_to_clean.startswith("'''") and final_text_to_clean.endswith("'''") and len(final_text_to_clean) >= 6:
            final_text_to_clean = final_text_to_clean[3:-3].strip()

    final_text_to_clean = re.sub(r"(?i)^class\s+\w+:\s*\n?", "", final_text_to_clean).strip()
    
    return final_text_to_clean

class_files_df["RAG_Docstring"] = class_files_df["RAG_Docstring"].astype(str).apply(clean_rag_docstring)

In [37]:
generated_docstrings_list

['"""\nAdamax is an extension of the Adam optimization algorithm that includes the element-wise application of the exponential moving average across all training examples, \nas well as a separate update step on the per-example parameter gradients. \n\nThe learning rate schedule for this optimizer can be configured via constructor arguments. Additionally, the class attributes \'_HAS_AGGREGATE_GRAD\' and \'_USES_Nesterov\' provide information about the optimization algorithm’s characteristics.\n\nAttributes:\n    _HAS_AGGREGATE_GRAD (bool): Flag indicating that Adamax updates are applied on aggregate gradients by iterating over them once. \n                                 This attribute is constant and its value is True.\n\nMethods:\n    __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-7, name=\'Adamax\', **kwargs)\n        Initializes Adamax with given hyperparameters and default values if not provided. \n        Stores the learning rate schedule in \'self._deca

In [39]:
def calculate_rouge(df, reference_column, hypothesis_column):
    rouge = Rouge()

    def calculate_score(row):
        scores = rouge.get_scores(row[hypothesis_column].lower(), row[reference_column].lower())
        return scores[0]['rouge-1']['f']

    df['ROUGE-1 ' + reference_column] = df.apply(calculate_score, axis=1)
    return df

# Calculate ROUGE-1 scores
data_1 = calculate_rouge(class_files_df, 'Comments', 'RAG_Docstring')

In [40]:
def calculate_bleu(df, reference_column, hypothesis_column):
    nltk.download('punkt')

    def calculate_score(row):
        reference = [row[reference_column].lower().split()]
        hypothesis = row[hypothesis_column].lower().split()
        score = sentence_bleu(reference, hypothesis, weights=(0.25, 0.25, 0.25, 0.25))
        return score

    df['BLEU Score ' + reference_column] = df.apply(calculate_score, axis=1)
    return df

In [41]:
# Calculate BLEU scores
data_1 = calculate_bleu(data_1, 'Comments', 'RAG_Docstring')

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/balajivenktesh/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [42]:
# Calculate BERT encoding score, using cosine similarity
def calculate_bert_score(ground_truth, generated):
    # Calculate BERT score
    _, _, bert_score_f1 = score([ground_truth], [generated], lang='en', model_type='bert-base-uncased')

    return bert_score_f1.item()   

In [43]:
# Calculate BLEU scores
list_append_1 = []
for index, row in data_1.iterrows():
    list_append_1.append(calculate_bert_score(str(row["Comments"]), str(row["RAG_Docstring"])))

In [46]:
data_1["Accuracy"] = list_append_1

In [49]:
# Calculate number of syllables in docstring
def count_syllables(word):
    # Remove punctuation
    word = re.sub(r'[^a-zA-Z]', '', word)
    
    # Vowel count
    vowels = 'aeiouy'
    syllables = 0
    last_was_vowel = False
    for char in word:
        if char.lower() in vowels:
            if not last_was_vowel:
                syllables += 1
            last_was_vowel = True
        else:
            last_was_vowel = False
    
    # Adjust syllable count for words ending in 'e'
    if word.endswith(('e', 'es', 'ed')):
        syllables -= 1
    
    # Adjust syllable count for words with no vowels
    if syllables == 0:
        syllables = 1
    
    return syllables

In [50]:
# Calculate Flesch reading score
def flesch_reading_ease(text):
    sentences = text.count('.') + text.count('!') + text.count('?') + 1
    words = len(re.findall(r'\b\w+\b', text))
    syllables = sum(count_syllables(word) for word in text.split())
    
    # Calculate Flesch Reading Ease score
    score = 206.835 - 1.015 * (words / sentences) - 84.6 * (syllables / words)
    
    return score

In [51]:
# Calculate Easy scores
list_append_2 = []
for index, row in data_1.iterrows():
    list_append_2.append(flesch_reading_ease(str(row["RAG_Docstring"])))

In [52]:
data_1["Ease"] = list_append_2

In [53]:
#%%
def compress(input):
	return zlib.compress(input.encode())

In [54]:
def conciness(ground_truth, generated):
    comp1 = compress(ground_truth)
    comp2 = compress(generated)
    return sys.getsizeof(comp2) / sys.getsizeof(comp1)

In [55]:
# Calculate Conciseness scores
list_append_3 = []
for index, row in data_1.iterrows():
    list_append_3.append(conciness(str(row["Comments"]), str(row["RAG_Docstring"])))

In [56]:
data_1["Conciseness"] = list_append_3

In [57]:
def calculate_parameter_coverage(code_str, docstring_str):
    """
    Calculates the proportion of function/method parameters mentioned in the docstring.
    Returns a float (0.0 to 1.0) or None if no parameters are found in the code.
    """        
    match = re.search(r"def\s+\w+\s*\((.*?)\):", code_str)
    if not match:
        match = re.search(r"async\s+def\s+\w+\s*\((.*?)\):", code_str) 

    if not match:
        return None 

    params_str = match.group(1)
    if not params_str.strip(): 
        return 1.0 

    potential_params = [p.strip().split('=')[0].split(':')[0].strip() for p in params_str.split(',')]
    actual_params = [p for p in potential_params if p and p not in ('self', 'cls') and not p.startswith('*')]

    if not actual_params:
        return 1.0 

    covered_params = 0
    docstring_lower = docstring_str.lower()
    for param_name in actual_params:
        if re.search(r"\b" + re.escape(param_name.lower()) + r"\b", docstring_lower):
            covered_params += 1
        elif f"{param_name.lower()}:" in docstring_lower or f"parameter {param_name.lower()}" in docstring_lower:
             covered_params += 1
    return covered_params / len(actual_params) if actual_params else 1.0

In [58]:
# --- Return Value Coverage Calculation Function ---
def calculate_return_coverage(code_str, docstring_str):
    """
    Checks if the docstring mentions a return value if the code seems to have one.
    Returns 1 if covered/not applicable, 0 if potentially missing, None on error.
    """
    has_return_statement = False
    for line in code_str.splitlines():
        stripped_line = line.strip()
        if stripped_line.startswith("return ") and not stripped_line.endswith("return None") and len(stripped_line) > len("return "):
            has_return_statement = True
            break
    
    if not has_return_statement:
        return 1.0 

    docstring_lower = docstring_str.lower()
    return_keywords = ["return", "returns", "yield", "yields"] 
    if any(keyword in docstring_lower for keyword in return_keywords):
        return 1.0
    else:
        return 0.0

In [59]:
# --- Basic Faithfulness Metric Function ---
def calculate_basic_faithfulness(generated_docstring, retrieved_context_text):
    """
    Calculates a basic faithfulness score based on token overlap.
    This is a crude proxy for actual faithfulness.
    Returns a float (0.0 to 1.0) or None.
    """
    # Simple tokenization and stopword removal
    stop_words = set(["a", "an", "the", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "do", "does", "did", "will", "would", "should", "can", "could", "may", "might", "must", "and", "or", "but", "if", "of", "at", "by", "for", "with", "about", "to", "in", "on", "this", "that", "it", "its", "you", "your", "i", "me", "my", "he", "she", "him", "her", "they", "them", "their"])
    
    try:
        gen_tokens = set(token.lower() for token in re.findall(r'\b\w+\b', generated_docstring) if token.lower() not in stop_words)
        ctx_tokens = set(token.lower() for token in re.findall(r'\b\w+\b', retrieved_context_text) if token.lower() not in stop_words)
    except Exception as e:
        print(f"Error tokenizing for faithfulness: {e}")
        return None

    if not gen_tokens: # If generated docstring has no valid tokens after filtering
        return 0.0 

    overlapping_tokens = gen_tokens.intersection(ctx_tokens)
    
    return len(overlapping_tokens) / len(gen_tokens) if gen_tokens else 0.0

In [60]:
def calculate_exception_coverage(code_str, docstring_str):
    if not all(isinstance(s, str) for s in [code_str, docstring_str]) or not docstring_str.strip() or docstring_str.startswith(("# ERROR:", "# SKIPPED:")): return None
    raised_exceptions = set(re.findall(r"raise\s+(\w+)", code_str)) # Basic: finds exception names
    if not raised_exceptions: return 1.0 # No exceptions to cover
    
    docstring_lower = docstring_str.lower()
    mentions_raises_section = "raises:" in docstring_lower
    covered_exceptions = 0
    for exc_name in raised_exceptions:
        if re.search(r"\b" + re.escape(exc_name.lower()) + r"\b", docstring_lower):
            covered_exceptions += 1
            
    # If a "Raises:" section exists, it's good, even if not all specific exceptions are named (simple check)
    if mentions_raises_section and raised_exceptions: return 1.0 
    if not raised_exceptions: return 1.0 # Should have been caught above
    return covered_exceptions / len(raised_exceptions) if raised_exceptions else 1.0

In [61]:
# --- Adherence to Docstring Conventions (Pydocstyle) ---
PYDOCSTYLE_ENABLED = True
def check_docstring_adherence_pydocstyle(code_str, generated_docstring_content):
    """
    Checks adherence of a generated docstring to PEP 257 using pydocstyle.
    The generated_docstring_content should be the *content* of the docstring,
    not including the triple quotes.
    Returns:
        float: A score from 0.0 to 1.0 (1.0 means no errors, 0.0 means many errors).
               Returns None if pydocstyle is not enabled or an error occurs.
    """
    # Sanitize content for embedding within triple quotes
    safe_content = generated_docstring_content.replace('\\', '\\\\') # Escape backslashes
    safe_content = safe_content.replace('"""', '\\"\\"\\"') # Escape internal triple-double-quotes
    safe_content = safe_content.replace("'''", "\\'\\'\\'") # Escape internal triple-single-quotes
    
    # Prepare the content for insertion, ensuring correct indentation for multi-line docstrings
    lines = safe_content.split('\n')
    if len(lines) == 1:
        # Single line docstring content, no special indentation needed beyond the initial one
        indented_docstring_body = lines[0]
    else:
        # Multi-line: first line as is, subsequent lines indented with 4 spaces
        # This assumes the docstring will be placed with an initial 4-space indent.
        indented_docstring_body = lines[0] + '\n' + '\n'.join(['    ' + line for line in lines[1:]])


    # Construct a minimal, valid Python snippet for pydocstyle
    # Try to place the docstring correctly within a class or function if identifiable
    code_prefix = ""
    code_suffix = "\n    pass" # Default suffix

    class_match = re.search(r"^(.*\bclass\s+\w+\s*\(?.*\)?:)", code_str, re.MULTILINE)
    func_match = re.search(r"^(.*\b(async\s+)?def\s+\w+\s*\(?.*\)?:)", code_str, re.MULTILINE)

    if class_match:
        header = class_match.group(1)
        # Find the end of the header line to insert the docstring
        code_prefix = code_str[:class_match.end()] + f'\n    """{indented_docstring_body}"""'
        code_suffix = code_str[class_match.end():] # The rest of the original class code
        # Ensure there's at least a 'pass' or some body if the original was just a header
        if not code_suffix.strip() or code_suffix.strip().startswith("#"):
            code_suffix = "\n    pass" + code_suffix 
        code_for_pydocstyle_check = code_prefix + code_suffix

    elif func_match:
        header = func_match.group(1)
        code_prefix = code_str[:func_match.end()] + f'\n    """{indented_docstring_body}"""'
        code_suffix = code_str[func_match.end():]
        if not code_suffix.strip() or code_suffix.strip().startswith("#"):
            code_suffix = "\n    pass" + code_suffix
        code_for_pydocstyle_check = code_prefix + code_suffix
    else:
        # Fallback: treat as module-level docstring if no class/def found
        # This is less ideal as the original code_str might not be a full module
        code_for_pydocstyle_check = f'"""{generated_docstring_content}"""\n{code_str}'


    errors_count = 0
    filtered_errors_count = 0
    tmp_file_path = None 
    try:
        with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as tmp_file:
            tmp_file.write(code_for_pydocstyle_check)
            tmp_file_path = tmp_file.name
        
        command = ['pydocstyle', tmp_file_path]
        process = subprocess.run(command, capture_output=True, text=True, encoding='utf-8')
        
        output = process.stdout.strip()
        print
        if output:
            all_errors = output.splitlines()
            errors_count = len(all_errors)
            # Filter out D100 (Missing docstring in public module) as it's an artifact
            # and D101, D102, D103 if we are only checking the first docstring.
            # For now, just D100 as the dummy structure is a module.
            filtered_errors = [err for err in all_errors if not err.strip().endswith("D100: Missing docstring in public module")]
            filtered_errors_count = len(filtered_errors)
        
        if process.stderr:
            if "Cannot parse file" in process.stderr or "unexpected EOF while parsing" in process.stderr or "invalid syntax" in process.stderr :
                 print(f"Pydocstyle CRITICAL PARSE ERROR for temp file {tmp_file_path}: {process.stderr}")
                 print("--- Content written to temp file that failed parsing: ---")
                 print(code_for_pydocstyle_check)
                 print("--------------------------------------------------------")
                 return 0.0 # Penalize heavily for parse error

    except Exception as e:
        print(f"An exception occurred during pydocstyle check: {str(e)}")
        if tmp_file_path and os.path.exists(tmp_file_path): # Check if tmp_file_path was assigned
             try:
                with open(tmp_file_path, 'r', encoding='utf-8') as f_err:
                    print(f"Content of temp file '{tmp_file_path}' that caused exception:\n{f_err.read()}")
             except Exception as read_err:
                print(f"Could not read temp file {tmp_file_path}: {read_err}")
        return None # Error during check
    finally:
        if tmp_file_path and os.path.exists(tmp_file_path):
            os.remove(tmp_file_path)
    
    # Normalize score based on filtered errors.
    # Using 10 as the denominator makes the score less harsh than 5.
    return max(0.0, 1.0 - (filtered_errors_count / 10.0))

In [62]:
param_coverage_list = []
return_coverage_list = []
faithfulness_list = []

In [63]:
for index, row in data_1.iterrows():
    param_coverage_list.append(calculate_parameter_coverage(str(row["Code_without_comments"]), str(row["RAG_Docstring"])))

In [64]:
data_1["Parameter_Coverage"] = param_coverage_list

In [65]:
for index, row in data_1.iterrows():
    return_coverage_list.append(calculate_return_coverage(str(row["Code_without_comments"]), str(row["RAG_Docstring"])))

In [66]:
data_1["Return_Coverage"] = return_coverage_list

In [67]:
data_1["Retrieved_Contexts"] = retrieved_contexts_list

In [68]:
for index, row in data_1.iterrows():
    faithfulness_list.append(calculate_basic_faithfulness(str(row["RAG_Docstring"]), str(row["Retrieved_Contexts"])))
    #faithfulness_list.append(faithfulness_score)
#if faithfulness_score is not None: print(f"    -> Basic Faithfulness: {faithfulness_score:.4f}")

In [69]:
data_1["Faithfulness_Score"] = faithfulness_list

In [70]:
pydocstyle_adherence_list_1 = []

In [71]:
for index, row in data_1.iterrows():
    pydocstyle_adherence_list_1.append(check_docstring_adherence_pydocstyle(str(row["Code_without_comments"]), str(row["RAG_Docstring"])))

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [72]:
data_1["PythonStyle_Adherence"] = pydocstyle_adherence_list_1

In [73]:
exception_coverage_list = []

In [74]:
for index, row in data_1.iterrows():
    exception_coverage_list.append(calculate_exception_coverage(str(row["Code_without_comments"]), str(row["RAG_Docstring"])))

In [75]:
data_1["Exception_Coverage"] = exception_coverage_list

In [76]:
data_1.to_excel('./deepseek/Code_Aware_RAG.xlsx')

In [77]:
data_1.to_pickle('./deepseek/Code_Aware_RAG.pkl')