In [24]:
import ipywidgets as widgets
from abtem.visualize.interactive.canvas import Canvas
from abtem.visualize.interactive.artists import ScatterArtist
import numpy as np

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import hsv_to_rgb, to_hex

In [2]:
from ipyevents import Event
from scipy.spatial import KDTree

In [3]:
class EditPointTags:

    def __init__(self, tags, data_fields=None):
        self._tags = tags

        if data_fields is None:
            data_fields = {}

        self._data_fields = data_fields
        self._event = Event(watched_events=['click'])

    @property
    def label(self):
        return self._label

    @label.setter
    def label(self, value):
        self._label = value

    def activate(self, canvas):
        self._event.source = canvas

        def handle_event(event):
            x, y = canvas.pixel_to_domain(event['offsetX'], event['offsetY'])
            tags = self._tags

            if event['altKey'] is True:
                positions = np.array([tags.x, tags.y]).T
                distances, indices = KDTree(positions).query([[x, y]])
                tags.delete_tags(indices)
            else:
                tags.add_tags([x], [y], {key: [value] for key, value in self._data_fields.items()})

        self._event.on_dom_event(handle_event)

    def deactivate(self, canvas):
        self._event.reset_callbacks()


In [86]:
coefs_num = 64
pad_factor = 10

r = np.arange(0, coefs_num).astype(float)
r_pad = np.linspace(0, 5., coefs_num * pad_factor)
q = np.fft.fftshift(np.fft.fftfreq(coefs_num, d=1.0))

In [123]:
#canvas1 = Canvas(height=300, width=400)
#canvas2 = Canvas(height=300, width=400)

canvas1 = Canvas()
canvas2 = Canvas()

real_space_artist = ScatterArtist(gl=True)
real_space_artist._mark.default_size = 5
real_space_artist.x = r_pad

canvas1.artists = {'artist': real_space_artist}
canvas1.x_limits = [0, r_pad.max()]
canvas1.y_limits = [-1.1, 1.1]

fft_artist = ScatterArtist(gl=True)
fft_artist._mark.default_size = 10
fft_artist.x = q
fft_artist.y = np.zeros(len(q))
fft_artist._mark.colors = ['#000000'] * len(q)

fft_artist2 = ScatterArtist(gl=True)
fft_artist2._mark.default_size = 5
fft_artist2.x = q
fft_artist2.y = np.zeros(len(q))
fft_artist2._mark.colors = ['white']

canvas2.artists = {'artist2':fft_artist2, 'artist1': fft_artist, }
canvas2.x_limits = [-.55, .55]
canvas2.y_limits = [-2, 40]

event = Event(watched_events=['click', 'mousemove'])

event_obj = Event(source=canvas2.figure, watched_events=['click', 'mousemove', 'mouseup', 'mousedown'])
event_obj._selected = None
event_obj._phase = np.zeros(len(q))
event_obj._a = None

def get_color(amplitude, phase):
    amp = amplitude
    p = phase
    h = np.mod(p[:,None] / (2.0 * np.pi),1)
    s = np.ones((coefs_num,1))
    v = np.sqrt(amp[:,None] / np.max(amp))
    hsv = np.hstack((h,s,v))
    rgb = hsv_to_rgb(hsv)
    return [to_hex(i) for i in rgb]

