## Refactoring: Pythonic way

Original application:


In [13]:
import pandas as pd
import ipywidgets as widgets
import plotly.express as px
from plotly import graph_objects as go

from nomad_lab_visualizer import Visualizer


In [14]:
df = pd.read_pickle("./data/query_archive/df")
df


Unnamed: 0_level_0,x_emb,y_emb,Atomic_number_A,Atomic_number_B,Space_group_number,Atomic_density,Cluster_label,Structure,File,Replicas
Formula,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
Ac2Ag4O8,39.728947,30.926718,47.0,89.0,227.0,0.062730,0,./data/query_archive/structures/Ac2Ag4O8,[Ac2Ag4O8_6819.xyz],1
Ac2As4O8,46.165726,27.248566,33.0,89.0,227.0,0.062067,-1,./data/query_archive/structures/Ac2As4O8,[Ac2As4O8_319.xyz],1
Ac2B4O8,58.595741,24.383308,5.0,89.0,227.0,0.092081,0,./data/query_archive/structures/Ac2B4O8,[Ac2B4O8_6978.xyz],1
Ac2Be4O8,60.653465,24.399380,4.0,89.0,227.0,0.090679,0,./data/query_archive/structures/Ac2Be4O8,[Ac2Be4O8_121.xyz],1
Ac2Ca4O8,62.921474,23.055676,20.0,89.0,227.0,0.058643,0,./data/query_archive/structures/Ac2Ca4O8,[Ac2Ca4O8_4736.xyz],1
...,...,...,...,...,...,...,...,...,...,...
Zn2Cr4O8,5.035081,26.554813,24.0,30.0,227.0,0.091459,-1,./data/query_archive/structures/Zn2Cr4O8,[Zn2Cr4O8_7411.xyz],1
Zn2Fe4O8,5.321269,26.232611,26.0,30.0,227.0,0.089538,-1,./data/query_archive/structures/Zn2Fe4O8,[Zn2Fe4O8_5784.xyz],1
Zn2Ir4O8,-2.570099,-27.888716,30.0,77.0,227.0,0.083793,-1,./data/query_archive/structures/Zn2Ir4O8,[Zn2Ir4O8_1811.xyz],1
ZnCuO3,0.301460,11.128940,29.0,30.0,221.0,0.095951,1,./data/query_archive/structures/ZnCuO3,[ZnCuO3_2842.xyz],1


In [15]:
visualizer = Visualizer(
    df,
    embedding_features=["x_emb", "y_emb", "Atomic_number_A", "Atomic_number_B"],
    hover_features=[
        "Atomic_number_A",
        "Atomic_number_B",
        "Space_group_number",
        "Atomic_density",
        "Cluster_label",
        "Replicas",
    ],
    target="Cluster_label",
    path_to_structures=True,
    # smart_fract=Fals?Pe,
    # convex_hull=False,
    #     regr_line_coefs=[0., 1.]
)

visualizer.show()
# 

