In [None]:
%matplotlib inline
import numpy as np
import pandas as pd
import time
import os
from scipy.interpolate import interp1d, griddata
from sklearn.decomposition import PCA
from pymatgen.core.structure import Structure
from pymatgen.core.periodic_table import Element

import matplotlib.pyplot as plt
from plot_imports import *
from utils import get_species, load_data, build_data, standardXAS, clusterXAS

seed = 14
np.random.seed(seed)

In [None]:
# colors for scatter points
cols = ['#6A71B4', '#F7B744']
cmap = colors.ListedColormap(cols)

# transparent colormaps for gaussians
tmaps = [make_transparent_cmap(cols[0],0,0.6),
         make_transparent_cmap(cols[1],0,0.6)]

def gaussian(x, y, cx, cy, r):
    return np.exp(-((x-cx)**2+(y-cy)**2)/(2*r*r))

def smooth_data(x, window_radius):
    window_len = 2*window_radius+1
    w = np.hanning(window_len)
    s = np.r_[x[window_len-1:0:-1], x, x[-2:-window_len-1:-1]]
    y = np.convolve(s, w/w.sum(), mode='valid')
    return y[window_radius:x.shape[-1] + window_radius]

In [None]:
data_path = 'data/data_merged.csv'
       
if not os.path.exists('images'):
    os.makedirs('images')

In [None]:
# load data
num_classes = 2
data, species = load_data(data_path, num_classes)

# exclude select samples
data_inc = data[data['tag'] == 'INC'].reset_index(drop=True)
print('number of inconclusive samples:', len(data_inc))
get_species(data_inc)

data_weyl = data[data['tag'] == 'WEYL'].reset_index(drop=True)
print('number of weyl semimetal samples:', len(data_weyl))
get_species(data_weyl)

data = data[data['tag'].isna()].reset_index(drop=True)
print('number of train/valid/test samples:', len(data))

species_data = get_species(data)

n = 200 # number of energy bins
nc = 20 # maximum number of principal components
evar = 0.8 # minimum cumulative explained variance

In [None]:
# load computed train/valid/test split
with open('data/idx_train.txt', 'r') as f:
    idx_train = [int(i.split('\n')[0]) for i in f.readlines()]

with open('data/idx_valid.txt', 'r') as f:
    idx_valid = [int(i.split('\n')[0]) for i in f.readlines()]

with open('data/idx_test.txt', 'r') as f:
    idx_test = [int(i.split('\n')[0]) for i in f.readlines()]

In [None]:
species_all = []
for z in range(1,119):
    species_all += [Element.from_Z(z)]

In [None]:
# combine validation and test sets into a single test set
idx_test = np.hstack([idx_valid, idx_test])

col = 'spectra_abs'
x_train_tot, y_train_tot, e_train_tot, type_encoding = build_data(data, species_all, idx_train, col)
x_test_tot, y_test_tot, e_test_tot, _ = build_data(data, species_all, idx_test, col)
y_train_tot = y_train_tot.ravel()
y_test_tot = y_test_tot.ravel()

print('Training size:',len(x_train_tot))
print('Testing size:',len(x_test_tot))

# standardize data
ne = x_train_tot.shape[-1]
sdx = standardXAS(ne)
sdx.fit_transform(x_train_tot, e_train_tot)
sdx.transform(x_test_tot, e_test_tot)

In [None]:
# get periodic table data
elem_data = dict(zip(['specie', 'row', 'column'], [[], [], []]))
for i in range(1,119):
    specie = Element.from_Z(i)
    elem_data['specie'].append(str(specie))
    
    # put La and Ac in rows 6 and 7 for display
    if str(specie) == 'La':
        elem_data['row'].append(6)
        elem_data['column'].append(specie.group)
    elif str(specie) == 'Ac':
        elem_data['row'].append(7)
        elem_data['column'].append(specie.group)
        
    # shift the group of row 8 and row 9 by -1 for display
    elif specie.row > 7:
        elem_data['row'].append(specie.row)
        elem_data['column'].append(specie.group - 1)
    
    else:
        elem_data['row'].append(specie.row)
        elem_data['column'].append(specie.group)
    
