%load_ext autoreload
%autoreload 2


import sys, os
from SimPEG import *
from simpegEM1D import (
    EM1D, EM1DSurveyTD, Utils1D, get_vertical_discretization_time, 
    set_mesh_1d, skytem_HM_2015
)

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import subplots
from matplotlib.patches import Rectangle
import ipysheet
from ipywidgets.widgets.interaction import show_inline_matplotlib_plots
from IPython.display import clear_output,display
from ipywidgets import (
    interactive,
    IntSlider,
    widget,
    FloatText,
    FloatSlider,
    FloatLogSlider,
    Checkbox,
    ToggleButtons,
    Button,
    Output
)
import ipywidgets as widgets
  
rhomin=0.1
rhomax=1000.
rho_default=50.
zmin=0.
zmax=350
rho_half=30.
tx_alt= 30.

time = np.logspace(-5, -2, 31)
hz = get_vertical_discretization_time(time, facter_tmax=0.3, factor_tmin=10., n_layer=19)
mesh1D = set_mesh_1d(hz)
depth = -mesh1D.gridN[:-1]
LocSigZ = -mesh1D.gridCC
TDsurvey = EM1DSurveyTD(
    rx_location = np.array([0., 0., 100.+tx_alt]),
    src_location = np.array([0., 0., 100.+tx_alt]),
    topo = np.r_[0., 0., 100.],
    depth = depth,
    rx_type = 'dBzdt',
    wave_type = 'stepoff',
    src_type = 'CircularLoop',
    a = 13.,
    I = 1.,
    time = time,
    base_frequency = 25.,
    use_lowpass_filter=False,
    high_cut_frequency=210*1e3        
)
sig_half = 1./20.
sig_blk = sig_half * 20.
chi_half = 0.
expmap = Maps.ExpMap(mesh1D)


# decorater used to block function printing to the console
def blockPrinting(func):
    def func_wrapper(*args, **kwargs):
        # block all printing to the console
        sys.stdout = open(os.devnull, 'w')
        # call the method in question
        value = func(*args, **kwargs)
        # enable all printing to the console
        sys.stdout = sys.__stdout__
        # pass the return value of the method back
        return value

    return func_wrapper

def define_sigma_layers(sig_half,sig_lay,layers):
    sig  = np.ones(TDsurvey.n_layer)*sig_half
    for i in range(len(layers)):
        if i==len(layers)-1:
            ind = (layers[i]>LocSigZ)
        else:
            ind = np.where((layers[i]>LocSigZ) & (layers[i+1]<=LocSigZ))
    #     print(i,sig_lay[i],ind)
        print(sig_half,sig_lay,layers)
        sig[ind] = sig_lay[i]
    return sig
        
    
def plot_mesh(sig,mesh1D,ax=None):
    import matplotlib.patheffects as pe
    if ax is None:
        fig, ax = plt.subplots(1,1, figsize=(5, 8))
    Utils1D.plotLayer(1./sig, mesh1D, showlayers=True,xlim=(rhomin*.9,rhomax*1.1),label='True resistivity',ax=ax)
    ax.set_ylim(-50,350)
    xlims = ax.get_xlim()
    rect = Rectangle((xlims[0],-50),xlims[1]-xlims[0],50,fc=[.8,.8,.8])
    ax.annotate('Ground surface', xy=(xlims[1],0),  xycoords='data',
                fontsize=14,xytext=(2000.,40), textcoords='data',
                arrowprops=dict(arrowstyle="->"),
                horizontalalignment='left', verticalalignment='top',
                )
    ax.annotate('Helicopter transmitter height', xy=(xlims[1],-tx_alt),  xycoords='data',
                fontsize=14,xytext=(2000.,-tx_alt-40), textcoords='data',
                arrowprops=dict(arrowstyle="->"),
                horizontalalignment='left', verticalalignment='top',
                )
    plt.hlines(-tx_alt, xlims[0], xlims[1], colors='k', linestyles='dashed')
    txt = ax.text(0.5, 0.94, "Air",transform=ax.transAxes,
                size=22,ha='center', va='center',
                path_effects=[pe.withStroke(linewidth=8, foreground=[.8,.8,.8])])


#     plt.annotate()
    ax.add_patch(rect)
    ax.invert_yaxis()
    ax.set_xlabel(r'Resistivity $\rho$ (ohm-m)')
    return ax



