# Function: Trim to Relevant Part - Case Type Specific

In [0]:
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType, StructType, StructField

schema = StructType([
    StructField("trimmed_text", StringType(), nullable=True),
    StructField("trimmedType", StringType(), nullable=True),
    StructField("text_around_trimmed_point", StringType(), nullable=True)
])



                    
                    # Append the enriched DataFrame to a list if you want to combine them later
                    dfs.append(df_filtered)
                    
                    # Construct the output path for the enriched CSV file
                    output_file_path = os.path.join(output_folder, f"{os.path.basename(sub_file.name)}.csv")

    # Default values in case of failure
@udf(returnType=schema)
def trim_judgment(full_text):
    """
    Trims judicial text to save LLM token.

    If "查明" is found, it trims before this keyword and returns the rest. If not found, it trims away the first third
    of the text, returning the remainder. Additionally, it returns the trimming type and twenty characters around the
    trimming point for diagnosis.

    Returns:
    - trimmed_text (str): The result of the trimming operation.
    - trimmedType (str): The method of trimming ('查明' or 'last 2/3').
    - text_around_trimmed_point (str): Context around the point of trimming.

    """

    if full_text is None:
        return None, "ERROR", ""
    
    # Search for "查明" from the end towards the beginning
    index = full_text.rfind("查明")
    
    if index != -1:
        # If "查明" is found, trim the text before "查明"
        trimmed_text = full_text[index + 2:]  # +2 to exclude "查明" itself
        trimmedType = "查明"
        # Calculate the position to start capturing text around "查明" safely
        start_pos = max(0, index - 10)  # Start from 10 chars before "查明", if possible
        end_pos = min(len(full_text), index + 10 + 2)  # Capture up to 10 chars after "查明", safely
        text_around_trimmed_point = full_text[start_pos:end_pos]
    else:
        # If "查明" is not found, trim the first third of the text
        one_third_length = len(full_text) // 3
        trimmed_text = full_text[one_third_length:]
        trimmedType = "last 2/3"
        # Determine the break point and safely capture twenty characters around it
        start_pos = max(0, one_third_length - 10)  # Start from 10 chars before the break point, if possible
        end_pos = min(len(full_text), one_third_length + 10)  # Capture up to 10 chars after the break point, safely
        text_around_trimmed_point = full_text[start_pos:end_pos]
    
    return trimmed_text, trimmedType, text_around_trimmed_point

# Function: Post Request and fetch Claude Response

In [0]:
import time
import requests
import json
from pyspark.sql.functions import udf
from pyspark.sql.types import StructType, StructField, StringType

# Define the schema for the UDF's return type
schema = StructType([
    StructField("response_text", StringType(), nullable=True),
    StructField("trimmedType", StringType(), nullable=True),
    StructField("text_around_trimmed_point", StringType(), nullable=True)
])

api_key = dbutils.secrets.get(scope = "OhMyGPTAPI", key = "OhMyGPTAPI")

#@udf(returnType=schema)

def trim_and_fetch_facts(judgment_text):
    
    trimmed_judgment_text, trimmedType, text_around_trimmed_point = trim_judgment(judgment_text)

    url = "https://api.ohmygpt.com/v1/messages"
    payload = json.dumps({
        "model": "claude-3-haiku-20240307",
        "stream": False,
        "system": "user will give you a legal judgement in Chinese, and you need to extract the kind and amount of drugs involved in the case. only reply using this format, '<kind of drug A>, <amount (in grams)>; <kind of drug B>, <amount(in grams)>'.  If you can't extract the kind of drug or amount. reply 'NA'. reply using Chinese.",
        "messages": [{"role": "user", "content": trimmed_judgment_text}],
        "max_tokens": 4096
    })
    headers = {
        "Authorization": 'Bearer ' + api_key,
        'Content-Type': 'application/json'
    }

    max_retries = 3
    for attempt in range(max_retries):
        try:
            # Parse the response JSON
            parsed_response = json.loads(response.text)
            # Check if the expected keys/path exists
            if "content" in parsed_response and len(parsed_response["content"]) > 0 and "text" in parsed_response["content"][0]:
                response_text = parsed_response["content"][0]["text"]
            else:
                raise KeyError("Unexpected response structure")
            return response_text, trimmedType, text_around_trimmed_point
        except (requests.exceptions.RequestException, json.JSONDecodeError, KeyError) as e:
            print(f"Error on attempt {attempt+1}: {e}")
            # Adjust wait time if needed
            time.sleep(2 ** attempt)

    # Default values in case of failure
    return "Failed to fetch facts from API", "Error", "N/A"



# Read Csv file

In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

# Initialize Spark session
spark = SparkSession.builder.appName("ExtractFacts").getOrCreate()

# Base path where the data is saved
base_path = "/mnt/processed_data_criminal_case_analysis"

# Define the set of Causes of Action to filter and sample
causes_of_action = {"drug_related"}

# Initialize a dictionary to hold the counts
cases_count = {}

