In [1]:
import torch
# the following cell is used to load tokenized data for testing, note that this loads the train test split, not the unsplit data. Unsplit data is used for testing our 
# code parrot jupyter errors dataset and the jupyter errors dataset. To load the unsplit data, you can uncomment the lines below and comment out the above lines.

# To load tokenized data, ensure the path is correct. Tokenizer as well as code to save tokenized content is in the run model file.

load_path = "dataset\\tokenized_content\\file_name.pt"

tokenized_data = torch.load(load_path)

train_ids = tokenized_data['train_ids']
test_ids = tokenized_data['test_ids']
train_masks = tokenized_data['train_masks']
test_masks = tokenized_data['test_masks']
train_labels = tokenized_data['train_labels']
test_labels = tokenized_data['test_labels']

# Uncomment the lines below to load unsplit data
# test_ids = tokenized_data['test_ids']
# test_masks = tokenized_data['test_masks']
# test_labels = tokenized_data['test_labels']

print("Tokenized data loaded successfully.")

  tokenized_data = torch.load(load_path)


Tokenized data loaded successfully.


The following cell contains our configuration for cell level bug detection using Flake8. Flake8 was adapted for cell-level bug detection by first decoding tokenized content and mapping each notebook cell to its corresponding line numbers, preserving the line numbering. The notebooks were then converted into Python scripts, allowing Flake8 to analyze the entire file. We considered running Flake8 on each cell individually, but this approach caused many false positives. When Flake8 detects an error, we map the error’s line number back to the corresponding cell and create a prediction array where the buggy cell is marked with a 1, and all preceding cells are marked with 0. Since Flake8 only reports up to the first fatal error it encounters without reporting subsequent errors error, subsequent cells are not considered. We trim the labels accordingly to ensure a fair comparison. If no errors are detected, the prediction array consisted entirely of 0s, indicating no buggy cells. Errors we used in our Flake8 configuration were selected to reduce false positives in bug detection avoiding things such as stylistic recommendations.

In [None]:
from transformers import RobertaTokenizer
import tempfile
import subprocess
import os
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import buggy_cell_vector_evalualtion_clean
import re

# tokenizer setup for decoding
tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base')

# setting up the special tokens use for finding cell boundaries in tokenized content
start_special_tokens = [f"<CELL_{i}>" for i in range(1, 1024)]
end_special_tokens = [f"<END_CELL_{i}>" for i in range(1, 1024)]
all_special_tokens = start_special_tokens + end_special_tokens

# Add tokens if not already in the vocabulary.
for token in all_special_tokens:
    if token not in tokenizer.get_vocab():
        tokenizer.add_tokens([token])

# setting up our evaluation class
vector_eval = buggy_cell_vector_evalualtion_clean.VectorEval()

###### start decode and getting line numbers
flat_codes, flat_labels, all_cell_line_ranges = [], [], []

for chunks_ids, chunks_masks, chunk_label_lists in tqdm(
    zip(test_ids, test_masks, test_labels),
    total=len(test_ids),
    desc="Decoding & cleaning notebooks",
    dynamic_ncols=True,
):
    file_ids = chunks_ids[:4] # we use 4 because that is the same number of chunks used when evaluating JupOtter
    chunks_label = chunk_label_lists[:4]
  
    flat_list = file_ids.reshape(-1).tolist() 

    # Decode with special tokens so they can be used to detect cell boundaries
    decoded_with_cells = tokenizer.decode(flat_list, skip_special_tokens=False)

    # Split lines to find cell boundaries in terms of line numbers
    lines = decoded_with_cells.split('\n')
    cell_line_ranges = [] #this will hold tuples of (start_line, end_line) corresponding to each cell in a notebook
    current_cell_start = None
    
    for idx, line in enumerate(lines):
        if re.search(r"<CELL_\d+>", line): # if the current line is a new cell
            # Start a new cell
            if current_cell_start is not None:
                # Close previous cell
                cell_line_ranges.append((current_cell_start, idx - 1))
            current_cell_start = idx + 1  # content starts next line
        elif re.search(r"<END_CELL_\d+>", line): # if the current line is end of a cell
            # End current cell
            if current_cell_start is not None:
                cell_line_ranges.append((current_cell_start, idx - 1))
                current_cell_start = None
    
    # close last cell
    if current_cell_start is not None:
        cell_line_ranges.append((current_cell_start, len(lines) - 1))
    
    # remove the special tokens to get clean code
    decoded_clean = decoded_with_cells
    for token in tokenizer.all_special_tokens:
        pattern = re.escape(token)
        decoded_clean = re.sub(pattern, "", decoded_clean)
    decoded_clean = re.sub(r"<CELL_\d+>", "", decoded_clean)
    decoded_clean = re.sub(r"<END_CELL_\d+>", "", decoded_clean)
    
    flat_codes.append(decoded_clean)
    flat_labels.append([int(item.item()) for sublist in chunks_label for item in sublist])

    all_cell_line_ranges.append(cell_line_ranges)

