In [1]:
# Import modules from the project structure
import src.etl.data_aggregation as da

In [2]:
import inspect

source_code = inspect.getsource(da.calculate_total_price_by_date)
print(source_code)

def calculate_total_price_by_date(df_products, df_transactions):
    # Join the products and transactions DataFrames on Product_ID
    df_joined = df_transactions.join(df_products, df_transactions.Product_ID == df_products.Product_ID)
    
    # Calculate the total price for each transaction
    df_total_price = df_joined.withColumn("Total_Price", F.col("Quantity") * F.col("Price"))
    
    # Group by Transaction_Date and calculate the total values
    df_total_by_date = df_total_price.groupBy("Transaction_Date").agg(
        F.sum("Quantity").alias("Total_Quantity"),
        F.sum("Total_Price").alias("Total_Price"),
    )
    
    return df_total_by_date



In [3]:
# Prompt template for unit test
template_unit = (
    "You are tasked with creating unity tests to pyspark functions.  "
    "Please follow these instructions carefully: \n\n"        
    "1. **Source code:** This is the source code: {source_code}."    
    "2. **Unit test:** Create a set of unit test to validate random scenarios for the provided function."
    "3. **Test code:** It should use the unittest library and be ready to run."
    "4. **Mocked data:** Define schemas and the data types for the mocked data to avoid pyspark errors when inferring schema."
)

In [4]:
print(source_code)

def calculate_total_price_by_date(df_products, df_transactions):
    # Join the products and transactions DataFrames on Product_ID
    df_joined = df_transactions.join(df_products, df_transactions.Product_ID == df_products.Product_ID)
    
    # Calculate the total price for each transaction
    df_total_price = df_joined.withColumn("Total_Price", F.col("Quantity") * F.col("Price"))
    
    # Group by Transaction_Date and calculate the total values
    df_total_by_date = df_total_price.groupBy("Transaction_Date").agg(
        F.sum("Quantity").alias("Total_Quantity"),
        F.sum("Total_Price").alias("Total_Price"),
    )
    
    return df_total_by_date



In [5]:
from langchain_ollama import OllamaLLM
from langchain_core.prompts import ChatPromptTemplate

#modelname = "qwen2.5"
#modelname = "codellama"

def ollama_chat_prompt(source_code, template, modelname="qwen2.5"):
    model = OllamaLLM(model=modelname)
    prompt = ChatPromptTemplate.from_template(template)
    chain = prompt | model
    return chain.invoke({"source_code": source_code})

In [6]:
result = ollama_chat_prompt(source_code, template_unit)

In [7]:
with open(f"result_unit_tests.txt", "w") as file:
    file.write(result)

In [8]:
# Prompt template for data quality
template_quality = (
    "You are tasked with creating data quality checks to pyspark functions.  "
    "Please follow these instructions carefully: \n\n"        
    "1. **Source code:** This is the source code: {source_code}."    
    "2. **Data Quality Checks:** Create a set of data quality checks to validate random scenarios for the provided function." 
)

In [9]:
result = ollama_chat_prompt(source_code, template_quality)

In [10]:
with open(f"result_quality.txt", "w") as file:
    file.write(result)