def create_table():
    def new_rho_logfloatslider(lay_name='1',value=rho_default):
        return widgets.FloatLogSlider(
                base=10,
                min=np.log10(rhomin),
                max=np.log10(rhomax),
                value=value,
                step=0.1,
                continuous_update=False,
                description=r"$\rho_{" + str(lay_name) + "}$",
                readout=False,
                layout=widgets.Layout(width='250px')
            )

    def insert_rho_sliders(rho_sliders,col=0,row_start=0):
        column1 = ipysheet.column(col, rho_sliders,row_start=row_start)
    #     for i,slide in enumerate(rho_sliders):
    #         cell(row_start+i,col,slide)
        cells = []
        for i in range(len(rho_sliders)):
            cells.append(ipysheet.cell(row_start+i,col+1,rho_default,numeric_format='0.0'))
        for c,s in zip(cells,rho_sliders):
            widgets.jslink((c, "value"),(s,"value"))
        return cells

    #Create depth values
    def new_thk_cell(row,col=2,thk=30):
        if thk is None:
            '''
            dep: set to previous layer
            '''
        return ipysheet.cell(row,col,thk,numeric_format='0')

    #Create depth values
    def new_dep_cell(row,col=3,cell_thk=None):
        c=ipysheet.cell(row,col,value=0,
                    color='black',
                    background_color='grey',
                    numeric_format='0',
                    read_only=True)
        @ipysheet.calculation(inputs=cell_thk, output=c)
        def add(*a):
            return sum(a)
        return c
    sheet = ipysheet.sheet(rows=4,columns=4,column_width=[100,50,50,50],
                           column_headers=['Resistivity slider','Resistivity value',
                                           r'Thickness (m)','Top of layer depth (m)'],
                          )
    #Set background values
    ipysheet.row(0, ['--','--'],column_start=2,column_end=3,
                 read_only=True,color='black', background_color='grey')

    #Thickness col
    cell_thk = []
    cell_dep = []
    for i in range(1,4):
        cell_thk.append(new_thk_cell(row=i,col=2,thk=30))
        cell_dep.append(new_dep_cell(row=i,col=3,cell_thk=cell_thk[:-1]))
    #     @ipysheet.calculation(inputs=cell_thk[:-1], output=cell_dep[-1])
    #     def add(*a):
    #         return sum(a)

    #Create rho sliders and depth cells
    cells_rho = []
    rho_sliders = []
    rho_sliders.append(new_rho_logfloatslider('background',rho_default*10))
    for i in range(3):
        rho_sliders.append(new_rho_logfloatslider(str(i+1)))
    #     cells.append(new_depfrom_cell(row=i,col=2,dep=10))
    cells_rho.extend(insert_rho_sliders(rho_sliders))

    #Create button
    button = widgets.Button(description='Add Row')
    out = widgets.Output()
    def add_row(_):
        with out:
            sheet.rows += 1
            rho_sliders.append(new_rho_logfloatslider(sheet.rows-1))
            cells_rho.extend(insert_rho_sliders((rho_sliders[-1],),row_start=sheet.rows-1))
            cell_thk.append(new_thk_cell(row=sheet.rows-1,col=2,thk=30))
            cell_dep.append(new_dep_cell(row=sheet.rows-1,col=3,cell_thk=cell_thk[:-1]))
            [c.observe(update_plot,'value') for c in rho_sliders]
            [c.observe(update_plot,'value') for c in cell_dep]
    button.on_click(add_row)
#     f,ax = plt.subplots(1,figsize=(4,8))
    
    
    def define_sigma_layers(sig_lay,layers):
        '''
        layers  (n,): tops of layers, negative values indicate below surface
        sig_lay (n,): conductivity of each layer, same size as layers. Final value considered to be conductivity of background '''
        sig  = np.ones(TDsurvey.n_layer)*sig_lay[-1]
        for i in range(len(layers)):
            if i==len(layers)-1:
                ind = (layers[i]>LocSigZ)
            else:
                ind = np.where((layers[i]>LocSigZ) & (layers[i+1]<=LocSigZ))
            sig[ind] = sig_lay[i]
        return sig

    def plot_res_model(rhos,lays,ax=None):
        layers = -np.asarray(lays)
        sigs=1./np.r_[np.asarray(rhos[1:]),rhos[0]]
        sig = define_sigma_layers(sigs,layers)
        ax = plot_mesh(sig,mesh1D,ax=ax)
        plt.show()
        return sig, ax

    output = widgets.Output(layout={'border': '10px solid black',
                                   'height':'600px',
                                   'width': '600px'})
        
    @output.capture()
    def update_plot(change):
        clear_output(True)
#         output.clear_output(wait=True)
        sig,_ = plot_res_model([r.value for r in rho_sliders],[d.value for d in cell_dep])
        show_inline_matplotlib_plots()

        
