In [None]:
import sys
import os
import pandas as pd
import numpy as np
from PyQt6.QtWidgets import (QApplication, QMainWindow, QFileDialog, QVBoxLayout, 
                           QHBoxLayout, QWidget, QPushButton, QSizePolicy, 
                           QComboBox, QLabel, QGridLayout)
from PyQt6.QtCore import Qt, pyqtSignal, QSize
from PyQt6.QtGui import QColor
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas, NavigationToolbar2QT as NavigationToolbar
from matplotlib.figure import Figure
from matplotlib.widgets import SpanSelector
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, rgb2hex
from matplotlib.cm import ScalarMappable
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

class RoundButton(QPushButton):
    def __init__(self, text, parent=None):
        super().__init__(text, parent)
        self.setFixedSize(40, 40)
        self.setStyleSheet("""
            QPushButton {
                border-radius: 20px;
                border: 2px solid #444444;
                font-weight: bold;
                color: white;
                background-color: #808080;
            }
        """)
        
    def set_inactive(self):
        self.setStyleSheet("""
            QPushButton {
                border-radius: 20px;
                border: 2px solid #444444;
                font-weight: bold;
                color: white;
                background-color: #D3D3D3;
            }
        """)
        self.setEnabled(False)

class WellPlate(QWidget):
    well_clicked = pyqtSignal(str)

    def __init__(self, parent=None):
        super().__init__(parent)
        self.layout = QVBoxLayout()
        self.setLayout(self.layout)
        
        # Add layout selector
        layout_selector = QHBoxLayout()
        self.layout_combo = QComboBox()
        self.layout_combo.addItems(['96 Well Plate', '2x 44 Well Plates'])
        self.layout_combo.currentTextChanged.connect(self.change_plate_layout)
        layout_selector.addWidget(QLabel("Plate Layout:"))
        layout_selector.addWidget(self.layout_combo)
        layout_selector.addStretch()
        self.layout.addLayout(layout_selector)

        self.info_label = QLabel()
        self.layout.addWidget(self.info_label)

        # Create horizontal layout for plates and barchart
        plate_and_chart = QHBoxLayout()
        self.layout.addLayout(plate_and_chart)

        # Left side: Plates
        left_side = QWidget()
        left_layout = QVBoxLayout()
        left_side.setLayout(left_layout)
        
        self.plates_widget = QWidget()
        self.plates_layout = QVBoxLayout()
        self.plates_widget.setLayout(self.plates_layout)
        left_layout.addWidget(self.plates_widget)
        
        plate_and_chart.addWidget(left_side)

        # Initialize with 96 well plate
        self.current_layout = "96"
        self.buttons = {}
        self.active_wells = set()
        self.setup_96_well_plate()

        # Add colormap selection
        self.colormap_combo = QComboBox()
        self.colormap_combo.addItems(['viridis', 'plasma', 'inferno', 'magma', 'cividis'])
        self.colormap_combo.currentTextChanged.connect(self.change_colormap)
        self.layout.addWidget(self.colormap_combo)
        
        self.colormap = plt.get_cmap('viridis')
        
        # Right side: Barchart
        right_side = QWidget()
        right_layout = QVBoxLayout()
        right_side.setLayout(right_layout)
        
        self.barchart_figure = Figure(figsize=(6, 8))
        self.barchart_canvas = FigureCanvas(self.barchart_figure)
        self.barchart_canvas.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
        right_layout.addWidget(self.barchart_canvas)
        
        plate_and_chart.addWidget(right_side)
        
        # Set stretch factors to control relative sizes
        plate_and_chart.setStretch(0, 1)  # Plate layout
        plate_and_chart.setStretch(1, 1)  # Barchart

    def setup_96_well_plate(self):
        self.clear_plates()
        plate_layout = QGridLayout()
        self.plates_layout.addLayout(plate_layout)
        
        self.buttons = {}
        for i in range(8):  # A-H
            for j in range(12):  # 1-12
                well = f"{chr(65+i)}{j+1:02d}"
                button = RoundButton(well)
                button.clicked.connect(lambda _, w=well: self.on_well_clicked(w))
                plate_layout.addWidget(button, i, j)
                self.buttons[well] = button
                if well not in self.active_wells:
                    button.set_inactive()

    def setup_44_well_plates(self):
        self.clear_plates()
        
        # Create two plate layouts
        for plate_num in range(2):
            plate_layout = QGridLayout()
            plate_label = QLabel(f"Plate {plate_num + 1}")
            self.plates_layout.addWidget(plate_label)
            self.plates_layout.addLayout(plate_layout)
            
            # 44 well plate is 4x11
            for i in range(4):  # A-D
                for j in range(11):  # 1-11
                    well = f"{chr(65+i)}{j+1:02d}"
                    if plate_num == 1:
                        # For second plate, use wells E-H
                        well = f"{chr(69+i)}{j+1:02d}"  # E-H, 1-11
                    button = RoundButton(well)
                    button.clicked.connect(lambda _, w=well: self.on_well_clicked(w))
                    plate_layout.addWidget(button, i, j)
                    self.buttons[well] = button
                    if well not in self.active_wells:
                        button.set_inactive()
            
            if plate_num == 0:
                self.plates_layout.addSpacing(20)  # Add space between plates

    def clear_plates(self):
        # Clear existing plate layouts
        while self.plates_layout.count():
            item = self.plates_layout.takeAt(0)
            if item.widget():
                item.widget().deleteLater()
            elif item.layout():
                while item.layout().count():
                    subitem = item.layout().takeAt(0)
                    if subitem.widget():
                        subitem.widget().deleteLater()

    def change_plate_layout(self, layout_text):
        self.current_layout = "96" if layout_text == "96 Well Plate" else "44"
        if self.current_layout == "96":
            self.setup_96_well_plate()
        else:
            self.setup_44_well_plates()
        
        # Reapply any existing heatmap
        if hasattr(self, 'last_values'):
            self.update_heatmap(self.last_values)

    def set_active_wells(self, wells):
        self.active_wells = set(wells)
        if self.current_layout == "96":
            self.setup_96_well_plate()
        else:
            self.setup_44_well_plates()

    def on_well_clicked(self, well):
        if well in self.active_wells:
            self.well_clicked.emit(well)

    def change_colormap(self, colormap_name):
        self.colormap = plt.get_cmap(colormap_name)
        if hasattr(self, 'last_values'):
            self.update_heatmap(self.last_values)

    def update_heatmap(self, values):
        if not values or not self.buttons:
            return
            
        self.last_values = values
        norm = Normalize(vmin=min(values), vmax=max(values))
        
        # Create dictionary of well:value pairs
        value_dict = {well: value for well, value in zip(self.active_wells, values)}
        
        # Update button colors
        for well, button in self.buttons.items():
            if well in self.active_wells:
                value = value_dict[well]
                color = self.colormap(norm(value))
                hex_color = rgb2hex(color)
                button.setStyleSheet(f"""
                    QPushButton {{
                        background-color: {hex_color};
                        color: white;
                        border-radius: 20px;
                        border: 2px solid #444444;
                        font-weight: bold;
                    }}
                """)
                button.setEnabled(True)
        
        # Update bar chart
        self.barchart_figure.clear()
        ax = self.barchart_figure.add_subplot(111)
        wells = list(self.active_wells)
        colors = [self.colormap(norm(value_dict[well])) for well in wells]
        y_pos = np.arange(len(wells))
        values_ordered = [value_dict[well] for well in wells]
        
        ax.barh(y_pos, values_ordered, align='center', color=colors)
        ax.set_yticks(y_pos)
        ax.set_yticklabels(wells, fontsize=4)
        ax.invert_yaxis()
        ax.set_xlabel('Intensity')
        ax.set_title('Intensity by Well')

        # Add mean and std dev lines
        mean_value = np.mean(values)
        std_dev = np.std(values)
        ax.axvline(mean_value, color='blue', linestyle=':', label='Mean')
        ax.axvline(mean_value + std_dev, color='red', linestyle=':', label='+1 Std Dev')
        ax.axvline(mean_value - std_dev, color='red', linestyle=':', label='-1 Std Dev')
        ax.legend(fontsize='x-small')

        self.barchart_figure.tight_layout()
        self.barchart_canvas.draw()

    def set_info_label(self, mass_range, mean, std_dev):
        self.info_label.setText(f"Mass Range: {mass_range[0]:.2f} - {mass_range[1]:.2f}, Mean: {mean:.2f}, Std Dev: {std_dev:.2f}")

