In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tkinter as tk
from tkinter import ttk, messagebox
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
import threading
import queue
import time

# --- Optional: Guard PennyLane import so the GUI still works without it ---
try:
    import pennylane as qml
    PENNYLANE_AVAILABLE = True
except Exception:
    qml = None
    PENNYLANE_AVAILABLE = False


class QuantumMarketClassifier:
    """Quantum + Classical Market Direction Classifier with a responsive Tkinter GUI.

    Key upgrades in this version:
    - Non-blocking training via background thread (no GUI freeze)
    - Progress bar + live status updates
    - Safer UI updates from thread using a Queue + root.after polling
    - Start/Stop controls; Compare becomes enabled when training completes
    - Configurable quantum params (qubits/layers/epochs) from the UI
    - Robust quantum embedding (pads/truncates feature vector)
    - More stable MLP training; more realistic price simulation
    - Minor ttk style fixes for better cross-platform appearance
    """

    def __init__(self, root):
        self.root = root
        self.root.title("Quantum Market Direction Classifier (Threaded)")
        self.root.geometry("1200x820")

        # --- App state ---
        self.training_thread = None
        self.stop_event = threading.Event()
        self.msg_queue = queue.Queue()

        # --- Config ---
        self.default_qubits = 3
        self.default_layers = 2
        self.default_epochs = 6
        self.n_qubits = self.default_qubits
        self.n_layers = self.default_layers
        self.q_epochs = self.default_epochs

        # Quantum device (only if PennyLane available)
        self.device = None
        if PENNYLANE_AVAILABLE:
            self.device = qml.device("default.qubit", wires=self.n_qubits)

        # Data/model objects
        self.data = None
        self.X_train = self.X_test = None
        self.y_train = self.y_test = None
        self.scaler = StandardScaler()
        self.models = {}
        self.q_params = None
        self.q_model = None

        # Create UI
        self._build_ui()
        self._apply_style()

        # Start polling queue for async updates
        self._poll_queue()

    # ------------------------ UI BUILD ------------------------ #
    def _build_ui(self):
        self.main = ttk.Frame(self.root, padding=10)
        self.main.pack(fill=tk.BOTH, expand=True)

        # Top controls
        top = ttk.LabelFrame(self.main, text="Input & Model Parameters", padding=10)
        top.pack(fill=tk.X, pady=6)

        # Data params
        ttk.Label(top, text="Ticker Symbol:").grid(row=0, column=0, sticky=tk.W, padx=(0,6), pady=3)
        self.ticker_var = tk.StringVar(value="AAPL")
        ttk.Entry(top, textvariable=self.ticker_var, width=10).grid(row=0, column=1, sticky=tk.W, pady=3)

        ttk.Label(top, text="Period (days):").grid(row=0, column=2, sticky=tk.W, padx=(16,6))
        self.period_var = tk.IntVar(value=200)
        ttk.Entry(top, textvariable=self.period_var, width=8).grid(row=0, column=3, sticky=tk.W)

        ttk.Label(top, text="Test Size:").grid(row=0, column=4, sticky=tk.W, padx=(16,6))
        self.test_size_var = tk.DoubleVar(value=0.2)
        ttk.Entry(top, textvariable=self.test_size_var, width=6).grid(row=0, column=5, sticky=tk.W)

        # Quantum params
        ttk.Label(top, text="Qubits:").grid(row=1, column=0, sticky=tk.W, pady=3)
        self.qubits_var = tk.IntVar(value=self.default_qubits)
        ttk.Spinbox(top, from_=2, to=8, textvariable=self.qubits_var, width=6,
                    command=self._update_quantum_params).grid(row=1, column=1, sticky=tk.W)

        ttk.Label(top, text="Layers:").grid(row=1, column=2, sticky=tk.W)
        self.layers_var = tk.IntVar(value=self.default_layers)
        ttk.Spinbox(top, from_=1, to=8, textvariable=self.layers_var, width=6,
                    command=self._update_quantum_params).grid(row=1, column=3, sticky=tk.W)

        ttk.Label(top, text="Q Epochs:").grid(row=1, column=4, sticky=tk.W)
        self.epochs_var = tk.IntVar(value=self.default_epochs)
        ttk.Spinbox(top, from_=1, to=30, textvariable=self.epochs_var, width=6,
                    command=self._update_quantum_params).grid(row=1, column=5, sticky=tk.W)

        # Buttons
        btns = ttk.Frame(self.main)
        btns.pack(fill=tk.X, pady=8)

        self.btn_load = ttk.Button(btns, text="1) Load Data", command=self.load_data)
        self.btn_load.pack(side=tk.LEFT, padx=4)

        self.btn_train = ttk.Button(btns, text="2) Train (non-blocking)", command=self.start_training, state=tk.DISABLED)
        self.btn_train.pack(side=tk.LEFT, padx=4)

        self.btn_stop = ttk.Button(btns, text="Stop Training", command=self.request_stop, state=tk.DISABLED)
        self.btn_stop.pack(side=tk.LEFT, padx=4)

        self.btn_compare = ttk.Button(btns, text="3) Compare Models", command=self.compare_models, state=tk.DISABLED)
        self.btn_compare.pack(side=tk.LEFT, padx=4)

        # Status + progress bar
        statf = ttk.Frame(self.main)
        statf.pack(fill=tk.X, pady=(0,6))
        self.status_var = tk.StringVar(value="Ready.")
        self.status_lbl = ttk.Label(statf, textvariable=self.status_var)
        self.status_lbl.pack(side=tk.LEFT)

        self.progress = ttk.Progressbar(statf, length=280, mode='determinate')
        self.progress.pack(side=tk.RIGHT)

        # Notebook
        nb = ttk.Notebook(self.main)
        nb.pack(fill=tk.BOTH, expand=True)
        self.nb = nb

        # Tabs
        self.tab_data = ttk.Frame(nb)
        self.tab_models = ttk.Frame(nb)
        self.tab_circuit = ttk.Frame(nb)
        nb.add(self.tab_data, text="Data Visualization")
        nb.add(self.tab_models, text="Model Comparison")
        nb.add(self.tab_circuit, text="Quantum Circuit")

        # Data plot
        self.data_fig, self.data_ax = plt.subplots(figsize=(8, 4), dpi=100)
        self.data_canvas = FigureCanvasTkAgg(self.data_fig, master=self.tab_data)
        self.data_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
        self._init_data_plot()

        # Results table + bar plot
        self._init_model_tab()

        # Circuit placeholder
        self.circuit_fig, self.circuit_ax = plt.subplots(figsize=(8, 3.5), dpi=100)
        self.circuit_canvas = FigureCanvasTkAgg(self.circuit_fig, master=self.tab_circuit)
        self.circuit_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
        self.circuit_ax.text(0.5, 0.5, "Quantum Circuit Visualization\n(appears after training)",
                             ha='center', va='center')
        self.circuit_ax.axis('off')
        self.circuit_canvas.draw()

    def _apply_style(self):
        style = ttk.Style()
        # Basic theme adjustments
        style.configure('TFrame', background='white')
        style.configure('TLabelFrame', background='white', font=('Helvetica', 10, 'bold'))
        style.configure('TButton', font=('Helvetica', 9))
        style.configure('TNotebook', background='white')
        style.configure('TNotebook.Tab', font=('Helvetica', 9))
        # Ensure entries are readable across platforms
        style.configure('TEntry', fieldbackground='white', foreground='black')
        self.root.option_add('*foreground', 'black')
        self.root.option_add('*background', 'white')

    # ------------------------ DATA ------------------------ #
    def load_data(self):
        try:
            ticker = self.ticker_var.get().strip().upper()
            period = max(40, int(self.period_var.get()))
            dates = pd.date_range(end=pd.Timestamp.today(), periods=period)

            # Geometric random walk for realistic positive prices
            base_price = 100.0
            random_walk = np.random.normal(0, 0.01, period).cumsum()
            prices = base_price * np.exp(random_walk)

            df = pd.DataFrame({
                'Date': dates,
                'Close': prices,
            })
            df['Return'] = df['Close'].pct_change()
            df['SMA_5'] = df['Close'].rolling(5).mean()
            df['SMA_20'] = df['Close'].rolling(20).mean()
            df['Direction'] = (df['Return'] > 0).astype(int)
            df.dropna(inplace=True)
            self.data = df.reset_index(drop=True)

            features = self.data[['Close', 'Return', 'SMA_5', 'SMA_20']].values
            target = self.data['Direction'].values

            self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(
                features, target, test_size=float(self.test_size_var.get()), shuffle=False
            )
            self.X_train = self.scaler.fit_transform(self.X_train)
            self.X_test = self.scaler.transform(self.X_test)

            # Enable training
            self.btn_train.config(state=tk.NORMAL)

            # Plot
            self._plot_data()

            messagebox.showinfo("Success", f"Loaded {len(self.data)} records for {ticker}")
            self._set_status("Data loaded.")
        except Exception as e:
            messagebox.showerror("Error", f"Failed to load data:\n{e}")
            self._set_status("Load failed.")

    def _init_data_plot(self):
        self.data_ax.set_title("Market Data")
        self.data_ax.set_xlabel("Date")
        self.data_ax.set_ylabel("Price")
        self.data_ax.grid(True, linestyle='--', alpha=0.6)
        self.data_ax.plot([], [], label='Price')
        self.data_ax.plot([], [], label='SMA 5')
        self.data_ax.plot([], [], label='SMA 20')
        self.data_ax.legend()
        self.data_canvas.draw()

    def _plot_data(self):
        ax = self.data_ax
        ax.clear()
        ax.plot(self.data['Date'], self.data['Close'], label='Close', linewidth=1.4)
        ax.plot(self.data['Date'], self.data['SMA_5'], label='SMA 5', alpha=0.8)
        ax.plot(self.data['Date'], self.data['SMA_20'], label='SMA 20', alpha=0.8)
        up = self.data[self.data['Direction'] == 1]
        dn = self.data[self.data['Direction'] == 0]
        ax.scatter(up['Date'], up['Close'], s=12, color='green', label='Up day')
        ax.scatter(dn['Date'], dn['Close'], s=12, color='red', label='Down day')
        ax.set_title(f"Market Data – {self.ticker_var.get().upper()}")
        ax.legend()
        self.data_canvas.draw()

    # ------------------------ TRAINING (THREAD) ------------------------ #
    def start_training(self):
        if self.X_train is None:
            messagebox.showwarning("Load data", "Please load data first.")
            return
        if self.training_thread and self.training_thread.is_alive():
            messagebox.showinfo("Training", "Training already in progress…")
            return

        self.stop_event.clear()
        self.btn_train.config(state=tk.DISABLED)
        self.btn_stop.config(state=tk.NORMAL)
        self.btn_compare.config(state=tk.DISABLED)
        self.progress['value'] = 0
        self.progress['maximum'] = 100
        self._set_status("Starting training…")

        self.training_thread = threading.Thread(target=self._train_all_models_worker, daemon=True)
        self.training_thread.start()

    def request_stop(self):
        if self.training_thread and self.training_thread.is_alive():
            self.stop_event.set()
            self._set_status("Stop requested…")
        else:
            self._set_status("No training in progress.")

    def _train_all_models_worker(self):
        try:
            # 1) Update quantum params & device
            self.n_qubits = int(self.qubits_var.get())
            self.n_layers = int(self.layers_var.get())
            self.q_epochs = int(self.epochs_var.get())
            if PENNYLANE_AVAILABLE:
                self.device = qml.device("default.qubit", wires=self.n_qubits)
            else:
                self.device = None
            self._queue_msg(("status", f"Quantum params: {self.n_qubits} qubits, {self.n_layers} layers, {self.q_epochs} epochs"))
            self._queue_progress(8)

            if self.stop_event.is_set():
                self._queue_msg(("done", "Training cancelled."))
                return

            # 2) Train Quantum VQC (fast demo loop)
            self.q_model = None
            if PENNYLANE_AVAILABLE:
                self.q_model, self.q_params = self._train_quantum_vqc(
                    self.X_train, self.y_train, epochs=self.q_epochs
                )
            else:
                self.q_model, self.q_params = None, None
                self._queue_msg(("status", "PennyLane not installed – skipping quantum model."))
            self._queue_progress(35)
            if self.stop_event.is_set():
                self._queue_msg(("done", "Training cancelled."))
                return

            # 3) Initialize classical models
            rf = RandomForestClassifier(n_estimators=120, random_state=42)
            svm = SVC(kernel='rbf', probability=True, random_state=42)
            mlp = MLPClassifier(hidden_layer_sizes=(24, 12), max_iter=1000, random_state=42)
            self.models = {
                "Random Forest": rf,
                "SVM": svm,
                "Neural Network": mlp
            }
            self._queue_msg(("status", "Fitting Random Forest…"))

            # 4) Fit classical models
            rf.fit(self.X_train, self.y_train)
            self._queue_progress(60)
            if self.stop_event.is_set():
                self._queue_msg(("done", "Training cancelled."))
                return

            self._queue_msg(("status", "Fitting SVM…"))
            svm.fit(self.X_train, self.y_train)
            self._queue_progress(80)
            if self.stop_event.is_set():
                self._queue_msg(("done", "Training cancelled."))
                return

            self._queue_msg(("status", "Fitting Neural Network…"))
            mlp.fit(self.X_train, self.y_train)
            self._queue_progress(96)

            # Done!
            self._queue_msg(("done", "Training completed."))
        except Exception as e:
            self._queue_msg(("error", str(e)))

    def _train_quantum_vqc(self, X, y, epochs=6):
        # Reduce dimensionality to n_qubits by padding/truncation for AngleEmbedding
        n_qubits = self.n_qubits
        n_layers = self.n_layers

        if not PENNYLANE_AVAILABLE:
            return None, None

        @qml.qnode(self.device)
        def qnode(params, x):
            # Ensure length == n_qubits (pad with zeros or truncate)
            x = np.asarray(x, dtype=float)
            x_embed = np.zeros(n_qubits, dtype=float)
            m = min(n_qubits, x.shape[0])
            x_embed[:m] = x[:m]
            qml.AngleEmbedding(x_embed, wires=range(n_qubits))
            qml.StronglyEntanglingLayers(params, wires=range(n_qubits))
            return qml.expval(qml.PauliZ(0))

        # Initialize parameters
        params = np.random.normal(0, np.pi, (n_layers, n_qubits, 3))

        # Simple MSE loss on a tiny mini-batch for speed
        def loss_fn(params, x, y_true):
            y_hat = qnode(params, x)
            return (y_true - y_hat) ** 2

        # Optimizer
        opt = qml.AdamOptimizer(0.08)
        n = len(X)

        for ep in range(epochs):
            if self.stop_event.is_set():
                self._queue_msg(("status", "Quantum training stopped."))
                break
            # random single-sample update for speed
            idx = np.random.randint(0, n)
            xb = X[idx]
            yb = y[idx]
            params = opt.step(loss_fn, params, xb, float(yb))
            self._queue_msg(("status", f"Quantum epoch {ep+1}/{epochs}"))
            # Progress between 30 and 55 during quantum
            prog = min(55, 30 + int((ep + 1) / max(1, epochs) * 25))
            self._queue_progress(prog)

        return qnode, params

    # ------------------------ COMPARISON & PLOTS ------------------------ #
    def compare_models(self):
        try:
            results = []

            # Quantum predictions (if available)
            if self.q_model is not None and self.q_params is not None:
                q_preds = np.array([1 if self.q_model(self.q_params, x) > 0 else 0 for x in self.X_test])
                q_acc = accuracy_score(self.y_test, q_preds)
                results.append(("Quantum VQC", q_acc))

            # Classical
            for name, model in self.models.items():
                preds = model.predict(self.X_test)
                acc = accuracy_score(self.y_test, preds)
                results.append((name, acc))

            if not results:
                messagebox.showwarning("No models", "Please train models first.")
                return

            self._update_results_table(results)
            self._update_comparison_plot(results)
            self._set_status("Comparison updated.")
        except Exception as e:
            messagebox.showerror("Error", f"Comparison failed:\n{e}")

    def _init_model_tab(self):
        frame = ttk.Frame(self.tab_models)
        frame.pack(fill=tk.BOTH, expand=True, pady=8)

        self.results_tree = ttk.Treeview(frame, columns=("Model", "Accuracy"), show='headings', height=6)
        self.results_tree.heading('Model', text='Model')
        self.results_tree.heading('Accuracy', text='Accuracy')
        self.results_tree.column('Model', width=240)
        self.results_tree.column('Accuracy', width=140, anchor=tk.CENTER)
        self.results_tree.pack(fill=tk.X, padx=8)

        # Bar plot underneath
        self.comp_fig, self.comp_ax = plt.subplots(figsize=(8, 4), dpi=100)
        self.comp_canvas = FigureCanvasTkAgg(self.comp_fig, master=self.tab_models)
        self.comp_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)

    def _update_results_table(self, results):
        self.results_tree.delete(*self.results_tree.get_children())
        for name, acc in results:
            self.results_tree.insert('', 'end', values=(name, f"{acc:.2%}"))

    def _update_comparison_plot(self, results):
        ax = self.comp_ax
        ax.clear()
        names = [r[0] for r in results]
        accs = [r[1] for r in results]
        bars = ax.bar(names, accs)
        ax.set_title("Model Accuracy Comparison")
        ax.set_ylabel("Accuracy")
        ax.set_ylim(0, 1)
        for b in bars:
            h = b.get_height()
            ax.text(b.get_x() + b.get_width()/2., h, f"{h:.2%}", ha='center', va='bottom')
        self.comp_canvas.draw()

    # ------------------------ THREAD-SAFE UI HELPERS ------------------------ #
    def _queue_msg(self, item):
        self.msg_queue.put(item)

    def _queue_progress(self, val):
        self.msg_queue.put(("progress", val))

    def _set_status(self, text):
        self.status_var.set(text)
        self.status_lbl.update_idletasks()

    def _poll_queue(self):
        try:
            while True:
                item = self.msg_queue.get_nowait()
                kind = item[0]
                if kind == "status":
                    self._set_status(item[1])
                elif kind == "progress":
                    self.progress['value'] = min(100, float(item[1]))
                elif kind == "done":
                    self._set_status(item[1])
                    self.btn_stop.config(state=tk.DISABLED)
                    self.btn_train.config(state=tk.NORMAL)
                    self.btn_compare.config(state=tk.NORMAL)
                    self.progress['value'] = 100
                    # Update circuit tab text to indicate trained
                    self.circuit_ax.clear()
                    self.circuit_ax.text(0.5, 0.5, "Quantum circuit trained\n(abstracted – see code)",
                                         ha='center', va='center')
                    self.circuit_ax.axis('off')
                    self.circuit_canvas.draw()
                elif kind == "error":
                    messagebox.showerror("Training Error", item[1])
                    self.btn_stop.config(state=tk.DISABLED)
                    self.btn_train.config(state=tk.NORMAL)
                    self.btn_compare.config(state=tk.DISABLED)
                    self._set_status("Error during training.")
        except queue.Empty:
            pass
        # Continue polling
        self.root.after(120, self._poll_queue)

    def _update_quantum_params(self):
        self.n_qubits = int(self.qubits_var.get())
        self.n_layers = int(self.layers_var.get())
        self.q_epochs = int(self.epochs_var.get())
        if PENNYLANE_AVAILABLE:
            self.device = qml.device("default.qubit", wires=self.n_qubits)

    # ------------------------ RUN ------------------------ #
    def run(self):
        self.root.mainloop()


if __name__ == "__main__":
    root = tk.Tk()
    app = QuantumMarketClassifier(root)
    app.run()
