<a href="https://colab.research.google.com/github/gautamankitkumar/ankitgau-ms-report-data/blob/main/notebooks/swap-histogram.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install ase
! git clone https://github.com/gautamankitkumar/ankitgau-ms-report-data.git
% cd ankitgau-ms-report-data
% cp data/swap-data.json notebooks/
% cd notebooks

In [1]:
import ternary
import json
import os
import torch

import numpy as np
import matplotlib as mpl
%matplotlib inline
import matplotlib.pyplot as plt

# Used for interactive graph
import plotly.graph_objects as go
from ipywidgets import widgets

from utils.fcc_helpers import cal_nrg
from utils.train_agent import BPNN
from utils.fp_calculator import set_sym
from ase.build import fcc111

In [2]:
all_list = []
for i in range(0,110,10):
    for j in range(0,110-i,10):
        all_list.append([i,j,100-i-j])
all_list.remove([0,0,100])
all_list.remove([100,0,0])
all_list.remove([0,100,0])
all_list = np.array(all_list)

In [5]:
with open('swap-data.json') as fp:
    swap_data = json.load(fp)

%matplotlib notebook
# Make a grid of points from X and Y points
# These X and Y point remain same across all sheets

# Plot on grid points
xy_grid_plot = go.FigureWidget([go.Scatterternary(a = all_list[:,0],b = all_list[:,1],c = all_list[:,2],
                                                  hovertemplate='Cu %{a}%<br>Ag %{b}%<br>Au %{c}% <extra></extra>',
                                                  mode='markers')])

xy_grid_plot.update_layout({
    'title': 'Sampled compositions. Click to show histogram for each marker composition. One plot at a time',
    'ternary': {
        'sum': 100,
    'aaxis':{'title':'Cu'},
    'baxis':{'title':'Ag'},
    'caxis':{'title':'Au'}
    }},
    hoverlabel=dict(
        bgcolor="white", 
        font_size=16, 
        font_family="Rockwell"
    ))

# This code block gets markers of composition space
scatter = xy_grid_plot.data[0]
colors = ['#a3a7e4'] * len(all_list)
scatter.marker.color = colors
scatter.marker.size = [10] * len(all_list)
xy_grid_plot.layout.hovermode = 'closest'

color_stack = []
# create our callback function
def update_point(trace, points, selector):
    
    global color_stack
    
    # Get Color and size array
    c = list(scatter.marker.color)
    s = list(scatter.marker.size)
    
    # #a3a7e4 is default color
    # #bae2be is new color
    
    for i in points.point_inds:
    # Here i corresponds to clicked index of marker. It should be within {0,17}
        plot_graph_flag = True
        # When the color stack is empty. Plot new and mark the index
        new_color = np.array(c)
        new_size = np.array(s)
        if not color_stack:
            new_color[i] = '#bae2be'
            new_size[i] = 20
            color_stack = [i]
            plot_graph_flag = True
        
        # When an old data needs to be removed. Marker is de-colorised and graph removed
        elif i in color_stack:
            new_color[i] = '#a3a7e4'
            new_size[i] = 10
            color_stack.remove(i)
            plot_graph_flag = False
        
        # When a new plot needs to be drawn. Old one is removed
        elif i not in color_stack:
            plot_graph_flag = True
            new_color[color_stack[0]] = '#a3a7e4'
            new_size[color_stack[0]] = 10
            new_color[i] = '#bae2be'
            new_size[i] = 20
            color_stack = [i]
            
        # Update clicked marker
        with xy_grid_plot.batch_update():
            scatter.marker.color = new_color
            scatter.marker.size = new_size
            
        # Update graph corresponding to clicked index
        with main_plot.batch_update(): 
            
            if plot_graph_flag:
                comp = all_list[i]
                key_name = str(comp[0]*15) + '_' + str(comp[1]*15) + '_' + str(comp[2]*15)
                plot_data = swap_data[key_name]
                
                main_plot.data[0].x = plot_data['AgAu']
                main_plot.data[1].x = plot_data['AgCu']
                main_plot.data[2].x = plot_data['AuCu']
                main_plot.data[3].x = plot_data['AuAg']
                main_plot.data[4].x = plot_data['CuAg']
                main_plot.data[5].x = plot_data['CuAu']
                main_plot.update_layout(go.Layout(title=dict(text=f'Histogram of Potential Swaps for {comp[0]}% Cu, {comp[1]}% Ag, {comp[2]}% Au')))
            else:
                main_plot.data[0].x = [0]
                main_plot.data[1].x = [0]
                main_plot.data[2].x = [0]
                main_plot.data[3].x = [0]
                main_plot.data[4].x = [0]
                main_plot.data[5].x = [0]
                main_plot.update_layout(go.Layout(title=dict(text=f'No Points selected')))
                
