In [0]:
%pip install gitpython tqdm databricks-langchain

In [0]:
dbutils.library.restartPython()

In [0]:
import os
import tempfile
import shutil
from git import Repo
from tqdm import tqdm
import re
import json
import ast

from pyspark.sql.functions import lit

from databricks_langchain import ChatDatabricks


GITHUB_REPO_URL = "https://github.com/birbalin25/CodeCritic.git"
# DATABRICKS_MODEL_ENDPOINT = "databricks-meta-llama-3-1-405b-instruct" 
DATABRICKS_MODEL_ENDPOINT = "databricks-claude-3-7-sonnet" 



llm = ChatDatabricks(model=DATABRICKS_MODEL_ENDPOINT, temperature=0, host="https:<host>", token="dapiXXXX")

def clone_repo(github_url):
    print(f"\n📥 Cloning repo: {github_url} ...")
    temp_dir = tempfile.mkdtemp(prefix="repo_")
    Repo.clone_from(github_url, temp_dir)
    return temp_dir

def get_python_files(base_path):
    excluded_files = {"setup.py", "__init__.py", "_README.py"}
    python_files = []

    for root, _, files in os.walk(base_path):
        for file in files:
            file_path = os.path.join(root, file)

            if not file.endswith(".py") and not file.endswith(".ipynb"):
                continue

            lower_file = file.lower()
            if (file in excluded_files or lower_file.startswith("test") or "__pycache__" in root):
                continue

            python_files.append(file_path)
            
    return python_files

def analyze_file_with_langchain(file_content, file_name):
    messages = [
        {
            "role": "system",
            "content": (
                "You are a senior performance engineer with deep expertise in Apache Spark, Python, Scala, and SQL. Your role is to analyze code for inefficiencies and provide precise, actionable optimization recommendations"
            ),
        },
        {
            "role": "user",
            "content": (
                f"Review the Python file `{file_name}`:\n"
                f"---BEGIN FILE CONTENT---\n{file_content}\n---END FILE CONTENT---\n"
                """
                Analyze the given code for inefficiencies and provide clear, concise improvement suggestions.
                - If the code is not related to Spark, respond with one or two concise sentences summarizing its purpose.
                - If the code is simple and requires no improvements, respond exactly with: Code is simple. No change needed.
                - If the code is Spark-related, identify and explain any inefficiencies and provide recommendations for improvements.
                - If Spark tables are joined, determine whether they use Spark SQL or DataFrame syntax.
                - Extract and return the join columns in the following format:
                    Join_columns_dictionary = {'table1': ['column1'], 'table2': ['column1']}
                - Use the actual table or alias names as keys, and list the columns used in the join as values.
                - If there are multiple joins, include all relevant tables and columns.
                - If no table joins are present, respond with: No table joins are present in this code.
                """
            ),
        },
    ]    
    response = llm.invoke(messages)
    return response["content"] if "content" in response else str(response)

def read_py(py_file):
    with open(py_file, "r", encoding="utf-8", errors="ignore") as f:
        file_content = f.read() 
        # print(f"py file_content is {file_content}")
        return file_content


def read_ipynb(py_file):
    with open(py_file, 'r', encoding='utf-8') as f:
        notebook = json.load(f)

    code_lines = []
    for cell in notebook.get('cells', []):
        if cell.get('cell_type') == 'code':
            lines = cell.get('source', [])
            code_lines.extend(lines)
            code_lines.append('\n')
    file_content = ''.join(code_lines)
    # print(f"ipynb file_content is {file_content}")         

    return file_content


def main():
    responses = []

    try:
        repo_path = clone_repo(GITHUB_REPO_URL)

        python_files = get_python_files(repo_path)
        print(f"\n🔍 Found {len(python_files)} Python files to analyze.\n")

        python_files = python_files[:20]
        print(f"python_files izz {python_files}")

        for file in tqdm(python_files, desc="🔎 Analyzing Python files"):
            try:
                if file.endswith('.py'):
                    file_content = read_py(file)
                if file.endswith('.ipynb'):
                    file_content = read_ipynb(file)    

                feedback = analyze_file_with_langchain(file_content, file)
                responses.append({"file": file, "feedback": feedback})
            except Exception as e:
                responses.append({"file": file, "feedback": f"[Error reading file]: {e}"})

        print(f"type of responses is {type(responses)}")
        print(f"responses is {responses}")

        results = []
        for item in responses:
            file_path = item.get('file')
            feedback = item.get('feedback', '')

            match = re.search(r'content=(["\'])(.*?)\1\s+additional_kwargs=', feedback, re.DOTALL)
            content = match.group(2) if match else None

            results.append({'file': file_path, 'llm_feedback': content, 'llm_feedback_raw': str(item)})

        df = spark.createDataFrame(results)

        delta_table = "bircatalog.birschema.llm_op"
        df = df.withColumn("llm_used", lit(DATABRICKS_MODEL_ENDPOINT))

        # df.write.format("delta").mode("overwrite").option("overwriteSchema", "true").saveAsTable(delta_table)
        df.write.format("delta").mode("append").saveAsTable(delta_table)

    except Exception as e:
        print(f"Error: {e}")

    finally:
        if os.path.exists(repo_path):
            shutil.rmtree(repo_path)

if __name__ == "__main__":
    main()


In [0]:
%sql

-- drop table bircatalog.birschema.llm_op


In [0]:
%sql
SELECT
  regexp_replace(llm_feedback, '\\\\n\\\\n', '\n\n') AS llm_feedback,llm_feedback_raw,file, llm_used
FROM bircatalog.birschema.llm_op