elem_data = pd.DataFrame(elem_data)

In [None]:
# plot element-specific spectra by class
nrow = 9
ncol = 18

prop.set_size(14)

fig, ax = plt.subplots(nrow, ncol, figsize=(2*ncol,2*nrow))
for i in range(nrow):
    for j in range(ncol):
        elem = elem_data.loc[(elem_data['row']==i+1) & (elem_data['column']==j+1), 'specie'].tolist()
        
        if len(elem):
            # valid element
            elem = elem[0]
            k = type_encoding.index(elem)
            ax[i,j].set_xticks([]); ax[i,j].set_yticks([])
            
            mask_train = e_train_tot[:,k].ravel() == 1
            mask_test = e_test_tot[:,k].ravel() == 1
            ncn = np.sum(mask_train).astype(int)
            nct = np.sum(mask_test).astype(int)
    
            if ncn > 1:
                # element in data
                # initialize a numpy array for spectra
                x_train = np.copy(x_train_tot[mask_train,:,k])
                y_train = np.copy(y_train_tot[mask_train])
                x_test = np.copy(x_test_tot[mask_test,:,k])
                y_test = np.copy(y_test_tot[mask_test])
                
                x_data = np.vstack([x_train, x_test])
                y_data = np.hstack([y_train, y_test])

                for o in range(num_classes):
                    if len(y_data[y_data==o]) > 0:
                        xc = x_data[y_data==o,:]
                        xc_mean = xc.mean(axis=0)
                        xc_std = smooth_data(xc.std(axis=0), 11)
                        ax[i,j].fill_between(range(n), xc_mean + xc_std, xc_mean - xc_std, color=cols[o],
                                             alpha=0.4, lw=0)
                        ax[i,j].plot(range(n), xc_mean, color=cols[o], lw=2, zorder=10000)

                ax[i,j].text(0.1, 0.8, elem, color='black', fontproperties=prop, transform=ax[i,j].transAxes,
                             zorder=500000)
                
            else:
                # element not in data
                ax[i,j].spines['bottom'].set_color('gray')
                ax[i,j].spines['top'].set_color('gray') 
                ax[i,j].spines['right'].set_color('gray')
                ax[i,j].spines['left'].set_color('gray')
                ax[i,j].text(0.1, 0.8, elem, color='gray', fontproperties=prop, transform=ax[i,j].transAxes)

        else:
            # invalid element
            ax[i,j].remove()
            
fig.savefig('images/periodic_table_spectra.png', bbox_inches='tight', dpi=400)

In [None]:
# plot element-specific pca and k-means clustering
nrow = 9
ncol = 18

marker_size = 4
prop.set_size(14)