scatter.on_click(update_point)

from plotly.subplots import make_subplots
import plotly.graph_objects as go


index_arr = [r'$Ag_{sub} \rightarrow Au_{surf}$',
             r'$Ag_{sub} \rightarrow Cu_{surf}$',
             r'$Au_{sub} \rightarrow Cu_{surf}$',
             r'$Au_{sub} \rightarrow Ag_{surf}$',
             r'$Cu_{sub} \rightarrow Ag_{surf}$',
             r'$Cu_{sub} \rightarrow Au_{surf}$']
# Initialize figure with subplots
fig1 = make_subplots(
    rows=2, cols=3, subplot_titles=index_arr,shared_xaxes='columns',shared_yaxes='columns'
)

# Initialize First graphs
fig1.add_trace(go.Histogram(x=[0],name = 'AgAu'),row=1,col=1)
fig1.add_trace(go.Histogram(x=[0],name = 'AgCu'),row=1,col=2)
fig1.add_trace(go.Histogram(x=[0],name = 'AuCu'),row=1,col=3)

fig1.add_trace(go.Histogram(x=[0],name = 'AuAg'),row=2,col=1)
fig1.add_trace(go.Histogram(x=[0],name = 'CuAg'),row=2,col=2)
fig1.add_trace(go.Histogram(x=[0],name = 'CuAu'),row=2,col=3)


fig1.update_xaxes(title_text=r"$\Delta E_{swap} (eV)$", row=2, col=1)
fig1.update_xaxes(title_text=r"$\Delta E_{swap} (eV)$", row=2, col=2)
fig1.update_xaxes(title_text=r"$\Delta E_{swap} (eV)$", row=2, col=3)

# TODO: Plot a single black line on each subplot, similar to ax.vline(0)
# fig1.update_layout(
#     shapes=[
#         dict(type="line", xref="x1", yref="y1",
#             x0=0, y0=0, x1=0, y1=10, line_width=3),
#         dict(type="line", xref="x2", yref="y2",
#             x0=0, y0=0, x1=0, y1=10, line_width=3),
#         dict(type="line", xref="x3", yref="y3",
#             x0=0, y0=0, x1=0, y1=10, line_width=3),
#         dict(type="line", xref="x4", yref="y4",
#             x0=0, y0=0, x1=0, y1=10, line_width=3),
#         dict(type="line", xref="x5", yref="y5",
#             x0=0, y0=0, x1=0, y1=10, line_width=3),
#         dict(type="line", xref="x6", yref="y6",
#             x0=0, y0=0, x1=0, y1=10, line_width=3),])

main_plot = go.FigureWidget(data=fig1, layout=go.Layout(title=dict(text='Histogram Plots')))
main_plot.update_layout(xaxis_showticklabels=True, xaxis2_showticklabels=True,xaxis3_showticklabels=True)

z = widgets.VBox([xy_grid_plot,main_plot])
z

VBox(children=(FigureWidget({
    'data': [{'a': array([ 0,  0,  0,  0,  0,  0,  0,  0,  0, 10, 10, 10, 10, 10…