In [None]:
%aiida

In [None]:
import numpy as np
import base64
import warnings
import numpy as  np

import ipywidgets as ipw
from IPython.display import display, clear_output
import nglview
from ase import Atoms
import time ## needed to order picked atoms for angles

from traitlets import Instance, Int, Set, Dict, Union,Unicode, default, link, observe, validate
from aiida.orm import Node

In [None]:
MOL_ASPECT  = 3.0
REST_ASPECT = 10.0
def find_ranges(iterable):
    """Yield range of consecutive numbers."""
    import more_itertools as mit
    for group in mit.consecutive_groups(iterable):
        group = list(group)
        if len(group) == 1:
            yield group[0]
        else:
            yield group[0], group[-1]
def set_to_string_range(selection,shift=0):
    """Convert a set like {1, 3, 4, 5} into a string like '1 3..5'."""
    """Shift used when e.g. for a user interface numbering starts from 1 not from 0"""
    return " ".join(
        [str(t+shift) if isinstance(t, int) else "{}..{}".format(t[0]+shift, t[1]+shift) for t in find_ranges(sorted(selection))])


def string_range_to_set(strng,shift=0):
    """Convert a string like '1 3..5' into a set like {1, 3, 4, 5}."""
    """Shift used when e.g. for a user interface numbering starts from 1 not from 0"""
    singles = [int(s)+shift for s in strng.split() if s.isdigit()]
    ranges = [r for r in strng.split() if '..' in r]
    if len(singles) + len(ranges) != len(strng.split()):
        return set(), False
    for rng in ranges:
        try:
            start, end = rng.split('..')
            singles += [i+shift for i in range(int(start), int(end) + 1)]
        except ValueError:
            return set(), False
    return set(singles), True


class CopyToClipboardButton(ipw.Button):
    """Button to copy text to clipboard."""

    value = Unicode(allow_none=True)  # Traitlet that contains a string to copy.

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        super().on_click(self.copy_to_clipboard)

    def copy_to_clipboard(self, change=None):  # pylint:disable=unused-argument
        """Copy text to clipboard."""
        from IPython.display import Javascript, display
        javas = Javascript("""
           function copyStringToClipboard (str) {{
               // Create new element
               var el = document.createElement('textarea');
               // Set value (string to be copied)
               el.value = str;
               // Set non-editable to avoid focus and move outside of view
               el.setAttribute('readonly', '');
               el.style = {{position: 'absolute', left: '-9999px'}};
               document.body.appendChild(el);
               // Select text inside element
               el.select();
               // Copy text to clipboard
               document.execCommand('copy');
               // Remove temporary element
               document.body.removeChild(el);
            }}
            copyStringToClipboard("{selection}");
       """.format(selection=self.value))  # For the moment works for Chrome, but doesn't work for Firefox.
        if self.value:  # If no value provided - do nothing.
            display(javas)

