In [None]:
import ipywidgets as widgets
from IPython.display import display
import matplotlib.pyplot as plt
from matplotlib.transforms import Affine2D
import numpy as np
import pandas as pd
import scanpy as sc
import squidpy as sq
import os

class VoilaApp:
    
    def __init__(self):
        self.adata = None  # Scanpy AnnData object to hold the loaded data
        self.setup_ui()

    def setup_ui(self):
        # File upload widget
        self.file_upload = widgets.FileUpload(accept='.h5ad,.hdf5', multiple=False)
        self.file_upload.observe(self.handle_file_upload, names='value')

        # Status output
        self.status_output = widgets.Output()

        # Left side widgets (Clusters, Spatial Plot)
        self.clusters_list_widget = widgets.SelectMultiple(options=[])
        self.other_clusters_list_widget = widgets.SelectMultiple(options=[])

        self.radio_all = widgets.RadioButtons(options=['All', 'Select'], value='All', description='Spatial Plot:')
        self.axis_input = widgets.FloatText(value=0, description='Axis:')
        self.start_button = widgets.Button(description="Start")
        self.start_button.on_click(self.plot_spatial_scatter)

        self.deg_plot_button = widgets.Button(description="DEG Analysis")
        self.deg_plot_button.on_click(self.deg_plot)

        self.deg_save_button = widgets.Button(description="Save DEG Result", disabled=True)
        self.deg_save_button.on_click(self.save_deg_result)

        left_vbox = widgets.VBox([
            self.file_upload,
            widgets.Label("Clusters"),
            self.clusters_list_widget,
            widgets.Label("Spatial Plot"),
            self.radio_all,
            self.axis_input,
            self.start_button,
            widgets.Label("Other Clusters"),
            self.other_clusters_list_widget,
            self.deg_plot_button,
            self.deg_save_button
        ])

        # Right side widgets (Genes, Plotting options)
        self.genes_list_widget = widgets.SelectMultiple(options=[])
        self.dot_plot_radio_button = widgets.RadioButtons(options=['Dot plot', 'Violin plot', 'Feature plot', 'Scatter plot'], description="Plot Type:")
        self.genes_start_button = widgets.Button(description="Plot Genes")
        self.genes_start_button.on_click(self.plot_genes)

        right_vbox = widgets.VBox([
            widgets.Label("Genes"),
            self.genes_list_widget,
            self.dot_plot_radio_button,
            self.genes_start_button
        ])

        # Output area for plots
        self.plot_output = widgets.Output()

        # Main layout
        self.main_layout = widgets.HBox([left_vbox, self.plot_output, right_vbox])

    def handle_file_upload(self, change):
        with self.status_output:
            self.status_output.clear_output()
            try:
                # Check if any file has been uploaded
                uploaded_file = next(iter(self.file_upload.value.values()))
                content = uploaded_file['content']

                # Save the uploaded file temporarily
                temp_file_path = 'temp_file.h5ad'
                with open(temp_file_path, 'wb') as f:
                    f.write(content)

                # Load the data
                self.load_data(temp_file_path)

                # Remove the temporary file
                os.remove(temp_file_path)
                
                print(f"File {uploaded_file['name']} successfully uploaded and processed.")
            except Exception as e:
                print(f"Error processing file: {str(e)}")


    def load_data(self, file_name):
        with self.status_output:
            self.status_output.clear_output()
            try:
                # Load the data using scanpy's read_h5ad function
                self.adata = sc.read_h5ad(file_name)
                
                # Update the clusters and genes list
                self.update_lists()
                print(f"Loaded '{file_name}'")
            except Exception as e:
                print(f"Failed to load file: {e}")

    def update_lists(self):
        # Update cluster and gene lists based on loaded data
        if self.adata is not None:
            self.clusters_list_widget.options = sorted(self.adata.obs['cell_type_2'].unique())
            self.other_clusters_list_widget.options = sorted(self.adata.obs['cell_type_2'].unique())
            self.genes_list_widget.options = sorted(self.adata.var_names.to_list())

    def plot_spatial_scatter(self, _):
        with self.plot_output:
            self.plot_output.clear_output()
            axis = 360
            num = self.axis_input.value
            if num > 0:
                axis = 180 / num

            if self.adata is not None:
                fig, ax = plt.subplots()

                # Extract x and y coordinates from adata
                x_coords = self.adata.obsm['spatial'][:, 0]
                y_coords = self.adata.obsm['spatial'][:, 1]

                # Calculate the center of the plot
                center_x = (x_coords.max() + x_coords.min()) / 2
                center_y = (y_coords.max() + y_coords.min()) / 2

                # Apply rotation transformation around the center
                rotation = Affine2D().rotate_around(center_x, center_y, -np.pi / axis)
                ax.transData = rotation + ax.transData

                # Plot using squidpy
                sq.pl.spatial_scatter(self.adata, shape=None, color="cell_type_2", ax=ax, frameon=False, title=None)

                plt.show()

    def plot_genes(self, _):
        with self.plot_output:
            self.plot_output.clear_output()
            selected_genes = list(self.genes_list_widget.value)
            
            if not selected_genes:
                print("Please select at least one gene.")
                return

            axis = 360
            num = self.axis_input.value
            if num > 0:
                axis = 180 / num

            if self.dot_plot_radio_button.value == 'Scatter plot':
                num_genes = len(selected_genes)
                fig, axes = plt.subplots(1, num_genes, figsize=(5 * num_genes, 5))

                if num_genes == 1:
                    axes = [axes]

                for gene, ax in zip(selected_genes, axes):
                    x_coords = self.adata.obsm['spatial'][:, 0]
                    y_coords = self.adata.obsm['spatial'][:, 1]

                    center_x = (x_coords.max() + x_coords.min()) / 2
                    center_y = (y_coords.max() + y_coords.min()) / 2

                    rotation = Affine2D().rotate_around(center_x, center_y, -np.pi / axis)
                    ax.transData = rotation + ax.transData

                    sq.pl.spatial_scatter(self.adata, shape=None, color=gene, ax=ax, frameon=False, title=None, cmap='Purples', size=1)

                    ax.set_title(f"{gene}")

                    coords = rotation.transform(np.vstack([x_coords, y_coords]).T)
                    x_min, y_min = coords.min(axis=0)
                    x_max, y_max = coords.max(axis=0)

                    ax.set_xlim(x_min-1000, x_max+1000)
                    ax.set_ylim(y_min-1000, y_max+1000)

                plt.show()

            else:
                if self.dot_plot_radio_button.value == 'Dot plot':
                    sc.pl.dotplot(self.adata, selected_genes, groupby="cell_type_2")
                elif self.dot_plot_radio_button.value == 'Violin plot':
                    sc.pl.stacked_violin(self.adata, selected_genes, groupby="cell_type_2")
                elif self.dot_plot_radio_button.value == 'Feature plot':
                    selected_genes.append('cell_type_2')
                    sc.pl.umap(self.adata, color=selected_genes, ncols=4)
                plt.show()

    def deg_plot(self, _):
        self.deg_save_button.disabled = False

        self.selected_clusters = list(self.clusters_list_widget.value)
        self.selected_other_clusters = list(self.other_clusters_list_widget.value)

        if not self.selected_clusters or not self.selected_other_clusters:
            print("Please select at least one cluster in both Cluster 1 and Cluster 2.")
            return
        
        merged_cluster_A = '_'.join(self.selected_clusters)
        merged_cluster_B = '_'.join(self.selected_other_clusters)

        self.adata.obs['cell_type_3'] = self.adata.obs['cell_type_2']
        self.adata.obs['cell_type_3'] = self.adata.obs['cell_type_3'].astype(str)
        self.adata.obs['cell_type_3'][self.adata.obs['cell_type_3'].isin(self.selected_clusters)] = merged_cluster_A
        self.adata.obs['cell_type_3'][self.adata.obs['cell_type_3'].isin(self.selected_other_clusters)] = merged_cluster_B                        
        self.adata.obs['cell_type_3'] = self.adata.obs['cell_type_3'].astype('category')

        sc.tl.rank_genes_groups(self.adata, 'cell_type_3', groups=[merged_cluster_A], reference=merged_cluster_B, method='wilcoxon')
        
        with self.plot_output:
            self.plot_output.clear_output()
            sc.pl.rank_genes_groups(self.adata, n_genes=25, sharey=False)
            plt.show()

    def save_deg_result(self, _):
        if 'rank_genes_groups' in self.adata.uns:            
            result = self.adata.uns['rank_genes_groups']
            groups = result['names'].dtype.names            
            result_df = pd.DataFrame(
                {group + '_' + key: result[key][group]
                for group in groups for key in ['names', 'scores', 'logfoldchanges', 'pvals', 'pvals_adj']}
            )
            # Save to CSV
            self.save_to_csv(result_df)

    def save_to_csv(self, df):
        with self.plot_output:
            self.plot_output.clear_output()
            print("Saving results to CSV is not directly supported in Voila. "
                  "Please use Jupyter's built-in save features to export the DataFrame.")
            display(df)

    def run(self):
        display(self.main_layout)

# Run the application
app = VoilaApp()
app.run()
