In [None]:
%aiida

In [None]:
import ase
import ase.io
#import ase.visualize
import numpy as np

import matplotlib.pyplot as plt

import tb_mean_field_hubbard as tbmfh

import ipywidgets as ipw
from IPython.display import clear_output

# AiiDAlab imports
from aiidalab_widgets_base.utils import string_range_to_list, list_to_string_range
from aiidalab_widgets_base import StructureManagerWidget,BasicStructureEditor, StructureBrowserWidget, StructureUploadWidget

# Local imports.
from viewers.cdxml2gnr import CdxmlUpload2GnrWidget

In [None]:
style = {'description_width': '120px'}
structure_selector = StructureManagerWidget(
    importers=[
        StructureUploadWidget(title="From computer"),
        StructureBrowserWidget(title="AiiDA database"),
        CdxmlUpload2GnrWidget(title="CDXML to GNR"),

    ],
    editors = [BasicStructureEditor()],
    storable=False,
    node_class='StructureData')
spin_u = ipw.Text(placeholder='1..10 15',
                                    description='IDs atoms spin U',
                                    style=style, layout={'width': '60%'})
spin_d = ipw.Text(placeholder='1..10 15',
                                    description='IDs atoms spin D',
                                    style=style, layout={'width': '60%'})

outs = ipw.Output()
outs1=ipw.Output()
display(structure_selector)
display(ipw.VBox([spin_u,spin_d,outs]))
geom=None
mfh_models=None

In [None]:
## set initial spins according to tags of the structure
def on_structure_change(c=None):
    global geom, mfh_models
    with outs:
        clear_output()
    if structure_selector.structure:
        tags = structure_selector.structure.get_tags()
        su=np.where(tags == 1)[0].tolist()
        sd=np.where(tags == 2)[0].tolist()        
        su_str=list_to_string_range(su)
        sd_str=list_to_string_range(sd)
        spin_d.value=sd_str
        spin_u.value=su_str
        geom = structure_selector.structure
        for atom in geom:
            if atom.symbol not in ['C','H']:
                with outs:
                    print("Element ",atom.symbol,' not allowed')
                structure_selector.structure=None
                geom=None

        
structure_selector.observe(on_structure_change,names='structure')

In [None]:
def on_spin_change(c=None):
    global geom
    if geom:
        if spin_u.value or spin_d.value:
                tags = np.zeros(len(geom))
                for u in  string_range_to_list(spin_u.value)[0]:
                    tags[u]=1
                for d in  string_range_to_list(spin_d.value)[0]:
                    tags[d]=2    

                geom.set_tags(tags)

spin_u.observe(on_spin_change)
spin_d.observe(on_spin_change)

In [None]:
# CALCULATION PARAMETERS
t1 = ipw.Text(
    value='2.7',
    description='t for 1st,2nd,.. neigh',
    disabled=False,style = {'description_width': 'initial'}
)

charge = ipw.IntText(
    value=0,
    description='charge',
    disabled=False,style = {'description_width': 'initial'}
) 
multiplicity = ipw.IntText(
    value=1,
    description='multiplicity',
    disabled=False,style = {'description_width': 'initial'}
) 
Um = ipw.FloatText(
    value=1,
    description='U/t min',
    disabled=False,style = {'description_width': 'initial'}
) # third nearest neighbor

UM = ipw.FloatText(
    value=1,
    description='U/t Max',
    disabled=False,style = {'description_width': 'initial'}
) # third nearest neighbor
def tb_pressed(c=None):
    global mfh_models
    if geom:
        try:
            t=[float(i) for i in t1.value.split()]
        except:
            t1.value='2.7'
            with outs:
                print('wrong t value')
                return
        if mfh_models is None:
            mfh_models=[]
        mfh_model = tbmfh.MeanFieldHubbardModel(geom, t, 
                                                charge.value, multiplicity.value,
                                               )
        mfh_models.append(mfh_model)
        with outs:
            clear_output()
            print("List of all ",len(mfh_models), " configurations to be computed")
            ii=0
            for m in mfh_models:
                ii+=1
                print("#",ii)
                m.print_parameters()
                m.visualize_spin_guess()

    else:
        with outs:
            print("Select a valid geometry")
        
def rtb_pressed(c=None):
    global mfh_models
    if mfh_models:
        try:
            t=[float(i) for i in t1.value.split()]
        except:
            t1.value='2.7'
            with outs:
                print('wrong t value')
                return
        energies=[]
        if not UM.value:
            UM.value = Um.value + 0.001
        u_t_ratios = np.arange(Um.value, UM.value+0.001, 0.1)
        tplot=True
        if len(u_t_ratios) >1:
            tplot=False            
        with outs:
            ii=0
            for m in mfh_models:
                em=[]
                for ut_ratio in u_t_ratios:
                    u = ut_ratio * t[0]
                    ii+=1
                    if len(u_t_ratios)==1:
                        print("results configuration #",ii)
                    m.run_mfh(u = u, print_iter=False, plot=tplot)
                    if len(u_t_ratios)==1:
                        m.report(num_orb=0)
                    em.append(m.energy)
                energies.append(em)
            if len(mfh_models)>1 :
                fig, ax = plt.subplots()
                ii=0
                for e in energies:
                    ii+=1
                    ax.plot(u_t_ratios,1000*(np.array(e)-np.array(energies[0])),label='#'+str(ii),marker='o',linestyle='-' )
                ax.set_xlabel("U/t")
                ax.set_ylabel("Energies wrt #1 [meV]")
                ax.legend()
                plt.show()
                    
def clr_pressed(c=None):
    global mfh_models
    spin_u.value=''
    spin_d.value=''
    mfh_models = None
    Um.value=1
    UM.value=1
    t1.value='2.7'
    charge.value=0
    multiplicity.value=1
    with outs:
        clear_output()
    with outs1:
        clear_output()

        
tb = ipw.Button(description="prepare additional TB", button_style='info', disabled=False)
tb.on_click(tb_pressed)
rtb = ipw.Button(description="run all TB", button_style='info', disabled=False)
rtb.on_click(rtb_pressed)
clr = ipw.Button(description="clear all", button_style='info', disabled=False)
clr.on_click(clr_pressed)
display(ipw.VBox([ipw.HBox([t1,charge,multiplicity]),ipw.HBox([Um,UM]),tb,rtb,clr]))

In [None]:
spinw=ipw.RadioButtons(
    options=['U', 'D'],value='U', 
    description='spin:',
    disabled=False
)
calc=ipw.IntText(
    value=1,
    description='calc #',
    disabled=False
)
index=ipw.IntText(
    value=1,
    description='orbital #',
    disabled=False
)
def on_show_click(c=None):
    spin=0
    if spinw.value=='D':
        spin=1
    with outs1:
        clear_output()
        print("Calc #%d orbital index: %d, spin %s" % (calc.value,index.value, spinw.value))
        print("Energy: %.6f eV" % mfh_models[calc.value -1].evals[spin][index.value-1])
        mfh_models[calc.value -1].plot_orbital(mo_index=index.value-1, spin=spin)
        evec = mfh_models[calc.value -1].evecs[spin][index.value-1]
# corresponding eigenvector (each element corresponds in order to atoms defined in geom/mfh_model.ase_geom)

    
show = ipw.Button(description="show orbital", button_style='info', disabled=False)
show.on_click(on_show_click)
display(ipw.VBox([ipw.HBox([calc,spinw,index,show]),outs1]))