In [None]:
class _StructureDataBaseViewer(ipw.VBox):
    """Base viewer class for AiiDA structure or trajectory objects.

    :param configure_view: If True, add configuration tabs
    :type configure_view: bool"""
    selection = Set(Int)
    vis_dict = Dict()
    DEFAULT_SELECTION_OPACITY = 0.2
    DEFAULT_SELECTION_RADIUS = 6
    DEFAULT_SELECTION_COLOR = 'green'

    def __init__(self, configure_view=True, **kwargs):

        # Defining viewer box.

        # 1. Nglviwer
        self._viewer = nglview.NGLWidget()
        self._viewer.camera = 'orthographic'
        self._viewer.observe(self._on_atom_click, names='picked')
        self._viewer.stage.set_parameters(mouse_preset='pymol')
        self.selection_dict = {} ## needed to display info about selected atoms e.g. distance, angle..
        self.selection_info = ''  ## needed to display info about selected atoms e.g. distance, angle..
        # 2. Camera type.
        camera_type = ipw.ToggleButtons(options={
            'Orthographic': 'orthographic',
            'Perspective': 'perspective'
        },
                                        description='Camera type:',
                                        value='orthographic',
                                        layout={"align_self": "flex-end"},
                                        style={'button_width': '115.5px'},
                                        orientation='vertical')

        def change_camera(change):
            self._viewer.camera = change['new']

        camera_type.observe(change_camera, names="value")
        view_box = ipw.VBox([self._viewer, camera_type])

        # Defining selection tab.

        # 1. Selected atoms.
        self._selected_atoms = ipw.Text(description='Selected atoms:', value='', style={'description_width': 'initial'})

        # 2. Copy to clipboard
        copy_to_clipboard = CopyToClipboardButton(description="Copy to clipboard")
        link((self._selected_atoms, 'value'), (copy_to_clipboard, 'value'))

        # 3. Informing about wrong syntax.
        self.wrong_syntax = ipw.HTML(
            value="""<i class="fa fa-times" style="color:red;font-size:2em;" ></i> wrong syntax""",
            layout={'visibility': 'hidden'})

        # 4. Button to clear selection.
        clear_selection = ipw.Button(description="Clear selection")
        clear_selection.on_click(lambda _: self.set_trait('selection', set()))  # lambda cannot contain assignments

        # 5. Button to apply selection
        apply_selection = ipw.Button(description="Apply selection")
        apply_selection.on_click(self.apply_selection)
        
        # 6. Information on selected atoms
        self.info = ipw.Output()        

        selection_tab = ipw.VBox([
            ipw.HBox([self._selected_atoms, self.wrong_syntax]),
            ipw.HBox([copy_to_clipboard, clear_selection, apply_selection]),
            self.info
        ])
        


        # Defining appearance tab.

        # 1. Supercell
        self.supercell_x = ipw.BoundedIntText(value=1, min=1, layout={"width": "30px"})
        self.supercell_y = ipw.BoundedIntText(value=1, min=1, layout={"width": "30px"})
        self.supercell_z = ipw.BoundedIntText(value=1, min=1, layout={"width": "30px"})
        supercell_selector = ipw.HBox([
            ipw.HTML(description="Super cell:", layout={"width": "initial"}),
            self.supercell_x,
            self.supercell_y,
            self.supercell_z,
        ])

        # 2. Choose background color.
        background_color = ipw.ColorPicker(description="Background")
        link((background_color, 'value'), (self._viewer, 'background'))
        background_color.value = 'white'

        # 3. Center button.
        center_button = ipw.Button(description="Center")
        center_button.on_click(lambda c: self._viewer.center())

        appearance_tab = ipw.VBox([supercell_selector, background_color, center_button])

        # Defining download tab.

        # 1. Choose download file format.
        self.file_format = ipw.Dropdown(options=['xyz', 'cif'], layout={"width": "200px"}, description="File format:")

        # 2. Download button.
        self.download_btn = ipw.Button(description="Download")
        self.download_btn.on_click(self.download)
        self.download_box = ipw.VBox(
            children=[ipw.Label("Download as file:"),
                      ipw.HBox([self.file_format, self.download_btn])])

        # 3. Screenshot button
        self.screenshot_btn = ipw.Button(description="Screenshot", icon='camera')
        self.screenshot_btn.on_click(lambda _: self._viewer.download_image())
        self.screenshot_box = ipw.VBox(children=[ipw.Label("Create a screenshot:"), self.screenshot_btn])

        download_tab = ipw.VBox([self.download_box, self.screenshot_box])

        # Constructing configuration box
        if configure_view:
            configuration_box = ipw.Tab(layout=ipw.Layout(flex='1 1 auto', width='auto'))
            configuration_box.children = [selection_tab, appearance_tab, download_tab]
            for i, title in enumerate(["Selection", "Appearance", "Download"]):
                configuration_box.set_title(i, title)
            children = [ipw.HBox([view_box, configuration_box])]
            view_box.layout = {'width': "60%"}
        else:
            children = [view_box]

        if 'children' in kwargs:
            children += kwargs.pop('children')

        super().__init__(children, **kwargs)

    def _on_atom_click(self, change=None):  # pylint:disable=unused-argument
        """Update selection when clicked on atom."""
        if 'atom1' not in self._viewer.picked.keys():
            return  # did not click on atom
        index = self._viewer.picked['atom1']['index']
        millis = int(round(time.time() * 1000))  ## time needed to order picked atoms -> angles
        if self.vis_dict:
            component = self._viewer.picked['component']
            index = self._translate_i_loc_glob[(component,index)]
            
            

        if not self.selection:
            ## do not invert order next two lines due to '@observe selection'
            self.selection_dict[index] = self._viewer.picked['atom1']
            self.selection_dict[index]['time'] = millis
            self.selection = {index}
            
            return

        if index not in self.selection:
            ## do not invert order next two lines due to '@observe selection'
            self.selection_dict[index] = self._viewer.picked['atom1'] 
            self.selection_dict[index]['time'] = millis
            self.selection = self.selection.union({index})
            
            return

        ## first update selection_dict then update selection
        selection_tmp = self.selection.difference({index})
        all_keys = list(self.selection_dict.keys()) # must not be a pointer to self.selection_dict.keys()
        for i in all_keys: 
            if i not in selection_tmp:
                self.selection_dict.pop(i, None)
        self.selection = selection_tmp
                
        
    ## transfromation of indexes in case of multiple representations
      ##dictionaries for  back and forth transformations 
    def _gen_translation_indexes(self):
        self._translate_i_glob_loc = {}
        self._translate_i_loc_glob = {}
        for component in range(len(self.vis_dict.keys())):
            comp_i = 0
            ids = list(string_range_to_set(self.vis_dict[component]['ids'],shift=0)[0])
            for i_g in ids: 
                self._translate_i_glob_loc[i_g] = (component, comp_i)
                self._translate_i_loc_glob[(component, comp_i)] = i_g
                comp_i += 1
      ## from global index to indexes of different components
    def _translate_glob_loc(self, indexes):
        all_comp=[list() for i in range(len(self.vis_dict.keys()))]
        for i_g in indexes:            
            i_c, i_a = self._translate_i_glob_loc[i_g]
            all_comp[i_c].append(i_a)

        return all_comp       

    def highlight_atoms(self,
                        vis_list,
                        color=DEFAULT_SELECTION_COLOR,
                        size=DEFAULT_SELECTION_RADIUS,
                        opacity=DEFAULT_SELECTION_OPACITY):
        """Highlighting atoms according to the provided list."""
        
        if not hasattr(self._viewer, "component_0"):
            return

        if self.vis_dict is None :
                self._viewer._remove_representations_by_name(repr_name='selected_atoms')  # pylint:disable=protected-access
                self._viewer.add_ball_and_stick(  # pylint:disable=no-member
                name="selected_atoms",
                selection=list() if vis_list is None else vis_list,
                color=color,
                aspectRatio=size,
                opacity=opacity)
        else:
            
            ncomponents=len(self.vis_dict.keys())
            for component in range(ncomponents):
                name = 'highlight_'+self.vis_dict[component]['name']
                self._viewer._remove_representations_by_name(repr_name=name,component=component)
                color = self.vis_dict[component]['highlight_color']
                aspectRatio = self.vis_dict[component]['highlight_aspectRatio']
                opacity = self.vis_dict[component]['highlight_opacity']
                if vis_list is None:                   
                    self._viewer.add_ball_and_stick(name=name,
                                                    selection=list(),
                                                    color=color,
                                                    aspectRatio=aspectRatio,
                                                    opacity=opacity,
                                                    component=component
                                                   )                    
                else:
                    all_comp = self._translate_glob_loc(vis_list)
                    selection= all_comp[component]
                    self._viewer.add_ball_and_stick(name=name,
                                                    selection=selection,
                                                    color=color,
                                                    aspectRatio=aspectRatio,
                                                    opacity=opacity,
                                                    component=component
                                                   )                                                    

    def dihedral(self,xyz):
        vec = xyz[:-1] - xyz[1:]
        vec[0] *= -1
        v = np.array( [ v - (v.dot(vec[1])/vec[1].dot(vec[1])) * vec[1] for v in [vec[0], vec[2]] ] )
        v /= np.sqrt(np.einsum('...i,...i', v, v)).reshape(-1,1)
        vec1 = vec[1] / np.linalg.norm(vec[1])
        x = np.dot(v[0], v[1])
        m = np.cross(v[0], vec1)
        y = np.dot(m, v[1])
        return np.degrees(np.arctan2( y, x ))                    
                    
    def create_selection_info(self):
        if not self.selection:
            self.selection_info = ''
            return
        self.selection_info = ''
        xyz=[]
        if len(self.selection) == 1: ## report coordinates
            for i in self.selection:     
                self.selection_info += 'id: ' + str(i)
                self.selection_info += ' ' + self.selection_dict[i]['element']
                #values = list( map(mydictionary.get, keys) )
                x=list(map(self.selection_dict[i].get, ['x','y','z']))
                xyz.append(np.array( x ))
                self.selection_info += '  x: ' + str(round(xyz[0][0],2))
                self.selection_info += '  y: ' + str(round(xyz[0][1],2)) 
                self.selection_info += '  z: ' + str(round(xyz[0][2],2))
        elif len(self.selection) == 2: ## report coordinates distance and center
            count=0
            for i in self.selection:
                self.selection_info += 'id: ' + str(i)
                self.selection_info += ' ' + self.selection_dict[i]['element']
                x=list(map(self.selection_dict[i].get, ['x','y','z']))
                xyz.append(np.array( x ))
                self.selection_info += '  x: ' + str(round(xyz[count][0],2))
                self.selection_info += '  y: ' + str(round(xyz[count][1],2)) 
                self.selection_info += '  z: ' + str(round(xyz[count][2],2))
                self.selection_info += '\n'
                count+=1
            dist = np.linalg.norm(xyz[0]-xyz[1])  
            com = np.average(xyz,axis=0)
            self.selection_info += 'distance: ' +  str(round(dist,2))
            self.selection_info += ' com: ' + str(round(com[0],2))  + ' ' + str(round(com[1],2))  + ' '+ str(round(com[2],2))
        
        elif len(self.selection) == 3 : ## report  angle, normal, and center
            times_picked=[t['time'] for t in list(map(self.selection_dict.get, list(self.selection)))]
            indexes_ordered = np.array(times_picked).argsort()
            count=0
            self.selection_info += str(len(self.selection))+ ' atoms selected, '
            for i in self.selection:   
                x=list(map(self.selection_dict[i].get, ['x','y','z']))
                xyz.append(np.array( x ))
                count+=1  
            com = np.average(xyz,axis=0)
            self.selection_info += ' com: ' + str(round(com[0],2))  + ' ' + str(round(com[1],2))  + ' '+ str(round(com[2],2))
            v2 = xyz[indexes_ordered[2]] - xyz[indexes_ordered[1]]
            v1 =xyz[indexes_ordered[0]] - xyz[indexes_ordered[1]]

            cosine_angle = np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
            angle = np.arccos(cosine_angle)
            if np.linalg.norm(np.cross(v2,v1)) >0.0:
                normal = np.cross(v2,v1)/np.linalg.norm(np.cross(v2,v1))
            else:
                normal=np.array([0,0,0])
            self.selection_info += '\n' + ' angle: ' + str(round(np.degrees(angle),2))
            self.selection_info += '\n' + ' normal: (' + str(round(normal[0],2)) + ' ' 
            self.selection_info += str(round(normal[1],2)) + ' ' + str(round(normal[2],2)) +')'
            
        elif len(self.selection) == 4 : ## report  angle,  and center
            times_picked=[t['time'] for t in list(map(self.selection_dict.get, list(self.selection)))]
            indexes_ordered = np.array(times_picked).argsort()
            count=0
            self.selection_info += str(len(self.selection))+ ' atoms selected, '
            for i in self.selection:       
                x=list(map(self.selection_dict[i].get, ['x','y','z']))
                xyz.append(np.array( x ))
                count+=1  
            com = np.average(xyz,axis=0)
            self.selection_info += ' com: ' + str(round(com[0],2))  + ' ' + str(round(com[1],2))  + ' '+ str(round(com[2],2))
            xyz_ord = np.array([xyz[indexes_ordered[0]] , xyz[indexes_ordered[1]] , xyz[indexes_ordered[2]], xyz[indexes_ordered[3]]])
            self.selection_info += '\n' + '  dihedral: ' + ' ' + str(round(self.dihedral(xyz_ord),2))             
        else:
            count=0
            self.selection_info += str(len(self.selection))+ ' atoms selected, '
            for i in self.selection:                
                #self.selection_info += 'id: ' + str(i)
                #self.selection_info += ' ' + self.selection_dict[i]['element']
                x=list(map(self.selection_dict[i].get, ['x','y','z']))
                xyz.append(np.array( x ))
                 #self.selection_info += '  x: ' + str(round(xyz[count][0],2))
                 #self.selection_info += '  y: ' + str(round(xyz[count][1],2)) 
                 #self.selection_info += '  z: ' + str(round(xyz[count][2],2))
                 #self.selection_info += '\n'
                count+=1  
            com = np.average(xyz,axis=0)
            self.selection_info += ' com: ' + str(round(com[0],2))  + ' ' + str(round(com[1],2))  + ' '+ str(round(com[2],2))            


    @default('selection')
    def _default_selection(self):
        return set()

    @validate('selection')
    def _validate_selection(self, provided):
        return set(provided['value'])

    @observe('selection')
    def _observe_selection(self, _=None):
        self.highlight_atoms(self.selection)
        nselected=len(self.selection)
        self._selected_atoms.value = set_to_string_range(self.selection,shift=1)
        self.create_selection_info()
        with self.info:
            clear_output()
            print(self.selection_info)

    def apply_selection(self, _=None):
        """Apply selection specified in the text field."""
        selection_string = self._selected_atoms.value
        expanded_selection, syntax_ok = string_range_to_set(self._selected_atoms.value,shift=-1)
        self.wrong_syntax.layout.visibility = 'hidden' if syntax_ok else 'visible'
        self.selection = expanded_selection
        self._selected_atoms.value = selection_string  # Keep the old string for further editing.

    def download(self, change=None):  # pylint: disable=unused-argument
        """Prepare a structure for downloading."""
        self._download(payload=self._prepare_payload(), filename='structure.' + self.file_format.value)

    @staticmethod
    def _download(payload, filename):
        """Download payload as a file named as filename."""
        from IPython.display import Javascript
        javas = Javascript("""
            var link = document.createElement('a');
            link.href = "data:;base64,{payload}"
            link.download = "{filename}"
            document.body.appendChild(link);
            link.click();
            document.body.removeChild(link);
            """.format(payload=payload, filename=filename))
        display(javas)

    def _prepare_payload(self, file_format=None):
        """Prepare binary information."""
        from tempfile import NamedTemporaryFile
        file_format = file_format if file_format else self.file_format.value
        tmp = NamedTemporaryFile()
        self.structure.write(tmp.name, format=file_format)  # pylint: disable=no-member
        with open(tmp.name, 'rb') as raw:
            return base64.b64encode(raw.read()).decode()

    @property
    def thumbnail(self):
        return self._prepare_payload(file_format='png')




