In [None]:
%load_ext aiida
%aiida

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from ase.visualize.plot import plot_atoms
import time
import os
import copy

import ipywidgets as ipw
from IPython.display import display, clear_output, HTML

import ase
from ase import Atoms
from ase.io import read,write
from ase.visualize import view

from matplotlib.ticker import (MultipleLocator, FormatStrFormatter,
                               AutoMinorLocator)

import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 200
mpl.rcParams['axes.linewidth'] = 1.5


In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:

# READ CP2K vibrational analysis data from Molden format

def read_cp2k_molden(pk):
    node = load_node(pk)
    with node.outputs.retrieved.open('aiida-VIBRATIONS-1.mol',mode='rb') as fhandle:
        data = fhandle.readlines()

    freq = []
    eq_coords = [] # element, x, y, z
    vibr_displacements = [] # [vibration_nr][coord]
    inten = []

    section = ''
    b2A=0.52917721067
    # Parse the datafile
    for line in data:
        line = line.strip().decode()
        # Are we entering into a new section?
        if line[0] == '[':
            section = line.strip('[]').lower()
            continue

        if section == 'freq':
            freq.append(float(line))

        if section == 'fr-coord':
            el, x, y, z = line.split()
            eq_coords.append([el, float(x)*b2A, float(y)*b2A, float(z)*b2A])

        if section == 'fr-norm-coord':
            if line.startswith('vibration'):
                vibr_displacements.append([])
                continue
            coords = [float(x) for x in line.split()]
            vibr_displacements[-1].append(coords)

        if section == 'int':
            inten.append(float(line))

    #elems = np.array(eq_coords)[:,0]
    #coords = np.array(eq_coords)[:,1:4]
    #a = Atoms(elems, coords)
    #view(a)
    
    return freq, eq_coords, vibr_displacements, inten
    


# READ QE spectrum from dynmat.x output

def read_dynmat_output(filename):
    with open(filename) as f:
        data = f.readlines()
    freq = []
    inten = []
    for line in data:
        line = line.strip()
        parts = line.split()
        if len(parts) == 4 and parts[0].isdigit():
            freq.append(float(parts[1]))
            inten.append(float(parts[3]))
    return freq, inten



datasets = {} # {label} = [0 - label, 1 - file, 2 - freq, 3 - intensity]
vibr_disp_datasets = {} # {label} = [0 - freq, 1 - eq_coords, 2 - vibr_displacements, 3 - intensity]

In [None]:
# Calculate spectra

# broadening
fwhm = 5.0

def lorentzian(x):
    return 0.5*fwhm/(np.pi*(x**2+(0.5*fwhm)**2))
                     
def gaussian(x):
    sigma = fwhm/2.3548
    return np.exp(-x**2/(2*sigma**2))/(sigma*np.sqrt(2*np.pi))

def shape(x):
    return lorentzian(x)


In [None]:
def lighten_color(color, amount=1.3):
    """
    Lightens the given color by multiplying (1-luminosity) by the given amount.
    Input can be matplotlib color string, hex string, or RGB tuple.

    Examples:
    >> lighten_color('g', 0.3)
    >> lighten_color('#F034A3', 0.6)
    >> lighten_color((.3,.55,.1), 0.5)
    """
    import matplotlib.colors as mc
    import colorsys
    if color in mc.cnames:
        c = mc.cnames[color]
    elif isinstance(color, str):
        h = color.lstrip('#')
        c = tuple(int(h[i:i+2], 16)/256.0 for i in (0, 2, 4))
    else:
        c = color
    c = colorsys.rgb_to_hls(*mc.to_rgb(c))
    new_c = np.array(colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2]))
    new_c[new_c < 0.0] = 0.0
    new_c[new_c > 1.0] = 1.0
    return new_c