fig, ax = plt.subplots(nrow, ncol, figsize=(2*ncol,2*nrow))
for i in range(nrow):
    for j in range(ncol):
        elem = elem_data.loc[(elem_data['row']==i+1) & (elem_data['column']==j+1), 'specie'].tolist()
        
        if len(elem):
            # valid element
            elem = elem[0]
            k = type_encoding.index(elem)
            ax[i,j].set_xticks([]); ax[i,j].set_yticks([])
            
            mask_train = e_train_tot[:,k].ravel() == 1
            mask_test = e_test_tot[:,k].ravel() == 1
            ncn = np.sum(mask_train).astype(int)
            nct = np.sum(mask_test).astype(int)
        
            if ncn > 1:
                # element in data
                # initialize a numpy array for spectra
                x_train = np.copy(x_train_tot[mask_train,:,k])
                y_train = np.copy(y_train_tot[mask_train])
                x_test = np.copy(x_test_tot[mask_test,:,k])
                y_test = np.copy(y_test_tot[mask_test])
                print("Train count: {:d} Test count: {:d}".format(ncn, nct))

                # perform PCA fit on training data, and transform all data
                mnc = min(nc, ncn)
                pca = PCA(n_components=mnc, svd_solver='full')
                pca.fit(x_train)

                # number of components to achieve target variance
                try: mnc = list((pca.explained_variance_ratio_.cumsum() > evar).astype(int)).index(1) + 1
                except: pass

                if mnc < 2: mnc = 2 # minimum of 2 components
                print("Retained {:.4f} explained variance with {:d} components.".format(
                    pca.explained_variance_ratio_[:mnc].cumsum()[-1],mnc))

                # refit with optimal components
                pca = PCA(n_components=mnc, svd_solver='full')
                pca.fit(x_train)
                p_train = pca.transform(x_train)

                if nct > 0:
                    p_test = pca.transform(x_test)
                else: p_test = 0

                # perform k-means clustering
                clx = clusterXAS(n_clusters=2)
                p_filter, y_filter = clx.remove_outliers(p_train, y_train, frac=0.98)
                clx.fit(p_filter, y_filter)

                # pca for visualization
                pca_viz = PCA(n_components=2, svd_solver='full')
                pca_viz.fit(p_train)
                pv_train = pca_viz.transform(p_train)
                pv_test = pca_viz.transform(p_test)

                # predict on a grid
                lim = 2*np.max(np.std(p_filter, axis=0))
                x, y = np.meshgrid(np.linspace(-lim,lim,1000), np.linspace(-lim,lim,1000))
                xh = pca_viz.inverse_transform(np.c_[x.ravel(), y.ravel()])
                z = clx.predict(xh)
                z = z.reshape(x.shape)

                # predict on the cluster center locations
                ct = clx.predict(clx.kmeans.cluster_centers_).astype(int)

                # grid plot
                ax[i,j].imshow(z, interpolation='nearest', extent=(x.min(), x.max(), y.min(), y.max()), cmap=cmap,
                               alpha=0.15, aspect='auto', origin='lower', vmin=0, vmax=1)

                # gaussian plots
                centers = []
                dists = np.zeros((pv_train.shape[0],2))
                for c in range(2):
                    cx, cy = pca_viz.transform(clx.kmeans.cluster_centers_[c,:].reshape(1,-1)).ravel()
                    centers += [[cx,cy]]
                    dists[:,c] = np.square(pv_train[:,0]-cx)+np.square(pv_train[:,1]-cy)
                centers = np.array(centers)
                clusters = np.argmin(dists, axis=1)
                inertia = sum([dists[c,clusters[c]] for c in range(len(pv_train))])
                for c in range(2):
                    cx, cy = centers[c]
                    r = np.sqrt(inertia/len(pv_train))
                    ax[i,j].imshow(gaussian(x, y, cx, cy, r), interpolation='bicubic',
                                   extent=(x.min(), x.max(), y.min(), y.max()), cmap=tmaps[ct[c]], aspect='auto',
                                   origin='lower', vmin=0, vmax=1)

                # data plot
                # shift the data by the mean of the first value in the 
                # training set just to make it easier to plot consistently
                mu = np.mean(x_train[:,0]); std = np.std(x_train[:,0])
                x_test -= mu
                x_train -= mu
                for o in range(num_classes):
                    for s,p,c in zip([x_train, x_test],
                                     [pv_train, pv_test],
                                     [y_train, y_test]):
                        if len(c[c==o]) > 0:
                            ax[i,j].scatter(p[:,0][c==o], p[:,1][c==o], color=cols[o], alpha=0.8, s=marker_size,
                                            lw=0, zorder=10000)

                # centers plot
                for c in range(2):
                    cx, cy = centers[c]
                    ax[i,j].scatter(cx, cy, marker='x', color='w', s=2*marker_size, lw=0.5, zorder=30000)
                  
                
                ax[i,j].set_xlim(-lim,lim)
                ax[i,j].set_ylim(-lim,lim)
                ax[i,j].text(0.1, 0.8, elem, color='black', fontproperties=prop, transform=ax[i,j].transAxes,
                             zorder=500000)
                
            else:
                # element not in data
                ax[i,j].spines['bottom'].set_color('gray')
                ax[i,j].spines['top'].set_color('gray') 
                ax[i,j].spines['right'].set_color('gray')
                ax[i,j].spines['left'].set_color('gray')
                ax[i,j].text(0.1, 0.8, elem, color='gray', fontproperties=prop, transform=ax[i,j].transAxes)

        else:
            # invalid element
            ax[i,j].remove()
            
fig.savefig('images/periodic_table_pca.png', bbox_inches='tight', dpi=400)