# Finding the Optmimal Lattice Constants for Bulk Support Materials
***
Notes:

  This will break when I put in the data that also has a c lattice constant.

  Probably just make two plots (one for only a and one for both a and c lattice constants).

***

# Notebook Setup

## Import Modules

In [1]:
import os
import sys

import numpy as np
import pandas as pd

from scipy.interpolate import interp1d

import plotly.plotly as py
import plotly.graph_objs as go

import pickle

## Script Inputs

In [2]:
groupby_cols = [
    "support_metal",
    "crystal_structure",
    "spinpol",
    ]

prop_name_list = [
    "support_metal",
    "crystal_structure",
    "spinpol",
    ]
# support_metal




# smart_format_dict = [
# 
#     [
#         {"support_metal": "Ni"},
#         {"color2": "red"},
#         ],

#     [
#         {"support_metal": "Rh"},
#         {"color2": "blue"},
#         ],

#     [
#         {"support_metal": "Co"},
#         {"color2": "black"},
#         ],

#     [
#         {"support_metal": "Ru"},
#         {"color2": "green"},
#         ],

#     [
#         {"support_metal": "Mo"},
#         {"color2": "orange"},
#         ],

#     [
#         {"support_metal": "W"},
#         {"color2": "grey"},
#         ],

#     ]

system_color_dict = {
    "Ni": "red",
    "Rh": "blue",
    "Co": "black",
    "Ru": "green",
    "Mo": "orange",
    "W": "grey",
    }

## Load Data

In [3]:
df_dir = os.path.join(
    ".",
    "job_dataframe_bulk_opt_ni_co_mo_01.pickle")
with open(df_dir, "rb") as fle:
    df_m_ni_co_mo_1 = pickle.load(fle, encoding="latin1")

df_dir = os.path.join(
    ".",
    "job_dataframe_bulk_opt_ni_co_mo_02.pickle")
with open(df_dir, "rb") as fle:
    df_m_ni_co_mo_2 = pickle.load(fle, encoding="latin1")

df_dir = os.path.join(
    ".",
    "job_dataframe_bulk_opt_ru_rh_w_01.pickle")
with open(df_dir, "rb") as fle:
    df_m_ru_rh_w_1 = pickle.load(fle, encoding="latin1")

frames = [
    df_m_ni_co_mo_1,
    df_m_ni_co_mo_2,
    df_m_ru_rh_w_1,
    ]

df_m = pd.concat(frames)

# Getting rid of duplicate columns
df_m = df_m.loc[:,~df_m.columns.duplicated()]
df_m = df_m.replace(np.nan, 'nan', regex=True)
df_m = df_m[df_m["elec_energy"] != "nan"]

# Selecting HCP materials
df_m = df_m[df_m["c"] == "nan"]
# df_m = df_m[df_m["c"] != "nan"]

# df_m = df_m[df_m["support_metal"] == "Mo"]
# df_m = df_m[df_m["spinpol"] == False]

# Processing Data

## Process DFT Jobs

In [4]:
data = []
groupby = df_m.groupby(groupby_cols)
for name_i, group_i in groupby:
    group_i = group_i.sort_values(by=["a"])


    group_i["elec_energy_norm"] = 0. + \
        group_i["elec_energy"] - \
        group_i["elec_energy"].min()

    a_list = group_i["a"].tolist()
    energy_list = group_i["elec_energy_norm"].tolist()

    name_i = ""
    for prop_i in prop_name_list:   
        name_i += str(group_i[prop_i].tolist()[0]) + "_"

    name_i = name_i[0:-1]

    color_i = None
    if len(list(set(group_i["support_metal"].tolist()))) == 1:
        support_metal = list(set(group_i["support_metal"].tolist()))[0]
        color_i = system_color_dict[support_metal]

    if len(list(set(group_i["spinpol"].tolist()))) == 1:
        spinpol = list(set(group_i["spinpol"].tolist()))[0]
        
        if spinpol is True:
            shape_i = "circle"
        elif spinpol is False:
            shape_i = "triangle-up"

    # ###################################################
    # ###################################################
    # ###################################################

    no_c_latt_const = False
    c_col = list(set(group_i["c"].tolist()))
    if len(c_col) == 1 and "nan" in c_col:
        no_c_latt_const = True

    a_list = group_i["a"].tolist()
    energy_list = group_i["elec_energy_norm"].tolist()

    name_i = ""
    for prop_i in prop_name_list:   
        name_i += str(group_i[prop_i].tolist()[0]) + "_"

    name_i = name_i[0:-1]
    data.append(go.Scatter(
        x=a_list,
        y=energy_list,
        mode='markers',
        name=name_i,
        marker=dict(
            size=8,
            color=color_i,
            symbol=shape_i,
            )
        ))

    # ###################################################
    # ###################################################
    # ###################################################

    x_axis_interp = np.linspace(
        np.array(a_list).min(),
        np.array(a_list).max(),
        1000,
        )
    interp_out = interp1d(
        a_list,
        energy_list,
        kind="cubic",
        )(x_axis_interp)

    interp_i = go.Scatter(
        x=x_axis_interp,
        y=interp_out,
        mode='lines',
        )

    data.append(interp_i)

    min_ind = np.argmin(interp_out)
    min_i = go.Scatter(
        x=[x_axis_interp[min_ind]],
        y=[interp_out[min_ind]],
        mode='markers',
        marker=dict(
            size=10,
            color="red",
            symbol="star",
            )
        )

    data.append(min_i)

