In [24]:
!pip install groq
!pip install datasets

!pip install python-dotenv



True

In [39]:
from dotenv import load_dotenv
load_dotenv("../.env")

True

In [12]:
# Mount gdrive for colab
colab = False
if colab:
    from google.colab import drive

    drive.mount('/content/drive')
    local_drive_mount = "/content/drive/MyDrive/"
else:
    local_drive_mount = "/Users/dbaeka/Library/CloudStorage/GoogleDrive-dbaekajnr@gmail.com/My Drive/"

In [13]:
# Set up Logging
import logging
import os

log_folder_path = os.path.join(local_drive_mount, 'soen691/logs')
os.makedirs(log_folder_path, exist_ok=True)

log_file_path = os.path.join(log_folder_path, 'logfile.txt')

logging.basicConfig(
    filename=log_file_path,
    level=logging.DEBUG,
    format='%(asctime)s - %(levelname)s - %(message)s',
    force=True
)

## Load Test Dataset
Since 500 test set is a subset of 5000 dataset, i.e. first 500, if we run inference on 5000 set, we do not need to run again on 500. We can simply extract predicted result from the results during post-processing before evaluation

In [14]:
from datasets import load_dataset
from tqdm import tqdm
import os

In [15]:
# Define IOUtils
import io
import json


def _make_w_io_base(f, mode: str):
    if not isinstance(f, io.IOBase):
        f_dirname = os.path.dirname(f)
        if f_dirname != "":
            os.makedirs(f_dirname, exist_ok=True)
        f = open(f, mode=mode, encoding="utf-8")
    return f


def _make_r_io_base(f, mode: str):
    if not isinstance(f, io.IOBase):
        f = open(f, mode=mode, encoding="utf-8")
    return f


def jdump(obj, f: str, mode="w", indent=4, default=str):
    """Dump a str or dictionary to a file in json format."""
    f = _make_w_io_base(f, mode)
    if isinstance(obj, (dict, list)):
        json.dump(obj, f, indent=indent, default=default)
    elif isinstance(obj, str):
        f.write(obj)
    else:
        raise ValueError(f"Unexpected type: {type(obj)}")
    f.close()


def jload(f, mode="r"):
    """Load a .json file into a dictionary."""
    f = _make_r_io_base(f, mode)
    jdict = json.load(f)
    f.close()
    return jdict

In [16]:
from typing import Optional, Sequence

models = [
    "Qwen/Qwen2.5-7B-Instruct",
    "Qwen/Qwen2.5-Coder-7B-Instruct",
    "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
    "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
]

# set model to use for notebook
model_set = 3
model_name = models[model_set]

# set test to use
test_name = "test_5000"

BATCH_SIZE = 1


def shard_dataset(repo_name, output_dir, chunk_size: int = 10_000):
    dataset = load_dataset(repo_name)['test']
    for i in range(0, len(dataset), chunk_size):
        shard = dataset[i:i + chunk_size]
        jdump(shard, f"{output_dir}/shard_{i // chunk_size}_input.json")


def prettify(name: str) -> str:
    return name.replace("/", "_").replace("-", "_").replace(".", "_")

In [17]:
# set shard size to use
shard_size = 32

base_results_dir = os.path.join(local_drive_mount, "soen691/results/")
test_results_dir = os.path.join(base_results_dir, prettify(model_name), f"{test_name}_input")

shard_dataset(f"dbaeka/soen_691_msg_{test_name}_hashed", test_results_dir, shard_size)

In [18]:
!pip install rank_bm25



In [19]:
## Set up BM25
from rank_bm25 import BM25Okapi

train_dataset = load_dataset("dbaeka/soen_691_msg_train")['train']
print(f"Train data length: {len(train_dataset)}")
print("\n-----------------\n")

tokenized_corpus = [doc["patch"].split(" ") for doc in train_dataset]
bm25 = BM25Okapi(tokenized_corpus)