# Iterate over each cause of action to read the saved DataFrames, filter, and count the cases
for cause in causes_of_action:
    path = f"{base_path}/{cause}/*/*.csv"
    
    # Read the saved data
    df = spark.read.csv(path, header=True, inferSchema=True)
    
    # Filter to get rows where TrialProcedure is '一审'
    df_filtered = df.filter(col("TrialProcedure").contains("一审"))
    
    # Apply the UDF to the DataFrame and extract results into separate columns
    df_processed = df_filtered.withColumn("processed", trim_and_fetch_facts(df_filtered.FullText))

    df_final = df_processed.select(
        "*",  # Keep existing columns
        col("processed.response_text").alias("response_text"),
        col("processed.trimmedType").alias("trimmedType"),
        col("processed.text_around_trimmed_point").alias("text_around_trimmed_point")
    ).drop("processed")  # Drop the 'processed' struct column
    
    # Write the modified DataFrame back to a CSV, preserving the original structure
    output_path = f"{base_path}/{cause}_DrugTypeAmount_March_15"
    df_final.write.csv(output_path, mode="overwrite", header=True)

    
    # Count the number of cases for the current cause of action and update the dictionary
    count = df_filtered.count()
    cases_count[cause] = count

# Print the counts for each cause of action
for cause, count in cases_count.items():
    print(f"{cause} modified case count: {count}")


## Test Run: read a small csv file - on Spark

In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

# Initialize Spark session
spark = SparkSession.builder.appName("CountCases").getOrCreate()

# Base path where the data is saved
base_path = "/mnt/processed_data_criminal_case_analysis"

# Define the set of Causes of Action to filter and sample
causes_of_action = {"drug_related"}

# Initialize a dictionary to hold the counts
cases_count = {}

# Iterate over each cause of action to read the saved DataFrames, filter, and count the cases
for cause in causes_of_action:
    # Construct the path to read the saved files for the current cause of action
    path = f"{base_path}/{cause}/2015_12_drug_related_judgment_data/*.csv"
    
    # Read the saved data
    df = spark.read.csv(path, header=True, inferSchema=True)
    
    # Filter to get rows where TrialProcedure is '一审'
    df_filtered = df.filter(col("TrialProcedure").contains("一审"))
    
    # Apply the UDF to the DataFrame and extract results into separate columns
    df_processed = df_filtered.withColumn("processed", trim_and_fetch_facts(df_filtered.FullText))

    df_final = df_processed.select(
        "*",  # Keep existing columns
        "processed.response_text",
        "processed.trimmedType",
        "processed.text_around_trimmed_point"
    ).drop("processed")  # Drop the 'processed' struct column
    
    # Write the modified DataFrame back to a CSV, preserving the original structure
    output_path = f"{base_path}/{cause}_DrugTypeAmount_March_15"
    try:
        df_final.write.csv(output_path, mode="overwrite", header=True)
    except Exception as e:
            print(f"Error writing to output path: {e}")

    
    # Count the number of cases for the current cause of action and update the dictionary
    count = df_filtered.count()
    cases_count[cause] = count

# Print the counts for each cause of action
for cause, count in cases_count.items():
    print(f"{cause} modified case count: {count}")


## Test Run: Move the Data Enrichment out of Spark

In [0]:
import os
import pandas as pd

def process_files(base_path, cause):
    # Define the output folder based on the base path and cause
    output_folder = f"{base_path}/{cause}_DrugTypeAmount_March_15"
    
    # Path pattern to match CSV files for the cause of action
    path_pattern = os.path.join(base_path, cause)
    
    # List to collect DataFrames (if needed)
    dfs = []

    # Using dbutils.fs.ls to list directories/files
    files = dbutils.fs.ls(path_pattern)

    for file in files:
        # Check if the item is a directory and iterate through its contents
        if file.isDir():
            sub_files = dbutils.fs.ls(file.path)
            for sub_file in sub_files:
                if sub_file.name.endswith(".csv"):
                    # Reading CSV file into DataFrame
                    file_path = sub_file.path.replace("dbfs:", "/dbfs")  # Convert to local file path if necessary
                    print(file_path)
                    df = pd.read_csv(file_path)

                    
                    # Filter rows where 'TrialProcedure' column contains '一审'
                    df_filtered = df[df['TrialProcedure'].str.contains('一审', na=False)]
                    if isinstance(df_filtered, pd.DataFrame):
                        print("This is a Pandas DataFrame.")
                    else:
                        print("This is not a Pandas DataFrame.")
                    # Apply 'trim_and_fetch_facts' function to 'FullText' column
                    enriched_data = df_filtered['FullText'].apply(trim_and_fetch_facts)

                    if df_filtered.empty:
                        print(f"No data after filtering for {sub_file.name}. Moving to the next file.")
                        continue


                    if not isinstance(enriched_data.iloc[0], tuple) or len(enriched_data.iloc[0]) != 3:
                        print(f"Data structure mismatch in file: {file_path}")
                        continue
                    
                    # Expanding the enriched_data into separate columns
                    df_expanded = pd.DataFrame(enriched_data.tolist(), columns=['ResponseText', 'TrimmedType', 'TextAroundTrimmedPoint'], index=df_filtered.index)
                    df_filtered = pd.concat([df_filtered, df_expanded], axis=1)
                    
                    # Append the enriched DataFrame to a list if you want to combine them later
                    dfs.append(df_filtered)
                    
                    # Construct the output path for the enriched CSV file
                    output_file_path = os.path.join(output_folder, f"{os.path.basename(sub_file.name)}.csv")
                    
                    # Save the processed DataFrame to the new CSV file, ensuring the path is in "/dbfs" format for local IO
                    df_filtered.to_csv(output_file_path.replace("dbfs:", "/dbfs"), index=False)

# Example usage
base_path = "/mnt/processed_data_criminal_case_analysis"
cause = "drug_related"
process_files(base_path, cause)