In [None]:
def vis_mode(mode):
    ase_atoms = ase.Atoms(chem_symbols, eq_coords)
    plt.figure(figsize=(16, 5))

    ax1 = plt.subplot(2, 1, 1)
    ax2 = plt.subplot(2, 1, 2)

    enhance = 1.0

    for i_at, displ in enumerate(vibr_displacements[mode]):
        pos = ase_atoms[i_at].position
        dis = enhance * np.array(displ)

        if ase_atoms[i_at].symbol == 'C':
            col = 'gray'
            rad = 0.4
        else:
            col = 'white'
            rad = 0.2

        p = (pos[0], pos[1])
        d = (dis[0], dis[1])
        cir_border = plt.Circle(p, rad+0.05, color='black')
        circ = plt.Circle(p, rad, color=col)
        ax1.add_artist(cir_border)
        ax1.add_artist(circ)
        ax1.arrow(p[0], p[1], d[0], d[1], head_width=0.5, head_length=0.1, fc='r', ec='r')

        p = (pos[0], pos[2])
        d = (dis[0], dis[2])
        cir_border = plt.Circle(p, rad+0.05, color='black')
        circ = plt.Circle(p, rad, color=col)
        ax2.add_artist(cir_border)
        ax2.add_artist(circ)
        ax2.arrow(p[0], p[1], d[0], d[1], head_width=0.5, head_length=0.1, fc='r', ec='r')


    all_pos = ase_atoms.positions

    ax1.axis('off')
    ax1.set_xlim([np.min(all_pos[:, 0])-1.0, np.max(all_pos[:, 0])+1.0])
    ax1.set_ylim([np.min(all_pos[:, 1])-1.0, np.max(all_pos[:, 1])+1.0])
    ax1.set_aspect('equal')

    ax2.axis('off')
    ax2.set_xlim([np.min(all_pos[:, 0])-1.0, np.max(all_pos[:, 0])+1.0])
    ax2.set_ylim([np.min(all_pos[:, 2])-4.0, np.max(all_pos[:, 2])+4.0])
    ax2.set_aspect('equal')

    plt.show()