In [5]:
group_i

Unnamed: 0,Job,a,c,crystal_structure,max_revision,path,revision_number,spinpol,support_metal,elec_energy,...,nbands,noncollinear,outdir,output,parflags,printensemble,pw,sigma,xc,elec_energy_norm
3280,<dft_job_automat.job_setup.Job object at 0x7fd...,3.1019,,bcc,1,data/W/bcc/True/3.1019-nan/_1,1,True,W,-2183.56,...,-50,False,calcdir,{'removesave': True},,True,600,0.005,BEEF-vdW,0.0868074
3282,<dft_job_automat.job_setup.Job object at 0x7fd...,3.10514,,bcc,1,data/W/bcc/True/3.10514-nan/_1,1,True,W,-2183.57,...,-50,False,calcdir,{'removesave': True},,True,600,0.005,BEEF-vdW,0.0788863
3284,<dft_job_automat.job_setup.Job object at 0x7fd...,3.10839,,bcc,1,data/W/bcc/True/3.10839-nan/_1,1,True,W,-2183.58,...,-50,False,calcdir,{'removesave': True},,True,600,0.005,BEEF-vdW,0.0714582
3286,<dft_job_automat.job_setup.Job object at 0x7fd...,3.11164,,bcc,1,data/W/bcc/True/3.11164-nan/_1,1,True,W,-2183.59,...,-50,False,calcdir,{'removesave': True},,True,600,0.005,BEEF-vdW,0.0647498
3288,<dft_job_automat.job_setup.Job object at 0x7fd...,3.11488,,bcc,1,data/W/bcc/True/3.11488-nan/_1,1,True,W,-2183.59,...,-50,False,calcdir,{'removesave': True},,True,600,0.005,BEEF-vdW,0.0591978
3290,<dft_job_automat.job_setup.Job object at 0x7fd...,3.11813,,bcc,1,data/W/bcc/True/3.11813-nan/_1,1,True,W,-2183.6,...,-50,False,calcdir,{'removesave': True},,True,600,0.005,BEEF-vdW,0.0521086
3292,<dft_job_automat.job_setup.Job object at 0x7fd...,3.12137,,bcc,1,data/W/bcc/True/3.12137-nan/_1,1,True,W,-2183.6,...,-50,False,calcdir,{'removesave': True},,True,600,0.005,BEEF-vdW,0.0467013
3294,<dft_job_automat.job_setup.Job object at 0x7fd...,3.12462,,bcc,1,data/W/bcc/True/3.12462-nan/_1,1,True,W,-2183.61,...,-50,False,calcdir,{'removesave': True},,True,600,0.005,BEEF-vdW,0.0411867
3296,<dft_job_automat.job_setup.Job object at 0x7fd...,3.12787,,bcc,1,data/W/bcc/True/3.12787-nan/_1,1,True,W,-2183.61,...,-50,False,calcdir,{'removesave': True},,True,600,0.005,BEEF-vdW,0.0369948
3298,<dft_job_automat.job_setup.Job object at 0x7fd...,3.13111,,bcc,1,data/W/bcc/True/3.13111-nan/_1,1,True,W,-2183.62,...,-50,False,calcdir,{'removesave': True},,True,600,0.005,BEEF-vdW,0.0318404


# Plotting

In [6]:
axes_lab_size = 12 * (4./3.)
tick_lab_size = 9 * (4./3.)

common_axis_dict = {

    # "range": y_axis_range,
    "zeroline": False,
    "showline": True,
    "mirror": 'ticks',
    "linecolor": 'black',
    "showgrid": False,

    "titlefont": dict(size=axes_lab_size),
    "tickfont": dict(
        size=tick_lab_size,
        ),
    "ticks": 'inside',
    "tick0": 0,
    "tickcolor": 'black',
    # "dtick": 0.25,
    "ticklen": 2,
    "tickwidth": 1,
    }

layout = {
    "title": "Lattice constant bulk optmization (Ni, Co, Mo)",
    "titlefont": go.layout.Titlefont(size=24),

    "xaxis": dict(
        common_axis_dict,
        **{
            "title": "a",
#             "range": x_range,
            },
        ),

    "yaxis": dict(
        common_axis_dict,
        **{
            "title": "Electronic Energy (eV)",
#             "range": y_range,
            },
        ),

    "font": dict(
        family='Arial',
        # size=18,
        color='black',
        ),

#     "width": 1.5 * 18.7 * 37.795275591,
#     "height": 18.7 * 37.795275591,

    "showlegend": True,

    "legend": dict(
        # x=0.,
        # y=1.8,
        font=dict(
            size=10,
            ),
        ),
    }

