In [2]:
import io
import requests
import ipywidgets as widgets
from IPython.display import display

# Define widgets
file_upload = widgets.FileUpload(
    accept='audio/*',  # Accept audio files
    multiple=False  # Do not allow multiple file uploads
)

button_mlp_overfitted = widgets.Button(description="Overfitted Multilayer Perceptron")
button_mlp_no_overfit = widgets.Button(description="Multilayer Perceptron (no overfit)")
button_cnn = widgets.Button(description="CNN")

output = widgets.Output()

# Function to handle file upload
uploaded_file_content = None  # Global variable to store the uploaded file content

def on_upload_change(change):
    global uploaded_file_content
    with output:
        output.clear_output()
        if file_upload.value:
            try:
                uploaded_file = file_upload.value[list(file_upload.value.keys())[0]]
                file_name = uploaded_file['metadata']['name']
                file_content = uploaded_file['content']

                # Store the uploaded file content
                uploaded_file_content = (file_name, io.BytesIO(file_content))
                print(f"File '{file_name}' uploaded successfully.")
                
            except Exception as e:
                print(f"Error during file upload: {e}")

# Attach handler to the file upload widget
file_upload.observe(on_upload_change, names='value')

# Function to handle prediction requests
def predict_model(endpoint):
    global uploaded_file_content
    with output:
        output.clear_output()
        if not uploaded_file_content:
            print("No file uploaded. Please upload an audio file first.")
            return
        try:
            response = requests.post(
                endpoint,
                files={"file": uploaded_file_content}
            )
            response.raise_for_status()
            
            response_json = response.json()
            predicted_genre = response_json.get('predicted_genre')
            confidence = response_json.get('confidence')

            if predicted_genre and confidence is not None:
                print(f"Predicted Genre: {predicted_genre}, Confidence: {confidence:.2f}")
            else:
                print("Error: Missing 'predicted_genre' or 'confidence' in the response.")
        except requests.exceptions.RequestException as e:
            print(f"Error during request to FastAPI: {e}")

# Button click event handlers
def on_click_mlp_overfitted(b):
    predict_model("http://localhost:8000/predict_overfitted_mlp")

def on_click_mlp_no_overfit(b):
    predict_model("http://localhost:8000/predict_mlp_no_overfit")

def on_click_cnn(b):
    predict_model("http://localhost:8000/predict_cnn")

# Attach handlers to the buttons
button_mlp_overfitted.on_click(on_click_mlp_overfitted)
button_mlp_no_overfit.on_click(on_click_mlp_no_overfit)
button_cnn.on_click(on_click_cnn)

# Display the widgets
display(file_upload, button_mlp_overfitted, button_mlp_no_overfit, button_cnn, output)


FileUpload(value={}, accept='audio/*', description='Upload')

Output()