In [None]:
def load_pk(b):
    global datasets,vibr_disp_datasets,chem_symbols,eq_coords,vibr_displacements
            
    try:
        freq, eq_coords, vibr_displacements, inten = read_cp2k_molden(pk_select.value)
        datasets = {'IR':['label','file',freq,inten]}
        vibr_disp_datasets = {'IR':[freq,eq_coords,vibr_displacements,inten]}
        with ir_out:
            clear_output()
            xmin = 600    # -15
            xmax = 1650   #  3400
            x = np.arange(xmin, xmax, 0.01)

            spectra = {} # {label} = [x, spectrum]

            for label, ds in datasets.items():
                spectrum = np.zeros(len(x))

                freq = ds[2]
                inten = np.array(ds[3])**2

                for i, f in enumerate(freq):
                    spectrum += inten[i]*shape(x-f)

                spectra[label] = [x, spectrum]            
            label = "IR"
            freq, eq_coords, vibr_displacements, inten = vibr_disp_datasets[label]
            chem_symbols = np.array(eq_coords)[:, 0]
            eq_coords = np.array(eq_coords)[:, 1:].astype(np.float32)
            vibr_displacements = np.array(vibr_displacements)
            [x, spectrum] = spectra[label]

            
            plot_data = [    
                ["IR",              ['#d62728', '-', 2, 10], "c2h2"],
            ]

            title = "main"

            scaling = 1.0

            # ----------------------------
            # PREPROCESS BEFORE PLOTTING

            spectra_max = 0.0

            plot_spectra = []

            for i, pd in enumerate(plot_data):
                label = pd[0]
                [x, spectrum] = spectra[label]

                plot_spectrum = np.copy(spectrum)

                # Scale highest peak to 1.0
                plot_spectrum *= 1.0/np.max(plot_spectrum)
                # flip spectrum
                plot_spectrum = 1.0 - plot_spectrum

                if np.max(plot_spectrum) > spectra_max:
                    spectra_max = np.max(plot_spectrum)

                plot_spectra.append(plot_spectrum)

            # ----------------------------

            plt.figure(figsize=(2, 2))

            for i, pd in enumerate(plot_data):
                label = pd[0]
                pl = pd[1]

                [x, spectrum] = spectra[label]
                plot_spectrum = plot_spectra[i]


                if len(pd) > 2:
                    label = pd[2]

                # Plot the SQUARE as main
                p = plt.plot(scaling*x, plot_spectrum+(len(plot_data)-1-i)*spectra_max*1.05,
                             label=label, color='k', linestyle=pl[1], linewidth=1, zorder=pl[3])

                plt.text(680, (len(plot_data)-i)*spectra_max*1.05-0.25, label, va='center', ha='center', fontdict={'size':8})


            plt.xlim([scaling*xmin, scaling*xmax])
            plt.xlim([640, 1060])
            plt.ylim([-0.1, 0.1 + 1.05*spectra_max*(len(plot_data))])

            plt.xlabel('Energy [cm$^{-1}$]')
            plt.ylabel('Relative intensity')

            ax = plt.gca()
            ax.axes.yaxis.set_ticks([])

            ax.xaxis.set_minor_locator(AutoMinorLocator(5))
            ax.tick_params(axis="x", which='both', direction='in')
            plt.show()            
            
            
            plt.figure(figsize=(4, 2))

            max_intensity = np.max(inten)
            min_f = 1e6
            max_f = 0.0
            y_shift = 0
            f_last = 0.0

            for i in range(len(freq)):
                f = freq[i]
                if np.min(x) < f < np.max(x) and inten[i] > 0.10*max_intensity:
                    plt.axvline(f, linestyle='--', color='b')
                    if f - f_last > 50:
                        y_shift = 0
                    f_last = f
                    y_pos = (np.max(spectrum)-1e-6-y_shift*np.max(spectrum)/20)%np.max(spectrum)
                    y_shift += 1
                    t = plt.text(f, y_pos, i)
                    t.set_bbox(dict(facecolor='white', edgecolor='black'))
                    if f < min_f:
                        min_f = f
                    if f > max_f:
                        max_f = f
            plt.plot(x, spectrum, 'r', lw=2.0)
            plt.xlim(min_f-1, max_f+1)
            #plt.savefig("./fig/zoom.png", dpi=300, bbox_inches='tight')
            plt.show()            
            
    except AssertionError:
        print("Incorrect pk.")
        return
    
    
style = {'description_width': '50px'}
layout = {'width': '70%'}

pk_select = ipw.IntText(value=0, description='pk', style=style, layout=layout)

load_pk_btn = ipw.Button(description='Load pk', style=style, layout=layout)
load_pk_btn.on_click(load_pk)

ir_out = ipw.Output()
display(pk_select,load_pk_btn,ir_out)

In [None]:
def on_animate_click(c):
    with animation_out:
        clear_output()
        trajectory = []
        time_arr = np.linspace(0.0, 2*np.pi, 20)
        print("Mode %d" % mode_to_display.value)
        vis_mode(mode_to_display.value)
        for time in time_arr:
            #eq_atoms = ase.Atoms(len(chem_symbols)*['O'], eq_coords)
            vibr_atoms = ase.Atoms(chem_symbols, eq_coords+enhance.value*np.sin(time)*vibr_displacements[mode_to_display.value])
            trajectory.append(vibr_atoms)
        nglv = view(trajectory, viewer='ngl')
        display(nglv)


mode_to_display = ipw.IntText(description="Mode",value=1)
enhance = ipw.FloatText(description='Amplitude',value=2)
animation_out = ipw.Output()
animate = ipw.Button(description='Animate')
animate.on_click(on_animate_click)

display(ipw.VBox([mode_to_display,enhance,animate,animation_out]))