#     with output:
#         update_plot()

    [c.observe(update_plot,'value') for c in rho_sliders]
    [c.observe(update_plot,'value') for c in cell_dep]
    
    vbox=widgets.VBox([sheet,button,output])
    display(vbox)
    
    
#     rho_sliders[0].value=rho_default
#     sig,ax= plot_res_model([r.value for r in rho_sliders],[d.value for d in cell_dep],ax=ax)
#     show_inline_matplotlib_plots()
    return rho_sliders,cell_dep


rho_sliders,cell_dep=create_table()

rhos  =[r.value for r in rho_sliders][:-1]
lays = [d.value for d in cell_dep]
print(rhos,lay_tops)


def define_sigma_layers(sig_lay,layers):
    '''
    layers  (n,): tops of layers, negative values indicate below surface
    sig_lay (n,): conductivity of each layer, same size as layers. Final value considered to be conductivity of background '''
    sig  = np.ones(TDsurvey.n_layer)*1.0
    for i in range(len(layers)):
        if i==len(layers)-1:
            ind = (layers[i]>LocSigZ)
        else:
            ind = np.where((layers[i]>LocSigZ) & (layers[i+1]<=LocSigZ))
        sig[ind] = sig_lay[i]
        print(sig)
    return sig 

def plot_res_model(rhos,lays,ax=None):
    layers = -np.asarray(lays)
    sig_lay=1./np.asarray(rhos)
    sig = define_sigma_layers(sigs,layers)
    ax = plot_mesh(sig,mesh1D,ax=ax)
    plt.show()
    return sig, ax


sig=define_sigma_layers(sigs,layers)
ax = plot_mesh(sig,mesh1D,ax=None)
plt.show()

### Get rid of background value 

def create_table():
    def new_rho_logfloatslider(lay_name='1',value=rho_default):
        return widgets.FloatLogSlider(
                base=10,
                min=np.log10(rhomin),
                max=np.log10(rhomax),
                value=value,
                step=0.1,
                continuous_update=False,
                description=r"$\rho_{" + str(lay_name) + "}$",
                readout=False,
                layout=widgets.Layout(width='250px')
            )

    def insert_rho_sliders(rho_sliders,col=0,row_start=0):
        column1 = ipysheet.column(col, rho_sliders,row_start=row_start)
    #     for i,slide in enumerate(rho_sliders):
    #         cell(row_start+i,col,slide)
        cells = []
        for i in range(len(rho_sliders)):
            cells.append(ipysheet.cell(row_start+i,col+1,rho_default,numeric_format='0.0'))
        for c,s in zip(cells,rho_sliders):
            widgets.jslink((c, "value"),(s,"value"))
        return cells

    #Create depth values
    def new_thk_cell(row,col=2,thk=30):
        if thk is None:
            '''
            dep: set to previous layer
            '''
        return ipysheet.cell(row,col,thk,numeric_format='0')

    #Create depth values
    def new_dep_cell(row,col=3,cell_thk=None):
        c=ipysheet.cell(row,col,value=0,
                    color='black',
                    background_color='grey',
                    numeric_format='0',
                    read_only=True)
        @ipysheet.calculation(inputs=cell_thk, output=c)
        def add(*a):
            return sum(a)
        return c
    sheet = ipysheet.sheet(rows=4,columns=4,column_width=[100,50,50,50],
                           column_headers=['Resistivity slider','Resistivity value',
                                           r'Thickness (m)','Top of layer depth (m)'],
                          )
#     #Set background values
#     ipysheet.row(0, ['--','--'],column_start=2,column_end=3,
#                  read_only=True,color='black', background_color='grey')

    #Thickness col
    cell_thk = []
    cell_dep = []
    for i in range(0,4):
        cell_thk.append(new_thk_cell(row=i,col=2,thk=30))
        cell_dep.append(new_dep_cell(row=i,col=3,cell_thk=cell_thk[:-1]))
    #     @ipysheet.calculation(inputs=cell_thk[:-1], output=cell_dep[-1])
    #     def add(*a):
    #         return sum(a)

    #Create rho sliders and depth cells
    cells_rho = []
    rho_sliders = []
