In [7]:
#%%writefile wavesolverInPML.py
import numpy as np
import time
from scipy import weave
from scipy.weave import converters

class WaveSolverInPML(object):
    
    """
    WaveSolver       Class calculates the wave equation in a homogeneous medium
    with a prefectly matched layer
    nx:              size of square calculation grid
    wavelength:      wavelength of the wave
    CFL:             CFL parameter dt*c/dx should be smaller than 1/sqrt(2)
    sim_duration:    duration of simulation
    pml_length:      size of PML layer
    plotcallback:    callback function for plotting/saving the result
                     the callback receives two parameters self.u and time
    output:          number of outputs per period of the source
    slowdown:        sleep slowdown secondes between after the calculations of output
    rho1, c1:        density and speed of sound of background phase
    rho2, c2:        density and speed of sound of scatterers
    src_cycles:      duration of the src in cycles (1/frequency)
    
    Following the numerical scheme derived in http://arxiv.org/pdf/1001.0319v1.pdf
    Marcus J. Grote and Imbo Sim for the PML, accounts for inhomogeneous medium (rho & c)
    """
    
    def __init__(self, nx=128, ny=128, wavelength=.3, CFL=0.1,\
                 sim_duration=2., output=5, slowdown=0, plotcallback=None,\
                 pml_length=16, rho1=1., c1=1., rho2=5., c2=1., src_cycles=1, scr1pos=-1, scr2pos=-1):
        
        self.__pml_length = pml_length
        self.__nx = nx+2*self.pml_length
        self.__ny = ny+2*self.pml_length
        self.__size = 2. #excluding the PML
        self.__xsize = 2.
        self.__ysize = self.__xsize/nx*ny
        self.__dx = self.__xsize/(self.__nx-2*self.pml_length-1)
        self.__c1 = c1
        self.__rho1 = rho1
        self.__c2 = c2
        self.__rho2 = rho2
        self.__sim_duration = sim_duration
        self.__slowdown = slowdown
        self.wavelength = wavelength
        self.__u  = np.zeros((self.__nx,self.__ny)) #amplitude at t
        self.__un = np.zeros((self.__nx,self.__ny)) #amplitude at t-dt
        self.__unn = np.zeros((self.__nx,self.__ny)) #amplitude at t-2*dt
        self.__rho = np.ones((self.__nx,self.__ny))*self.__rho1
        self.__cc = np.ones((self.__nx,self.__ny))*self.__c1*self.__c1
        self.CFL = CFL
        
        #PML variables
        self.__zeta1 = np.zeros(self.__nx)
        self.__zeta2 = np.zeros(self.__ny)
        self.__phix = np.zeros((self.__nx,self.__ny))
        self.__phiy = np.zeros((self.__nx,self.__ny))
        self.__phixn = np.zeros((self.__nx,self.__ny))
        self.__phiyn = np.zeros((self.__nx,self.__ny))
        self.__zetam = 20.
        
        #Screen
        self.__scr1 = np.zeros(self.__ny)
        self.__scr1pos = scr1pos+pml_length
        self.__scr2 = np.zeros(self.__ny)
        self.__scr2pos = scr2pos+pml_length
        
        self.output = output 
        self.__n = 0
        self.__src_function = None
        if src_cycles is 1:
            self.__src_timefunction = self.__gausspulse
            self.__src_duration = (1./self.__nu+self.__src_t0)*1.1
        else:
            self.__src_timefunction = self.__sinetrain
            self.__src_tmax=src_cycles/self.__nu+self.__src_t0           
            self.__src_duration = (self.__src_t0+src_cycles/self.__nu+1./self.__nu*0.1)
       
        if plotcallback is not None:
            self.__plotcallback = plotcallback
        else:
            self.__plotcallback = self.simplecallback

        self.__absorber_y = []
        
        self.__init_pml()

    def __init_pml(self):
        #setup the PML using a 2-times differentiable expression
        for i in range(0,self.pml_length):
            xl=float(self.pml_length-i)/self.pml_length
            self.__zeta1[i]=self.__zetam*\
                (xl-np.sin(2.*np.pi*xl)/2./np.pi)
            self.__zeta1[self.__nx-i-1]=self.__zetam*\
                (xl-np.sin(2.*np.pi*xl)/2./np.pi)
            self.__zeta2[i]=self.__zetam*\
                (xl-np.sin(2.*np.pi*xl)/2./np.pi)
            self.__zeta2[self.__ny-i-1]=self.__zetam*\
                (xl-np.sin(2.*np.pi*xl)/2./np.pi)
    
    def simplecallback(self,u,t):
        """Prints the current time"""
        print "time {0:.2}".format(t)
    
    def cfg_simple(self):
        #Source
        self.__emissionlength=self.__nx-2*self.pml_length
        self.__src_emissionlength = self.__nx-2.*self.pml_length
        self.__src_starty = int(self.__nx/2.-self.__src_emissionlength/2.)
        self.__src_function = self.__planesource
        #Timestepper
        self.__timestepper = self.__inhomogeneous_PML_stable

    def cfg_simplediffraction(self, disks=None, rects=None):

        self.__emissionlength=self.__nx-2*self.pml_length
        self.__src_emissionlength = self.__nx-2.*self.pml_length
        self.__src_starty = int(self.__nx/2.-self.__src_emissionlength/2.)
        self.__src_posx = self.pml_length
        self.__src_function = self.__planesource
        #need to account for PML size
        add_pml=self.pml_length*self.__xsize/(self.__nx-2.*self.pml_length)
        rx = self.__xsize/(self.__nx-2*self.pml_length)*self.__nx
        ry = self.__ysize/(self.__ny-2*self.pml_length)*self.__ny
        yy, xx = np.meshgrid(np.linspace(0., rx, self.__nx),
                             np.linspace(0., ry, self.__ny))
        #Discs
        if disks is not None:
            disks['xpos'][:] += add_pml
            disks['ypos'][:] += add_pml
            #change impedance
            for i in range(disks.size):
                a=np.where((xx-disks['xpos'][i])**2.+
                              (yy-disks['ypos'][i])**2.<(disks['radius'][i]**2.))
                self.__cc[a]=self.__c2*self.__c2
                self.__rho[a]=self.__rho2
        #Rectangles
        if rects is not None:
            rects['x_up'][:] += add_pml
            rects['y_up'][:] += add_pml
            for i in range(rects.size):
                a=np.where((xx>=rects['x_up'][i]) & (xx<=rects['x_up']+rects['width'][i]) 
                           & (yy<=rects['y_up'][i]) & (yy>=rects['y_up'][i]-rects['height'][i]))
                self.__cc[a]=self.__c2*self.__c2
                self.__rho[a]=self.__rho2     
        #update CFL
        self.CFL=self.CFL          
        #Timestepper
        self.__timestepper = self.__inhomogeneous_PML_stable

    def cfg_halfslit(self, ypos=.5, pos=0.3):
        """
        Half slit experiment, where border is oriented along y-axis,
        and the wave is coming from the left. The homogeneous wave
        equation is solved.
        ypos: y-position in real coordinates [0;y size]
        pos: x-position in real coordinates [0;size]
        """
        #Source
        self.__src_posx = self.pml_length
        self.__src_emissionlength = int(.95*(self.__ny-2.*self.pml_length))
        self.__src_starty = int(self.__ny/2.-self.__src_emissionlength/2.)
        self.__src_function = self.__planesource
        #Absorber
        yy=np.linspace(0., self.__ysize, self.__ny-2*self.pml_length)
        self.__absorber_y = np.asarray(np.where(yy<ypos)[0])+self.pml_length
        self.__absorber_x = int(pos/self.__xsize*(self.__nx-2*self.pml_length))+self.pml_length 
        #Timestepper
        self.__timestepper = self.__inhomogeneous_PML_stable            
        
    def cfg_singleslit(self, ow=0.1, pos=0.3):
        """
        Single slit experiment, where slit is oriented along y-axis,
        and the wave is coming from the left. The homogeneous wave
        equation is solved.
        ow:  width of the slit
        pos: x-position in real coordinates [0;size]
        """
        #Source
        self.__src_posx = self.pml_length
        self.__src_emissionlength = int(.9*(self.__ny-2.*self.pml_length))
        
        
        self.__src_starty = int(self.__ny/2.-self.__src_emissionlength/2.)-1
        self.__src_function = self.__planesource
        #Absorber
        yy=np.linspace(0., self.__ysize, self.__ny-2*self.pml_length)
        self.__absorber_y = np.asarray(np.where((yy>self.__ysize/2.+ow/2.)|\
                          (yy<self.__ysize/2.-ow/2.))[0])+self.pml_length
        self.__absorber_x = int(pos/self.__xsize*(self.__nx-2*self.pml_length))+self.pml_length 
        #Timestepper
        self.__timestepper = self.__inhomogeneous_PML_stable      
             
    def cfg_doubleslit(self, ow=0.1, dw=1., pos=0.3):
        """
        Double slit experiment, where the slits are oriented along y-axis,
        and the wave is coming from the left. The homogeneous wave
        equation is solved.
        ow:  width of the slits
        dw: distance between the slits' centers 
        pos: x-position between [0;size]
        """
        #Source
        self.__src_posx = self.pml_length
        self.__src_emissionlength = int(.9*(self.__ny-2.*self.pml_length))
        self.__src_starty = int(self.__ny/2.-self.__src_emissionlength/2.)
        self.__src_function = self.__planesource
        #Absorber
        yy = np.linspace(0., self.__ysize, self.__ny-2*self.pml_length)
        self.__absorber_y = np.asarray(np.where((yy>self.__ysize/2.+dw/2.+ow/2.)|\
            (yy<self.__ysize/2.-dw/2.-ow/2.)|(yy>self.__ysize/2.-dw/2.+ow/2.) &\
            (yy<self.__ysize/2.+dw/2.-ow/2.))[0])+self.pml_length
        self.__absorber_x = int(pos/self.__xsize*(self.__nx-2*self.pml_length))+self.pml_length   
        #Timestepper
        self.__timestepper = self.__inhomogeneous_PML_stable

    def cfg_pointsources(self, points):
        #Source
        self.__src_posxa = np.asarray(map(int,points['xpos']/self.__xsize*\
                                          (self.__nx-2*self.pml_length)))+self.pml_length
        self.__src_posya = np.asarray(map(int,points['ypos']/self.__ysize*
                                          (self.__ny-2*self.pml_length)))+self.pml_length
        self.__src_ampla = np.asarray(points['ampl'])    
        self.__src_function = self.__pointsource
        #Timestepper
        self.__timestepper = self.__inhomogeneous_PML_stable
    
    def __pointsource(self, t):
        if t<self.__src_duration:
            self.__un[self.__src_posxa, self.__src_posya]=\
                self.__src_ampla*self.__src_timefunction(t)
                 
    def __planesource(self, t):
        if t<self.__src_duration:
            self.__un[self.__src_posx, self.__src_starty:self.__src_starty+self.__src_emissionlength]=\
                self.__src_timefunction(t)

    def __gausspulse(self, t):
        bw= 1.4
        bwr=-6
        ref = 10.0**(bwr / 20.0)
        a = -(np.pi * self.__nu * bw) ** 2 / (4.0 * np.log(ref))
        return -np.exp(-a *(t-self.__src_t0)**2.)*np.sin(2.*np.pi*self.__nu*(t-self.__src_t0))

    def __sinetrain(self, t):
        slp=10.*self.__nu #how fast the pulse rises
        return -np.cos(2.*np.pi*self.__nu*(t-self.__src_t0))/4.\
               *(1-np.tanh((t-self.__src_tmax)*slp))*(1.+np.tanh((t-self.__src_t0)*slp))

    @property
    def pml_length(self):
        """size of PML layer"""
        return self.__pml_length
    
    @property
    def u(self):
        """Wave on the grid"""
        return self.__u
    
    @property
    def CFL(self):
        """CFL parameter (see class decription)"""
        return self.__CFL
    
    @property
    def dt(self):
        """Length of each timestep"""
        return self.__dt

    @property
    def nx(self):
        """X-size of the computational domain without PML"""
        return self.__nx-2*self.pml_length
 
    @property
    def ny(self):
        """Y-size of the computational domain without PML"""
        return self.__ny-2*self.pml_length
    
    @property
    def nu(self):
        """Frequency"""
        return self.__nu
    
    @property
    def omega(self):
        """Angular frequency"""
        return 2.*self.__nu*np.pi
    
    @property
    def nt(self):
        """Number of timesteps"""
        return int(self.__sim_duration/self.__dt)
    
    @property
    def wavelength(self):
        """Wavelength"""
        return self.__wavelength

    @property
    def c1(self):
        """Background speed of sound"""
        return self.__c1
    
    @property
    def nu(self):
        """Frequency"""
        return self.__nu
    
    @property
    def output(self):
        """List of output timesteps using the plotcallback function"""
        return self.__output
    
    @property
    def dx(self):
        """grid spacing"""
        return self.__dx

    @property
    def sizex(self):
        """grid size in x-direction, should be 2."""
        return self.__xsize

    @property
    def sizey(self):
        """grid size in y-direction"""
        return self.__ysize
    
    @property
    def scr1(self): 
        """screen variable, integrated intensity"""
        return self.__scr1[self.__pml_length:-self.__pml_length]

    @property
    def scr2(self): 
        """screen variable, integrated intensity"""
        return self.__scr2[self.__pml_length:-self.__pml_length]

    @output.setter
    def output(self, output):
        self.__output=int(1./self.nu/self.__dt/output)
            
    @CFL.setter
    def CFL(self, CFL):
        self.__CFL = CFL  #CFL number < 1/sqrt(2)
        self.__dt = CFL*self.__dx/self.__cc.max()**.5
        self.__nt = int(self.__sim_duration/self.__dt)
    
    @wavelength.setter
    def wavelength(self, wavelength):
        self.__wavelength = wavelength
        self.__nu=self.__c1/self.__wavelength
        self.__src_t0=1./self.__nu
    
    def __inhomogeneous_PML_stable(self):   
        u = self.__u
        un = self.__un
        unn = self.__unn
        dt = float(self.__dt)
        dx = self.__dx
        nx = self.__nx
        ny = self.__ny
        zeta1 = self.__zeta1
        zeta2 = self.__zeta2
        cc = self.__cc
        rho = self.__rho
        phix = self.__phix
        phiy = self.__phiy
        phixn = self.__phixn
        phiyn = self.__phiyn
        
        code = """
            for (int i=1; i<nx-1; ++i) {
                for (int j=1; j<ny-1; ++j) {
                    u(i,j) = 1./(1./(dt*dt)+(zeta1(i)+zeta2(j))/(2.*dt))*
                            ( un(i,j)*(2./dt/dt-zeta1(i)*zeta2(j))
                             +unn(i,j)*((zeta1(i)+zeta2(j))/2./dt-1./dt/dt)
                             +cc(i,j)/dx/dx*((un(i-1,j)+un(i+1,j)+un(i,j-1)+un(i,j+1)-4.*un(i,j))
                                             -0.25/rho(i,j)*
                                             ((rho(i+1,j)-rho(i-1,j))*(un(i+1,j)-un(i-1,j))
                                             +(rho(i,j+1)-rho(i,j-1))*(un(i,j+1)-un(i,j-1))))
                             +0.5/dx*( phixn(i,j-1)+phixn(i,j)-phixn(i-1,j-1)-phixn(i-1,j)
                                      +phiyn(i-1,j)+phiyn(i,j)-phiyn(i-1,j-1)-phiyn(i,j-1)));
                 }
            }
            for (int i=1; i<nx-1; ++i) {
                for (int j=1; j<ny-1; ++j) {
                    phix(i,j) = 1./(1./dt+zeta1(i)/2.)*(phixn(i,j)*(1./dt-zeta1(i)/2.)
                                 +(zeta2(j)-zeta1(i))*cc(i,j)*.25/dx*
                                   ( u(i+1,j+1)+u(i+1,j)-u(i,j+1)-u(i,j)
                                    +un(i+1,j+1)+un(i+1,j)-un(i,j+1)-un(i,j)));
                    phiy(i,j) = 1./(1./dt+zeta2(j)/2.)*(phiyn(i,j)*(1./dt-zeta2(j)/2.)
                                 +(zeta1(i)-zeta2(j))*cc(i,j)*.25/dx*
                                   ( u(i+1,j+1)+u(i,j+1)-u(i+1,j)-u(i,j)
                                    +un(i+1,j+1)+un(i,j+1)-un(i+1,j)-un(i,j)));
                }
            }
        """
        weave.inline(code,['u', 'un', 'unn', 'dt', 'dx','nx','ny','zeta1','zeta2','cc','rho',
                           'phixn','phiyn','phix','phiy'],
                   type_converters = converters.blitz)#, compiler = 'gcc')
    
    def __timestep(self):
        
            #Call Source     
            if self.__src_function:
                self.__src_function(self.__n*self.__dt)
            
            #One time step        
            self.__timestepper()
            
            #Impose reflective b.c. at pml boundary                 
            self.__u[0,:] = self.__u[1,:]
            self.__u[-1,:] = self.__u[-2,:]
            self.__u[:,0] = self.__u[:,1]
            self.__u[:,-1] = self.__u[:,-2]

            #Consider an absorber using plane wave radiation b.c.
            if np.size(self.__absorber_y)>0:
                self.__u[self.__absorber_x,self.__absorber_y] = -self.__dx/self.__dt/\
                    self.__c1*\
                    (self.__u[self.__absorber_x-1,self.__absorber_y]-\
                     self.__un[self.__absorber_x-1,self.__absorber_y])+\
                     self.__u[self.__absorber_x-1,self.__absorber_y]
                self.__u[self.__absorber_x+1,self.__absorber_y] = 0.                

            #save values for the time derivative 
            self.__unn = self.__un.copy() #n-1 time step
            self.__un = self.__u.copy()   #n time step
            self.__phixn = self.__phix.copy()
            self.__phiyn = self.__phiy.copy()
            
            #Update vertical screen variables (Intensity)
            if (self.__scr1pos>=0):
                self.__scr1 += np.sum(self.__u[self.__scr1pos:self.__scr1pos+1,:]**2.,axis=0)
            if (self.__scr2pos>=0):
                self.__scr2 += np.sum(self.__u[self.__scr2pos:self.__scr2pos+1,:]**2.,axis=0)

    def solvestep(self):
        """
        Solves the wave PDE for the time given set by output parameter.
        Function return False as long as t < than sim_duration
        """
        if self.__n>=self.__nt:
            return True
        
        for self.__n in range (self.__n,self.__n+self.__output):
            self.__timestep()

        self.__plotcallback(self.__u[self.pml_length:-self.pml_length,
                                     self.pml_length:-self.pml_length],
                            self.__n*self.__dt)
        time.sleep(self.__slowdown)
        
        return False
    
    

