In [None]:
import numpy as np
import os, glob
import ntpath
import seekpath
import pandas as pd

from pathlib import Path
from ase.io import read
from natsort import natsorted
import matplotlib.pyplot as plt

import sumo
from sumo.plotting.bs_plotter import SBSPlotter
from sumo.plotting.dos_plotter import SDOSPlotter
from sumo.electronic_structure.effective_mass import get_fitting_data, fit_effective_mass

from pymatgen.core import Structure
from pymatgen.electronic_structure import core
from pymatgen.electronic_structure.core import Spin, Orbital
from pymatgen.electronic_structure.plotter import BSPlotter, DosPlotter
from pymatgen.electronic_structure.bandstructure import BandStructure, Kpoint, BandStructureSymmLine

In [None]:
#change here - the paths and the name of the structure
bs_dir = "/Users/beatrizmourino/Documents/epfl/ta/ch-359/w9/calc/bs/exp"

cif_dir = "/Users/beatrizmourino/Documents/epfl/ta/ch-359/w9/structures"
all_csv = glob.glob(os.path.join(bs_dir, "*.csv"))
all_csv = natsorted(all_csv) #sort in numerical ascending order so that the kpoints and bandsdata are read correctly
bs_file = bs_dir + "/UTSA80_bg.bs"
out = bs_dir + "/UTSA80bs.out"
res_dir = "/Users/beatrizmourino/Documents/epfl/ta/ch-359/w9/results/exp"


cif = cif_dir + "/UTSA80_CO.cif" #use the seekpath structure

In [None]:
def get_homo(out):
    with open(out, 'r') as fh:
        for line in fh.readlines():
            if 'Number of occupied orbitals' in line:
                homo = int(line.split()[-1])
                print(f"number of occupied orbitals {homo}")
                return homo+2

def get_pmg_bands(kpoints, bands, pymatgen_structure, labels, vbm):

    bs_pmg = BandStructureSymmLine(kpoints=kpoints[0:], 
                                   eigenvals={Spin.up: bands.T},
                                    lattice=pymatgen_structure.lattice.reciprocal_lattice, 
                                   efermi=vbm+0.000000001, 
                                   labels_dict=make_labels_dict(kpoints,labels),
                                   structure=pymatgen_structure)
    
    return bs_pmg

def plot_bandstructure(bs_pmg, name, outdir=res_dir):
    bsplotter = SBSPlotter(bs_pmg)
    plt = bsplotter.get_plot()
    plt.savefig(os.path.join(outdir, '{}_bands.tiff'.format(name)), bbox_inches='tight', )
    plt.title(name)
    plt.close('all')

def make_labels_dict(kpoints,labels):
    kps = kpoints
    labels = labels
    
    label_dict = {}
    
    for index, label in labels:
        if label == 'GAMMA':
            label = '\Gamma'
        label_dict[label] = kps[index]
        
    return label_dict

def path_leaf(path):
    head, tail = ntpath.split(path)
    return tail or ntpath.basename(head)

def get_label(out):
    labels=[]
    with open(out, 'r') as fh:
        for line in fh.readlines():
            if 'KPOINTS| Number of k-points in set' in line: #checking the number of points between each kpoint
                npoints = int(line.split()[-1])-1
            if 'KPOINTS| Special' in line: #looking for the lines that will give info on labels
                label = str(line.split()[-4])
                if not labels: #is empty; always start appending when previously it was empty
                    labels.append((0,label))
                else:
                    if len(labels)==1: #also append the second one, spaced from the first one by the number of points
                        labels.append((labels[0][0]+npoints,label))
                    else: #from the third on, spacing varies: we have 3 sets of points that differ from each other by the number of points, then the next one is just the last of those three plus 1
                        if label != labels[-1][-1]: #only append if not repeated (avoid duplicates)
                            if len(labels)==2: #the third is spaced from the second by 1 point
                                labels.append((labels[-1][0]+1,label))
                            else:
                                if label != labels[-1][-1]: 
                                    if int(repr(labels[-1][0])[-1]) == int(repr(labels[-3][0])[-1]): #checking if we have a complete set of 3, if so, we need to add 1
                                        labels.append((labels[-1][0]+1,label))
                                    else: #otherwise we continue adding the npoints
                                        labels.append((labels[-1][0]+npoints,label))
        return labels