#     rho_sliders.append(new_rho_logfloatslider('background',rho_default*10))
    for i in range(4):
        rho_sliders.append(new_rho_logfloatslider(str(i+1)))
    #     cells.append(new_depfrom_cell(row=i,col=2,dep=10))
    cells_rho.extend(insert_rho_sliders(rho_sliders))

    #Create button
    button = widgets.Button(description='Add Row')
    out = widgets.Output()
    def add_row(_):
        with out:
            sheet.rows += 1
            rho_sliders.append(new_rho_logfloatslider(sheet.rows-1))
            cells_rho.extend(insert_rho_sliders((rho_sliders[-1],),row_start=sheet.rows-1))
            cell_thk.append(new_thk_cell(row=sheet.rows-1,col=2,thk=30))
            cell_dep.append(new_dep_cell(row=sheet.rows-1,col=3,cell_thk=cell_thk[:-1]))
            [c.observe(update_plot,'value') for c in rho_sliders]
            [c.observe(update_plot,'value') for c in cell_dep]
    button.on_click(add_row)
#     f,ax = plt.subplots(1,figsize=(4,8))
    
    

    output = widgets.Output(layout={'border': '1px solid black',
                                   'height':'600px',
                                   'width': '600px'})
    with output:
        sig,_ = plot_res_model([r.value for r in rho_sliders],[d.value for d in cell_dep])
        show_inline_matplotlib_plots()

    @output.capture()
    def update_plot(change):
        clear_output(True)
#         output.clear_output(wait=True)
        sig,_ = plot_res_model([r.value for r in rho_sliders],[d.value for d in cell_dep])
        show_inline_matplotlib_plots()

#     with output:
#         update_plot()

    [c.observe(update_plot,'value') for c in rho_sliders]
    [c.observe(update_plot,'value') for c in cell_dep]
    
    vbox=widgets.VBox([sheet,button,output])
    display(vbox)
    
#     rho_sliders[0].value=rho_default
#     sig,ax= plot_res_model([r.value for r in rho_sliders],[d.value for d in cell_dep],ax=ax)
#     show_inline_matplotlib_plots()
    return rho_sliders,cell_dep


rho_sliders,cell_dep= create_table()

file = open("../../docs/images/heli_wlogo_border_blur-01.png", "rb")
image = file.read()
widgets.Image(
    value=image,
    format='png',
    width=100,
    height=150,
)

f,ax = plt.subplots(1)
ax = plot_mesh(sig,mesh1D,ax)

from ipywidgets import interactive
import matplotlib.pyplot as plt
import numpy as np

def f(m, b):
    plt.figure(2)
    x = np.linspace(-10, 10, num=1000)
    plt.plot(x, m * x + b)
    plt.ylim(-5, 5)
    plt.show()

interactive_plot = interactive(f, m=(-2.0, 2.0), b=(-3, 3, 0.5))
output = interactive_plot.children[-1]
output.layout.height = '350px'
interactive_plot

    



def update_res_model(rho_vals,dep_vals):
    button = Button(description="Update model")
    output = Output()
    display(button, output)

    def on_button_clicked(b):
        with output:
            clear_output(True)
            print('plotting again')
            sig = plot_res_model(rhos=rho_vals,lays=dep_vals)
            show_inline_matplotlib_plots()
        return
    button.on_click(on_button_clicked)

    show_inline_matplotlib_plots()
    return button,output




button,output = update_res_model([c.value for c in rho_sliders], 
             [c.value for c in cell_deps])
display(output)


cell_dep[-1].value = 120

c1 = cell_dep[0]
c2 = cell_dep[1]

import traitlets
c1.observe(func,traitlets.All)

traitlets.link()

slider_dict = {}
for s in sliders:
    slider_dict[s.description]=s
widgets.interact(viewer, **slider_dict)

def define_sigma_layers(sig_lay,layers):
    '''
    layers  (n,): tops of layers, negative values indicate below surface
    sig_lay (n,): conductivity of each layer, same size as layers. Final value considered to be conductivity of background '''
    sig  = np.ones(TDsurvey.n_layer)*sig_lay[-1]
    for i in range(len(layers)):
        if i==len(layers)-1:
            ind = (layers[i]>LocSigZ)
        else:
            ind = np.where((layers[i]>LocSigZ) & (layers[i+1]<=LocSigZ))
        sig[ind] = sig_lay[i]
    return sig

def plot_res_model(rhos,lays):
    layers = -np.asarray(lays)
    sigs=1./np.r_[np.asarray(rhos[1:]),rhos[0]]
    sig = define_sigma_layers(sigs,layers)
    plot_mesh(sig,mesh1D)
    plt.show()
    return sig

# rho_list=[c.value for c in rho_sliders[1:]]
# rho_half=rho_sliders[0].value
# layers=[c.value for c in cell_deps]
# plot_res_model(rho_list,layers)




rho_sliders,cell_deps = create_table()

# kwargs = {'rho_list':sliders,'rho_half':sliders[0],lay_tops=fixed()
# app = widgets.interactive(plot_res_model,rhos=fixed([c.value for c in rho_sliders]),lays=fixed([c.value for c in cell_deps]))
# display(app)