In [39]:
#%matplotlib inline
import cv2
import base64
import io
import numpy as np
import time
from IPython import display #for continous display
from ipywidgets import widgets
import matplotlib.pyplot as plt #plotting
import matplotlib.patches as mpatches #patches to plot the slits etc.
from matplotlib.collections import PatchCollection
#import wavesolverInPML 

global patches, a, tot_sum1, tot_sum2

def dovideo(obj):
    cap = cv2.VideoCapture(0)

    while(True):
        # Capture frame-by-frame
        ret, frame = cap.read()
        # Our operations on the frame come here
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        cv2.imshow('frame',np.fliplr(gray))
        if (cv2.waitKey(1) & 0xFF) == ord('q'):
            break
    # When everything done, release the capture
    cap.release()
    cv2.destroyAllWindows()

mypos=0.3
def plotwave(u,time):
    global tot_sum1, tot_sum2
    scr1pos=int(w_scr1.value/2.*a.nx)
    scr2pos=int(w_scr2.value/2.*a.ny)
    u=np.rot90(u)

    f=plt.figure(1, figsize=(6*1.5,3*1.5), dpi=100)
    #f=plt.figure(1,figsize=(6,4))
    #f, (ax1, ax2) = plt.subplots(ncols=2, sharey=True,figsize=(6,3), dpi=100)
    ax1 = plt.subplot2grid((1,4), (0,0), colspan=3)
    ax2 = plt.subplot2grid((1,4), (0,3), colspan=1)
    ax1.imshow(u, origin='upper', extent=[0., a.nx*a.dx, 0., a.ny*a.dx],\
               vmax=1, vmin=-1) #plot the wave field
    ax1.set_xlim([0.,a.nx*a.dx])
    ax1.set_ylim([0.,a.ny*a.dx])
    ax1.axes.get_xaxis().set_visible(False)
    ax1.axes.get_yaxis().set_visible(False)
    ax1.plot([w_scr1.value,w_scr1.value],[0., a.ny*a.dx],'-.b')
    ax1.plot([w_scr2.value,w_scr2.value],[0., a.ny*a.dx],'-.g')
    collection = PatchCollection(patches, alpha=0.3)
    ax1.add_collection(collection)
    scr1=a.scr1
    scr2=a.scr2

    if (a.scr1.max()>1e-2):
        ax2.plot(a.scr1/a.scr1.max(),np.linspace(0,a.sizey,u.shape[0]),'b')
    if (a.scr2.max()>1e-2):
        ax2.plot(a.scr2/a.scr2.max(),np.linspace(0,a.sizey,u.shape[0]),'g')
    ax2.set_xlim([0.,1.1])
    ax2.set_ylim([0.,a.sizey])
    ax2.axes.get_xaxis().set_visible(False)
    ax2.axes.get_yaxis().set_visible(False)    
    f.tight_layout()
    
    w_fig1.value = plot_to_html()
    plt.close()
    
    w_text1.value="Time {0:.5f}".format(time)
    
