In [5]:
import gradio as gr
import requests
import pandas as pd

# Define the FastAPI endpoint
FASTAPI_URL = "http://localhost:8000"  # Replace with the actual FastAPI URL

def classify_text_prompt(text, answer_set):
    """Send a text input to the FastAPI model for classification."""
    response = requests.post(f"{FASTAPI_URL}/predict_text", json={"text": text, "answers": answer_set})
    if response.status_code == 200:
        return response.json()["prediction"]
    else:
        return "Error in prediction."

def classify_excel(file, column_name, answer_set):
    """Send an Excel file to the FastAPI model for batch classification."""
    df = pd.read_excel(file)
    input_data = df[column_name].tolist()
    response = requests.post(f"{FASTAPI_URL}/predict_excel", json={"data": input_data, "answers": answer_set})
    if response.status_code == 200:
        df["ICD-10-CM"] = response.json()["predictions"]
        save_path = "classified_output.xlsx"
        df.to_excel(save_path, index=False)
        return save_path
    else:
        return "Error in classification."

def retrain_model(new_dataset):
    """Send a new dataset to FastAPI for retraining the model."""
    files = {"file": new_dataset}
    response = requests.post(f"{FASTAPI_URL}/retrain", files=files)
    if response.status_code == 200:
        return "Model retrained successfully!"
    else:
        return "Error in retraining model."

def load_file_columns(file):
    """Load file (CSV or Excel) and return its column names."""
    try:
        # Detect file type based on extension
        if file.name.endswith(".csv"):
            df = pd.read_csv(file)
        elif file.name.endswith((".xls", ".xlsx")):
            df = pd.read_excel(file)
        else:
            return "Unsupported file type.", None
        return None, list(df.columns)  # Return column names
    except Exception as e:
        return f"Error loading file: {str(e)}", None


def classify_file(file, column_name, answer_set):
    """Send a CSV or Excel file to the FastAPI model for batch classification."""
    try:
        # Load file dynamically based on its extension
        if file.name.endswith(".csv"):
            df = pd.read_csv(file)
        elif file.name.endswith((".xls", ".xlsx")):
            df = pd.read_excel(file)
        else:
            return "Unsupported file type.", None

        if column_name not in df.columns:
            return "Selected column not found in the file.", None

        # Extract input data from the selected column
        input_data = df[column_name].astype(str).tolist()

        # Call the FastAPI endpoint for classification
        response = requests.post(f"{FASTAPI_URL}/predict_excel", json={"data": input_data, "answers": answer_set})
        if response.status_code == 200:
            df["ICD-10-CM"] = response.json()["predictions"]

            # Save classified data as the same format as the input
            if file.name.endswith(".csv"):
                save_path = "classified_output.csv"
                df.to_csv(save_path, index=False)
            else:
                save_path = "classified_output.xlsx"
                df.to_excel(save_path, index=False)

            return "Classification successful! Download file below.", save_path
        else:
            return f"Error in classification: {response.text}", None
    except Exception as e:
        return f"Error processing file: {str(e)}", None

# Create the Gradio Interface
with gr.Blocks() as gradio_app:
    with gr.Tab("Prompt Text Classification"):
        gr.Markdown("### Classify Text to ICD-10-CM")
        text_input = gr.Textbox(label="Enter Text to Classify")
        answer_set = gr.CheckboxGroup(choices=["Set 1", "Set 2", "Set 3"], label="Select Answer Set")
        classify_button = gr.Button("Classify")
        text_output = gr.Textbox(label="Classification Result")
        classify_button.click(classify_text_prompt, inputs=[text_input, answer_set], outputs=text_output)

    with gr.Tab("Batch File Classification"):
        gr.Markdown("### Classify CSV or Excel File to ICD-10-CM")
        file_input = gr.File(label="Upload CSV or Excel File")
        column_selector = gr.Dropdown(choices=[], label="Select Column for Input", interactive=True)
        answer_set_selection = gr.CheckboxGroup(choices=["Set 1", "Set 2", "Set 3"], label="Select Answer Set")
        classify_file_button = gr.Button("Classify File")
        classification_status = gr.Textbox(label="Status")
        file_output = gr.File(label="Download Classified File")

        # Dynamically load columns when a file is uploaded
        def update_columns(file):
            error, columns = load_file_columns(file)
            if error:
                return gr.update(choices=[], value=None), error
            return gr.update(choices=columns, value=columns[0]), ""

        file_input.change(update_columns, inputs=[file_input], outputs=[column_selector, classification_status])

        # Perform classification when button is clicked
        classify_file_button.click(
            classify_file,
            inputs=[file_input, column_selector, answer_set_selection],
            outputs=[classification_status, file_output]
        )
        
    with gr.Tab("Retrain Model"):
        gr.Markdown("### Retrain the Classification Model")
        new_data_upload = gr.File(label="Upload New Dataset (CSV or Excel)")
        retrain_button = gr.Button("Retrain Model")
        retrain_result = gr.Textbox(label="Retrain Result")
        retrain_button.click(retrain_model, inputs=new_data_upload, outputs=retrain_result)

# Launch the Gradio app
gradio_app.launch()


* Running on local URL:  http://127.0.0.1:7863

To create a public link, set `share=True` in `launch()`.