results = []

tq = tqdm(
    enumerate(zip(flat_codes, flat_labels, all_cell_line_ranges)),
    total=len(flat_codes),
    desc="Static analysis Eval",
    dynamic_ncols=True,
    leave=True,
)
skippedNoLineNum = 0
buggy_pred = 0
non_buggy_pred = 0
skipped = 0
for i, (code, label, cell_ranges) in tq:
    cell_level_prediction = []
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", encoding="utf-8", delete=False) as tmp_file:
        tmp_file.write(code)
        tmp_filename = tmp_file.name


    try:
        result = subprocess.run(
            [ 
                    "flake8",
                    "--select=E9,F402,F405,F406,F407,F501,F502,F503,F505,F506,F507,F508,F509,F521,F524,F525,F621,F622,F633,F701,F702,F704,F706,F707,F821,F822,F823,F831,F901",
                    tmp_filename
                ],
        capture_output=True,
        text=True,
        encoding='utf-8' 
        )

        is_buggy = 0 if result.returncode == 0 else 1

        if is_buggy:

            match = re.search(r"line (\d+)", result.stdout) # search for line number with the error
            match2 = re.search(r":(\d+):\d+:", result.stdout)
            if match or match2: # if we found a line number
                if match:
                    line_number = int(match.group(1))
                else:
                    line_number = int(match2.group(1))

                for i in cell_ranges: # match the line number with error to its corresponding cell all cells after error are not counted
                    if line_number >= i[0] and line_number <= i[1]:
                        cell_level_prediction.append(1)
                        break
                    else:
                        cell_level_prediction.append(0)
                
            else: # if buggy but could not find a line number
                skippedNoLineNum += 1                                                                                        
            buggy_pred += 1 
        else:
            non_buggy_pred += 1
            for i in cell_ranges:
                cell_level_prediction.append(0) # no buggy predictions, so make array of 0s indicating no bugs in any cells

        vector_eval.eval_vector(cell_level_prediction, label[:len(cell_level_prediction)]) # cell level prediction needs to be trimmed because it cant detect errors after the first one it finds
        results.append((cell_level_prediction, label))
    except subprocess.TimeoutExpired:
        skipped += 1
        print(f"Skipped file {tmp_filename} due to timeout.")
        continue  # skip this file and move on

    os.remove(tmp_filename)
 
    
    # live metrics, only used to display in the tqdm bar
    preds_so_far = results[-1][0]  # Get the last prediction
    labels_so_far = results[-1][1]  # Get the last label
    if preds_so_far != []:
        f1 = f1_score(labels_so_far[:len(preds_so_far)], preds_so_far, zero_division=0)
        acc = accuracy_score(labels_so_far[:len(preds_so_far)], preds_so_far)  # Ensure preds_so_far is trimmed to match labels_so_far
        tq.set_postfix({'F1': f"{f1:.3f}", 'Acc': f"{acc:.3f}", 'Recall': f"{recall_score(labels_so_far[:len(preds_so_far)], preds_so_far, zero_division=0):.3f}"})
        tq.refresh()  #update tqdm bar

vector_eval.print_results()
vector_eval.reset()

# Evaluate
correct = sum([pred == true for pred, true in results])
total = len(results)
accuracy = correct / total

print(f"\nFile-level Bug Detection via Flake8 completed.")
print(f"Skipped {skipped} files due to timeout.")
print(f"Skipped {skippedNoLineNum} files due to no line number found in error message.")
print(f"Buggy predictions: {buggy_pred}, Non-buggy predictions: {non_buggy_pred}")