def plot_to_html():
    # write image data to a string buffer and get the PNG image bytes
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    return """<img src='data:image/png;base64,{}'/>""".format(base64.b64encode(buf.getvalue()).decode('ascii'))
    
    
def linesource(a):
    n_p = w_nosrc.value
    dny = w_distsrc.value
    p = np.zeros(n_p, dtype={'names':['ampl', 'xpos', 'ypos'],
                                 'formats':['f8','f8','f8']})
    #Sources along a line
    if (n_p == 1):
        dny=0.
        
    for i in range(n_p):
        p['ampl'][i] = 1.
        p['xpos'][i] = 0.1
        p['ypos'][i] = (2.-(dny*n_p))/2.+dny*i
    a.cfg_pointsources(p)
    return a

def halfcircle(a):
    n_p=5
    p = np.zeros(n_p, dtype={'names':['ampl', 'xpos', 'ypos'],
                                 'formats':['f8','f8','f8']})
    n=0
    for i in range(n_p):
        p['ampl'][n] = 5.
        p['xpos'][n] = .5*np.cos(np.pi/(n_p-1)*i)+1.
        p['ypos'][n] = .5*np.sin(np.pi/(n_p-1)*i)+1.
        n += 1
    a.cfg_pointsources(p)
    return a

def fresnelplate(a):
    #Fresnel zone plate
    nn=33 #order of plate
    f=1. #focus position
    l=a.c1/a.nu #wavelength
    yy=np.linspace(0.,2.,a.ny)
    ampl=np.zeros(a.ny)
    n=0
    #make a zoneplate
    for i in range(0,nn,2):
        k=np.where(((yy-1.)<(((i+1)*l*(f+(i+1)*l/4.))**.5)) &
                   ((yy-1.)>((i*l*(f+i*l/4.))**.5)))
        ampl[k]=1.
        ampl[a.ny-np.asarray(k)]=1.

    n_p=np.asarray(np.where(ampl==1.))
    p = np.zeros(n_p.size, dtype={'names':['ampl', 'xpos', 'ypos'],
                                 'formats':['f8','f8','f8']})
    p['ampl'][:]=1.
    p['xpos'][:]=.3
    p['ypos'][:]=yy[n_p]
    a.cfg_pointsources(p)
    return a