from traitlets import HasTraits, observe, Instance, Int

class A(HasTraits):
    value = Int()
    
    def __init__(self, val):
        self.value = val
    
    @observe('value')
    def func(self, change):
        print(change)
        
class App(HasTraits):
    myA = Instance(klass=A, args=(2,))
    

app = App()
app.myA.value = 2

app.myA.value = 0

dir(app.myA)

[c.value for c in rho_sliders]

print(len(rho_sliders),len(cell_deps))

def plot_res_model_tbl(rho_list,rho_half,lay_tops):
    layers = -np.array(lay_tops)
    sig_lay = 1./np.r_[rho_list[1:]]
    sig_half= 1./rho_list[0]
    sig = define_sigma_layers(sig_half,sig_lay,layers)
    plot_mesh(sig,mesh1D)
    plt.show()
    return sig
    
    
kwargs = {'rho_list':sliders[1:],'rho_half':sliders[0],lay_tops=fixed()
widgets.interact(plot_res_model_tbl, **slider_dict,)

def some_args(*args):
    [print(arg,'and') for arg in args]
    return
some_args(np.arange(0,100))






def create_table():
    def new_rho_logfloatslider(lay_name='1'):
        return widgets.FloatLogSlider(
                base=10,
                min=np.log10(rhomin),
                max=np.log10(rhomax),
                value=50.,
                step=0.1,
                continuous_update=False,
                description=r"$\rho_{" + str(lay_name) + "}$",
                readout=False,
                layout=Layout(width='250px')
            )

    def insert_rho_sliders(rho_sliders,col=0,row_start=0):
        column1 = ipysheet.column(col, rho_sliders,row_start=row_start)
    #     for i,slide in enumerate(rho_sliders):
    #         cell(row_start+i,col,slide)
        cells = []
        for i in range(len(rho_sliders)):
            cells.append(ipysheet.cell(row_start+i,col+1,rho_default,numeric_format='0.0'))
        for c,s in zip(cells,rho_sliders):
            widgets.jslink((c, "value"),(s,"value"))
        return cells

    #Create depth values
    def new_thk_cell(row,col=2,thk=30):
        if thk is None:
            '''
            dep: set to previous layer
            '''
        return ipysheet.cell(row,col,thk,numeric_format='0')

    #Create depth values
    def new_dep_cell(row,col=3,cell_thk=None):
        c=ipysheet.cell(row,col,value=0,
                    color='black',
                    background_color='grey',
                    numeric_format='0',
                    read_only=True)
        @ipysheet.calculation(inputs=cell_thk, output=c)
        def add(*a):
            return sum(a)
        return c
    sheet = ipysheet.sheet(rows=4,columns=4,column_width=[100,50,50,50],
                           column_headers=['Resistivity slider','Resistivity value',
                                           r'Thickness (m)','Top of layer depth (m)'],
                          )
    #Set background values
    ipysheet.row(0, ['--','--'],column_start=2,column_end=3,
                 read_only=True,color='black', background_color='grey')

    #Thickness col
    cell_thk = []
    cell_dep = []
    for i in range(1,4):
        cell_thk.append(new_thk_cell(row=i,col=2,thk=30))
        cell_dep.append(new_dep_cell(row=i,col=3,cell_thk=cell_thk[:-1]))
    #     @ipysheet.calculation(inputs=cell_thk[:-1], output=cell_dep[-1])
    #     def add(*a):
    #         return sum(a)

    #Create rho sliders and depth cells
    cells_rho = []
    rho_sliders = []
    rho_sliders.append(new_rho_logfloatslider('background'))
    for i in range(3):
        rho_sliders.append(new_rho_logfloatslider(str(i+1)))
    #     cells.append(new_depfrom_cell(row=i,col=2,dep=10))
    cells_rho.extend(insert_rho_sliders(rho_sliders))

    #Create button
    button = widgets.Button(description='Add Row')
    out = widgets.Output()
    def add_row(_):
        with out:
            sheet.rows += 1
            rho_sliders.append(new_rho_logfloatslider(sheet.rows-1))
            cells_rho.extend(insert_rho_sliders((rho_sliders[-1],),row_start=sheet.rows-1))
            cell_thk.append(new_thk_cell(row=sheet.rows-1,col=2,thk=30))
            cell_dep.append(new_dep_cell(row=sheet.rows-1,col=3,cell_thk=cell_thk[:-1]))
    button.on_click(add_row)
    widgets.VBox([sheet,button])
    return
