In [0]:
import requests
import json
import re
import base64

class DatabricksClient:
    def __init__(self, instance, token):
        self.instance = instance
        self.token = token
        self.headers = {
            "Authorization": f"Bearer {token}",
            "Content-Type": "application/json"
        }

    def get_job_details(self, job_id):
        response = requests.get(
            f"{self.instance}/api/2.1/jobs/get",
            headers=self.headers,
            params={"job_id": job_id}
        )
        self._handle_response(response)
        return response.json()

    def get_runs(self, job_id):
        response = requests.get(
            f"{self.instance}/api/2.1/jobs/runs/list",
            headers=self.headers,
            params={"job_id": job_id, "active_only": "false", "completed_only": "true", "run_type": "JOB_RUN"}
        )
        self._handle_response(response)
        return response.json().get("runs", [])

    def get_failed_runs(self, job_id):
        runs = self.get_runs(job_id)
        return [run for run in runs if run.get("state", {}).get("result_state") == "FAILED"]

    def get_task_run_output(self, run_id):
        response = requests.get(
            f"{self.instance}/api/2.1/jobs/runs/get",
            headers=self.headers,
            params={"run_id": run_id}
        )
        self._handle_response(response)
        run_details = response.json()
        tasks = run_details.get("tasks", [])

        error_messages = []
        notebook_path = None
        for task in tasks:
            if task.get("state", {}).get("result_state") == "FAILED":
                task_run_id = task.get("run_id")
                notebook_path = task.get("notebook_task", {}).get("notebook_path")
                response = requests.get(
                    f"{self.instance}/api/2.1/jobs/runs/get-output",
                    headers=self.headers,
                    params={"run_id": task_run_id}
                )
                self._handle_response(response)
                task_run_output = response.json()
                error = task_run_output.get("error")
                if error:
                    error_messages.append((error, notebook_path))

        return error_messages

    def suggest_fix(self, job_id, notebook_path, error_message):
        endpoint = "/serving-endpoints/databricks-meta-llama-3-1-8b-instruct/invocations"
        variable_name = error_message.split("'")[1]
        prompt = f"""
        The following code is causing a NameError: name '{variable_name}' is not defined. 
        Please provide a single line of code that fixes this error.
        """
        payload = {
            "messages": [
                {"role": "user", "content": prompt}
            ]
        }

        response = requests.post(
            f"{self.instance}{endpoint}",
            headers=self.headers,
            data=json.dumps(payload)
        )
        self._handle_response(response)
        return response.json()

    def extract_suggested_fix_code(self, suggested_fix):
        suggested_fix_text = suggested_fix["choices"][0]["message"]["content"]
        lines = suggested_fix_text.splitlines()
        for line in lines:
            if "=" in line:
                return line.strip()
        return ""

    def get_notebook_code(self, notebook_path):
        response = requests.get(
            f"{self.instance}/api/2.0/workspace/export",
            headers=self.headers,
            params={"path": notebook_path, "format": "JUPYTER"}
        )
        self._handle_response(response)
        notebook_content = response.json()["content"]
        decoded_content = base64.b64decode(notebook_content).decode("utf-8")
        notebook_json = json.loads(decoded_content)
        code_cells = [cell["source"] for cell in notebook_json["cells"] if cell["cell_type"] == "code"]
        code = ["".join(cell) for cell in code_cells]
        return "\n".join(code)

    def update_notebook_code(self, notebook_code, suggested_fix_code):
        return notebook_code + "\n" + suggested_fix_code

    def update_notebook(self, notebook_path, updated_notebook_code, error_variable):
        response = requests.get(
            f"{self.instance}/api/2.0/workspace/export",
            headers=self.headers,
            params={"path": notebook_path, "format": "JUPYTER"}
        )
        self._handle_response(response)
        notebook_content = response.json()["content"]
        decoded_content = base64.b64decode(notebook_content).decode("utf-8")
        notebook_json = json.loads(decoded_content)

        # Find the cell that contains the error
        for cell in notebook_json["cells"]:
            if cell["cell_type"] == "code" and error_variable in "".join(cell["source"]):
                # Replace the lines that contain the error variable with the suggested fix code
                new_source = []
                added_fix = False
                for line in cell["source"]:
                    if error_variable in line and not added_fix:
                        new_source.append(updated_notebook_code + "\n")
                        added_fix = True
                    else:
                        new_source.append(line)
                cell["source"] = new_source

        encoded_content = base64.b64encode(json.dumps(notebook_json).encode()).decode()
        response = requests.post(
            f"{self.instance}/api/2.0/workspace/import",
            headers=self.headers,
            json={"path": notebook_path, "format": "JUPYTER", "content": encoded_content, "language": "PYTHON", "overwrite": True},
            verify=False
        )
        if response.status_code != 200:
            print(f"Error updating notebook: {response.text}")
        else:
            print("Notebook updated successfully")

    def _handle_response(self, response, return_json=True):
        print(f"Response status code: {response.status_code}")
        if response.status_code != 200:
            print(f"Error: {response.text}")
            if return_json:
                try:
                    response.json()
                except json.JSONDecodeError:
                    raise Exception(f"Error: {response.text}")
            else:
                raise Exception(f"Error: {response.text}")
        if return_json:
            try:
                return response.json()
            except (json.JSONDecodeError, KeyError, IndexError) as e:
                raise Exception(f"Error parsing response: {e}")

def main():
    instance = "https://dbc-576eaed4-f615.cloud.databricks.com/"
    token = "dapieec685fd891eb34ba3cc688760845cd8"
    job_id = 559672901702591

    client = DatabricksClient(instance, token)
    failed_runs = client.get_failed_runs(job_id)
    updated_notebooks = set()

    for run in failed_runs:
        run_id = run["run_id"]
        error_messages = client.get_task_run_output(run_id)
        for error_message, notebook_path in error_messages:
            if notebook_path and notebook_path not in updated_notebooks:
                print(f"Error message: {error_message}")
                print(f"Notebook path: {notebook_path}")
                try:
                    variable_name = error_message.split("'")[1]
                    fix_suggestion = client.suggest_fix(job_id, notebook_path, error_message)
                    suggested_fix_code = client.extract_suggested_fix_code(fix_suggestion)
                    print(f"Suggested fix code: {suggested_fix_code}")
                    client.update_notebook(notebook_path, suggested_fix_code, variable_name)
                    updated_notebooks.add(notebook_path)
                except Exception as e:
                    print(f"Error processing error message: {error_message}. Error: {str(e)}")

if __name__ == "__main__":
    main()