In [45]:
import tkinter as tk
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk

class tkApp:
    def __init__(self):
        #---------------initialize frame-------------
        self.root=tk.Tk()
        self.root.geometry("1200x800")
        self.root.title("SIR Model for Infectious Diseases")
        self.root.update()

        self.left_frame=tk.Frame(master=self.root)
        self.right_frame=tk.Frame(master=self.root)
        self.button_frame=tk.Frame(master=self.root)
        
        #self.plot_label=tk.Label(master=self.right_frame,
                                 #text='Solved SIR Model')
        #self.plot_label.pack()

        self.widget=None
        self.toolbar=None
        self.plot_label=None
        self.sigma_label=None
        self.sigma_slider=None

        #-------------4th order Runge Kutta method for solving ODEs---------
        def solve(self):
            if self.widget:
                  self.widget.destroy()
            if self.toolbar:
                  self.toolbar.destroy()
            if self.plot_label:
                self.plot_label.destroy()
            i_str=iinitial.get()
            i0=float(i_str)
            s_str=sinitial.get()
            s0=float(s_str)
            e_str=einitial.get()
            e0=float(e_str)
            beta_str=b.get()
            beta=float(beta_str)
            mu_str=m.get()
            mu=float(mu_str)
            gamma_str=g.get()
            gamma=float(gamma_str)
            n=100
            dt=0.5

            # susceptible ode
            def dsdt(t,s,i):
                return(-beta*s*i)

            # infected ode
            def didt(t,s,i):
                return(beta*s*i-gamma*i)

            #basic SIR model  w/o demography
            def RK4SIR(n,beta, gamma,s0,i0,r0,dt):
                #initialize the arrays for solutions
                S=[s0]+[0]*n
                I=[i0]+[0]*n
                R=[r0]+[0]*n
            
                #calculate each step using RK4
                for i in range(0,n):
                    Si=S[i]
                    Ii=I[i]
                

                    sk1=dsdt(i,Si,Ii)
                    ik1=didt(i,Si,Ii)
        

                    sk2=dsdt(i+dt/2,Si+dt/2*sk1,Ii+dt/2*ik1)
                    ik2=didt(i+dt/2,Si+dt/2*sk1,Ii+dt/2*ik1)

                    sk3=dsdt(i+dt/2,Si+dt/2*sk2,Ii+dt/2*ik2)
                    ik3=didt(i+dt/2,Si+dt/2*sk2,Ii+dt/2*ik2)

                    sk4=dsdt(i+dt,Si+dt*sk3,Ii+dt*ik3)
                    ik4=didt(i+dt,Si+dt*sk3,Ii+dt*ik3)

                    S[i+1]=Si+dt/6*(sk1+2*sk2+2*sk3+sk4)
                    I[i+1]=Ii+dt/6*(ik1+2*ik2+2*ik3+ik4)

                #caculate recovered array
                for i in range(len(R)):
                    R[i]=1-I[i]-S[i]

                return S,I,R
            
            # Susceptible function for with demography
            def dsddt(t,s,i):
                return(mu-beta*s*i-mu*s)

            #infected function for with demography
            def diddt(t,s,i):
                return(beta*s*i-gamma*i-mu*i)

            #main function for 4th order Runge-Kutta for SIR model with demography
            def RK4SIRdemog(n,beta, gamma,mu,s0,i0,r0,dt):
                #initialize the arrays for solutions
                S=[s0]+[0]*n
                I=[i0]+[0]*n
                R=[r0]+[0]*n
                #need to add in a break if s0+i0+r0/=1

                #calculate each step using RK4
                for i in range(0,n):
                    Si=S[i]
                    Ii=I[i]

                    sk1=dsddt(i,Si,Ii)
                    ik1=diddt(i,Si,Ii)
        

                    sk2=dsddt(i+dt/2,Si+dt/2*sk1,Ii+dt/2*ik1)
                    ik2=diddt(i+dt/2,Si+dt/2*sk1,Ii+dt/2*ik1)

                    sk3=dsddt(i+dt/2,Si+dt/2*sk2,Ii+dt/2*ik2)
                    ik3=diddt(i+dt/2,Si+dt/2*sk2,Ii+dt/2*ik2)

                    sk4=dsddt(i+dt,Si+dt*sk3,Ii+dt*ik3)
                    ik4=diddt(i+dt,Si+dt*sk3,Ii+dt*ik3)

                    S[i+1]=Si+dt/6*(sk1+2*sk2+2*sk3+sk4)
                    I[i+1]=Ii+dt/6*(ik1+2*ik2+2*ik3+ik4)

                #caculate recovered array
                for i in range(len(R)):
                    R[i]=1-I[i]-S[i]

                return S,I,R
            def dsedt(t,s,e,i):
                return(-beta*s*i)

            def deedt(t,s,e,i):
                return(beta*s*i-sigma*e)

            def diedt(t,s,e,i):
                return(sigma*e-gamma*i)

            def RK4SEIR(n,beta,gamma,sigma,s0,e0,i0,r0,dt):
                S=[s0]+[0]*n
                E=[e0]+[0]*n
                I=[i0]+[0]*n
                R=[r0]+[0]*n

                for i in range(0,n):
                    Si=S[i]
                    Ei=E[i]
                    Ii=I[i]

                    sk1=dsedt(i,Si,Ei,Ii)
                    ek1=deedt(i,Si,Ei,Ii)
                    ik1=diedt(i,Si,Ei,Ii)

                    sk2=dsedt(i+dt/2,Si+dt/2*sk1,Ei+dt/2*ek1,Ii+dt/2*ik1)
                    ek2=deedt(i+dt/2,Si+dt/2*sk1,Ei+dt/2*ek1,Ii+dt/2*ik1)
                    ik2=diedt(i+dt/2,Si+dt/2*sk1,Ei+dt/2*ek1,Ii+dt/2*ik1)

                    sk3=dsedt(i+dt/2,Si+dt/2*sk2,Ei+dt/2*ek2,Ii+dt/2*ik2)
                    ek3=deedt(i+dt/2,Si+dt/2*sk2,Ei+dt/2*ek2,Ii+dt/2*ik2)
                    ik3=diedt(i+dt/2,Si+dt/2*sk2,Ei+dt/2*ek2,Ii+dt/2*ik2)

                    sk4=dsedt(i+dt,Si+dt*sk3,Ei+dt*ek3,Ii+dt*ik3)
                    ek4=deedt(i+dt,Si+dt*sk3,Ei+dt*ek3,Ii+dt*ik3)
                    ik4=diedt(i+dt,Si+dt*sk3,Ei+dt*ek3,Ii+dt*ik3)

                    S[i+1]=Si+dt/6*(sk1+2*sk2+2*sk3+sk4)
                    E[i+1]=Ei+dt/6*(ek1+2*ek2+2*ek3+ek4)
                    I[i+1]=Ii+dt/6*(ik1+2*ik2+2*ik3+ik4)
                for i in range(len(R)):
                    R[i]=1-S[i]-E[i]-I[i]
                return S,E,I,R
            
            def dseddt(t,s,e,i):
                return(mu-(beta*i+mu)*s)

            def deeddt(t,s,e,i):
                return(beta*s*i-(mu+sigma)*e)

            def dieddt(t,s,e,i):
                return(sigma*e-(mu+gamma)*i)

            def RK4SEIRdemog(n,beta,gamma,mu,sigma,s0,e0,i0,r0,dt):
                S=[s0]+[0]*n
                E=[e0]+[0]*n
                I=[i0]+[0]*n
                R=[r0]+[0]*n

                for i in range(0,n):
                    Si=S[i]
                    Ei=E[i]
                    Ii=I[i]

                    sk1=dseddt(i,Si,Ei,Ii)
                    ek1=deeddt(i,Si,Ei,Ii)
                    ik1=dieddt(i,Si,Ei,Ii)

                    sk2=dseddt(i+dt/2,Si+dt/2*sk1,Ei+dt/2*ek1,Ii+dt/2*ik1)
                    ek2=deeddt(i+dt/2,Si+dt/2*sk1,Ei+dt/2*ek1,Ii+dt/2*ik1)
                    ik2=dieddt(i+dt/2,Si+dt/2*sk1,Ei+dt/2*ek1,Ii+dt/2*ik1)

                    sk3=dseddt(i+dt/2,Si+dt/2*sk2,Ei+dt/2*ek2,Ii+dt/2*ik2)
                    ek3=deeddt(i+dt/2,Si+dt/2*sk2,Ei+dt/2*ek2,Ii+dt/2*ik2)
                    ik3=dieddt(i+dt/2,Si+dt/2*sk2,Ei+dt/2*ek2,Ii+dt/2*ik2)

                    sk4=dseddt(i+dt,Si+dt*sk3,Ei+dt*ek3,Ii+dt*ik3)
                    ek4=deeddt(i+dt,Si+dt*sk3,Ei+dt*ek3,Ii+dt*ik3)
                    ik4=dieddt(i+dt,Si+dt*sk3,Ei+dt*ek3,Ii+dt*ik3)

                    S[i+1]=Si+dt/6*(sk1+2*sk2+2*sk3+sk4)
                    E[i+1]=Ei+dt/6*(ek1+2*ek2+2*ek3+ek4)
                    I[i+1]=Ii+dt/6*(ik1+2*ik2+2*ik3+ik4)
                for i in range(len(R)):
                    R[i]=1-S[i]-E[i]-I[i]
                return S,E,I,R

            fig=Figure(figsize=(5,5), dpi=100)
            plot1=fig.subplots()
            plot1.clear()
            t=list(range(0,n+1))

            #solve basic SIR model and plot    
            if statusmu==1 and statuse==0:
                r0=1-s0-i0
                s,i,r=RK4SIRdemog(n,beta,gamma,mu,s0,i0,r0,dt)
                self.plot_label=tk.Label(master=self.right_frame,
                                    text='Solved SIR with demography')
                self.plot_label.pack()
                plot1.plot(t,s,'r',label='susceptible')
                plot1.plot(t,i,'b',label='infected')
                plot1.plot(t,r,'g',label='recovered')
                
            elif statusmu==0 and statuse==0:
                r0=1-s0-i0
                s,i,r=RK4SIR(n,beta,gamma,s0,i0,r0,dt)
                self.plot_label=tk.Label(master=self.right_frame,
                                     text='Solved SIR without demography')
                self.plot_label.pack()
                plot1.plot(t,s,'r',label='susceptible')
                plot1.plot(t,i,'b',label='infected')
                plot1.plot(t,r,'g',label='recovered')
                
            elif statusmu==1 and statuse==1:
                r0=1-s0-i0-e0
                sigma_str=sig.get()
                sigma=float(sigma_str)
                s,e,i,r=RK4SEIRdemog(n,beta,gamma,mu, sigma,s0,e0,i0,r0,dt)
                self.plot_label=tk.Label(master=self.right_frame,
                                    text='Solved SEIR with demography')
                self.plot_label.pack()
                plot1.plot(t,s,'r',label='susceptible')
                plot1.plot(t,e,'k',label='exposed')
                plot1.plot(t,i,'b',label='infected')
                plot1.plot(t,r,'g',label='recovered')
                
            elif statusmu==0 and statuse==1:
                r0=1-s0-i0-e0
                sigma_str=sig.get()
                sigma=float(sigma_str)
                s,e,i,r=RK4SEIR(n,beta,gamma,sigma,s0,e0,i0,r0,dt)
                self.plot_label=tk.Label(master=self.right_frame,
                                    text='Solved SEIR without demography')
                self.plot_label.pack()
                plot1.plot(t,s,'r',label='susceptible')
                plot1.plot(t,e,'k',label='exposed')
                plot1.plot(t,i,'b',label='infected')
                plot1.plot(t,r,'g',label='recovered')
                
            plot1.legend()
            canvas=FigureCanvasTkAgg(fig, master=self.right_frame)
            canvas.draw()

            self.toolbar=NavigationToolbar2Tk(canvas, self.right_frame)
            self.toolbar.update()

            self.widget=canvas.get_tk_widget()
            self.widget.pack()
        
        
        def checkmustat():
            global statusmu
            statusmu=checkmu.get() 
            
        def checkestat():
            global statuse
            statuse=checke.get()
        

        #----------initialize variables----------
        b=tk.StringVar()
        g=tk.StringVar()
        m=tk.StringVar()
        sig=tk.StringVar()
        sinitial=tk.StringVar()
        iinitial=tk.StringVar()
        checkmu=tk.IntVar()
        einitial=tk.StringVar()
        checke=tk.IntVar()
        

        #-------------left frame sub-frames--------
        self.beta_frame=tk.Frame(master=self.left_frame)
        self.gamma_frame=tk.Frame(master=self.left_frame)
        self.mu_frame=tk.Frame(master=self.left_frame)
        self.S_frame=tk.Frame(master=self.left_frame)
        self.I_frame=tk.Frame(master=self.left_frame)
        self.E_frame=tk.Frame(master=self.left_frame)
        self.sigma_frame=tk.Frame(master=self.left_frame)
        #self.button_frame=tk.Frame(master=self.left_frame)

        #beta frame
        self.beta_label=tk.Label(master=self.beta_frame, 
                                 text='beta')
        self.beta_slider=tk.Scale(master=self.beta_frame, 
                                  from_=0,
                                  to=2, 
                                  resolution=0.1,
                                  orient=tk.HORIZONTAL, 
                                  variable=b)
        self.beta_label.pack(side='left')
        self.beta_slider.pack(side='left')

        #gamma frame
        self.gamma_label=tk.Label(self.gamma_frame, 
                                  text='gamma')
        self.gamma_slider=tk.Scale(self.gamma_frame, 
                                   from_=0, 
                                   to=1, 
                                   resolution=0.1, 
                                   orient=tk.HORIZONTAL, 
                                   variable=g)
        self.gamma_label.pack(side='left')
        self.gamma_slider.pack(side='left')
        
        #mu frame
        self.mu_label=tk.Label(self.mu_frame,
                               text='mu')
        self.mu_slider=tk.Scale(self.mu_frame,
                                from_=0,
                                to=1,
                                resolution=0.00001,
                                orient=tk.HORIZONTAL,
                                variable=m)
        self.mu_checkbutton=tk.Checkbutton(self.mu_frame,
                                           text='Check for with demography',
                                           variable=checkmu,
                                           command=checkmustat)
        
        self.mu_label.pack(side='left')
        self.mu_slider.pack(side='left')
        self.mu_checkbutton.pack(side='left')
      

        #S0 frame
        self.S_label=tk.Label(self.S_frame, 
                              text='S0')
        self.S_slider=tk.Scale(self.S_frame, 
                               from_=0, 
                               to=1, 
                               resolution=0.01, 
                               orient=tk.HORIZONTAL, 
                               variable=sinitial)
        self.S_label.pack(side='left')
        self.S_slider.pack(side='left')

         #I0 frame
        self.I_label=tk.Label(self.I_frame, 
                              text='I0')
        self.I_slider=tk.Scale(self.I_frame, 
                               from_=0, 
                               to=1, 
                               resolution=0.01, 
                               orient=tk.HORIZONTAL, 
                               variable=iinitial)
        self.I_label.pack(side='left')
        self.I_slider.pack(side='left')

        #E0 frame
        self.E_label=tk.Label(self.E_frame, 
                              text='E0')
        self.E_slider=tk.Scale(self.E_frame, 
                               from_=0, 
                               to=1, 
                               resolution=0.01, 
                               orient=tk.HORIZONTAL, 
                               variable=einitial)
        self.E_checkbutton=tk.Checkbutton(self.E_frame,
                                           text='Check for with Exposed',
                                           variable=checke,
                                           command=checkestat)
        self.E_label.pack(side='left')
        self.E_slider.pack(side='left')
        self.E_checkbutton.pack(side='left')

        
        self.sigma_label=tk.Label(master=self.sigma_frame,
                                      text='sigma')
        self.sigma_slider=tk.Scale(master=self.sigma_frame,
                                       from_=0,
                                       to=1,
                                       resolution=.01,
                                       orient=tk.HORIZONTAL,
                                       variable=sig)
        self.sigma_label.pack(side='left')
        self.sigma_slider.pack(side='left')

        #button widgets
        self.calc_button=tk.Button(self.button_frame, 
                                   text='Calculate', 
                                   command=lambda: [solve(self)], 
                                   height=2, 
                                   width=10)
        self.quit_button=tk.Button(self.button_frame, 
                                   text='Quit', 
                                   command=self.root.destroy)
        self.calc_button.pack(side='left')
        self.quit_button.pack(side='left')

        #--------pack frames----------
        self.beta_frame.pack(side='top')
        self.gamma_frame.pack(side='top')
        self.mu_frame.pack(side='top')
        self.S_frame.pack(side='top')
        self.I_frame.pack(side='top')
        self.E_frame.pack(side='top')
        self.sigma_frame.pack(side='top')
        #self.button_frame.pack(side='top')

        self.left_frame.pack(side='left')
        self.right_frame.pack(side='left')
        self.button_frame.pack(side='bottom')



        self.root.mainloop()


#-----------call class to run GUI-----------
if __name__=='__main__':
    TK_Window=tkApp()