In [18]:
import tkinter as tk
from tkinter import ttk, messagebox, filedialog
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk
from RK4_models import rk4_sir, rk4_sir_demog, rk4_seir, rk4_seir_demog

class SIRModelApp:
    def __init__(self):
        self.root = tk.Tk()
        self.root.geometry("1400x900")
        self.root.title("SIR Model for Infectious Diseases")

        self.widget = None
        self.toolbar = None
        self.plot_label = None

        self._init_variables()
        self._init_frames()
        self._init_widgets()

        self.root.mainloop()

    def _init_variables(self):
        self.b = tk.DoubleVar(value=0.5)
        self.g = tk.DoubleVar(value=0.1)
        self.m = tk.DoubleVar(value=0.0)
        self.sig = tk.DoubleVar(value=0.1)
        self.days = tk.IntVar(value=30)
        self.sinitial = tk.DoubleVar(value=0.9)
        self.iinitial = tk.DoubleVar(value=0.1)
        self.einitial = tk.DoubleVar(value=0.0)
        self.statusmu = tk.IntVar(value=0)
        self.statuse = tk.IntVar(value=0)

    def _init_frames(self):
        self.left_frame = tk.Frame(self.root, width=300)
        self.right_frame = tk.Frame(self.root)
        self.button_frame = tk.Frame(self.root)

        self.left_frame.pack(side='left', fill='y', padx=10, pady=10)
        self.right_frame.pack(side='left', fill='both', expand=True, padx=10, pady=10)
        self.button_frame.pack(side='bottom', fill='x', pady=10)


    def _init_widgets(self):
        # Parameter presets dropdown
        preset_label = tk.Label(self.left_frame, text="Parameter Presets")
        preset_label.pack()
        self.preset_combo = ttk.Combobox(self.left_frame, state="readonly")
        self.preset_combo['values'] = ["Custom", "Measles", "COVID-19", "Flu"]
        self.preset_combo.current(0)
        self.preset_combo.pack()
        self.preset_combo.bind("<<ComboboxSelected>>", self._load_preset)

        # Sliders
        self.s_slider = self._create_slider(self.left_frame, 'β', self.b, 0, 2, 0.01, "Transmission coefficient")
        self.g_slider = self._create_slider(self.left_frame, 'γ', self.g, 0, 1, 0.01, "Recovery rate")
        self.m_slider = self._create_slider(self.left_frame, 'μ', self.m, 0, 0.5, 0.0001, "Birth/death rate", checkbutton=True, var=self.statusmu)
        self.sig_slider = self._create_slider(self.left_frame, 'σ', self.sig, 0, 1, 0.01, "Latent period")
        self.sinit_slider = self._create_slider(self.left_frame, 'S₀', self.sinitial, 0, 1, 0.01, "Initial susceptible")
        self.iinit_slider = self._create_slider(self.left_frame, 'I₀', self.iinitial, 0, 1, 0.01, "Initial infected")
        self.einit_slider = self._create_slider(self.left_frame, 'E₀', self.einitial, 0, 1, 0.01, "Initial exposed", checkbutton=True, var=self.statuse)
        self.days_slider = self._create_slider(self.left_frame, 'Days', self.days, 1, 100, 1, "Simulation days", is_integer=True)

        # Buttons
        self.calc_button = tk.Button(self.button_frame, text='Calculate', command=self.solve, height=2, width=10)
        self.save_button = tk.Button(self.button_frame, text='Save Plot', command=self.save_plot, height=2, width=10)
        self.quit_button = tk.Button(self.button_frame, text='Quit', command=self.root.destroy)
        self.calc_button.pack(side='left', padx=5, pady=5)
        self.save_button.pack(side='left', padx=5, pady=5)
        self.quit_button.pack(side='left', padx=5, pady=5)

    def _create_slider(self, parent, label, variable, frm, to, res, info, checkbutton=False, var=None, is_integer=False):
        frame = tk.Frame(parent)
        info_button = tk.Button(frame, text='?', command=lambda: messagebox.showinfo(f'{label} info', info))
        text_label = tk.Label(frame, text=label)
        slider = tk.Scale(frame, from_=frm, to=to, resolution=res, orient='horizontal', variable=variable)
        info_button.pack(side='left')
        text_label.pack(side='left')
        slider.pack(side='left')
        if checkbutton and var is not None:
            check = tk.Checkbutton(frame, text=f"Enable {label}", variable=var)
            check.pack(side='left')
        frame.pack(side='top')
        return slider

    def _load_preset(self, event=None):
        preset = self.preset_combo.get()
        if preset == "Measles":
            self.b.set(1.5)
            self.g.set(0.1)
            self.m.set(0.0)
            self.sig.set(0.2)
            self.sinitial.set(0.95)
            self.iinitial.set(0.05)
            self.einitial.set(0.0)
            self.statusmu.set(0)
            self.statuse.set(0)
            self.days.set(60)
        elif preset == "COVID-19":
            self.b.set(0.45)
            self.g.set(0.1)
            self.m.set(0.00003)
            self.sig.set(0.2)
            self.sinitial.set(0.989)
            self.iinitial.set(0.01)
            self.einitial.set(0.001)
            self.statusmu.set(1)
            self.statuse.set(1)
            self.days.set(100)
        elif preset == "Flu":
            self.b.set(0.3)
            self.g.set(0.2)
            self.m.set(0.00003)
            self.sig.set(0.33)
            self.sinitial.set(0.97)
            self.iinitial.set(0.02)
            self.einitial.set(0.005)
            self.statusmu.set(1)
            self.statuse.set(1)
            self.days.set(60)
        else:
            # Custom preset: do nothing
            pass

    def solve(self):
        # Remove previous plot
        if self.widget:
            self.widget.destroy()
            self.widget = None
        if self.toolbar is not None:
            self.toolbar.destroy()
            self.toolbar = None
        if self.plot_label is not None:
            self.plot_label.destroy()
            self.plot_label = None

        n = self.days.get()
        dt = 0.25 if n < 30 else 0.5 if n < 70 else 1

        s0 = self.sinitial.get()
        i0 = self.iinitial.get()
        e0 = self.einitial.get() if self.statuse.get() else 0

        total = s0 + i0 + e0
        if total > 1:
            messagebox.showerror("Input Error", f"S₀ + I₀ + E₀ = {total:.2f} > 1. Adjust sliders so total ≤ 1.")
            return
        beta = self.b.get()
        gamma = self.g.get()
        mu = self.m.get()
        sigma = self.sig.get()
        statusmu = self.statusmu.get()
        statuse = self.statuse.get()
        r0 = 1 - s0 - i0 - e0

        if r0 < 0:
            messagebox.showerror("Input Error", "Initial population percentages exceed 100%.")
            return

        fig = Figure(figsize=(7, 6), dpi=100)
        ax = fig.add_subplot(111)
        t = list(range(n + 1))

        if statusmu and statuse:
            s, e, i, r = rk4_seir_demog(n, beta, gamma, mu, sigma, s0, e0, i0, r0, dt)
            ax.plot(t, s, 'r', label='Susceptible')
            ax.plot(t, e, 'k', label='Exposed')
            ax.plot(t, i, 'b', label='Infected')
            ax.plot(t, r, 'g', label='Recovered')
            label = 'SEIR with Demography'
            cumulative = np.cumsum(i)
        elif not statusmu and statuse:
            s, e, i, r = rk4_seir(n, beta, gamma, sigma, s0, e0, i0, r0, dt)
            ax.plot(t, s, 'r', label='Susceptible')
            ax.plot(t, e, 'k', label='Exposed')
            ax.plot(t, i, 'b', label='Infected')
            ax.plot(t, r, 'g', label='Recovered')
            label = 'SEIR without Demography'
            cumulative = np.cumsum(i)
        elif statusmu and not statuse:
            s, i, r = rk4_sir_demog(n, beta, gamma, mu, s0, i0, r0, dt)
            ax.plot(t, s, 'r', label='Susceptible')
            ax.plot(t, i, 'b', label='Infected')
            ax.plot(t, r, 'g', label='Recovered')
            label = 'SIR with Demography'
            cumulative = np.cumsum(i)
        else:
            s, i, r = rk4_sir(n, beta, gamma, s0, i0, r0, dt)
            ax.plot(t, s, 'r', label='Susceptible')
            ax.plot(t, i, 'b', label='Infected')
            ax.plot(t, r, 'g', label='Recovered')
            label = 'SIR without Demography'
            cumulative = np.cumsum(i)

        # Plot cumulative infections as dashed line on secondary y-axis
        ax2 = ax.twinx()
        ax2.plot(t, cumulative, 'm--', label='Cumulative Infected')
        ax2.set_ylabel('Cumulative Infections', color='m')
        ax2.tick_params(axis='y', colors='m')

        ax.set_title(label)
        ax.set_xlabel("Days")
        ax.set_ylabel("Population Fraction")
        ax.legend(loc='upper left')
        ax2.legend(loc='upper right')

        self.plot_label = tk.Label(self.right_frame, text=label)
        self.plot_label.pack()

        canvas = FigureCanvasTkAgg(fig, master=self.right_frame)
        self.toolbar = NavigationToolbar2Tk(canvas, self.right_frame)
        self.toolbar.update()
        self.widget = canvas.get_tk_widget()

        canvas.draw()
        self.widget.pack()

        # Save fig reference for saving later
        self.fig = fig

    def save_plot(self):
        if hasattr(self, "fig"):
            file_path = filedialog.asksaveasfilename(defaultextension=".png",
                                                     filetypes=[("PNG files", "*.png"), ("All files", "*.*")])
            if file_path:
                self.fig.savefig(file_path)
                messagebox.showinfo("Save Plot", f"Plot saved to {file_path}")
        else:
            messagebox.showwarning("Save Plot", "No plot to save. Please run a simulation first.")

if __name__ == '__main__':
    app = SIRModelApp()