def gridofdisks(a):    
    #grid of disks

    ndx=1
    ndy=1
    d = np.zeros(ndx*ndy, dtype={'names':['radius','xpos','ypos'],
                                 'formats':['f8','f8','f8']})
    n=0 
    for i in range(ndx):
        for j in range(ndy):
            d['radius'][n] = .3
            d['xpos'][n] = 2./float(ndx)*float(i)+1.-float(ndx-1.)/float(ndx)
            d['ypos'][n] = 2./float(ndy)*float(j)+1.-float(ndy-1.)/float(ndy)
            n += 1
    a.cfg_simplediffraction(disks=d)
    return a

def simplerectangle(a):
    #reflective rectangle
    nr=1
    r = np.zeros(nr, dtype={'names':['x_up', 'y_up', 'width', 'height'],
                                 'formats':['f8','f8','f8','f8']})
    r['x_up'][0] = 0.5
    r['y_up'][0] = 1.5
    r['width'][0] = .4
    r['height'][0] = 1
    a.cfg_simplediffraction(rects=r)
    return a

def init_tabs(newvalue):
    #if (w_tabs.selected_index==0):
    #     
    #if (w_tabs.selected_index==1):
    #    
    init_graph(1)
    
def init_graph(newvalue):
    calc_patch(w_tabs.selected_index)
    plotwave(np.zeros((a.nx,a.ny)),0.)

