## Refactoring: Pythonic way

Original application:


In [1]:
import pandas as pd
from nomad_lab_visualizer import Visualizer


In [2]:
df = pd.read_pickle("./data/query_archive/df")
df


Unnamed: 0_level_0,x_emb,y_emb,Atomic_number_A,Atomic_number_B,Space_group_number,Atomic_density,Cluster_label,Structure,File,Replicas
Formula,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
Ac2Ag4O8,39.728947,30.926718,47.0,89.0,227.0,0.062730,0,./data/query_archive/structures/Ac2Ag4O8,[Ac2Ag4O8_6819.xyz],1
Ac2As4O8,46.165726,27.248566,33.0,89.0,227.0,0.062067,-1,./data/query_archive/structures/Ac2As4O8,[Ac2As4O8_319.xyz],1
Ac2B4O8,58.595741,24.383308,5.0,89.0,227.0,0.092081,0,./data/query_archive/structures/Ac2B4O8,[Ac2B4O8_6978.xyz],1
Ac2Be4O8,60.653465,24.399380,4.0,89.0,227.0,0.090679,0,./data/query_archive/structures/Ac2Be4O8,[Ac2Be4O8_121.xyz],1
Ac2Ca4O8,62.921474,23.055676,20.0,89.0,227.0,0.058643,0,./data/query_archive/structures/Ac2Ca4O8,[Ac2Ca4O8_4736.xyz],1
...,...,...,...,...,...,...,...,...,...,...
Zn2Cr4O8,5.035081,26.554813,24.0,30.0,227.0,0.091459,-1,./data/query_archive/structures/Zn2Cr4O8,[Zn2Cr4O8_7411.xyz],1
Zn2Fe4O8,5.321269,26.232611,26.0,30.0,227.0,0.089538,-1,./data/query_archive/structures/Zn2Fe4O8,[Zn2Fe4O8_5784.xyz],1
Zn2Ir4O8,-2.570099,-27.888716,30.0,77.0,227.0,0.083793,-1,./data/query_archive/structures/Zn2Ir4O8,[Zn2Ir4O8_1811.xyz],1
ZnCuO3,0.301460,11.128940,29.0,30.0,221.0,0.095951,1,./data/query_archive/structures/ZnCuO3,[ZnCuO3_2842.xyz],1


In [3]:
visualizer = Visualizer(
    df,
    embedding_features=["x_emb", "y_emb", "Atomic_number_A", "Atomic_number_B"],
    hover_features=[
        "Atomic_number_A",
        "Atomic_number_B",
        "Space_group_number",
        "Atomic_density",
        "Cluster_label",
        "Replicas",
    ],
    target="Cluster_label",
    path_to_structures=True,
    smart_fract=False,
    convex_hull=False,
    #     regr_line_coefs=[0., 1.]
)

visualizer.show()