Train data length: 117739

-----------------



In [20]:
import numpy as np

# Query example
query = "CrossProduct"
tokenized_query = query.split(" ")

# Get scores
scores = bm25.get_scores(tokenized_query)

# Sort documents by score (descending order)
sorted_indices = np.argsort(scores)[::-1]  # Get indices sorted by highest score

# Show top 3 matches
top_k = 3
print("Top Retrieved Documents:")
for i in range(top_k):
    index = sorted_indices[i]
    patch = train_dataset["patch"]
    print(f"Rank {i + 1}: Score {scores[index]:.4f} - {patch[index]}")

Top Retrieved Documents:
Rank 1: Score 0.0000 - @@ -537,7 +537,7 @@ define([
             var docUri = new Uri(document.location.href);
             var modelUri = new Uri(model._basePath);
             model._baseUri = modelUri.resolve(docUri);
-        });
+        }, getFailedLoadFunction(model, 'gltf', url));
 
         return model;
     };

Rank 2: Score 0.0000 - @@ -346,7 +346,7 @@ public class ExtensionsITCase extends BaseITCase {
                     .build())
                 .build());
 
-        await().atMost(30, TimeUnit.SECONDS).pollInterval(250, TimeUnit.MILLISECONDS).untilAsserted(() -> {
+        await().atMost(60, TimeUnit.SECONDS).pollInterval(250, TimeUnit.MILLISECONDS).untilAsserted(() -> {
             // Get extension details again, we need to use RAW here as Jackson will not be able
             // to write the `uses` property (access = READ_ONLY)
             final ResponseEntity<Map<String, Object>> got = get("/api/v1/extensions/" + id,

Rank 3: Score 0.0000 

In [42]:
# Set up Groq
from groq import Groq

client = Groq(
    api_key="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",
)

print(os.environ.get("GROQ_API_KEY"))

# Map models
groq_model_map = {
    "Qwen/QwQ-32B": "qwen-qwq-32b",
    "Qwen/Qwen2.5-7B-Instruct": "qwen-2.5-32b",
    "Qwen/Qwen2.5-Coder-7B-Instruct": "qwen-2.5-coder-32b",
    "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B": "deepseek-r1-distill-qwen-32b",
    "deepseek-ai/DeepSeek-R1-Distill-Llama-70B": "deepseek-r1-distill-llama-70b-specdec"
}
mapped_model = groq_model_map.get(model_name)
print(f"Model: {mapped_model}")
logging.debug(f"Model: {mapped_model}")

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
Model: deepseek-r1-distill-qwen-32b


In [81]:
import torch
import re

INSTRUCTION_PROMPT = "Please GIVE FORMAL Codereview for software developers in ONE SENTENCE for testcase, implementing Few Shot Learning from example. Dont start with Codereview/review. Just give the answer."

WITH_SUMMARY = False
WITH_CALLGRAPH = False
SEED = 0
NUM_OF_RESULTS = 5
NUM_OF_FEW_SHOT = 2
TEMPERATURE = 0.7
IS_REASONING_MODEL = True
PAUSE_DURATION = 4
BATCH_CALL = False

torch.manual_seed(SEED)


def extract_cot_and_answer(response):
    # Extract content within <think>...</think>
    cot_match = re.search(r"<think>(.*?)</think>", response, re.DOTALL)
    cot = cot_match.group(1).strip() if cot_match else ""

    if not IS_REASONING_MODEL:
        cot = "NO THINKING"

    # Extract content after </think>
    answer_match = re.search(r"</think>\s*(.*)", response, re.DOTALL)
    answer = answer_match.group(1).strip() if answer_match else ""

    return {"cot": cot, "answer": answer}


def get_response(messages, num_of_results: int, max_new_tokens: int, temperature: float):
    response = client.chat.completions.create(
        model=mapped_model,
        messages=messages[0],
        max_tokens=max_new_tokens,
        temperature=temperature,
        n=num_of_results,
        stop=["</s>"],
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0,
        seed=SEED
    )
    logging.debug(f"Fingerprints: {response.system_fingerprint}")
    result = []
    for choice in response.choices:
        logging.debug(f"Model Response: {choice.message.content}")
        logging.debug("_" * 70)
        result.append(extract_cot_and_answer(choice.message.content))
    return result

def forward(messages, max_new_tokens: int = 2048, temperature: float = 0.05) -> Optional[Sequence[str]]:
    logging.debug("Generating")
    results = []
    if not BATCH_CALL:
        for i in range(NUM_OF_RESULTS):
            logging.debug(f"Result {i + 1}")
            logging.debug("_" * 70)
            result = get_response(messages, 1, max_new_tokens, temperature)
            results.append(result[0]) 
            sleep(1)
        results = [results]
    else:
        logging.debug("Result")
        logging.debug("_" * 70)
        result = get_response(messages, NUM_OF_RESULTS, max_new_tokens, temperature)
        results.append(result)    
    return results


def get_bm25_review_context(example, train_data, num_shot: int = 1):
    tokenized_query = example.split(" ")
    scores = bm25.get_scores(tokenized_query)
    scores_arr = np.array(scores)
    sorted_indices = scores_arr.argsort()[-num_shot:][::-1]
    msg = []
    for i in sorted_indices:
        context = ""
        context = context + "Code: \t" + train_data["patch"][i] + "\n"
        if WITH_SUMMARY:
            context = context + "Summary: \t" + train_data["summary"][i] + "\n"
        if WITH_CALLGRAPH:
            context = context + "Callgraph: \t" + train_data["callgraph"][i] + "\n"
        context = context + "Codereview: "
        msg.append({"role": "user", "content": context})
        context = "<think>\n...some explantion here...\n</think>\n\n" + train_data["msg"][i] + " </s>" + "\n\n"
        msg.append({"role": "assistant", "content": context})
    return msg

In [82]:
from time import sleep


def review_comment_generation(model_name: str, test_name: str, shard_index: int, base_dir: str, batch_size: int = 32):
    input_dir = os.path.join(base_dir, prettify(model_name), f"{test_name}_input")
    input_path = os.path.join(input_dir, f"shard_{shard_index}_input.json")
    input_data = jload(input_path)
    input_list = [{"hash": h, "value": v} for h, v in zip(input_data["hash"], input_data["value"])]

    output_dir = os.path.join(base_dir, prettify(model_name), f"{test_name}_output")
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, f"shard_{shard_index}_output.json")

    # Load existing results if they exist
    existing_results = jload(output_path) if os.path.exists(output_path) else {}

    # Filter out already processed hashes with 5 results
    filtered_input = [
        sample for sample in input_list
        if sample["hash"] not in existing_results or len(existing_results[sample["hash"]]) != NUM_OF_RESULTS
    ]

    for i in tqdm(range(0, len(filtered_input), batch_size)): 
        end_index = min(i + batch_size, len(filtered_input))
        batch = filtered_input[i:end_index]

        print(f"Processing batch {i} to {end_index}")
        logging.debug(f"Processing batch {i} to {end_index}")

        prompts = []
        for j in range(len(batch)):
            dialog = [{"role": "user", "content": INSTRUCTION_PROMPT}]
            context_msg = get_bm25_review_context(batch[j]["value"]["patch"], train_dataset, num_shot=NUM_OF_FEW_SHOT)
            dialog.extend(context_msg)

            test_code = batch[j]["value"]["patch"]
            test_summary = batch[j]["value"]["summary"]
            test_callgraph = batch[j]["value"]["callgraph"]

            context = ""
            context = context + "Code: \t" + test_code + "\n"
            if WITH_SUMMARY:
                context = context + "Summary: \t" + test_summary + "\n"
            if WITH_CALLGRAPH:
                context = context + "Callgraph: \t" + test_callgraph + "\n"
            context = context + "Codereview: "

            dialog.append({"role": "user", "content": context})
            prompts.append(dialog)

            logging.debug("################context ####################")
            logging.debug(dialog)
        try:
            results = forward(prompts, temperature=TEMPERATURE)
            for sample, result in zip(batch, results):
                filtered_result = [
                    r for r in result
                    if r.get("cot", "").strip() != "" or r.get("answer", "").strip() != ""
                ]
                if filtered_result:
                    existing_results[sample["hash"]] = filtered_result
            jdump(existing_results, output_path)
            
            sleep(PAUSE_DURATION)
        except Exception as e:
            logging.error("Error: ", e)
            sleep(3)

    logging.info(f"Completed processing shard {shard_index}")

In [85]:
import random

TOTAL_SHARDS = 157

shard_indices = list(range(TOTAL_SHARDS))
random.shuffle(shard_indices)

for shard_idx in tqdm(shard_indices):
    print(f"Processing shard {shard_idx}")
    logging.info(f"Processing shard {shard_idx}")
    review_comment_generation(model_name, test_name, shard_idx, base_results_dir, BATCH_SIZE)

  0%|          | 0/157 [00:00<?, ?it/s]

Processing shard 3



0it [00:00, ?it/s][A


Processing shard 147



  0%|          | 0/32 [00:00<?, ?it/s][A

Processing batch 0 to 1



  3%|▎         | 1/32 [00:39<20:31, 39.73s/it][A

Processing batch 1 to 2



  6%|▋         | 2/32 [01:05<15:37, 31.24s/it][A

Processing batch 2 to 3



  9%|▉         | 3/32 [01:39<15:50, 32.78s/it][A

Processing batch 3 to 4



 12%|█▎        | 4/32 [02:43<20:56, 44.86s/it][A

Processing batch 4 to 5



 16%|█▌        | 5/32 [03:42<22:31, 50.06s/it][A

Processing batch 5 to 6



 19%|█▉        | 6/32 [04:11<18:37, 42.98s/it][A

Processing batch 6 to 7



 22%|██▏       | 7/32 [05:16<20:50, 50.01s/it][A

Processing batch 7 to 8



 25%|██▌       | 8/32 [06:01<19:23, 48.50s/it][A

Processing batch 8 to 9



 28%|██▊       | 9/32 [07:16<21:46, 56.79s/it][A

Processing batch 9 to 10



 31%|███▏      | 10/32 [07:52<18:31, 50.52s/it][A

Processing batch 10 to 11



 34%|███▍      | 11/32 [08:50<18:24, 52.58s/it][A

Processing batch 11 to 12



 38%|███▊      | 12/32 [09:31<16:25, 49.27s/it][A

Processing batch 12 to 13



 41%|████      | 13/32 [09:56<13:16, 41.94s/it][A

Processing batch 13 to 14



 44%|████▍     | 14/32 [10:20<10:57, 36.53s/it][A

Processing batch 14 to 15



 47%|████▋     | 15/32 [10:56<10:17, 36.30s/it][A

Processing batch 15 to 16



 50%|█████     | 16/32 [11:50<11:06, 41.67s/it][A

Processing batch 16 to 17



 53%|█████▎    | 17/32 [13:04<12:50, 51.36s/it][A

Processing batch 17 to 18



 56%|█████▋    | 18/32 [13:54<11:53, 50.98s/it][A

Processing batch 18 to 19



 59%|█████▉    | 19/32 [15:07<12:29, 57.65s/it][A

Processing batch 19 to 20



 62%|██████▎   | 20/32 [16:40<13:38, 68.20s/it][A

Processing batch 20 to 21



 66%|██████▌   | 21/32 [18:22<14:22, 78.38s/it][A

Processing batch 21 to 22



 69%|██████▉   | 22/32 [18:48<10:24, 62.47s/it][A

Processing batch 22 to 23



 72%|███████▏  | 23/32 [23:26<19:04, 127.12s/it][A

Processing batch 23 to 24



 75%|███████▌  | 24/32 [23:52<12:54, 96.82s/it] [A

Processing batch 24 to 25



 78%|███████▊  | 25/32 [28:11<16:58, 145.50s/it][A

Processing batch 25 to 26



 81%|████████▏ | 26/32 [52:35<54:06, 541.11s/it][A

Processing batch 26 to 27



 84%|████████▍ | 27/32 [53:08<32:24, 388.82s/it][A

Processing batch 27 to 28



 88%|████████▊ | 28/32 [53:40<18:46, 281.50s/it][A

Processing batch 28 to 29



 91%|█████████ | 29/32 [54:10<10:18, 206.12s/it][A

Processing batch 29 to 30



 94%|█████████▍| 30/32 [55:05<05:21, 160.92s/it][A

Processing batch 30 to 31



 97%|█████████▋| 31/32 [56:02<02:09, 129.60s/it][A

Processing batch 31 to 32



100%|██████████| 32/32 [56:36<00:00, 106.14s/it][A
  1%|▏         | 2/157 [56:36<73:07:20, 1698.33s/it]

Processing shard 56



  0%|          | 0/32 [00:00<?, ?it/s][A

Processing batch 0 to 1



  3%|▎         | 1/32 [00:28<14:48, 28.66s/it][A

Processing batch 1 to 2



  6%|▋         | 2/32 [00:58<14:40, 29.35s/it][A

Processing batch 2 to 3



  9%|▉         | 3/32 [02:17<25:14, 52.24s/it][A

Processing batch 3 to 4



 12%|█▎        | 4/32 [03:23<26:48, 57.43s/it][A

Processing batch 4 to 5



 16%|█▌        | 5/32 [04:23<26:21, 58.57s/it][A

Processing batch 5 to 6



 19%|█▉        | 6/32 [05:30<26:37, 61.45s/it][A

Processing batch 6 to 7



 22%|██▏       | 7/32 [06:24<24:34, 58.98s/it][A

Processing batch 7 to 8



 25%|██▌       | 8/32 [06:51<19:30, 48.78s/it][A

Processing batch 8 to 9



 28%|██▊       | 9/32 [07:19<16:09, 42.17s/it][A

Processing batch 9 to 10



 31%|███▏      | 10/32 [07:54<14:38, 39.95s/it][A

Processing batch 10 to 11



 34%|███▍      | 11/32 [08:37<14:20, 40.97s/it][A

Processing batch 11 to 12



 38%|███▊      | 12/32 [09:27<14:33, 43.67s/it][A

Processing batch 12 to 13



 41%|████      | 13/32 [10:05<13:18, 42.04s/it][A

Processing batch 13 to 14



 44%|████▍     | 14/32 [10:59<13:39, 45.53s/it][A

Processing batch 14 to 15



 47%|████▋     | 15/32 [11:51<13:26, 47.46s/it][A

Processing batch 15 to 16



 50%|█████     | 16/32 [12:26<11:40, 43.80s/it][A

Processing batch 16 to 17



 53%|█████▎    | 17/32 [13:15<11:20, 45.39s/it][A

Processing batch 17 to 18



 56%|█████▋    | 18/32 [14:01<10:38, 45.63s/it][A

Processing batch 18 to 19



 59%|█████▉    | 19/32 [15:01<10:45, 49.66s/it][A

Processing batch 19 to 20



 62%|██████▎   | 20/32 [15:53<10:04, 50.36s/it][A

Processing batch 20 to 21



 66%|██████▌   | 21/32 [17:00<10:09, 55.41s/it][A

Processing batch 21 to 22



 69%|██████▉   | 22/32 [17:52<09:04, 54.45s/it][A

Processing batch 22 to 23



 72%|███████▏  | 23/32 [18:49<08:18, 55.35s/it][A

Processing batch 23 to 24



 75%|███████▌  | 24/32 [20:04<08:09, 61.14s/it][A

Processing batch 24 to 25



 78%|███████▊  | 25/32 [20:52<06:39, 57.08s/it][A

Processing batch 25 to 26



 81%|████████▏ | 26/32 [22:04<06:10, 61.80s/it][A

Processing batch 26 to 27



 84%|████████▍ | 27/32 [22:46<04:38, 55.67s/it][A

Processing batch 27 to 28



 88%|████████▊ | 28/32 [24:08<04:14, 63.71s/it][A

Processing batch 28 to 29



 91%|█████████ | 29/32 [24:57<02:57, 59.27s/it][A

Processing batch 29 to 30



 94%|█████████▍| 30/32 [26:04<02:02, 61.44s/it][A

Processing batch 30 to 31



 97%|█████████▋| 31/32 [26:54<00:58, 58.14s/it][A

Processing batch 31 to 32



100%|██████████| 32/32 [45:47<00:00, 85.86s/it] [A
  2%|▏         | 3/157 [1:42:24<91:21:00, 2135.46s/it]

Processing shard 77



  0%|          | 0/32 [00:00<?, ?it/s][A

Processing batch 0 to 1



  3%|▎         | 1/32 [00:27<14:01, 27.16s/it][A

Processing batch 1 to 2



  6%|▋         | 2/32 [01:03<16:20, 32.67s/it][A

Processing batch 2 to 3



  9%|▉         | 3/32 [01:58<20:46, 42.99s/it][A

Processing batch 3 to 4



 12%|█▎        | 4/32 [03:04<24:11, 51.83s/it][A

Processing batch 4 to 5



 16%|█▌        | 5/32 [04:05<24:52, 55.27s/it][A

Processing batch 5 to 6



 19%|█▉        | 6/32 [04:55<23:12, 53.54s/it][A

Processing batch 6 to 7



 22%|██▏       | 7/32 [05:59<23:39, 56.76s/it][A

Processing batch 7 to 8



 25%|██▌       | 8/32 [07:21<25:54, 64.77s/it][A

Processing batch 8 to 9



 28%|██▊       | 9/32 [08:31<25:32, 66.62s/it][A

Processing batch 9 to 10



 31%|███▏      | 10/32 [09:10<21:12, 57.84s/it][A

Processing batch 10 to 11



 34%|███▍      | 11/32 [09:54<18:45, 53.61s/it][A

Processing batch 11 to 12



 38%|███▊      | 12/32 [10:31<16:11, 48.58s/it][A

Processing batch 12 to 13



 41%|████      | 13/32 [11:12<14:39, 46.30s/it][A

Processing batch 13 to 14



 44%|████▍     | 14/32 [12:16<15:33, 51.83s/it][A

Processing batch 14 to 15



 47%|████▋     | 15/32 [12:59<13:52, 48.94s/it][A

Processing batch 15 to 16



 50%|█████     | 16/32 [13:27<11:25, 42.85s/it][A

Processing batch 16 to 17



 53%|█████▎    | 17/32 [13:56<09:37, 38.50s/it][A

Processing batch 17 to 18



 56%|█████▋    | 18/32 [14:28<08:31, 36.56s/it][A

Processing batch 18 to 19



 59%|█████▉    | 19/32 [15:30<09:35, 44.28s/it][A

Processing batch 19 to 20



 62%|██████▎   | 20/32 [16:01<08:01, 40.15s/it][A

Processing batch 20 to 21



 66%|██████▌   | 21/32 [16:26<06:32, 35.69s/it][A

Processing batch 21 to 22



 69%|██████▉   | 22/32 [17:33<07:31, 45.17s/it][A

Processing batch 22 to 23


 69%|██████▉   | 22/32 [17:34<07:59, 47.95s/it]
  2%|▏         | 3/157 [1:59:58<102:39:05, 2399.65s/it]


KeyboardInterrupt: 