def init_solver():
    global a, tot_sum1, tot_sum2
    no_frames=30.
    mynx=w_nx.value
    myny=256
    a=WaveSolverInPML(nx=mynx, ny=myny, pml_length=50, CFL=.4, wavelength=w_lambda.value,\
              output=w_lambda.value/w_duration.value*no_frames, slowdown=0, plotcallback=plotwave, \
              sim_duration=w_duration.value, rho1=.7, c1= 1., rho2=1, c2=.7, src_cycles=w_nopulses.value,\
              scr1pos=int(np.floor(w_scr1.value/2.*(mynx-1))),scr2pos=int(np.floor(w_scr2.value/2.*(mynx-1))))

def calc_patch(case):
    global patches, a
    patches=[]
    if (case==1): #Single Slit  
        rect = mpatches.Rectangle([mypos, 0.], 0.05, a.sizey/2.-w_width.value/2, ec="none")#xy, width, height,
        patches.append(rect)
        rect = mpatches.Rectangle([mypos, a.sizey/2.+w_width.value/2.], 0.05, a.sizey/2.-w_width.value/2., ec="none")#xy, width, height,
        patches.append(rect)
    if (case==2): #Double Slit  
        rect = mpatches.Rectangle([mypos, 0.], 0.05, a.sizey/2.-w_dist.value/2.-w_width.value/2.,\
                                  ec="none")#xy, width, height,
        patches.append(rect)
        rect = mpatches.Rectangle([mypos, a.sizey/2.-w_dist.value/2.+w_width.value/2.], 0.05, \
                                  w_dist.value-w_width.value, ec="none")#xy, width, height,
        patches.append(rect)
        rect = mpatches.Rectangle([mypos, a.sizey/2.+w_dist.value/2.+w_width.value/2.], 0.05, \
                                  a.sizey/2.-w_dist.value/2.-w_width.value/2., ec="none")#xy, width, height,
        patches.append(rect)
    if (case==0): #Half Slit  
        rect = mpatches.Rectangle([mypos, 0.], 0.05, w_ypos.value, ec="none")#xy, width, height,
        patches.append(rect)

            