VBox(children=(VBox(children=(HBox(children=(VBox(children=(Dropdown(description='x-axis', layout=Layout(width…

Strting to refactor:


In [4]:
import os

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

import ipywidgets as widgets
import plotly.graph_objects as go
import plotly.express as px

from IPython.display import display, Markdown, FileLink

from itertools import cycle

from nomad_lab_visualizer.updates import (
    marker_style_updates,
    fract_change_updates,
    update_hover_variables,
)

from nomad_lab_visualizer.view_structure import view_structure_r, view_structure_l
from nomad_lab_visualizer.smart_fract import smart_fract_make
from nomad_lab_visualizer.instantiate_widgets import instantiate_widgets
from nomad_lab_visualizer.batch_update import batch_update


# TODO:
# - [ ] gui optino for target
# - [ ] switch on and off the convex hull
# - [ ] smart_fract as an preprocess
# - [ ] moving options out
#
# python api is same as: http://3dmol.csb.pitt.edu/doc/$3Dmol.GLViewer.html#toc0


In [61]:
# Generate data

N = 10

df = pd.DataFrame(np.random.randn(N, 4), columns=list("ABCD"))
df["E"] = pd.Series(np.random.randint(0, 10, size=N), dtype="category")
df["F"] = pd.Series(np.random.randint(-1, 2, size=N), dtype="category")
df
# df.dtypes


Unnamed: 0,A,B,C,D,E,F
0,-0.287319,-0.99654,-1.203591,-0.337469,0,-1
1,-2.540634,-0.334815,0.150543,1.47318,7,1
2,0.821525,-0.579491,-1.925279,-1.296644,1,-1
3,0.800383,-0.893136,-0.641342,-0.197291,4,-1
4,-1.336998,0.285731,0.575695,-0.201369,7,0
5,-0.330901,0.718104,0.562255,0.953785,4,0
6,-1.03738,-0.652701,2.280855,1.692344,0,-1
7,-0.208266,-0.783505,1.674454,0.162349,5,1
8,0.338125,-1.026481,0.167935,1.112484,3,0
9,-0.820795,0.115481,0.244547,0.809821,4,-1


In [69]:
{k:v for k, v in df.groupby('F')}


{-1:           A         B         C         D  E  F
 0 -0.287319 -0.996540 -1.203591 -0.337469  0 -1
 2  0.821525 -0.579491 -1.925279 -1.296644  1 -1
 3  0.800383 -0.893136 -0.641342 -0.197291  4 -1
 6 -1.037380 -0.652701  2.280855  1.692344  0 -1
 9 -0.820795  0.115481  0.244547  0.809821  4 -1,
 0:           A         B         C         D  E  F
 4 -1.336998  0.285731  0.575695 -0.201369  7  0
 5 -0.330901  0.718104  0.562255  0.953785  4  0
 8  0.338125 -1.026481  0.167935  1.112484  3  0,
 1:           A         B         C         D  E  F
 1 -2.540634 -0.334815  0.150543  1.473180  7  1
 7 -0.208266 -0.783505  1.674454  0.162349  5  1}

In [70]:
df 

Unnamed: 0,A,B,C,D,E,F
0,-0.287319,-0.99654,-1.203591,-0.337469,0,-1
1,-2.540634,-0.334815,0.150543,1.47318,7,1
2,0.821525,-0.579491,-1.925279,-1.296644,1,-1
3,0.800383,-0.893136,-0.641342,-0.197291,4,-1
4,-1.336998,0.285731,0.575695,-0.201369,7,0
5,-0.330901,0.718104,0.562255,0.953785,4,0
6,-1.03738,-0.652701,2.280855,1.692344,0,-1
7,-0.208266,-0.783505,1.674454,0.162349,5,1
8,0.338125,-1.026481,0.167935,1.112484,3,0
9,-0.820795,0.115481,0.244547,0.809821,4,-1


In [6]:
class Viewer(widgets.DOMWidget):
    pass


In [7]:
#
# df
# new_df = resample(df, .5)
# viewer = Visualiser(new_df)
#


In [8]:
fig = px.scatter(x=[0, 1, 2, 3, 4], y=[0, 1, 4, 9, 16])
fig.show()


In [9]:
fig.data[0], fig.data[0]["x"]


(Scatter({
     'hovertemplate': 'x=%{x}<br>y=%{y}<extra></extra>',
     'legendgroup': '',
     'marker': {'color': '#636efa', 'symbol': 'circle'},
     'mode': 'markers',
     'name': '',
     'orientation': 'v',
     'showlegend': False,
     'x': array([0, 1, 2, 3, 4]),
     'xaxis': 'x',
     'y': array([ 0,  1,  4,  9, 16]),
     'yaxis': 'y'
 }),
 array([0, 1, 2, 3, 4]))

In [13]:
def resample(data):
    return data


def Visualize(
    data: pd.DataFrame,
    embedding_features: list[str],
    hover_features: list[str],
    target: list[str],
    smart_frac: float,
    convex_hull: bool,
    regr_line_,
    path_to_structures: list[str],
):
    """
    df: pandas dataframe containing all data to be visualized
    embedding_features: list of features used for embedding
    hover_features: list of features shown while hovering
    target: feature used to create traces (same target value - same trace)
    smart_frac: fraction of points is selected to maximize visualization of data distribution
    path_to_structures: path to a directory that contains all 'xyz' structures to be visualized
    """
    pass

    def add_comvex_hull(self):
        """
        convex hull is drawn around each trace
        """
        pass

    def add_regression_line(self, coefs: list[float]):
        """
        coefs: coeffs of a regression line
        """
        pass


# ????df = resample(df, fraction, target)

# visualiser = Visualize(df)
# visualiser.add_comvex_hull()
# visualiser.add_regression_line()


In [14]:
class FigureWidget:
    pass


class AtomisticViewer:
    pass


class SettingsWidget:
    pass


class Visualizer:
    def __init__(self, *args, **kwargs):

        self.figure = FigureWidget()
        self.viewer = AtomisticViewer()
        self.settings = SettingsWidget()

        # TODO: link events together between different widgets


In [16]:
def make(
    data: pd.DataFrame,
    embedding_features: list[str],
    hover_features: list[str],
    target: list[str],
    smart_frac: float,
    convex_hull: bool,
    regr_line_coefs: list[float],
    path_to_structures: list[str],
):
    """
    df: pandas dataframe containing all data to be visualized
    embedding_features: list of features used for embedding
    hover_features: list of features shown while hovering
    target: feature used to create traces (same target value - same trace)
    smart_frac: fraction of points is selected to maximize visualization of data distribution
    convex_hull: convex hull is drawn around each trace
    regr_line_coefs: coeffs of a regression line
    path_to_structures: path to a directory that contains all 'xyz' structures to be visualized
    """
    pass


In [17]:
# #  properties = Properties()
# # Visualiser()
#
# default_style = {}
#
#
# class WidgetLayout:
#     pass
#
#
# class NewVisualizer(WidgetLayout):
#     def __init__(
#         self,
#         data: pd.,
#         *args, **kwargs) -> None:
#         pass
#
#
# widget = NewVisualizer(
#     data = np.zeros(3,10),
#     columns = []
#
# )
#


In [None]:
# constants

# list of possible marker symbols
symbols_list = [
    "circle",
    "circle-open",
    "circle-dot",
    "circle-open-dot",
    "circle-cross",
    "circle-x",
    "square",
    "square-open",
    "square-dot",
    "square-open-dot",
    "square-cross",
    "square-x",
    "diamond",
    "diamond-open",
    "diamond-dot",
    "diamond-open-dot",
    "diamond-cross",
    "diamond-x",
    "triangle-up",
    "triangle-up-open",
    "triangle-up-dot",
    "triangle-up-open-dot",
    "triangle-down",
    "triangle-down-open",
    "triangle-down-dot",
    "triangle-down-open-dot",
]
# list of possible colors of the hulls
color_hull = [
    "Black",
    "Blue",
    "Cyan",
    "Green",
    "Grey",
    "Orange",
    "Red",
    "Yellow",
]
# list of possible colors of the regression line
color_line = [
    "Black",
    "Blue",
    "Cyan",
    "Green",
    "Grey",
    "Orange",
    "Red",
    "Yellow",
]
# list of possible dash types for the regression line
line_dashs = ["dash", "solid", "dot", "longdash", "dashdot", "longdashdot"]
# list of possible dash types for the hulls
hull_dashs = ["dash", "solid", "dot", "longdash", "dashdot", "longdashdot"]
# list of possible font families
font_families = [
    "Arial",
    "Courier New",
    "Helvetica",
    "Open Sans",
    "Times New Roman",
    "Verdana",
]
# list of possible font colors
font_color = [
    "Black",
    "Blue",
    "Cyan",
    "Green",
    "Grey",
    "Orange",
    "Red",
    "Yellow",
]
# list of possible discrete palette colors
discrete_palette_colors = [
    "Plotly",
    "D3",
    "G10",
    "T10",
    "Alphabet",
    "Dark24",
    "Light24",
    "Set1",
    "Pastel1",
    "Dark2",
    "Set2",
    "Pastel2",
    "Set3",
    "Antique",
    "Bold",
    "Pastel",
    "Prism",
    "Safe",
    "Vivid",
]
# list of possible continuous gradient colors
continuous_gradient_colors = px.colors.named_colorscales()


In [18]:

class Config:
    """ all values below are initialized to a specific value that can be modified using widgets
    """
    bg_color = "rgba(229,236,246, 0.5)"  # default value of the background color
    marker_size = 7  # size of all markers
    cross_size = 15  # size of the crosses
    min_value_markerfeat =  4  # min value of markers size if sizes represent a certain feature value
    max_value_markerfeat = 20  # max value of markers size if sizes represent a certain feature value
    font_size = 12  # size of fonts
    hull_width = 1  # width of the  the convex hull
    line_width = 1  # width of the regression line
    hull_dash = "solid"  # dash of the convex hull
    line_dash = "dash"  # dash of the regression line
    hull_color = "Grey"  # color of the convex hull
    line_color = "Black"  # color of the regression line
    

In [27]:
# embedding_features = ['A', 'B', 'C']
# hover_features = ['AA', 'BB', 'CC']
# feat_x = 'A'
# feat_y = 'B'
# fracture = 1.0
# 
# widget_feature_x = widgets.Dropdown(
#     description="x-axis",
#     options=embedding_features,
#     value=feat_x,
#     layout=widgets.Layout(width="250px"),
# )
# 
# widget_feature_y = widgets.Dropdown(
#     description="y-axis",
#     options=embedding_features,
#     value=feat_y,
#     layout=widgets.Layout(width="250px"),
# )
# 
# widget_fracture = widgets.BoundedFloatText(
#     min=0,
#     max=1.,
#     # step=0.01,
#     value=fracture,
#     layout=widgets.Layout(left="98px", width="60px"),
# )
# 
# widget_facture_label = widgets.Label(
#     value="Fraction: ", layout=widgets.Layout(left="95px")
# )
# 
# widget_feature_color = widgets.Dropdown(
#     description="Color",
#     options=["Default color"] + hover_features,
#     value="Default color",
#     layout=widgets.Layout(width="250px"),
# )
# 
# widget_feature_color_type = widgets.RadioButtons(
#     options=["Gradient", "Discrete"],
#     value="Gradient",
#     layout=widgets.Layout(width="140px", left="90px"),
# )
# 
# widget_feature_color_list = widgets.Dropdown(
#     options=px.colors.named_colorscales(),
#     value="viridis",
#     layout=widgets.Layout(width="65px", height="35px", left="40px"),
# )
# 
# widget_feature_marker = widgets.Dropdown(
#     description="Marker",
#     options=["Default size"] + hover_features,
#     value="Default size",
#     layout=widgets.Layout(width="250px"),
# )
# widget_feature_marker_minvalue = widgets.BoundedFloatText(
#     min=0,
#     # max=self.max_value_markerfeat,
#     step=1,
#     # value=self.min_value_markerfeat,
#     layout=widgets.Layout(left="91px", width="60px", height="10px"),
# )
# widget_feature_marker_minvalue_label = widgets.Label(
#     value="Min value: ", layout=widgets.Layout(left="94px", width="70px")
# )
# widget_feature_marker_maxvalue = widgets.BoundedFloatText(
#     # min=self.min_value_markerfeat,
#     step=1,
#     # value=self.max_value_markerfeat,
#     layout=widgets.Layout(left="91px", width="60px"),
# )
# widget_feature_marker_maxvalue_label = widgets.Label(
#     value="Max value: ", layout=widgets.Layout(left="94px", width="70px")
# )
# 
# 
# box_feat = widgets.VBox(
#     [
#         widgets.HBox(
#             [
#                 widgets.VBox(
#                     [
#                         widget_feature_x,
#                         widget_feature_y,
#                         widgets.HBox([widget_facture_label, widget_fracture]),
#                     ]
#                 ),
#                 widgets.VBox(
#                     [
#                         widget_feature_color,
#                         widgets.HBox(
#                             [widget_feature_color_type, widget_feature_color_list],
#                             layout=widgets.Layout(top="10px"),
#                         ),
#                     ]
#                 ),
#                 widgets.VBox(
#                     [
#                         widget_feature_marker,
#                         widgets.VBox(
#                             [
#                                 widgets.HBox(
#                                     [
#                                         widget_feature_marker_minvalue_label,
#                                         widget_feature_marker_minvalue,
#                                     ],
#                                 ),
#                                 widgets.HBox(
#                                     [
#                                         widget_feature_marker_maxvalue_label,
#                                         widget_feature_marker_maxvalue,
#                                     ],
#                                 ),
#                             ]
#                         ),
#                     ]
#                 ),
#             ]
#         ),
#     ]
# )
# 
# box_feat


VBox(children=(HBox(children=(VBox(children=(Dropdown(description='x-axis', layout=Layout(width='250px'), opti…

In [52]:
from ipywidgets import GridspecLayout


class Settings(GridspecLayout):
    def __init__(
        self, embedding_features, hover_features, feature_x, feature_y, fracture, **kwargs
    ):
        super().__init__(3, 3)

        widget_feature_x = widgets.Dropdown(
            description="x-axis",
            options=embedding_features,
            value=feature_x,
            layout=widgets.Layout(width="250px"),
        )

        widget_feature_y = widgets.Dropdown(
            description="y-axis",
            options=embedding_features,
            value=feature_y,
            layout=widgets.Layout(width="250px"),
        )

        widget_fracture = widgets.BoundedFloatText(
            min=0,
            max=1.0,
            # step=0.01,
            value=fracture,
            layout=widgets.Layout(left="98px", width="60px"),
        )

        widget_facture_label = widgets.Label(
            value="Fraction: ", layout=widgets.Layout(left="95px")
        )

        widget_feature_color = widgets.Dropdown(
            description="Color",
            options=["Default color"] + hover_features,
            value="Default color",
            layout=widgets.Layout(width="250px"),
        )

        widget_feature_color_type = widgets.RadioButtons(
            options=["Gradient", "Discrete"],
            value="Gradient",
            layout=widgets.Layout(width="140px", left="90px"),
        )

        widget_feature_color_list = widgets.Dropdown(
            options=px.colors.named_colorscales(),
            value="viridis",
            layout=widgets.Layout(width="65px", height="35px", left="40px"),
        )

        widget_feature_marker = widgets.Dropdown(
            description="Marker",
            options=["Default size"] + hover_features,
            value="Default size",
            layout=widgets.Layout(width="250px"),
        )
        widget_feature_marker_minvalue = widgets.BoundedFloatText(
            min=0,
            # max=self.max_value_markerfeat,
            step=1,
            # value=self.min_value_markerfeat,
            layout=widgets.Layout(left="91px", width="60px", height="10px"),
        )
        widget_feature_marker_minvalue_label = widgets.Label(
            value="Min value: ", layout=widgets.Layout(left="94px", width="70px")
        )
        widget_feature_marker_maxvalue = widgets.BoundedFloatText(
            # min=self.min_value_markerfeat,
            step=1,
            # value=self.max_value_markerfeat,
            layout=widgets.Layout(left="91px", width="60px"),
        )
        widget_feature_marker_maxvalue_label = widgets.Label(
            value="Max value: ", layout=widgets.Layout(left="94px", width="70px")
        )

        self[0, 0] = widget_feature_x
        self[1, 0] = widget_feature_y
        self[2, 0] = widgets.Box([widget_facture_label, widget_fracture])

        self[0, 1] = widget_feature_color
        self[1:, 1] = widgets.HBox(
            [widget_feature_color_type, widget_feature_color_list],
            layout=widgets.Layout(top="10px"),
        )

        self[0, 2] = widget_feature_marker
        self[1, 2] = widgets.Box(
            [
                widget_feature_marker_minvalue_label,
                widget_feature_marker_minvalue,
            ]
        )

        self[2, 2] = widgets.Box(
            [
                widget_feature_marker_maxvalue_label,
                widget_feature_marker_maxvalue,
            ]
        )

        self.layout.height = "140px"
        # self.layout.top = "30px"


embedding_features = ["A", "B", "C"]
hover_features = ["AA", "BB", "CC"]
feature_x = "A"
feature_y = "B"
fracture = 1.0

settings = Settings(embedding_features, hover_features, feature_x, feature_y, fracture)
settings


Settings(children=(Dropdown(description='x-axis', layout=Layout(grid_area='widget001', width='250px'), options…

In [57]:
fig = go.FigureWidget()
fig

FigureWidget({
    'data': [], 'layout': {'template': '...'}
})

In [125]:
# extract features

target = 'F'
feature_x = 'A'
feature_y = 'B'

labels = df[target].unique().tolist()

x = []
y = []
for label in labels:
    mask = df['F']==label
    x.append(df[feature_x][mask].to_numpy())
    y.append(df[feature_y][mask].to_numpy())


In [144]:
# TODO dict for   
# TODO: default: points = np.column_stack((x, y))

from scipy.spatial import ConvexHull, convex_hull_plot_2d

class Figure(go.FigureWidget):
    def __init__(self, x, y, labels, layout=None, **kwargs):

        self._x = x
        self._y = y
        self._labels = labels

        self._regression_trace = None
        self._complex_hull_traces = None

        super().__init__(None, layout, **kwargs)

        # All permanent layout settings are defined here
        self.update_layout(
            hoverlabel=dict(bgcolor="white", font_size=16, font_family="Rockwell"),
            width=800,
            height=400,
            margin=dict(l=50, r=50, b=70, t=20, pad=4),
        )
        self.update_xaxes(
            ticks="outside", tickwidth=1, ticklen=10, linewidth=1, linecolor="black"
        )
        self.update_yaxes(
            ticks="outside", tickwidth=1, ticklen=10, linewidth=1, linecolor="black"
        )

        for (x, y, label) in zip(x, y, labels):
            self.add_trace(go.Scatter(x=x, y=y, name=label, mode="markers"))

    def add_regression_line(self, coeffs):

        self._regression_trace = go.Scatter(name="Line")
        self.add_trace(self._regression_trace)

    def add_complex_hull(self):
        
        
        for (x, y, label) in zip(self._x, self._y, self._labels):
            if len(x) < 3: continue

            points = np.column_stack((x, y))
            hull = ConvexHull(points)
            inds = np.append(hull.vertices, hull.vertices[0])

            self.add_trace(go.Scatter(x=points[inds,0], y=points[inds,1]))
            # for simplex in hull.simplices:
            #     self.add_trace(go.Scatter(x=points[simplex, 0], y=points[simplex, 1]))


fig = Figure(x, y, labels)
fig.add_complex_hull()
fig


Figure({
    'data': [{'mode': 'markers',
              'name': '-1',
              'type': 'scatter',
       …

In [139]:

points = np.column_stack((x[0], y[0]))
hull = ConvexHull(points)
inds = np.append(hull.vertices, hull.vertices[0])

# self.add_trace(go.Scatter(x=points[,0], y=points[hull.vertices,1]))
# for simplex in hull.simplices:
#     self.add_trace(go.Scatter(x=points[simplex, 0], y=points[simplex, 1]))