In [None]:
class StructureDataViewer(_StructureDataBaseViewer):
    """Viewer class for AiiDA structure objects.

    Attributes:
        structure (Atoms, StructureData, CifData): Trait that contains a structure object,
        which was initially provided to the viewer. It can be either directly set to an
        ASE Atoms object or to AiiDA structure object containing `get_ase()` method.

        displayed_structure (Atoms): Trait that contains a structure object that is
        currently displayed (super cell, for example). The trait is generated automatically
        and can't be set outside of the class.p
    """
    structure = Union([Instance(Atoms), Instance(Node)], allow_none=True)
    displayed_structure = Instance(Atoms, allow_none=True, read_only=True)

    def __init__(self, structure=None, vis_func=None, **kwargs):
        super().__init__(**kwargs)
        
        self.vis_func = vis_func
        self.structure = structure

        self.supercell_x.observe(self.repeat, names='value')
        self.supercell_y.observe(self.repeat, names='value')
        self.supercell_z.observe(self.repeat, names='value')

        
    def setup(self,c=None):
        
        
        # delete all old components
        while hasattr(self._viewer, "component_0"):
            self._viewer.component_0.clear_representations()
            cid = self._viewer.component_0.id
            self._viewer.remove_component(cid)
        
        if self.vis_func is None:
            return
                
        else:
            vis_dict = vis_func(self.displayed_structure)
            self.set_trait('vis_dict', vis_dict)
            
            for component in range(len(self.vis_dict)):
                
                rep_indexes=list(string_range_to_set(self.vis_dict[component]['ids'],shift=0)[0])
                
                mol = self.displayed_structure[rep_indexes]
                
                self._viewer.add_component(nglview.ASEStructure(mol), default_representation=False)
                aspectRatio = self.vis_dict[component]['aspectRatio']
                self._viewer.add_ball_and_stick(aspectRatio=aspectRatio, opacity=1.0,component=component)
        self._gen_translation_indexes()
        self._viewer.add_unitcell()
        self._viewer.center()        
        
        
    def repeat(self, _):
        if self.structure is not None:
            self.set_trait(
                'displayed_structure',
                self.structure.repeat([
                    self.supercell_x.value,
                    self.supercell_y.value,
                    self.supercell_z.value,
                ]))
    def orient_z_up(self, _=None):
        try:
            cell_z = self.displayed_structure.cell[2, 2]
            com = self.displayed_structure.get_center_of_mass()
            def_orientation = self._viewer._camera_orientation
            top_z_orientation = [1.0, 0.0, 0.0, 0,
                                 0.0, 1.0, 0.0, 0,
                                 0.0, 0.0, -np.max([cell_z, 30.0]) , 0,
                                 -com[0], -com[1], -com[2], 1]
            self._viewer._set_camera_orientation(top_z_orientation)  
        except:
            return
            #self._viewer.add_unitcell() does nothing...

    @validate('structure')
    def _valid_structure(self, change):  # pylint: disable=no-self-use
        """Update structure."""
        structure = change['value']

        if structure is None:
            return None  # if no structure provided, the rest of the code can be skipped

        if isinstance(structure, Atoms):
            return structure
        if isinstance(structure, Node):
            return structure.get_ase()
        raise ValueError("Unsupported type {}, structure must be one of the following types: "
                         "ASE Atoms object, AiiDA CifData or StructureData.")

    @observe('structure')
    def _update_displayed_structure(self, change):
        """Update displayed_structure trait after the structure trait has been modified."""
        # Remove the current structure(s) from the viewer.
        if change['new'] is not None:
            self.set_trait(
                'displayed_structure',
                change['new'].repeat([1, 1, 1]))
        else:
            self.set_trait('displayed_structure', None)
        self._viewer.add_unitcell()
        self._viewer.center() 
        self.orient_z_up()
        if(self.vis_func) : self.setup()


    @observe('displayed_structure')
    def _update_structure_viewer(self, change):
        """Update the view if displayed_structure trait was modified."""
        with self.hold_trait_notifications():
            for comp_id in self._viewer._ngl_component_ids:  # pylint: disable=protected-access
                self._viewer.remove_component(comp_id)
            self.selection = set()

            if change['new'] is not None:
                self._viewer.add_component(nglview.ASEStructure(change['new']))
                self._viewer.clear()
                self._viewer.add_ball_and_stick(aspectRatio=4)  # pylint: disable=no-member
                self._viewer.add_unitcell()# pylint: disable=no-member
                self._viewer.center()
                self.orient_z_up()
                if(self.vis_func) : self.setup()

In [None]:
from apps.surfaces.widgets.ANALYZE_structure import StructureAnalyzer
def vis_func(structure):
    an=StructureAnalyzer()
    an.set_trait('structure',structure)
    details=an.details
    if details['system_type']=='SlabXY':
        all_mol = [item for sublist in details['all_molecules'] for item in sublist]
        the_rest = list(set(range(len(structure)))-set(all_mol))
        vis_dict={
            0 : {
                'ids' : set_to_string_range(all_mol,shift=0),
                'aspectRatio' : 3 , 
                'highlight_aspectRatio' : 3.3 , 
                'highlight_color' : 'red',
                'highlight_opacity' : 1,
                'name' : 'molecule'
            }, 
            1 : {
                'ids' : set_to_string_range(the_rest,shift=0),
                'aspectRatio' : 10 , 
                'highlight_aspectRatio' : 10.3 , 
                'highlight_color' : 'green',
                'highlight_opacity' :1,
                'name' : 'substrate'
            },            
        }
    return vis_dict

In [None]:
a=load_node(8381)

In [None]:
structure = a.get_ase()

In [None]:
v=StructureDataViewer(structure,vis_func)

In [None]:
display(v)