def plotall(newvalue):
    global patches,a
    init_solver()
    calc_patch(w_tabs.selected_index)

    if (w_tabs.selected_index==1): #Single Slit
        a.cfg_singleslit(ow=w_width.value,pos=mypos)   
    if (w_tabs.selected_index==2): #Double Slit
        a.cfg_doubleslit(ow=w_width.value, dw=w_dist.value, pos=mypos)
    if (w_tabs.selected_index==0): #Half Slit
        a.cfg_halfslit(ypos=w_ypos.value,pos=mypos) 
        
    while True:
        if a.solvestep():
            break
        time.sleep(0)
    

       
w_top=widgets.HTML(value="""
            <h2> Diffraction Examples</h2>
            <p>This program calculates diffraction of a scalar wave equation in 2 dimensions.
            </p><hr>
            """)
w_fig1 = widgets.HTML()
w_fig2 = widgets.HTML()
w_text1 = widgets.HTML()

w_lambda=widgets.FloatSlider(min=.1,max=.3,step=.01, value=.12,description='Wavelength')

w_desc1=widgets.HTML(value="""
            <p>Please choose the width of the slit.</p>
            """)
w_desc2=widgets.HTML(value="""
            <p>Please choose the width of the slit and the distance between the slits.</p>
            """)
w_width=widgets.FloatSlider(min=.1, max=1., step=.05, value=.3, description='Width')
w_width.observe(init_graph,"value")
w_dist=widgets.FloatSlider(min=.1, max=1., step=.05, value=.2, description='Distance')
w_dist.observe(init_graph,"value")
w_ypos=widgets.FloatSlider(min=.1, max=.9, step=.05, value=.4, description='Position')
w_ypos.observe(init_graph,"value")