def call_sumo(bs, carrier='hole'):
    """
    Arguments:
        bs: Pymatgen BandStructureSymmLine object
        carrier
    Returns:
        (float) mass of lightest band of that type
    """
    masses = []
    if carrier == 'hole': 
        extreme = bs.get_vbm()
    elif carrier == 'electron': 
        extreme = bs.get_cbm()
        
    for spin in [Spin.up, Spin.down]:
        for b_ind in extreme['band_index'][spin]:
            fit_data = get_fitting_data(bs, spin, 
                b_ind, extreme['kpoint_index'][0], 5)[0]
            masses.append(fit_effective_mass(fit_data['distances'],
                fit_data['energies']))
        
    return min([abs(mass) for mass in masses])
    
def get_bs_feats(pmg_bands):
    
    direct_band_gap = pmg_bands.get_direct_band_gap()
    band_gap_dict = pmg_bands.get_band_gap()
    direct = band_gap_dict['direct']
    gap_energy = band_gap_dict['energy']
    transition = band_gap_dict['transition']
    vbm = pmg_bands.get_vbm()['energy']
    cbm = pmg_bands.get_cbm()['energy']
    print(f'vbm: {vbm}, cbm: {cbm}')
    effmass_hole = call_sumo(pmg_bands, 'hole')
    effmass_electron = call_sumo(pmg_bands, 'electron')
    
    return {
        'direct_band_gap_energy': direct_band_gap, 
        'direct_band_gap': direct,
        'band_gap_energy': gap_energy,
        'band_gap_transition': transition, 
        'vbm': vbm, 
        'cbm': cbm,
        'effective_mass_hole': effmass_hole,
        'effective_mass_electron': effmass_electron
    }

In [None]:
homo = get_homo(out)
lumo = #complete here

In [None]:
homo_max=[]
lumo_min=[]

for csv in all_csv:
    data = np.loadtxt(csv, dtype=float)
    homo_max.append(max(data[:,homo]))
    lumo_min.append(min(data[:,lumo]))

vbm = max(homo_max) #this will be used to shift the plot, vbm will be set as 0
cbm = min(lumo_min)

print('vbm', vbm)
print('cbm', cbm)

In [None]:
structure = Structure.from_file(cif)
name = #how can you get the name of the structure? hint: you can use the split() method 
print(name)

In [None]:
bands_file=[]
kps=[]
bds=[]

for csv_file in all_csv:
    csv = np.loadtxt(csv_file, dtype=float)
    bands_file.append(csv)
    for i,bs in enumerate(csv[0:]):
        #print(bs[0:3])
        if not kps:
            kps.append((bs[0:3]))
            bds.append(np.array(bs[4:]))
        else:
            if list(bs[0:3]) == list(kps[-1]):
                continue
            else:
                kps.append((bs[0:3]))
                bds.append(np.array(bs[4:]))

kpoints = np.array(np.stack(kps, axis=0))
bands = np.stack(bds, axis=0)
print(len(kpoints), len(bands))

In [None]:
labels = #fill here to get the labels; hint: look at the definitions
len(labels)
labels

In [None]:
pmg_bands = get_pmg_bands() #fill with the arguments

In [None]:
pmg_bands.get_cbm() #can also try with vbm; explore features of the pmg bands

In [None]:
plot_bandstructure(pmg_bands, name) #stop here after running this line and think what you expect of the effective masses

In [None]:
results = get_bs_feats(pmg_bands)
df_results = pd.DataFrame(results.items())
df_results.to_csv(res_dir + "/bands_data", index=False)
print(df_results)