In [21]:
fig = dict(data=data, layout=layout)

py.iplot(fig, filename='__temp__/basic-scatter')

In [8]:
# x_axis_interp
# interp_out = interp1d(
#     a_list,
#     energy_list,
#     kind="cubic",
#     )(x_axis_interp)

In [9]:
# min_ind = np.argmin(interp_out)
# interp_out[min_ind]
# x_axis_interp[min_ind]

In [10]:
# group_i["elec_energy_norm"].min()

# row_min = group_i[group_i['elec_energy_norm'] == group_i['elec_energy_norm'].min()]

# row_min["elec_energy_norm"].iloc[0]
# row_min["a"].iloc[0]


# min_i = go.Scatter(
#     x=[row_min["a"].iloc[0]],
#     y=[row_min["elec_energy_norm"].iloc[0]],
# #     interp1d(
# #         a_list,
# #         energy_list,
# #         kind="cubic",
# #         )(x_axis_interp),

#     mode='markers',
#     )

# data.append(min_i)

In [11]:
# from scipy.interpolate import interp1d

# x = np.linspace(0, 10, num=11, endpoint=True)
# y = np.cos(-x**2/9.0)
# # f = interp1d(x, y)
# f2 = interp1d(x, y, kind='cubic')

# f2()

# y=np.interp(
#     x_axis_interp,
#     a_list,
#     energy_list,
#     )

# interp1d(a_list, energy_list, kind='cubic')(x_axis_interp)

In [12]:
#     if no_c_latt_const is False:
#         print("sKLfjaskfsdj")
#         groupby = group_i.groupby(["c"])
#         for name_j, group_j in groupby:            
#             a_list = group_j["a"].tolist()
#             energy_list = group_j["elec_energy_norm"].tolist()

#             name_i = ""
#             for prop_i in prop_name_list:   
#                 name_i += str(group_j[prop_i].tolist()[0]) + "_"

#             name_i = name_i[0:-1]
            
#             data.append(go.Scatter(
#                 x=a_list,
#                 y=energy_list,
#                 mode='markers',
#                 name=name_i,
#                 marker = dict(
#                     size = 8,
#                     color = color_i,
# #                     line = dict(
# #                         width = 2,
# #                         color = 'rgb(0, 0, 0)'
# #                         )
#                     )
#                 ))

In [13]:
# if len(list(set(group_i["spinpol"].tolist()))) == 1:
#     spinpol = list(set(group_i["spinpol"].tolist()))[0]

In [14]:
# no_c_latt_const = all(np.isnan(group_i["c"].tolist()))
# all(
# np.isnan(group_i["c"].tolist())
# )

In [15]:
# x_axis_interp = np.linspace(
#     np.array(a_list).min(),
#     np.array(a_list).max(),
#     300,
#     )

# np.interp(
#     x_axis_interp,
#     a_list,
#     energy_list,
#     )

# interp_i = go.Scatter(
#     x=x_axis_interp,
#     y=np.interp(
#         x_axis_interp,
#         a_list,
#         energy_list,
#         ),

#     mode='lines',

# #     name=name_i,
# #     marker = dict(
# #         size = 8,
# #         color = color_i,
# # #                     line = dict(
# # #                         width = 2,
# # #                         color = 'rgb(0, 0, 0)'
# # #                         )
# #         )

#     )

# # data.append(interp_i)

In [16]:
# c_col = list(set(group_i["c"].tolist()))
# if len(c_col) == 1 and "nan" in c_col:
#     tmp = 42

In [17]:
# df_m

# df_m = df_m.replace(np.nan, 'nan', regex=True)

In [18]:
# a_list = group_i["a"].tolist()

# # if "c" in list()
# # c_list = group_i["c"].tolist()

# energy_list = group_i["elec_energy"].tolist()

# no_c_latt_const = all(np.isnan(group_i["c"].tolist()))

# if no_c_latt_const is False:
#     groupby = df_m.groupby(["c"])
#     for name_j, group_j in groupby:
#         tmp = 42

# name_i = ""
# for prop_i in prop_name_list:   
#     name_i += str(group_i[prop_i].tolist()[0]) + "_"

# name_i = name_i[0:-1]

In [19]:
# if len(list(set(group_i["support_metal"].tolist()))) == 1:
#     support_metal = list(set(group_i["support_metal"].tolist()))[0]

In [20]:
# data_points = go.Scatter3d(
#     x=a_lst,
#     y=b_a_lst,
#     z=elec_energy_lst,
#     mode='markers',
  
#     marker=dict(
#         size=6,
# #         color=elec_energy_lst,                # set color to an array/list of desired values
#         color="grey",                # set color to an array/list of desired values
# #         colorscale='magma',   # choose a colorscale
#         colorscale='Blackbody',   # choose a colorscale
#         opacity=1.
#         )    
#     )

# fig = go.Figure(data=[data_points], layout=layout2)
# py.iplot(fig, filename=pl_dir + '/pl_latt_opt_data_points')