w_nopulses=widgets.FloatSlider(min=1, max=100, step=1, value=100., description='Pulses #')
w_duration=widgets.FloatSlider(min=1., max=8., step=.1, value=3.0, description='Sim. duration')
w_scr1=widgets.FloatSlider(min=0.35, max=1.95, step=.05, value=.35, description='Pos. Screen 1')
w_scr1.observe(init_graph,"value")
w_scr2=widgets.FloatSlider(min=0.35, max=1.95, step=.05, value=0.35, description='Pos. Screen 2')
w_scr2.observe(init_graph,"value")
w_nx=widgets.IntSlider(min=128, max=640, step=128, value=256, description='Domain Length')


w_nosrc=widgets.IntSlider(min=1, max=10, step=1, value=2, description='Sources #')
w_distsrc=widgets.FloatSlider(min=.1, max=1., step=0.1, value=.2, description='Dist. Sources')


page1=widgets.Box([w_ypos], margin=10)
page2=widgets.Box([w_desc1, w_width], margin=10)
page3=widgets.Box([w_desc2, w_width, w_dist], margin=10)
page4=widgets.Box([w_nopulses, w_duration, w_scr1, w_scr2,w_nx], margin=10)

w_tabs = widgets.Tab(children=[page1, page2, page3, page4])
w_tabs.set_title(0, 'Halfslit')
w_tabs.set_title(1, 'Single Slit')
w_tabs.set_title(2, 'Double Slit')
w_tabs.set_title(3, 'Settings')
w_tabs.observe(init_tabs, "selected_index")

w_b1=widgets.Button(description="Start Wave",width='100px')
w_b1.on_click(plotall)
w_b2=widgets.Button(description="Start Video",width='100px')
w_b2.on_click(dovideo)

w_brow=widgets.HBox([w_b1, w_b2])
w_hbox=widgets.HBox([w_fig1, w_fig2])
w_vbox=widgets.VBox([w_top, w_text1, w_hbox, w_lambda, w_tabs, w_brow])

display.display(w_vbox);
init_solver()
init_graph(1)