class MassSpectrumAnalyzer(QMainWindow):
    def __init__(self):
        super().__init__()
        self.setWindowTitle("Mass Spectrum Analyzer")
        self.setGeometry(100, 100, 1400, 800)

        self.central_widget = QWidget()
        self.setCentralWidget(self.central_widget)
        self.layout = QVBoxLayout(self.central_widget)

        self.well_plate = WellPlate()
        self.well_plate.well_clicked.connect(self.update_well_spectrum)
        self.layout.addWidget(self.well_plate)

        self.figure = Figure(figsize=(5, 4), dpi=100)
        self.canvas = FigureCanvas(self.figure)
        self.layout.addWidget(self.canvas)

        self.toolbar = NavigationToolbar(self.canvas, self)
        self.layout.addWidget(self.toolbar)

        button_layout = QHBoxLayout()
        self.select_folder_button = QPushButton("Select Folder")
        self.select_folder_button.clicked.connect(self.select_folder)
        button_layout.addWidget(self.select_folder_button)

        self.average_spectrum_button = QPushButton("Show Average Spectrum")
        self.average_spectrum_button.clicked.connect(self.plot_average_spectrum)
        button_layout.addWidget(self.average_spectrum_button)

        self.export_csv_button = QPushButton("Export to CSV")
        self.export_csv_button.clicked.connect(self.export_to_csv)
        button_layout.addWidget(self.export_csv_button)

        self.pca_button = QPushButton("Perform PCA")
        self.pca_button.clicked.connect(self.perform_pca)
        button_layout.addWidget(self.pca_button)

        self.layout.addLayout(button_layout)

        self.data = {}
        self.average_spectrum = None
        self.current_spectrum = None
        self.span = None
        self.last_selected_range = None

        self.canvas.mpl_connect('button_press_event', self.on_mouse_press)
        self.canvas.mpl_connect('button_release_event', self.on_mouse_release)

    def select_folder(self):
        folder = QFileDialog.getExistingDirectory(self, "Select Folder")
        if folder:
            self.load_data(folder)
            self.plot_average_spectrum()

    def load_data(self, folder):
        self.data = {}
        active_wells = []
        
        for file in os.listdir(folder):
            if file.endswith(('.csv', '.CSV')):
                well = file[-7:-4]  # Extract well ID from filename
                file_path = os.path.join(folder, file)
                # Read data and handle duplicate mass-to-charge values by averaging
                df = pd.read_csv(file_path, names=["mass_to_charge", "intensity"])
                # Group by mass_to_charge and average the intensities
                df = df.groupby('mass_to_charge')['intensity'].mean().reset_index()
                self.data[well] = df
                active_wells.append(well)

        # Set active wells in well plate
        self.well_plate.set_active_wells(active_wells)

        # Create average spectrum by first creating a common mass-to-charge axis
        all_mz = sorted(set().union(*[set(df['mass_to_charge']) for df in self.data.values()]))
        
        # Interpolate each spectrum onto the common axis
        aligned_spectra = []
        for df in self.data.values():
            spectrum = pd.Series(index=all_mz, dtype=float)
            spectrum.loc[df['mass_to_charge']] = df['intensity']
            # Fill gaps with 0 or interpolate as needed
            spectrum = spectrum.fillna(0)
            aligned_spectra.append(spectrum)
        
        # Calculate average spectrum
        self.average_spectrum = pd.concat(aligned_spectra, axis=1).mean(axis=1)

 

    def plot_average_spectrum(self):
        self.figure.clear()
        ax = self.figure.add_subplot(111)
        ax.plot(self.average_spectrum.index, self.average_spectrum.values)
        ax.set_xlabel('Mass to Charge')
        ax.set_ylabel('Intensity')
        ax.set_title('Average Mass Spectrum')
        
        self.span = SpanSelector(ax, self.on_select, 'horizontal', useblit=True, 
                                 props=dict(alpha=0.5, facecolor='red'))
        self.canvas.draw()
        self.current_spectrum = 'average'
    
        if self.last_selected_range:
            self.update_heatmap(self.last_selected_range)
    
    def update_well_spectrum(self, well):
        if well not in self.data:
            return
        self.figure.clear()
        ax = self.figure.add_subplot(111)
        spectrum = self.data[well]
        ax.plot(spectrum['mass_to_charge'], spectrum['intensity'])
        ax.set_xlabel('Mass to Charge')
        ax.set_ylabel('Intensity')
        ax.set_title(f'Mass Spectrum for Well {well}')
        
        self.span = SpanSelector(ax, self.on_select, 'horizontal', useblit=True, 
                                 props=dict(alpha=0.5, facecolor='red'))
        self.canvas.draw()
        self.current_spectrum = well
    
        if self.last_selected_range:
            self.update_heatmap(self.last_selected_range)
    
    def on_select(self, xmin, xmax):
        self.last_selected_range = (xmin, xmax)
        self.update_heatmap((xmin, xmax))
    
    def update_heatmap(self, mass_range):
        values = []
        for well, spectrum in self.data.items():
            mask = (spectrum['mass_to_charge'] >= mass_range[0]) & (spectrum['mass_to_charge'] <= mass_range[1])
            mean_intensity = spectrum.loc[mask, 'intensity'].mean()
            values.append(mean_intensity)
        self.well_plate.update_heatmap(values)
        
        mean_value = np.mean(values)
        std_dev = np.std(values)
        self.well_plate.set_info_label(mass_range, mean_value, std_dev)
    
    def export_to_csv(self):
        if not self.last_selected_range:
            return
    
        file_name, _ = QFileDialog.getSaveFileName(self, "Save CSV", "", "CSV Files (*.csv)")
        if file_name:
            data = []
            for well, spectrum in self.data.items():
                mask = (spectrum['mass_to_charge'] >= self.last_selected_range[0]) & (spectrum['mass_to_charge'] <= self.last_selected_range[1])
                mean_intensity = spectrum.loc[mask, 'intensity'].mean()
                data.append({
                    'Well': well,
                    'Mass Range': f"{self.last_selected_range[0]:.2f} - {self.last_selected_range[1]:.2f}",
                    'Intensity': mean_intensity
                })
            
            df = pd.DataFrame(data)
            df.to_csv(file_name, index=False)
    
    def perform_pca(self):
        if not self.data or not self.last_selected_range:
            return
    
        # Prepare data for PCA
        X = []
        wells = []
        for well, spectrum in self.data.items():
            mask = (spectrum['mass_to_charge'] >= self.last_selected_range[0]) & (spectrum['mass_to_charge'] <= self.last_selected_range[1])
            X.append(spectrum.loc[mask, 'intensity'].values)
            wells.append(well)
    
        X = np.array(X)
    
        # Standardize the data
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)
    
        # Perform PCA
        pca = PCA(n_components=2)
        pca_result = pca.fit_transform(X_scaled)
    
        # Create and show PCA window
        self.pca_window = PCAWindow(pca_result, wells)
        self.pca_window.show()
    
    def on_mouse_press(self, event):
        if event.button == 3:  # Right mouse button
            self.zoom_start = (event.xdata, event.ydata)
    
    def on_mouse_release(self, event):
        if event.button == 3:  # Right mouse button
            if event.dblclick:
                self.reset_zoom()
            elif hasattr(self, 'zoom_start'):
                self.zoom(self.zoom_start, (event.xdata, event.ydata))
    
    def zoom(self, start, end):
        ax = self.figure.gca()
        x_min, x_max = sorted([start[0], end[0]])
        y_min, y_max = sorted([start[1], end[1]])
        ax.set_xlim(x_min, x_max)
        ax.set_ylim(y_min, y_max)
        self.canvas.draw()
    
    def reset_zoom(self):
        ax = self.figure.gca()
        ax.relim()
        ax.autoscale_view()
        self.canvas.draw()

if __name__ == '__main__':
    app = QApplication(sys.argv)
    window = MassSpectrumAnalyzer()
    window.show()
    sys.exit(app.exec())  # Note: In PyQt6, exec_() is now just exec()
