In [0]:
from databricks.sdk.service.workspace import ImportFormat
from pyspark.sql.functions import trim 

import base64

import databricks.sdk

import nbformat
import numpy as np 

import pandas as pd
import dspy

In [0]:
class Module(dspy.Module):
  def __init__(self, workspace_client, extract_code=False):
    self.workspace_client = workspace_client
    self.module_dict = None
    self.extract_code = extract_code
    self.question_dict = {
  'module_1': "Extract the exact code snippet from the context that creates a table/view called \'workloads\' and retrieves its schema. Do not generate any code.",
  'module_2': "Extract the exact code snippet from the context that returns only a distinct list of workspaceId. Do not generate any code.", 
  'module_3': "Extract the exact code snippet from the context that returns the number of unique clusters. Do not generate, modify, or infer any code.",
  'module_4': """Extract the exact code snippet from the context that returns the workload hours each day for the workspace ID in ordered fashion.  
**Do not generate, modify, or suggest any code. Only extract what is explicitly present.**  
**If no matching snippet is found, return an empty string for `code_snippet` and assign a score of 0.**""",
  'module_5': "Extract the exact code snippet from the context that returns interactive node hours per day on the different Spark versions over time. Do not generate any code",
  'module_6': "Extract the exact code snippet from the context which returns top two most recently shipped (shipDate) Line Items per Part using window function. Do not generate any code."
}

  def get_module_dict(self):
    return self.module_dict
  
  def set_module_dict(self, module_name):
    if module_name == "SQL":
      self.module_dict = SQLModule().get_module_dict()

  def get_code_from_notebooks(self, notebook_path):
    notebook = self.workspace_client.workspace.export(notebook_path, 
                                      format=ImportFormat.JUPYTER)
    ipynb = base64.decodebytes(notebook.content.encode('ascii')).decode("utf-8")
    notebook = nbformat.reads(ipynb, as_version=4)
    code_dict = {}
    x=0
    if self.extract_code:
      for cell in notebook.cells:
        if cell.cell_type == 'code':
          code_dict[x] = cell.source
          x+=1
    else:
      for cell in notebook.cells:
        if cell.cell_type == 'code' or cell.cell_type == 'markdown':
        # if cell.cell_type == 'code':
          code_dict[x] = cell.source
          x+=1

    context = ''.join(code_dict.values())
    
    return context
  
  # For few shot learning 
  def get_training_examples(self, table_name, key):
    df = spark.read.table(table_name)
    df_filtered = df.filter(trim(df.question) == self.question_dict[key])
    training_examples = [dspy.Example(text=row.context, score=row.score, code_snippet = row.code_snippet).with_inputs("text") for row in df_filtered.collect()]
    return training_examples
  
  def evaluate_responses(self, context, few_shots_table_name):
    answer_list = []
    for k,module in self.module_dict.items():
      if few_shots_table_name:
        training_examples = self.get_training_examples(few_shots_table_name, k)
        module.generate_answer.demos = training_examples

      # print(f"Evaluating response for module: {k}")
      response = module(context=context)
      if self.extract_code:
        response_dict = {
          'reasoning': response.reasoning,
          'code_snippet': response.code_snippet,
          # 'question': response.code_snippet.desc
          'question': module.generate_answer.__dict__['predict'].__dict__['signature'].__fields__['code_snippet'].json_schema_extra['desc']
        }
      else:
        response_dict = {
          'question': module.generate_answer.__dict__['predict'].__dict__['signature'].instructions,
          'code_snippet': response.code_snippet,
          'score': response.score,
          'explanation': response.explanation,
          'chain_of_thought_reasoning': response.reasoning
        }
      # print(response_dict)
      answer_list.append(response_dict)
    answer_list_df = pd.DataFrame(answer_list)

    return answer_list_df
  
 
  
  def get_error_and_answer_dict(self, context_path, few_shots_table_name=None):
  
    context = self.get_code_from_notebooks(context_path)
    # Read from few shot db and add corresponding examples to few shot db
    
    answers = self.evaluate_responses(context, few_shots_table_name)
    answers['context_url'] = context_path
    if not self.extract_code:
      answers['score'] = answers['score'].astype('float64')
    return {
      'context': context,
      'answers_dict': answers
    }