def handle_event(event):
    mouse_x, mouse_y = canvas2.pixel_to_domain(event['offsetX'], event['offsetY'])
        
    if event['type'] == 'mousedown':

        positions = np.array([fft_artist.x, fft_artist.y]).T
        distances, indices = KDTree(positions).query([[mouse_x, mouse_y]])
        event_obj._selected = indices[0]
        event_obj._mouse_x = mouse_x
        event_obj._mouse_y = mouse_y
    
    elif event['type'] == 'mousemove':
        if event_obj._selected is not None:
            y = fft_artist.y.copy()
            x = fft_artist.x
            i = event_obj._selected
            
            phase = event_obj._phase.copy()
            
            #color = fft_artist._mark.color.copy()
            phase[i] = -(x[i] - mouse_x) * 10
            
            i = event_obj._selected
            
            if i > coefs_num // 2:
                y[i] = mouse_y
                y[coefs_num // 2 - (i - coefs_num // 2)] = mouse_y
                
            elif i < coefs_num // 2:
                y[i] = mouse_y
                y[coefs_num // 2 + (coefs_num // 2 - i)] = mouse_y
            
            else:
                y[i] = mouse_y
            
            fft_artist.y = y
            fft_artist2.y = y
            
            event_obj._phase = phase
            fft_artist._mark.colors = get_color(y, phase)
            
            
    elif event['type'] == 'mouseup':
        event_obj._selected = None

def set_sig(sig):
    sig_fft = np.fft.ifftshift(sig)
    sig_fft_pad = np.zeros((sig_fft.size*pad_factor), dtype='complex')
    
    sig_fft_pad[0:coefs_num//2] = sig_fft[0:coefs_num//2]
    sig_fft_pad[1-coefs_num//2+coefs_num*pad_factor:coefs_num*pad_factor] = sig_fft[1-coefs_num//2:]
    sig_pad = np.real(np.fft.ifft(sig_fft_pad)) * pad_factor
    
    real_space_artist.y = sig_pad
        
def update_canvas1(*args):
    sig_fft = fft_artist.y * np.exp(1.j * event_obj._phase)
    set_sig(sig_fft)

def set_preset(change):
    
    preset = change['new']

    if preset == 'sine':
        sig = np.sin((2.0*np.pi*(1.0/16.0))*r)
    elif preset == 'cosine':
        sig = np.cos((2.0*np.pi*(1.0/16.0))*r)
    elif preset == 'wavepacket':
        sig = (np.cos((2.0*np.pi*(8.0/coefs_num))*r)**2) \
            * (np.sin((2.0*np.pi*(0.5/coefs_num))*r)**4)
    elif preset == 'wavepacket_broad':
        sig = (np.cos((2.0*np.pi*(8.0/coefs_num))*r)**2) \
            * np.abs(np.sin((2.0*np.pi*(0.5/coefs_num))*r)**1)
    elif preset == 'wavepacket_narrow':
        sig = (np.cos((2.0*np.pi*(8.0/coefs_num))*r)**2) \
            * (np.sin((2.0*np.pi*(0.5/coefs_num))*r)**16)    
    elif preset == 'atoms':
        sig = np.sin((2.0*np.pi*(4.0/coefs_num))*r)**8
    elif preset == 'atoms_surface':
        sig = (np.sin((2.0*np.pi*(4.0/coefs_num))*r)**8)*(r<coefs_num/2.0)
    elif preset == 'atoms_defect':
        sig = (np.sin((2.0*np.pi*(4.0/coefs_num))*r)**8)*(r<coefs_num/2.0) \
            + (np.cos((2.0*np.pi*(4.0/coefs_num))*r)**8)*(r>coefs_num*(7/16))*(r<coefs_num*(15/16))
    elif preset == 'two_sites':
        sig = 1.0*np.sin((2.0*np.pi*(2.0/coefs_num))*r)**16 \
            + 0.5*np.cos((2.0*np.pi*(2.0/coefs_num))*r)**16 - 0.295
    elif preset == 'two_sites_shift_left':
        sig = 1.0*np.sin((2.0*np.pi*(2.0/coefs_num))*r)**16 \
            + 0.5*np.cos((2.0*np.pi*(2.0/coefs_num))*(r+2))**16 - 0.295
    elif preset == 'two_sites_shift_right':
        sig = 1.0*np.sin((2.0*np.pi*(2.0/coefs_num))*r)**16 \
            + 0.5*np.cos((2.0*np.pi*(2.0/coefs_num))*(r-2))**16 - 0.295
    
    else:
        raise Exception("Preset not defined")
    
    sig = np.fft.ifftshift(np.fft.fft(sig))

    fft_artist.y = np.abs(sig)
    fft_artist2.y = np.abs(sig)
    fft_artist._mark.colors = get_color(np.abs(sig), np.angle(sig))
    
    sig_fft = fft_artist.y * np.exp(1.j * np.angle(sig))
    set_sig(sig_fft)
    event_obj._phase = np.angle(sig)

fft_artist.observe(update_canvas1, 'x')    
fft_artist.observe(update_canvas1, 'y')    
fft_artist._mark.observe(update_canvas1, 'colors')  

set_preset({'new':'sine'})

preset_dropdown = widgets.Dropdown(description='Preset', options=['sine', 
                                                                  'cosine', 
                                                                  'wavepacket',
                                                                  'wavepacket_broad',
                                                                  'wavepacket_narrow',
                                                                  'atoms',
                                                                  'atoms_surface',
                                                                  'two_sites',
                                                                  'two_sites_shift_left',
                                                                  'two_sites_shift_right'])
preset_dropdown.observe(set_preset, 'value')

canvas1.title = 'Real space'
canvas1.y_label = 'Amplitude'
canvas1.x_label = 'Spatial coordinate'

canvas2.title = 'Fourier space'
canvas2.y_label = 'Amplitude'
canvas2.x_label = 'Frequency'

update_canvas1()
event_obj.on_dom_event(handle_event)
tools = widgets.VBox([preset_dropdown])

widgets.VBox([widgets.HBox([canvas1.widget, canvas2.widget, tools])])

VBox(children=(HBox(children=(VBox(children=(HBox(children=(HBox(layout=Layout(width='50px')), HTML(value="<p …