VBox(children=(VBox(children=(HBox(children=(VBox(children=(Dropdown(description='x-axis', layout=Layout(width…

Strting to refactor:


In [16]:
from ipywidgets import GridspecLayout


class Settings(GridspecLayout):
    def __init__(
        self,
        embedding_features,
        hover_features,
        feature_x,
        feature_y,
        fracture,
        **kwargs
    ):
        super().__init__(3, 3)

        widget_feature_x = widgets.Dropdown(
            description="x-axis",
            options=embedding_features,
            value=feature_x,
            layout=widgets.Layout(width="250px"),
        )

        widget_feature_y = widgets.Dropdown(
            description="y-axis",
            options=embedding_features,
            value=feature_y,
            layout=widgets.Layout(width="250px"),
        )

        widget_fracture = widgets.BoundedFloatText(
            min=0,
            max=1.0,
            # step=0.01,
            value=fracture,
            layout=widgets.Layout(left="98px", width="60px"),
        )

        widget_facture_label = widgets.Label(
            value="Fraction: ", layout=widgets.Layout(left="95px")
        )

        widget_feature_color = widgets.Dropdown(
            description="Color",
            options=["Default color"] + hover_features,
            value="Default color",
            layout=widgets.Layout(width="250px"),
        )

        widget_feature_color_type = widgets.RadioButtons(
            options=["Gradient", "Discrete"],
            value="Gradient",
            layout=widgets.Layout(width="140px", left="90px"),
        )

        widget_feature_color_list = widgets.Dropdown(
            options=px.colors.named_colorscales(),
            value="viridis",
            layout=widgets.Layout(width="65px", height="35px", left="40px"),
        )

        widget_feature_marker = widgets.Dropdown(
            description="Marker",
            options=["Default size"] + hover_features,
            value="Default size",
            layout=widgets.Layout(width="250px"),
        )
        widget_feature_marker_minvalue = widgets.BoundedFloatText(
            min=0,
            # max=self.max_value_markerfeat,
            step=1,
            # value=self.min_value_markerfeat,
            layout=widgets.Layout(left="91px", width="60px", height="10px"),
        )
        widget_feature_marker_minvalue_label = widgets.Label(
            value="Min value: ", layout=widgets.Layout(left="94px", width="70px")
        )
        widget_feature_marker_maxvalue = widgets.BoundedFloatText(
            # min=self.min_value_markerfeat,
            step=1,
            # value=self.max_value_markerfeat,
            layout=widgets.Layout(left="91px", width="60px"),
        )
        widget_feature_marker_maxvalue_label = widgets.Label(
            value="Max value: ", layout=widgets.Layout(left="94px", width="70px")
        )

        self[0, 0] = widget_feature_x
        self[1, 0] = widget_feature_y
        self[2, 0] = widgets.Box([widget_facture_label, widget_fracture])

        self[0, 1] = widget_feature_color
        self[1:, 1] = widgets.HBox(
            [widget_feature_color_type, widget_feature_color_list],
            layout=widgets.Layout(top="10px"),
        )

        self[0, 2] = widget_feature_marker
        self[1, 2] = widgets.Box(
            [
                widget_feature_marker_minvalue_label,
                widget_feature_marker_minvalue,
            ]
        )

        self[2, 2] = widgets.Box(
            [
                widget_feature_marker_maxvalue_label,
                widget_feature_marker_maxvalue,
            ]
        )

        self.layout.height = "140px"
        # self.layout.top = "30px"


In [17]:
class SettingsList(widgets.Box):
    def __init__(
        self,
        embedding_features,
        hover_features,
        feature_x,
        feature_y,
        fracture,
        **kwargs
    ):
        widget_feature_x = widgets.Dropdown(
            description="x-axis",
            options=embedding_features,
            value=feature_x,
            layout=widgets.Layout(width="250px"),
        )

        widget_feature_y = widgets.Dropdown(
            description="y-axis",
            options=embedding_features,
            value=feature_y,
            layout=widgets.Layout(width="250px"),
        )

        widget_fracture = widgets.BoundedFloatText(
            min=0,
            max=1.0,
            # step=0.01,
            value=fracture,
            layout=widgets.Layout(left="98px", width="60px"),
        )

        widget_feature_color = widgets.Dropdown(
            description="Color",
            options=["Default color"] + hover_features,
            value="Default color",
            layout=widgets.Layout(width="250px"),
        )

        widget_feature_color_type = widgets.RadioButtons(
            options=["Gradient", "Discrete"],
            value="Gradient",
            layout=widgets.Layout(width="140px", left="90px"),
        )

        widget_feature_color_list = widgets.Dropdown(
            options=px.colors.named_colorscales(),
            value="viridis",
            layout=widgets.Layout(width="65px", height="35px", left="40px"),
        )

        widget_feature_marker = widgets.Dropdown(
            description="Marker",
            options=["Default size"] + hover_features,
            value="Default size",
            layout=widgets.Layout(width="250px"),
        )
        widget_feature_marker_minvalue = widgets.BoundedFloatText(
            min=0,
            # max=self.max_value_markerfeat,
            step=1,
            # value=self.min_value_markerfeat,
            layout=widgets.Layout(left="91px", width="60px", height="10px"),
        )
        widget_feature_marker_minvalue_label = widgets.Label(
            value="Min value: ", layout=widgets.Layout(left="94px", width="70px")
        )
        widget_feature_marker_maxvalue = widgets.BoundedFloatText(
            # min=self.min_value_markerfeat,
            step=1,
            # value=self.max_value_markerfeat,
            layout=widgets.Layout(left="91px", width="60px"),
        )
        widget_feature_marker_maxvalue_label = widgets.Label(
            value="Max value: ", layout=widgets.Layout(left="94px", width="70px")
        )
        super().__init__(
            children=[
                widget_feature_x,
                widget_feature_y,
                widgets.Box(
                    [
                        widgets.Label(
                            value="Fraction: ", layout=widgets.Layout(left="95px")
                        ),
                        widget_fracture,
                    ]
                ),
                widget_feature_color,
                widgets.HBox(
                    [widget_feature_color_type, widget_feature_color_list],
                    layout=widgets.Layout(top="10px"),
                ),
                widget_feature_marker,
                widgets.Box(
                    [
                        widget_feature_marker_minvalue_label,
                        widget_feature_marker_minvalue,
                    ]
                ),
                widgets.Box(
                    [
                        widget_feature_marker_maxvalue_label,
                        widget_feature_marker_maxvalue,
                    ]
                ),
            ],
            layout=widgets.Layout(
                display="flex",
                flex_flow="column",
                border="solid 2px",
                align_items="stretch",
                # width="50%",
            ),
        )
        # self.layout.height = "140px"
        # self.layout.top = "30px"


embedding_features = ["A", "B", "C"]
hover_features = ["AA", "BB", "CC"]
feature_x = "A"
feature_y = "B"
fracture = 1.0

settings = SettingsList(
    embedding_features, hover_features, feature_x, feature_y, fracture
)
settings


SettingsList(children=(Dropdown(description='x-axis', layout=Layout(width='250px'), options=('A', 'B', 'C'), v…

In [18]:
# extract features

target = 'F'
feature_x = 'A'
feature_y = 'B'

labels = df[target].unique().tolist()

x = []
y = []
for label in labels:
    mask = df['F']==label
    x.append(df[feature_x][mask].to_numpy())
    y.append(df[feature_y][mask].to_numpy())


KeyError: 'F'

In [19]:
# TODO dict for   
# TODO: default: points = np.column_stack((x, y))

from scipy.spatial import ConvexHull, convex_hull_plot_2d

class Figure(go.FigureWidget):
    def __init__(self, x, y, labels, layout=None, **kwargs):

        self._x = x
        self._y = y
        self._labels = labels

        self._regression_trace = None
        self._complex_hull_traces = None

        super().__init__(None, layout, **kwargs)

        # All permanent layout settings are defined here
        self.update_layout(
            hoverlabel=dict(bgcolor="white", font_size=16, font_family="Rockwell"),
            width=800,
            height=400,
            margin=dict(l=50, r=50, b=70, t=20, pad=4),
        )
        self.update_xaxes(
            ticks="outside", tickwidth=1, ticklen=10, linewidth=1, linecolor="black"
        )
        self.update_yaxes(
            ticks="outside", tickwidth=1, ticklen=10, linewidth=1, linecolor="black"
        )

        for (x, y, label) in zip(x, y, labels):
            self.add_trace(go.Scatter(x=x, y=y, name=label, mode="markers"))

    def add_regression_line(self, p1, p2):
        """
        Note: solution of the intersection of the line and the boundary box
        Arguments:
        - p1: point in 2d
        - p2: point in 2d
        """

        self._regression_trace = go.Scatter(x = [p1[0], p2[0]],y = [p1[1], p2[1]], name="Line", mode="lines")
        self.add_trace(self._regression_trace)

    def add_complex_hull(self):
        
        for (x, y, label) in zip(self._x, self._y, self._labels):
            if len(x) < 3: continue

            points = np.column_stack((x, y))
            hull = ConvexHull(points)
            
            inds = np.append(hull.vertices, hull.vertices[0])
            # TODO: use the same color as the datapoints
            self.add_trace(go.Scatter(x=x[inds], y=y[inds], name=f'{label} (hull)', mode="lines"))

            # for simplex in hull.simplices:
            #     self.add_trace(go.Scatter(x=points[simplex, 0], y=points[simplex, 1]))


fig = Figure(x, y, labels)
fig.add_complex_hull()

fig.add_regression_line([-2,-2], [2,3])
fig


NameError: name 'x' is not defined

## Example for collapsable settings

In [6]:
b1 = widgets.Button(description="|||", layout=widgets.Layout(width="40px"))

# create some control elements
int_slider = widgets.IntSlider(value=1, min=0, max=10, step=1, description="freq")
text_xlabel = widgets.Text(value="", description="xlabel", continuous_update=False)
text_ylabel = widgets.Text(value="", description="ylabel", continuous_update=False)

text_xlabel.value = "x"
text_ylabel.value = "y"

controls = widgets.VBox([int_slider, text_xlabel, text_ylabel])

def b1_clicked(_):
    if settings.layout.display == "block" or settings.layout.display is None:
       settings.layout.display = "none"
    else:
       settings.layout.display = "block"


b1.on_click(b1_clicked)

# display(widgets.HBox([widgets.VBox([b1, controls]), fig]))
display(widgets.HBox([widgets.VBox([b1, settings]), fig]))


NameError: name 'widgets' is not defined

In [21]:
# create some control elements
int_slider = widgets.IntSlider(value=1, min=0, max=10, step=1, description='freq')
text_xlabel = widgets.Text(value='', description='xlabel', continuous_update=False)
text_ylabel = widgets.Text(value='', description='ylabel', continuous_update=False)
 
text_xlabel.value = 'x'
text_ylabel.value = 'y'

controls = widgets.VBox([int_slider, text_xlabel, text_ylabel])

left = widgets.Accordion([controls])
left.set_title(0, '|||')


app = widgets.HBox([
   left ,
   fig,
])

app

HBox(children=(Accordion(children=(VBox(children=(IntSlider(value=1, description='freq', max=10), Text(value='…

In [1]:
class MyView(widgets.Box):
    def __init__(self):
        
        children = []
        layout = widgets.Layout(
            display="flex",
            flex_flow="column",
            border="solid 2px",
            align_items="stretch",
            # width="50%",
        )

        super().__init__(children, layout=layout)


NameError: name 'widgets' is not defined

In [23]:
w = MyView()
w

MyView(layout=Layout(align_items='stretch', border='solid 2px', display='flex', flex_flow='column'))

In [25]:
class FigureWidget:
    pass


class AtomisticViewer:
    pass


class SettingsWidget:
    pass


class Visualizer:
    def __init__(self, *args, **kwargs):

        self.figure = FigureWidget()
        self.viewer = AtomisticViewer()
        self.settings = SettingsWidget()

        # TODO: link events together between different widgets

def make(
    data: pd.DataFrame,
    embedding_features: list[str],
    hover_features: list[str],
    targets: list[str],
    smart_frac: float,
    convex_hull: bool,
    regr_line_coefs: list[float],
    path_to_structures: list[str],
):
    """
    df: pandas dataframe containing all data to be visualized
    embedding_features: list of features used for embedding
    hover_features: list of features shown while hovering
    target: feature used to create traces (same target value - same trace)
    smart_frac: fraction of points is selected to maximize visualization of data distribution
    convex_hull: convex hull is drawn around each trace
    regr_line_coefs: coeffs of a regression line
    path_to_structures: path to a directory that contains all 'xyz' structures to be visualized
    """
    pass


In [26]:
class AtomisticViewerWidget(widgets.HBox):
    pass


In [27]:
structures_list = ["h2o", "co2"]


widget_structure = widgets.Combobox(
    placeholder="",
    description="Structure:",
    options=structures_list,
    # layout=widgets.Layout(width="200px"),
)

widget_perv_button = widgets.Button(
    description="<", layout=widgets.Layout(width="50px")
)
widget_next_button = widgets.Button(
    description=">", layout=widgets.Layout(width="50px")
)

widget_label = widgets.Label('1/6', layout=widgets.Layout(width="50px", display="flex", justify_content="center"))

output = widgets.Output(layout = widgets.Layout(width="400px", height="350px"))


widgets.VBox(
    [
        widgets.HBox(
            [
                widget_structure,
                widget_perv_button,
                widget_label,
                widget_next_button
            ]
        ),
        output,
    ]
)


VBox(children=(HBox(children=(Combobox(value='', description='Structure:', options=('h2o', 'co2'), placeholder…

In [28]:
import ipywidgets as widgets
import py3Dmol


class py3DmolWidget(widgets.Output):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.layout = widgets.Layout(width="400px", height="350px")
        # self._viewer = py3Dmol.view()
        self._viewer = py3Dmol.view(width="auto", height=400)
        # with self:
            # self._viewer.show()
            # self._viewer.resize()

    def load_structure(self, xyz: str):
        self._viewer.clear()
        # self._viewer.removeAllModels()
        self._viewer.addModel(xyz, "xyz")
        self._viewer.zoomTo()
        self._viewer.setStyle(
            {
                "stick": {"colorscheme": "Jmol"},
                "sphere": {"radius": 0.5, "colorscheme": "Jmol"},
            }
        )
        self._viewer.setBackgroundColor("white")
        self._viewer.setProjection("orthographic")

        with self:
            self._viewer.update()


In [29]:
w = py3DmolWidget()


In [30]:
w

py3DmolWidget(layout=Layout(height='350px', width='400px'))

In [31]:
w1 = py3DmolWidget()
w1

In [None]:
filename = 'data/query_archive/structures/Ac2Ag4O8/Ac2Ag4O8_6819.xyz'
# filename = "data/query_archive/structures/BaO3Y/BaO3Y_1248.xyz"
with open(filename, 'r') as file:
        xyz = file.read()

w.load_structure(xyz)

In [None]:
filenames = [
    "data/query_archive/structures/BaO3Y/BaO3Y_1248.xyz",
    "data/query_archive/structures/BaO3Y/BaO3Y_184.xyz",
    "data/query_archive/structures/BaO3Y/BaO3Y_194.xyz",
    "data/query_archive/structures/BaO3Y/BaO3Y_226.xyz",
    "data/query_archive/structures/BaO3Y/BaO3Y_895.xyz",
]
load_structure(filenames)

In [None]:
n = 5
c = itertools.cycle(range(n))


In [None]:
(c)

In [None]:
for (i,_) in zip(c, range(10)):
    print(f'{i+1}/{n}')