# Rocksalt-Zincblende Classification Problem

In [3]:
import math as m
import itertools as it 
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy.stats
from sklearn import linear_model
from sklearn import preprocessing
from sklearn import metrics
from sklearn import model_selection
import time
import multiprocessing as mp
import ray

# Data Inicialization and Feature Space Generation

In [26]:
def inicializace_dat():
    Z = np.array([
                    3, 4, 5, 6, 7, 8, 9, 
                    11, 12, 13, 14, 15, 16, 
                    17, 19, 20, 29, 30, 31, 32, 
                    33, 34, 35, 37, 38, 47, 48, 
                    49, 50, 51, 52, 53, 55, 56
    ])
    # atomove cislo prvku A, 34 hodnot
    Prvky = np.array([
                 'Li', 'Be', 'B ', 'C ', 'N ',
                 'O ', 'F ', 'Na', 'Mg', 'Al', 'Si',
                 'P ', 'S ', 'Cl', 'K ', 'Ca', 'Cu',
                 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br',
                 'Rb', 'Sr', 'Ag', 'Cd', 'In', 'Sn',
                 'Sb', 'Te', 'I ', 'Cs', 'Ba'
    ])
    # prvky prislusejici danemu atomovemu cislu, 34 hodnot
    IP = np.array([
                  -5.329, -9.459, -8.190, -10.852, -13.585,
                  -16.433, -19.404, -5.223, -8.037,
                  -5.780, -7.758, -9.751, -11.795,
                  -13.902, -4.433, -6.428, -8.389, -10.136,
                  -5.818, -7.567, -9.262, -10.946, -12.650,
                  -4.289, -6.032, -8.058, -9.581, -5.537,
                  -7.043, -8.468, -9.867, -11.257, -4.006,
                  -5.516
    ])
    # ionizacni potencial (IP)
    EA =  np.array([
        -0.698, 0.631, -0.107, -0.872, -1.867, -3.006, -4.273, -0.716,
        0.693, -0.313, -0.993, -1.920, -2.845, -3.971, -0.621, 0.304,
        -1.638, 1.081, -0.108, -0.949, -1.839, -2.751, -3.739, -0.590,
        0.343, -1.667, 0.839, -0.256, -1.039, -1.847, -2.666, -3.513,
        -0.570, 0.278
    ])
    # elektronova afinita (EA)
    EN = np.array([
        3.014, 4.414, 4.149, 5.862, 7.726, 9.720, 11.839, 2.969, 3.672,
        3.046, 4.375, 5.835, 7.320, 8.936, 2.527, 3.062, 5.014, 4.527,
        2.963, 4.258, 5.551, 6.848, 8.194, 2.440, 2.844, 4.862, 4.371,
        2.897, 4.041, 5.158, 6.266, 7.385, 2.288, 2.619
    ])
    # Elektronegativita dle Mullikenovy definice (EN)
    HOMO = np.array([
        -2.874, -5.600, -3.715, -5.416, -7.239, -9.197, -11.294, -2.819,
        -4.782, -2.784, -4.163, -5.596, -7.106, -8.700, -2.426, -3.864,
        -4.856, -6.217, -2.732, -4.046, -5.341, -6.654, -8.001, -2.360,
        -3.641, -4.710, -5.952, -2.697, -3.866, -4.991, -6.109, -7.236,
        -2.220, -3.346
    ])

    LUMO = np.array([
        -0.978, -2.098, 2.248, 1.992, 3.057, 2.541, 1.251, -0.718, -1.358,
        0.695, 0.440, 0.183, 0.642, 0.574, -0.697, -2.133, -0.641, -1.194,
        0.130, 2.175, 0.064, 1.316, 0.708, -0.705, -1.379, -0.479, -1.309,
        0.368, 0.008, 0.105, 0.099, 0.213, -0.548, -2.129
    ])


    # The radii at which the radial probability density of the valence s, p, 
    # and d orbital are respectively maximal.
    r_s = np.array([
        1.652, 1.078, 0.805, 0.644, 0.539, 0.462, 0.406, 1.715, 1.330, 1.092,
        0.938, 0.826, 0.742, 0.679, 2.128, 1.757, 1.197, 1.099, 0.994, 0.917,
        0.847, 0.798, 0.749, 2.240, 1.911, 1.316, 1.232, 1.134, 1.057, 1.001,
        0.945, 0.896, 2.464, 2.149
    ])

    r_p = np.array([
        1.995, 1.211, 0.826, 0.630, 0.511, 0.427, 0.371, 2.597, 1.897, 1.393,
        1.134, 0.966, 0.847, 0.756, 2.443, 2.324, 1.680, 1.547, 1.330, 1.162,
        1.043, 0.952, 0.882, 3.199, 2.548, 1.883, 1.736, 1.498, 1.344, 1.232,
        1.141, 1.071, 3.164, 2.632
    ])

    r_d = np.array([
        6.930, 2.877, 1.946, 1.631, 1.540, 2.219, 1.428, 6.566, 3.171, 1.939,
        1.890, 1.771, 2.366, 1.666, 1.785, 0.679, 2.576, 2.254, 2.163, 2.373,
        2.023, 2.177, 1.869, 1.960, 1.204, 2.968, 2.604, 3.108, 2.030, 2.065,
        1.827, 1.722, 1.974, 1.351
    ])


    dE = np.array([
        -0.059, -0.038, -0.033, -0.022, 0.430, 0.506, 0.495, 0.466, 1.713,
        1.020, 0.879, 2.638, -0.146, -0.133, -0.127, -0.115, -0.178, -0.087,
        -0.055, -0.005, 0.072, 0.219, 0.212, 0.150, 0.668, 0.275, -0.146,
        -0.165, -0.166, -0.168, -0.266, -0.369, -0.361, -0.350, -0.019,
        0.156, 0.152, 0.203, 0.102, 0.275, 0.259, 0.241, 0.433, 0.341, 0.271,
        0.158, 0.202, -0.136, -0.161, -0.164, -0.169, -0.221, -0.369, -0.375,
        -0.381, -0.156, -0.044, -0.030, 0.037, -0.087, 0.070, 0.083, 0.113,
        0.150, 0.170, 0.122, 0.080, 0.016, 0.581, -0.112, -0.152, -0.158,
        -0.165, -0.095, -0.326, -0.350, -0.381, 0.808, 0.450, 0.264, 0.136,
        0.087
    ])
    # dE = E(RS) - E(ZB) ... 82 hodnot pro binární sloučeniny
    AB = np.array([
        'Li-F ', 'Li-Cl', 'Li-Br', 'Li-I ', 'Be-O ', 'Be-S ', 'Be-Se', 'Be-Te',
        'B -N ', 'B -P ', 'B -As', 'C -C ', 'Na-F ', 'Na-Cl', 'Na-Br', 'Na-I ',
        'Mg-O ', 'Mg-S ', 'Mg-Se', 'Mg-Te', 'Al-N ', 'Al-P ', 'Al-As', 'Al-Sb',
        'Si-C ', 'Si-Si', 'K -F ', 'K -Cl', 'K -Br', 'K -I ', 'Ca-O ', 'Ca-S ',
        'Ca-Se', 'Ca-Te', 'Cu-F ', 'Cu-Cl', 'Cu-Br', 'Cu-I ', 'Zn-O ', 'Zn-S ',
        'Zn-Se', 'Zn-Te', 'Ga-N ', 'Ga-P ', 'Ga-As', 'Ga-Sb', 'Ge-Ge', 'Rb-F ',
        'Rb-Cl', 'Rb-Br', 'Rb-I ', 'Sr-O ', 'Sr-S ', 'Sr-Se', 'Sr-Te', 'Ag-F ',
        'Ag-Cl', 'Ag-Br', 'Ag-I ', 'Cd-O ', 'Cd-S ', 'Cd-Se', 'Cd-Te', 'In-N ',
        'In-P ', 'In-As', 'In-Sb', 'Sn-Sn', 'B -Sb', 'Cs-F ', 'Cs-Cl', 'Cs-Br',
        'Cs-I ', 'Ba-O ', 'Ba-S ', 'Ba-Se', 'Ba-Te', 'Ge-C ', 'Sn-C ', 'Ge-Si',
        'Sn-Si', 'Sn-Ge'
    ])
    # Z vektorů dat vytvoření dictionary obsahující listy, které mají prvky svoje vypočtené hodnoty
    # Kodovani pomoci stringu nazev prvku (dva charaktery dlouhy!!!!)
    oniers = {}
    for i in range(len(Prvky)):
        oniers[Prvky[i]]= [Z[i], IP[i], EA[i], EN[i], HOMO[i], LUMO[i], r_s[i], r_p[i], r_d[i]]

    # Data jednotlivych dimeru ulozenych v dictionary listů, celkem 82 listů
    # Tyto listy jsou vlastně matice o 8 radcich a dvou sloupcich
    dimers = {} # inicializace
    temp = [] #inicializace temporary listu
    for i in AB: # pro kazdy dimer
        for j in range(1,9): # vytvori list listů osmi dvojic hodnot (bez Z[i])
            temp.append( [ oniers[i[:2]][j] , oniers[i[3:]][j] ] )
            # [IP, EA, EN, HOMO, LUMO, r_s, r_p, r_d]
        dimers[i] = temp # kodovani pomoci nazvu dimeru
        temp = [] # clearing pro dalsi iteraci
    dE = dE.reshape(-1,1) # restrukturalizace dat
    return dimers, AB, dE
    
def feature_space_generation(noise, noised_feature, sigma, dimers, AB, tier0, tier1, tier2, tier3, tier4, tier5):    
    # Nyni definujeme a nagenerujeme mnoziny moznych deskriptoru
    # Jednotlive mnoziny jsou listy floatovych hodnot ulozene jako dictionary a klic je  nazev dimeru ve formatu '__-__'
    # Brute force definice zakladnich mnozin deskriptorů pro kazdy dimer:
    
    # tier - which tier of the descriptor to include
    # Vektor popisující tvar deskriptorů:
    A1 = {}
    A2 = {}
    A3 = {}
    for i in AB:
        A1[i] = [ dimers[i][0][0] , dimers[i][1][0] , dimers[i][0][1] , dimers[i][1][1] ] 
        A2[i] = [ dimers[i][3][0] , dimers[i][4][0] , dimers[i][3][1] , dimers[i][4][1] ]
        A3[i] = [ dimers[i][5][0] , dimers[i][6][0] , dimers[i][7][0] , dimers[i][5][1] , dimers[i][6][1] , dimers[i][7][1] ]
        
    if noise==True:
        gauss = np.random.normal(1, sigma, 1)[0]
        
        if noised_feature==1 or noised_feature==True:
            for i in AB:
                A1[i][0] = A1[i][0]*np.random.normal(1, sigma, 1)[0]
        if noised_feature==2 or noised_feature==True:
            for i in AB:
                A1[i][1] = A1[i][1]*np.random.normal(1, sigma, 1)[0]
        if noised_feature==3 or noised_feature==True:
            for i in AB:
                A1[i][2] = A1[i][2]*np.random.normal(1, sigma, 1)[0]
        if noised_feature==4 or noised_feature==True:
            for i in AB:
                A1[i][3] = A1[i][3]*np.random.normal(1, sigma, 1)[0]
        if noised_feature==5 or noised_feature==True:
            for i in AB:
                A2[i][0] = A2[i][0]*np.random.normal(1, sigma, 1)[0]
        if noised_feature==6 or noised_feature==True:
            for i in AB:
                A2[i][1] = A2[i][1]*np.random.normal(1, sigma, 1)[0]
        if noised_feature==7 or noised_feature==True:
            for i in AB:
                A2[i][2] = A2[i][2]*np.random.normal(1, sigma, 1)[0]
        if noised_feature==8 or noised_feature==True:
            for i in AB:
                A2[i][3] = A2[i][3]*np.random.normal(1, sigma, 1)[0]
        if noised_feature==9 or noised_feature==True:
            for i in AB:
                A3[i][0] = A3[i][0]*np.random.normal(1, sigma, 1)[0]
        if noised_feature==10 or noised_feature==True:
            for i in AB:
                A3[i][1] = A3[i][1]*np.random.normal(1, sigma, 1)[0]
        if noised_feature==11 or noised_feature==True:
            for i in AB:
                A3[i][2] = A3[i][2]*np.random.normal(1, sigma, 1)[0]
        if noised_feature==12 or noised_feature==True:
            for i in AB:
                A3[i][3] = A3[i][3]*np.random.normal(1, sigma, 1)[0]
        if noised_feature==13 or noised_feature==True:
            for i in AB:
                A3[i][4] = A3[i][4]*np.random.normal(1, sigma, 1)[0]
        if noised_feature==14 or noised_feature==True:
            for i in AB:
                A3[i][5] = A3[i][5]*np.random.normal(1, sigma, 1)[0]
            
        
    DD = []
    DD.append('IP(A)')
    DD.append('EA(A)')
    DD.append('IP(B)')
    DD.append('EA(B)')

    DD.append('H(A)')
    DD.append('L(A)')
    DD.append('H(B)')
    DD.append('L(B)')

    DD.append('r_s(A)')
    DD.append('r_p(A)')
    DD.append('r_d(A)')
    DD.append('r_s(B)')
    DD.append('r_p(B)')
    DD.append('r_d(B)')
    
    if tier0==True:####
        DD_A1=DD[:4]
        DD_A2=DD[4:8]
        DD_A3=DD[8:14]

    # Generovani jednoduchych deskriptoru
    DD_B1 = []
    DD_B2 = []
    DD_B3 = []

    DD_C3 = []
    DD_D3 = []
    DD_E3 = []
    
    if tier1==True:####
        DD_dvojice = list( it.combinations( DD_A1 , 2 ) )
        for j in DD_dvojice:
            DD_B1.append('|'+j[0]+'-'+j[1]+'|')
            DD_B1.append('|'+j[0]+'+'+j[1]+'|')

        DD_dvojice = list( it.combinations( DD_A2 , 2 ) )
        for j in DD_dvojice:
            DD_B2.append('|'+j[0]+'-'+j[1]+'|')
            DD_B2.append('|'+j[0]+'+'+j[1]+'|')

        DD_dvojice = list( it.combinations( DD_A3 , 2 ) )
        for j in DD_dvojice:
            DD_B3.append('|'+j[0]+'-'+j[1]+'|')
            DD_B3.append('|'+j[0]+'+'+j[1]+'|')

    DD = DD + DD_B1 + DD_B2 + DD_B3
    
    if tier1==True:####
        for j in DD_A3:
            DD_C3.append(j+'^2')
            
    if tier2==True:####
        DD_dvojice = list( it.combinations( DD_A3 , 2 ) )
        for j in DD_dvojice:
            DD_C3.append('('+j[0]+'+'+j[1]+')^2')
    
    if tier1==True:####
        for j in DD_A3:
            DD_D3.append('exp('+j+')')
            
    if tier2==True:####
        DD_dvojice = list( it.combinations( DD_A3 , 2 ) )
        for j in DD_dvojice:
            DD_D3.append('exp('+j[0]+'+'+j[1]+')')
            
    if tier2==True:####
        for j in DD_A3:
            DD_E3.append('exp('+j+')^2')

    if tier3==True:
        DD_dvojice = list( it.combinations( DD_A3 , 2 ) )
        for j in DD_dvojice:
            DD_E3.append('exp('+j[0]+'+'+j[1]+')^2')

    DD = DD + DD_C3 + DD_D3 + DD_E3


    B1 = {}
    B2 = {}
    B3 = {}
    C3 = {}
    D3 = {}
    E3 = {}
    temp = []

    for i in AB:

        if tier1==True:####
            dvojice = list( it.combinations( A1[i] , 2 ) )
            for j in dvojice:
                temp.append( abs( j[0] - j[1] ) )
                temp.append( abs( j[0] + j[1] ) )

            B1[i] = temp
            temp = []


            dvojice = list( it.combinations( A2[i] , 2 ) )
            for j in dvojice:
                temp.append( abs( j[0] - j[1] ) )
                temp.append( abs( j[0] + j[1] ) )

            B2[i] = temp
            temp = []


            dvojice=list( it.combinations( A3[i] , 2 ) )
            for j in dvojice:
                temp.append( abs( j[0] - j[1] ) )
                temp.append( abs( j[0] + j[1] ) )

            B3[i] = temp
            temp = []
            
        else:
            
            B1[i] = []
            B2[i] = []
            B3[i] = []
            temp = []



        if tier1==True:####
            for j in A3[i]:
                temp.append( (j)**2 )

        if tier2==True:####
            dvojice=list( it.combinations( A3[i] , 2 ) )
            for j in dvojice:
                temp.append( ( j[0] + j[1] )**2 )

        C3[i] = temp
        temp = []

        if tier1==True:####
            for j in A3[i]:
                temp.append( m.exp( j ) )

        if tier2==True:####
            dvojice = list( it.combinations( A3[i] , 2 ) )
            for j in dvojice:
                temp.append( m.exp( j[0] + j[1] ) )

        D3[i] = temp
        temp = []

        if tier2==True:####
            for j in A3[i]:
                temp.append( m.exp( (j)**2 ) )

        if tier3==True:####
            dvojice = list( it.combinations( A3[i] , 2 ) )
            for j in dvojice:
                temp.append( m.exp( ( j[0] + j[1] )**2 ) )

        E3[i] = temp
        temp = []



    G = {}
    temp = []

    DD_G = []


    # A1,B1 ; A2,B2 lomeno A3,C3,D3,E3
    if tier1==True and tier2==False:####                    
        for j in [DD_A1, DD_A2]:
            for l in [DD_A3]:
                for k in list(it.product(j,l)):
                    DD_G.append(k[0]+'/'+k[1])

    elif tier1==True and tier2==True and tier3==False:                    
        for j in [DD_A1, DD_A2, DD_B1, DD_B2]:
            for l in [DD_A3]:
                for k in list(it.product(j,l)):
                    DD_G.append(k[0]+'/'+k[1])
                    
        for j in [DD_A1, DD_A2]:
            for l in [DD_C3[:6], DD_D3[:6]]:
                for k in list(it.product(j,l)):
                    DD_G.append(k[0]+'/'+k[1])

    elif tier1==True and tier2==True and tier3==True and tier4==False:            
        for j in [DD_A1, DD_A2, DD_B1, DD_B2]:
            for l in [DD_A3, DD_C3[:6], DD_D3[:6]]:
                for k in list(it.product(j,l)):
                    DD_G.append(k[0]+'/'+k[1])
                    
        for j in [DD_A1, DD_A2]:
            for l in [DD_C3[6:], DD_D3[6:]]:
                for k in list(it.product(j,l)):
                    DD_G.append(k[0]+'/'+k[1])

    elif tier1==True and tier2==True and tier3==True and tier4==True and tier5==False:                    
        for j in [DD_A1, DD_A2, DD_B1, DD_B2]:
            for l in [DD_A3, DD_C3, DD_D3]:
                for k in list(it.product(j,l)):
                    DD_G.append(k[0]+'/'+k[1])

        for j in [DD_A1, DD_A2, DD_B1, DD_B2]:
            for l in [DD_E3[:6]]:
                for k in list(it.product(j,l)):
                    DD_G.append(k[0]+'/'+k[1])
                    
        for j in [DD_A1, DD_A2]:
            for l in [DD_E3[6:]]:
                for k in list(it.product(j,l)):
                    DD_G.append(k[0]+'/'+k[1])

    elif tier1==True and tier2==True and tier3==True and tier4==True and tier5==True:                       
        for j in [DD_A1, DD_A2, DD_B1, DD_B2]:
            for l in [DD_A3, DD_C3, DD_D3, DD_E3]:
                for k in list(it.product(j,l)):
                    DD_G.append(k[0]+'/'+k[1])

# no tiers:
#    for j in [DD_A1, DD_A2, DD_B1, DD_B2]:
#        for l in [DD_C3, DD_D3, DD_E3]:
#            for k in list(it.product(j,l)):
#                DD_G.append(k[0]+'/'+k[1])
                    

    
    
    
    # A3/D3 a A3/E3
    if tier1==True and tier2==True and tier3==False:
        for j in [DD_D3[:6]]:
            for k in list(it.product(DD_A3,j)):
                DD_G.append(k[0]+'/'+k[1])

    elif tier1==True and tier2==True and tier3==True and tier4==False:                
        for j in [DD_D3, DD_E3[:6]]:
            for k in list(it.product(DD_A3,j)):
                DD_G.append(k[0]+'/'+k[1])

    elif tier1==True and tier2==True and tier3==True and tier4==True:
        for j in [DD_D3, DD_E3]:
            for k in list(it.product(DD_A3,j)):
                DD_G.append(k[0]+'/'+k[1])

    # B3/D3 a B3/E3
    if tier1==True and tier2==True and tier3==True and tier4==False:
        for j in [DD_D3[:6]]:
            for k in list(it.product(DD_B3,j)):
                DD_G.append(k[0]+'/'+k[1])

    elif tier1==True and tier2==True and tier3==True and tier4==True and tier5==False:
        for j in [DD_D3, DD_E3[:6]]:
            for k in list(it.product(DD_B3,j)):
                DD_G.append(k[0]+'/'+k[1])

    elif tier1==True and tier2==True and tier3==True and tier4==True and tier5==True:  
        for j in [DD_D3, DD_E3]:
            for k in list(it.product(DD_B3,j)):
                DD_G.append(k[0]+'/'+k[1])
    
    
# no tiers:
#    for j in [DD_D3, DD_E3]:
#        for k in list(it.product(DD_A3,j)):
#            DD_G.append(k[0]+'/'+k[1])
#
#    for j in [DD_D3, DD_E3]:
#        for k in list(it.product(DD_B3,j)):
#            DD_G.append(k[0]+'/'+k[1])


    # Problemove:

    # A3/A3
    if tier1==True:####
        DD_dvojice=list( it.combinations( DD_A3 , 2 ) )
        for j in DD_dvojice:
            DD_G.append(j[0]+'/'+j[1])
            DD_G.append(j[1]+'/'+j[0])


    # A3/C3
    if tier2==True:####
        DD_dvojice=list( it.combinations( DD_A3 , 2 ) )
        for j in DD_dvojice:
            DD_G.append(j[0]+'/'+'('+j[1]+')'+'^2')
            DD_G.append(j[1]+'/'+'('+j[0]+')'+'^2')
            #DD_G.append(j[0]+'/'+'('+j[0] +'+'+ j[1]+')^2')
            #DD_G.append(j[1]+'/'+'('+j[0] +'+'+ j[1]+')^2')

    if tier3==True:####
        DD_dvojice=list( it.combinations( DD_A3 , 2 ) )
        for j in DD_dvojice:
            #DD_G.append(j[0]+'/'+'('+j[1]+')'+'^2')
            #DD_G.append(j[1]+'/'+'('+j[0]+')'+'^2')
            DD_G.append(j[0]+'/'+'('+j[0] +'+'+ j[1]+')^2')
            DD_G.append(j[1]+'/'+'('+j[0] +'+'+ j[1]+')^2')


    if tier3==True:####
        DD_trojice=list(it.combinations(DD_A3,3))
        for j in DD_trojice:
            for k in [(0,1,2),(1,0,2),(2,0,1)]:
                DD_G.append(j[k[0]]+'/'+'('+j[k[1]] +'+'+ j[k[2]]+')^2')



    # B3/A3:
    if tier2==True:####
        DD_trojice=list(it.combinations(DD_A3,3))
        for j in DD_trojice:
            for k in [(0,1,2),(2,1,0),(0,2,1)]:
                DD_G.append('|'+j[k[0]] +'-'+ j[k[1]]+'| /'+j[k[2]])
    
    if tier2==True:####
        DD_dvojice=list(it.combinations(DD_A3,2))
        for j in DD_dvojice:
            DD_G.append('|'+j[1]+'-'+j[0]+'| /' + j[1])
            DD_G.append('|'+j[0]+'-'+j[1]+'| /' + j[0])

    # B3/C3
    if tier3==True:####
        DD_trojice=list(it.combinations(DD_A3,3))
        for j in DD_trojice:
            for k in [(0,1,2),(2,1,0),(0,2,1)]:
                DD_G.append('|' + j[k[0]] + '-' +  j[k[1]]+'| /'+j[k[2]]+ '^2')
    
    if tier4==True:####
        DD_dvojice=list(it.combinations(DD_A3,2))
        for j in DD_dvojice:
            DD_G.append('|'+j[0]+'-'+j[1]+'| /'+'('+j[0]+'+'+j[1]+')^2')

    if tier3==True:####
        DD_dvojice=list(it.combinations(DD_A3,2))
        for j in DD_dvojice:
            DD_G.append('|'+j[0]+'-'+j[1]+'| /' + j[0]+'^2')
            DD_G.append('|'+j[1]+'-'+j[0]+'| /'+j[1]+'^2')

    if tier4==True:####
        DD_trojice=list(it.combinations(DD_A3,3))
        for j in DD_trojice:
            DD_G.append('|'+j[0]+'-'+j[1]+'| /'+ '('+j[0]+'+'+j[2]+')^2')
            DD_G.append('|'+j[0]+'-'+j[2]+'| /'+ '('+j[0]+'+'+j[1]+')^2')

    if tier4==True:####
        DD_ctverice = list(it.combinations(DD_A3,4))
        for j in DD_ctverice:
            for k in [(0,1,2,3),(0,2,1,3),(0,3,1,2),(2,1,0,3),(3,1,0,2),(2,3,0,1)]:
                #temp.append(abs(j[k[0]]+j[k[1]])/(j[k[2]]+j[k[3]])**2)
                DD_G.append('|'+j[k[0]]+'-'+j[k[1]]+') /'+'('+j[k[2]]+'+'+j[k[3]]+')^2')


    # Zde je TEST: Ciste podily 1/r navic:
    if tier1==True:####
        for j in DD_A3:
            DD_G.append('1/' + j)

    DD =  DD + DD_G

#####################################################
    temp = []
    for i in AB:
        # A1,B1 ; A2,B2 lomeno A3,C3,D3,E3
        if tier1==True and tier2==False:####
            for j in [A1[i], A2[i]]:
                for l in [A3[i]]:
                    for k in list(it.product(j,l)):
                        temp.append(k[0]/k[1])
                        
        elif tier1==True and tier2==True and tier3==False:
            for j in [A1[i], A2[i], B1[i], B2[i]]:
                for l in [A3[i]]:
                    for k in list(it.product(j,l)):
                        temp.append(k[0]/k[1])
                        
            for j in [A1[i], A2[i]]:
                for l in [C3[i][:6], D3[i][:6]]:
                    for k in list(it.product(j,l)):
                        temp.append(k[0]/k[1])
                        
        elif tier1==True and tier2==True and tier3==True and tier4==False:            
            for j in [A1[i], A2[i], B1[i], B2[i]]:
                for l in [A3[i], C3[i][:6], D3[i][:6]]:
                    for k in list(it.product(j,l)):
                        temp.append(k[0]/k[1])
                        
            for j in [A1[i], A2[i]]:
                for l in [C3[i][6:], D3[i][6:]]:
                    for k in list(it.product(j,l)):
                        temp.append(k[0]/k[1])
                        
        elif tier1==True and tier2==True and tier3==True and tier4==True and tier5==False:
            for j in [A1[i], A2[i], B1[i], B2[i]]:
                for l in [A3[i], C3[i], D3[i]]:
                    for k in list(it.product(j,l)):
                        temp.append(k[0]/k[1])
                        
            for j in [A1[i], A2[i], B1[i], B2[i]]:
                for l in [E3[i][:6]]:
                    for k in list(it.product(j,l)):
                        temp.append(k[0]/k[1])
                        
            for j in [A1[i], A2[i]]:
                for l in [E3[i][6:]]:
                    for k in list(it.product(j,l)):
                        temp.append(k[0]/k[1])

        elif tier1==True and tier2==True and tier3==True and tier4==True and tier5==True:   
            for j in [A1[i], A2[i], B1[i], B2[i]]:
                for l in [A3[i], C3[i], D3[i], E3[i]]:
                    for k in list(it.product(j,l)):
                        temp.append(k[0]/k[1])
###################################################
                        
                        
###################################################
        # A3/D3 a A3/E3
        if tier1==True and tier2==True and tier3==False:
            for j in [D3[i][:6]]:
                for k in list(it.product(A3[i],j)):
                    temp.append(k[0]/k[1])
                    
        elif tier1==True and tier2==True and tier3==True and tier4==False:
            for j in [D3[i], E3[i][:6]]:
                for k in list(it.product(A3[i],j)):
                    temp.append(k[0]/k[1])
                
        elif tier1==True and tier2==True and tier3==True and tier4==True:
            for j in [D3[i], E3[i]]:
                for k in list(it.product(A3[i],j)):
                    temp.append(k[0]/k[1])
                
        # B3/D3 a B3/E3
        if tier1==True and tier2==True and tier3==True and tier4==False:
            for j in [D3[i][:6]]:
                for k in list(it.product(B3[i],j)):
                    temp.append(k[0]/k[1])
                    
        elif tier1==True and tier2==True and tier3==True and tier4==True and tier5==False:
            for j in [D3[i], E3[i][:6]]:
                for k in list(it.product(B3[i],j)):
                    temp.append(k[0]/k[1])
                    
        elif tier1==True and tier2==True and tier3==True and tier4==True and tier5==True:  
            for j in [D3[i], E3[i]]:
                for k in list(it.product(B3[i],j)):
                    temp.append(k[0]/k[1])

####################################################
                    
        # PROBLEMOVE PODILY:

        # A3/A3 - vsechny podily jenom ne podily X/X = 1:
        if tier1==True:####
            dvojice=list( it.combinations( A3[i] , 2 ) )
            for j in dvojice:
                temp.append(j[0]/j[1])
                temp.append(j[1]/j[0])


        # A3/C3:
        if tier2==True:####
            dvojice=list(it.combinations(A3[i],2))
            for j in dvojice:
                temp.append(j[0]/(j[1])**2)
                temp.append(j[1]/(j[0])**2)
                #temp.append(j[0]/(j[0] + j[1])**2)
                #temp.append(j[1]/(j[0] + j[1])**2)
        
        if tier3==True:####
            dvojice=list(it.combinations(A3[i],2))
            for j in dvojice:
                #temp.append(j[0]/(j[1])**2)
                #temp.append(j[1]/(j[0])**2)
                temp.append(j[0]/(j[0] + j[1])**2)
                temp.append(j[1]/(j[0] + j[1])**2)                
                
                
        if tier3==True:
            trojice=list(it.combinations(A3[i],3))
            for j in trojice:
                for k in [(0,1,2),(1,0,2),(2,0,1)]:
                    temp.append(j[k[0]]/(j[k[1]] + j[k[2]])**2)


        # B3/A3:
        if tier2==True:
            trojice=list(it.combinations(A3[i],3))
            for j in trojice:
                for k in [(0,1,2),(2,1,0),(0,2,1)]:
                    temp.append(abs(j[k[0]] - j[k[1]])/j[k[2]])

        if tier2==True:
            dvojice=list(it.combinations(A3[i],2))
            for j in dvojice:
                temp.append(abs(1-(j[0]/j[1])))
                temp.append(abs(1-(j[1]/j[0])))


        # B3/C3
        if tier3==True:
            trojice=list(it.combinations(A3[i],3))
            for j in trojice:
                for k in [(0,1,2),(2,1,0),(0,2,1)]:
                    temp.append(abs(j[k[0]] - j[k[1]])/j[k[2]]**2)

        if tier4==True:
            dvojice=list(it.combinations(A3[i],2))
            for j in dvojice:
                temp.append(abs(j[0]-j[1])/(j[0]+j[1])**2)

        if tier3==True:
            dvojice=list(it.combinations(A3[i],2))
            for j in dvojice:
                temp.append(abs(j[0]-j[1])/j[0]**2)
                temp.append(abs(j[1]-j[0])/j[1]**2)

        if tier4==True:
            trojice=list(it.combinations(A3[i],3))
            for j in trojice:
                temp.append(abs(j[0]-j[1])/(j[0]+j[2])**2)
                temp.append(abs(j[0]-j[2])/(j[0]+j[1])**2)

        if tier4==True:
            ctverice = list(it.combinations(A3[i],4))
            for j in ctverice:
                for k in [(0,1,2,3),(0,2,1,3),(0,3,1,2),(2,1,0,3),(3,1,0,2),(2,3,0,1)]:
                    #temp.append(abs(j[k[0]]+j[k[1]])/(j[k[2]]+j[k[3]])**2)
                    temp.append(abs(j[k[0]]-j[k[1]])/(j[k[2]]+j[k[3]])**2)






        # Zde je TEST: Ciste podily 1/r navic:
        if tier1==True:####
            for j in A3[i]:
                temp.append(1/j)


        G[i] = temp
        temp = []



    # F1, F2, F3: ... 44 deskriptoru celkem
    if tier3==True:####
        DD_F1 = []
        for j in list(it.combinations(DD_A1[:2],2)):
            for k in list(it.combinations(DD_A1[2:],2)):
                DD_F1.append( "|" + "|" + j[0] + "-" + j[1] + "|" + "+" + "|" + k[0] + "-" + k[1] + "|" + "|" )
                DD_F1.append( "|" + "|" + j[0] + "-" + j[1] + "|" + "-" + "|" + k[0] + "-" + k[1] + "|" + "|" )
                DD_F1.append( "|" + "|" + j[0] + "+" + j[1] + "|" + "+" + "|" + k[0] + "+" + k[1] + "|" + "|" )
                DD_F1.append( "|" + "|" + j[0] + "+" + j[1] + "|" + "-" + "|" + k[0] + "+" + k[1] + "|" + "|" )

        DD = DD + DD_F1

        DD_F2 = []
        for j in list(it.combinations(DD_A2[:2],2)):
            for k in list(it.combinations(DD_A2[2:],2)):
                DD_F2.append( "|" + "|" + j[0] + "-" + j[1] + "|" + "+" + "|" + k[0] + "-" + k[1] + "|" + "|" )
                DD_F2.append( "|" + "|" + j[0] + "-" + j[1] + "|" + "-" + "|" + k[0] + "-" + k[1] + "|" + "|" )
                DD_F2.append( "|" + "|" + j[0] + "+" + j[1] + "|" + "+" + "|" + k[0] + "+" + k[1] + "|" + "|" )
                DD_F2.append( "|" + "|" + j[0] + "+" + j[1] + "|" + "-" + "|" + k[0] + "+" + k[1] + "|" + "|" )

        DD = DD + DD_F2

        DD_F3 = []
        for j in list(it.combinations(DD_A3[:3],2)):
            for k in list(it.combinations(DD_A3[3:],2)):
                DD_F3.append( "|" + "|" + j[0] + "-" + j[1] + "|" + "+" + "|" + k[0] + "-" + k[1] + "|" + "|" )
                DD_F3.append( "|" + "|" + j[0] + "-" + j[1] + "|" + "-" + "|" + k[0] + "-" + k[1] + "|" + "|" )
                DD_F3.append( "|" + "|" + j[0] + "+" + j[1] + "|" + "+" + "|" + k[0] + "+" + k[1] + "|" + "|" )
                DD_F3.append( "|" + "|" + j[0] + "+" + j[1] + "|" + "-" + "|" + k[0] + "+" + k[1] + "|" + "|" )
        
        DD = DD + DD_F3

    # F1, F2, F3: ... 52 deskriptoru celkem

    F1 = {}
    F2 = {}
    F3 = {}

    temp = []
    if tier3==True:####
        for i in AB:
            for j in list(it.combinations(A1[i][:2],2)):
                for k in list(it.combinations(A1[i][2:],2)):
                    temp.append( abs( abs(j[0]-j[1]) + abs(k[0]-k[1]) ) )
                    temp.append( abs( abs(j[0]-j[1]) - abs(k[0]-k[1]) ) )
                    temp.append( abs( abs(j[0]+j[1]) + abs(k[0]+k[1]) ) )
                    temp.append( abs( abs(j[0]+j[1]) - abs(k[0]+k[1]) ) )

            F1[i] = temp
            temp = []


        for i in AB:
            for j in list(it.combinations(A2[i][:2],2)):
                for k in list(it.combinations(A2[i][2:],2)):
                    temp.append( abs( abs(j[0]-j[1]) + abs(k[0]-k[1]) ) )
                    temp.append( abs( abs(j[0]-j[1]) - abs(k[0]-k[1]) ) )
                    temp.append( abs( abs(j[0]+j[1]) + abs(k[0]+k[1]) ) )
                    temp.append( abs( abs(j[0]+j[1]) - abs(k[0]+k[1]) ) )

            F2[i] = temp
            temp = []


        for i in AB:
            for j in list(it.combinations(A3[i][:3],2)):
                for k in list(it.combinations(A3[i][3:],2)):
                    temp.append( abs( abs(j[0]-j[1]) + abs(k[0]-k[1]) ) )
                    temp.append( abs( abs(j[0]-j[1]) - abs(k[0]-k[1]) ) )
                    temp.append( abs( abs(j[0]+j[1]) + abs(k[0]+k[1]) ) )
                    temp.append( abs( abs(j[0]+j[1]) - abs(k[0]+k[1]) ) )

            F3[i] = temp
            temp = []
    else:
        for i in AB:
            F1[i] = []
            F2[i] = []
            F3[i] = []
            temp = []
        
    # 
    Descriptors = len(A1[AB[0]])+len(A2[AB[0]])+len(A3[AB[0]])+ len(B1[AB[0]])+len(B2[AB[0]])+len(B3[AB[0]])+len(C3[AB[0]])+len(D3[AB[0]])+len(E3[AB[0]])+len(G[AB[0]])+len(F1[AB[0]])+len(F2[AB[0]])+len(F3[AB[0]])

    # Matice D. Opustíme dictionary a použijeme np.array
    D=np.empty((82,Descriptors),dtype=float)
    for i in range(len(AB)):
        for j in range(len(A1[AB[i]])):
            D[i][j] = A1[AB[i]][j]

        for j in range(len(A2[AB[i]])):
            D[i][j+len(A1[AB[i]])] = A2[AB[i]][j]

        for j in range(len(A3[AB[i]])):
            D[i][j+len(A1[AB[i]])+len(A2[AB[i]])] = A3[AB[i]][j]

        for j in range(len(B1[AB[i]])):
            D[i][j+len(A1[AB[i]])+len(A2[AB[i]])+len(A3[AB[i]])] = B1[AB[i]][j]

        for j in range(len(B2[AB[i]])):
            D[i][j+len(A1[AB[i]])+len(A2[AB[i]])+len(A3[AB[i]])+len(B1[AB[i]])] = B2[AB[i]][j]

        for j in range(len(B3[AB[i]])):
            D[i][j+len(A1[AB[i]])+len(A2[AB[i]])+len(A3[AB[i]])+len(B1[AB[i]])+len(B2[AB[i]])] = B3[AB[i]][j]

        for j in range(len(C3[AB[i]])):
            D[i][j+len(A1[AB[i]])+len(A2[AB[i]])+len(A3[AB[i]])+len(B1[AB[i]])+len(B2[AB[i]])+len(B3[AB[i]])] = C3[AB[i]][j]

        for j in range(len(D3[AB[i]])):
            D[i][j+len(A1[AB[i]])+len(A2[AB[i]])+len(A3[AB[i]])+len(B1[AB[i]])+len(B2[AB[i]])+len(B3[AB[i]])+len(C3[AB[i]])] = D3[AB[i]][j]

        for j in range(len(E3[AB[i]])):
            D[i][j+len(A1[AB[i]])+len(A2[AB[i]])+len(A3[AB[i]])+len(B1[AB[i]])+len(B2[AB[i]])+len(B3[AB[i]])+len(C3[AB[i]])+len(D3[AB[i]])] = E3[AB[i]][j]

        for j in range(len(G[AB[i]])):
            D[i][j+len(A1[AB[i]])+len(A2[AB[i]])+len(A3[AB[i]])+len(B1[AB[i]])+len(B2[AB[i]])+len(B3[AB[i]])+len(C3[AB[i]])+len(D3[AB[i]])+len(E3[AB[i]])] = G[AB[i]][j]

        for j in range(len(F1[AB[i]])):
            D[i][j+len(A1[AB[i]])+len(A2[AB[i]])+len(A3[AB[i]])+len(B1[AB[i]])+len(B2[AB[i]])+len(B3[AB[i]])+len(C3[AB[i]])+len(D3[AB[i]])+len(E3[AB[i]])+len(G[AB[i]])] = F1[AB[i]][j]

        for j in range(len(F2[AB[i]])):
            D[i][j+len(A1[AB[i]])+len(A2[AB[i]])+len(A3[AB[i]])+len(B1[AB[i]])+len(B2[AB[i]])+len(B3[AB[i]])+len(C3[AB[i]])+len(D3[AB[i]])+len(E3[AB[i]])+len(G[AB[i]])+len(F1[AB[i]])] = F2[AB[i]][j]

        for j in range(len(F3[AB[i]])):
            D[i][j+len(A1[AB[i]])+len(A2[AB[i]])+len(A3[AB[i]])+len(B1[AB[i]])+len(B2[AB[i]])+len(B3[AB[i]])+len(C3[AB[i]])+len(D3[AB[i]])+len(E3[AB[i]])+len(G[AB[i]])+len(F1[AB[i]])+len(F2[AB[i]])] = F3[AB[i]][j]


    #print('A1: ', '0 ... ',len(A1[AB[0]])-1)
    #print('A2: ',len(A1[AB[0]]),' ... ',len(A1[AB[0]])+len(A2[AB[0]])-1)
    #print('A3: ',len(A1[AB[0]])+len(A2[AB[0]]), ' ... ', len(A1[AB[0]])+len(A2[AB[0]])+len(A3[AB[0]])-1)
    #print('B1: ',len(A1[AB[0]])+len(A2[AB[0]])+len(A3[AB[0]]), ' ... ',  len(A1[AB[0]])+len(A2[AB[0]])+len(A3[AB[0]])+len(B1[AB[i]])-1)
    #print('B2: ',len(A1[AB[0]])+len(A2[AB[0]])+len(A3[AB[0]])+len(B1[AB[0]]), ' ... ',len(A1[AB[0]])+len(A2[AB[0]])+len(A3[AB[0]])+len(B1[AB[0]])+len(B2[AB[0]])-1)
    #print('B3: ',len(A1[AB[0]])+len(A2[AB[0]])+len(A3[AB[0]])+len(B1[AB[0]])+len(B2[AB[0]]), ' ... ', len(A1[AB[0]])+len(A2[AB[0]])+len(A3[AB[0]])+len(B1[AB[0]])+len(B2[AB[0]])+len(B3[AB[0]])-1)
    #print('C3: ', len(A1[AB[0]])+len(A2[AB[0]])+len(A3[AB[0]])+len(B1[AB[0]])+len(B2[AB[0]])+len(B3[AB[0]]),' ... ', len(A1[AB[0]])+len(A2[AB[0]])+len(A3[AB[0]])+len(B1[AB[0]])+len(B2[AB[0]])+len(B3[AB[0]])+len(C3[AB[0]])-1)
    #print('D3: ',len(A1[AB[0]])+len(A2[AB[0]])+len(A3[AB[0]])+len(B1[AB[0]])+len(B2[AB[0]])+len(B3[AB[0]])+len(C3[AB[0]]), ' ... ', len(A1[AB[0]])+len(A2[AB[0]])+len(A3[AB[0]])+len(B1[AB[0]])+len(B2[AB[0]])+len(B3[AB[0]])+len(C3[AB[0]])+len(D3[AB[0]])-1)
    #print('E3: ',len(A1[AB[0]])+len(A2[AB[0]])+len(A3[AB[0]])+len(B1[AB[0]])+len(B2[AB[0]])+len(B3[AB[0]])+len(C3[AB[0]])+len(D3[AB[0]]), ' ... ', len(A1[AB[0]])+len(A2[AB[0]])+len(A3[AB[0]])+len(B1[AB[0]])+len(B2[AB[0]])+len(B3[AB[0]])+len(C3[AB[0]])+len(D3[AB[0]])+len(E3[AB[0]])-1)
    
    return D, DD, A1, A2, A3, B1, B2, B3, C3, D3, E3, F1, F2, F3, G

# LASSO+$\ell_{0}$ Minimazition

In [3]:
def LASSO(POCET,LAMBDA_LENGTH,D,dE):
    # LASSO
    # Standardizace dat
    #dE_standard=preprocessing.StandardScaler().fit_transform(dE)
    #D_standard=preprocessing.StandardScaler().fit_transform(D)

    # Normalizace dat
    D_normalized=preprocessing.normalize(D,axis=0)
    #dE_normalized=preprocessing.normalize(dE,axis=0)

    # Hledani nejmensiho lambda, kdy jsou vsechny koeficienty nula. Poté generování 100 hodnot lambda sestupne dle geometricke posloupnosti
    LAMBDA=np.empty(LAMBDA_LENGTH,dtype=float)
    LAMBDA[0]=(1/D.shape[0])*np.max((D_normalized.T).dot(dE))
    q=m.pow(1000,1/(LAMBDA_LENGTH-1))
    for i in range(1,len(LAMBDA)):
        LAMBDA[i]=LAMBDA[i-1]/q
    
    # LASSO fit pro vektor LAMBDA hodnot
    Potencialni=np.zeros((POCET,LAMBDA_LENGTH),dtype=int)-np.ones((POCET,LAMBDA_LENGTH),dtype=int)
    for j in range(1,LAMBDA_LENGTH): ### BEZIME OD JEDNICKY KVULI ZAOKHROUHLOVANI NA HRANICI NULOVYCH KOEFICIENTU
        lasso = linear_model.Lasso(alpha=LAMBDA[j],fit_intercept=True, normalize=True, max_iter=1e05, warm_start=True, positive=False, copy_X=True, precompute=True)
        lasso.fit(D,dE)

        coef = np.absolute(lasso.sparse_coef_.data) # np.absolute? jo nebo ne?

        k=0
        while(k<POCET):
            if np.max(coef)!=0:
                Potencialni[k,j]=lasso.sparse_coef_.indices[np.argmax(coef)]
                coef[np.argmax(coef)] = 0
                k=k+1
            else:
                break

    THETA=set() # 31. je #966.... je mozne ze pridanim dalsich sloupcu ktere maji vysokou korelaci nakonec bude zaclenen do danych triceti, ale momentalne neni. Pridanim mnozin F1*, F2*, F3* se dokazalo
    # mezi tricet nejlepsich dat deskriptor cislo 966.
    for j in range(POCET):
        for i in Potencialni[j,:]:
            if len(THETA)<POCET and i!=-1:
                THETA.add(i)
    # Ordinary Least Squares
    THETA=list(THETA)
    THETA.sort() # prevedeni setu na list a serazeni vzestupne
    #rutina pro kazdou k-tici. spocitat OLS a porovnat nejlepsi MSE s prave spoctenym.

    # Definujeme nejhorsi mozne MSE, tj. MSE pri coef = 0 a intercept = 0:

    MSE = np.ones((4),dtype = float)*metrics.mean_squared_error(dE,np.zeros(dE.shape,dtype = float)) # ctyr dimenzionalni MSE vektor
    MaxAE = np.zeros((4),dtype = float)
    ND = [ [],[],[],[] ]
    # Specialni rutina pro 1D deskriptor
    for i in THETA:
        OMEGA = D[:,i].reshape(-1, 1)
        ols = linear_model.LinearRegression(fit_intercept = True, normalize = True)
        ols.fit(OMEGA, dE)
        #dE_predicted = np.ones((82,1),dtype = float)*ols.intercept_ + np.dot(OMEGA,ols.coef_.T)
        dE_predicted = ols.predict(OMEGA)
        novy_model_MSE = metrics.mean_squared_error(dE, dE_predicted)

        if novy_model_MSE < MSE[0]:
            MSE[0] = novy_model_MSE
            MaxAE[0] = metrics.max_error(dE,dE_predicted)
            ND[0] = [i, ols.coef_, ols.intercept_, m.sqrt(MSE[0]), MaxAE[0]]

    # Rutina pro 2D, ... ,4D deskriptor
    for j in range(2,5): # 2,3,4
        OMEGA = np.zeros((D.shape[0],j),dtype = float)
        jetice = it.combinations(THETA,j)

        for i in jetice:

            for k in range(j):
                OMEGA[:,k] = D[:,i[k]]

            ols = linear_model.LinearRegression(fit_intercept = True, normalize = True)
            ols.fit(OMEGA,dE)
            #dE_predicted = np.ones((82,1),dtype = float)*ols.intercept_ + np.dot(OMEGA,ols.coef_.T) stejne jako radek pod tim
            dE_predicted = ols.predict(OMEGA)
            novy_model_MSE = metrics.mean_squared_error(dE,dE_predicted)

            if novy_model_MSE < MSE[j-1]:
                MSE[j-1] = novy_model_MSE
                MaxAE[j-1] = metrics.max_error(dE,dE_predicted)
                ND[j-1] = [i, ols.coef_, ols.intercept_, m.sqrt(MSE[j-1]), MaxAE[j-1]]
    return ND

# Training

In [76]:
start_time = time.time()
POCET = 35 # 30,35,40,45
LAMBDA_LENGTH = 100

dimers,AB,dE = inicializace_dat()

(D, DD, A1, A2, A3, B1, B2, B3, C3, D3, E3, F1, F2, F3, G) = feature_space_generation(dimers,AB, True, True, True, True, True, True)

ND = LASSO(POCET, LAMBDA_LENGTH, D, dE)


print('| Index deskriptoru ,', 'Vektor koeficientů ,', 'Bias ,', 'RMSE ,', 'MaxAE reziduí| \n')
for i in range(4):
    print(ND[i],'\n')

print('Modely: \n')

print('dE_D1 =','+',round(ND[0][1][0][0],3),'*',DD[ND[0][0]],round(ND[0][2][0],3),'... RMSE =',ND[0][3],'eV', '\n')
print('dE_D2 =','+',round(ND[1][1][0][0],3),'*',DD[ND[1][0][0]],round(ND[1][1][0][1],3),'*',DD[ND[1][0][1]],round(ND[1][2][0],3),'... RMSE =',ND[1][3],'eV', '\n')
print('dE_D3 =','+',round(ND[2][1][0][0],3),'*',DD[ND[2][0][0]],round(ND[2][1][0][1],3),'*',DD[ND[2][0][1]],round(ND[2][1][0][2],3),'*',DD[ND[2][0][2]],round(ND[2][2][0],3),'... RMSE =',ND[2][3],'eV', '\n')
print('dE_D4 =','+',round(ND[3][1][0][0],3),'*',DD[ND[3][0][0]],round(ND[3][1][0][1],3),'*',DD[ND[3][0][1]],'+',round(ND[3][1][0][2],3),'*',DD[ND[3][0][2]],round(ND[3][1][0][3],3),'*',DD[ND[3][0][3]],'+',round(ND[3][2][0],3),'... RMSE =',ND[3][3],'eV', '\n')
elapsed_time = (time.time() - start_time)
print("Elapsed time:", elapsed_time)
#print(DD[ND[0][0]])
#print(DD[ND[1][0][0]],(DD[ND[1][0][1]]))
#print(DD[ND[2][0][0]],(DD[ND[2][0][1]]),(DD[ND[2][0][2]]))
#print(DD[ND[3][0][0]],(DD[ND[3][0][1]]),(DD[ND[3][0][2]]),(DD[ND[3][0][3]]))

# Jak dobre data vzdalenosti dimeru souhlasi s se skutecnymi vzdalenostmi v krystalu? (dimer NaCl vs krystal NaCl)
# Crystalline Atomic Position Estimator (CAPE)

| Index deskriptoru , Vektor koeficientů , Bias , RMSE , MaxAE reziduí| 

[819, array([[0.05462871]]), array([-0.33182968]), 0.13779627590238272, 0.42094868533700636] 

[(966, 2717), array([[ 0.112769  , -1.55825175]]), array([-0.13327281]), 0.09878859313854685, 0.2865286727517662] 

[(966, 2717, 3097), array([[ 0.10818263, -1.80660314, -3.78218416]]), array([-0.02326177]), 0.07559818375102052, 0.24384216831416916] 

[(2540, 2717, 3022, 3111), array([[ 15.13252431,  -1.42169934,   6.18964366, -15.7778313 ]]), array([0.1107227]), 0.06718535854498611, 0.23435703176842132] 

Modely: 

dE_D1 = + 0.055 * |IP(A)+IP(B)|/r_p(A)^2 -0.332 ... RMSE = 0.13779627590238272 eV 

dE_D2 = + 0.113 * |IP(B)-EA(B)|/r_p(A)^2 -1.558 * |r_s(A)-r_p(B)|/exp(r_s(A)) -0.133 ... RMSE = 0.09878859313854685 eV 

dE_D3 = + 0.108 * |IP(B)-EA(B)|/r_p(A)^2 -1.807 * |r_s(A)-r_p(B)|/exp(r_s(A)) -3.782 * |r_s(B)-r_p(B)|/exp(r_d(A)) -0.023 ... RMSE = 0.07559818375102052 eV 

dE_D4 = + 15.133 * r_s(B)/exp(r_p(A)+r_s(B))^2 -

# Some Poking Around

## A routine for finding an approximation of the lambda value when the number of descriptors changes

In [31]:
i=1
(D, DD, A1, A2, A3, B1, B2, B3, C3, D3, E3, F1, F2, F3, G) = feature_space_generation(dimers,AB)
D_normalized=preprocessing.normalize(D,axis=0)
LAMBDA_LENGTH = 100
LAMBDA=np.empty(LAMBDA_LENGTH,dtype=float)
LAMBDA[0]=(1/D.shape[0])*np.max((D_normalized.T).dot(dE))
q=m.pow(1000,1/(LAMBDA_LENGTH-1))

lasso = linear_model.Lasso(alpha=(LAMBDA[0]-0.0001*0),fit_intercept = True, normalize = True, max_iter=1e05)
lasso.fit(D,dE)
while(len(lasso.sparse_coef_.data)<20):
    lasso = linear_model.Lasso(alpha=(LAMBDA[0]-0.0001*i),fit_intercept = True, normalize = True, max_iter=1e05)
    lasso.fit(D,dE)
    i=i+1
print('Nejvetsi: ', np.max(np.absolute(lasso.sparse_coef_)))
print(np.absolute(lasso.sparse_coef_))
print("Lambda hodnota: ",LAMBDA[0]-0.0001*i, "Pocet iteraci: ",i)

A1:  0 ...  3
A2:  4  ...  7
A3:  8  ...  13
B1:  14  ...  25
B2:  26  ...  37
B3:  38  ...  67
C3:  68  ...  88
D3:  89  ...  109
E3:  110  ...  130
Nejvetsi:  68.45144161234207
  (0, 253)	0.03284668326594066
  (0, 553)	0.3736008229337588
  (0, 600)	0.0289746437931465
  (0, 629)	0.17836123653315425
  (0, 866)	0.011344040469060724
  (0, 1021)	0.35604958671834386
  (0, 2151)	0.026917209678096774
  (0, 2235)	0.009289316452989938
  (0, 2557)	0.04544692441729787
  (0, 2562)	6.033270337070058
  (0, 2717)	0.5231661715188844
  (0, 2764)	0.07088919831887554
  (0, 2769)	1.5846098025152748
  (0, 3022)	1.3054188292933626
  (0, 3097)	0.36794175287696324
  (0, 3110)	3.0067631926163414
  (0, 3399)	68.45144161234207
  (0, 3642)	0.126233021543595
  (0, 3902)	0.03090768629494805
  (0, 4125)	0.008963462506717867
Lambda hodnota:  0.0024813419226688907 Pocet iteraci:  441


## Pearson Correlation Coefficient Calculation of Suspicous Columns

In [28]:
print(DD[3110])
print(DD[3111])
print(DD[3112])
print()
#scipy.stats.pearsonr(D[:,3110],D[:,3111])
for i in it.combinations([D[:,3110], D[:,3111], D[:,3112]],2):
    print(scipy.stats.pearsonr(i[0],i[1]))
print()
for i in it.combinations([D[:,2557], D[:,2558], D[:,2559]],2):
    print(scipy.stats.pearsonr(i[0],i[1]))

|r_s(B)-r_p(B)|/exp(r_d(A)+r_s(B))
|r_s(B)-r_p(B)|/exp(r_d(A)+r_p(B))
|r_s(B)-r_p(B)|/exp(r_d(A)+r_d(B))

(0.9969983364853635, 1.147889135842077e-90)
(0.935572814265117, 6.49166249329293e-38)
(0.9168039305067941, 1.238745872964693e-33)

(0.928003508795371, 4.757784126976563e-36)
(0.8310629510601594, 4.407474079391752e-22)
(0.8872337921934383, 1.3175481682461208e-28)


# Cross Validation 10%, 90% (7, 75 split)

In [44]:
@ray.remote(num_return_vals=2)
def cross_validation_LASSO(cross_iter, POCET, LAMBDA_LENGTH, D, dE):
    # Vektory co budou drzet hodnoty pro kazdou cross validaci
    RMSE_CV = np.empty((4,cross_iter),dtype = float)
    MaxAE_CV = np.empty((4,cross_iter),dtype = float)

    for cv in range(cross_iter):
        # X_train = D_CV, y_train = dE_CV
        D_CV, X_test, dE_CV, y_test = model_selection.train_test_split(D, dE, test_size=7, random_state=cv, shuffle = True)
        
        ND = LASSO(POCET, LAMBDA_LENGTH, D_CV, dE_CV)
        # Pro 1D:
        y_predicted = np.ones((y_test.shape[0],1),dtype = float)*ND[0][2] + np.dot(X_test[:,ND[0][0]].reshape(-1, 1),ND[0][1].T)
        RMSE_CV[0,cv] = np.sqrt(metrics.mean_squared_error(y_test,y_predicted))
        MaxAE_CV[0,cv] = metrics.max_error(y_test,y_predicted)
        
        #Pro 2D, 3D, 4D:
        for j in range(2,5): # 2,3,4
            temporary = np.zeros((y_test.shape[0],j),dtype = float)
            for k in range(j):
                temporary[:,k] = X_test[:,ND[j-1][0][k]]
            y_predicted = np.ones((y_test.shape[0],1),dtype = float)*ND[j-1][2] + np.dot(temporary,ND[j-1][1].T)
            RMSE_CV[j-1,cv] = np.sqrt(metrics.mean_squared_error(y_test, y_predicted))
            MaxAE_CV[j-1,cv] = metrics.max_error(y_test, y_predicted)

    return RMSE_CV, MaxAE_CV

In [None]:
print("RMSE_CV: ",np.sum(RMSE_CV, axis = 1)/RMSE_CV.shape[1])
print("MaxAE_CV: ",np.sum(MaxAE_CV, axis = 1)/MaxAE_CV.shape[1])

# Tier Based Cross Validation 10%,90% <-> 7,75 split

In [6]:
ray.init(num_cpus = 16)
dimers, AB, dE = inicializace_dat()
tier0 = [True, False, False, False, False, False]
tier1 = [True, True, False, False, False, False]
tier2 = [True, True, True, False, False, False]
tier3 = [True, True, True, True, False, False]
tier4 = [True, True, True, True, True, False]
tier5 = [True, True, True, True, True, True]

vysledky_cv = {}

POCET = 35 # 30,35,40,45
LAMBDA_LENGTH = 100
cross_iter = 250

tiers = [tier0, tier1, tier2, tier3, tier4, tier5]

# Parallel forcycle:
for i in tiers:
    D, DD, A1, A2, A3, B1, B2, B3, C3, D3, E3, F1, F2, F3, G = feature_space_generation(False, 1, 0.1, dimers, AB, *i)
    #print("Tier", sum(i)-1)
    #print(D.shape)
    #print(len(DD))
    vysledky_cv[tiers.index(i)] = [cross_validation_LASSO(STATE, cross_iter, POCET, LAMBDA_LENGTH, D, dE)]
    
vysledky_cv_real = {}
for i in tiers:
    vysledky_cv_real[tiers.index(i)] = [ray.get(vysledky_cv[tiers.index(i)][0]), ray.get(vysledky_cv[tiers.index(i)][1])]
    
ray.shutdown()

Tier 0
Tier 1
Tier 2
Tier 3
Tier 4
Tier 5




[2m[33m(pid=raylet)[0m F0812 10:56:30.998903  7300 node_manager.cc:521]  Check failed: client_id != gcs_client_->client_table().GetLocalClientId() Exiting because this node manager has mistakenly been marked dead by the monitor.
[2m[33m(pid=raylet)[0m *** Check failure stack trace: ***
[2m[33m(pid=raylet)[0m     @           0x8403fa  google::LogMessage::Fail()
[2m[33m(pid=raylet)[0m     @           0x8417f3  google::LogMessage::SendToLog()
[2m[33m(pid=raylet)[0m     @           0x840122  google::LogMessage::Flush()
[2m[33m(pid=raylet)[0m     @           0x840311  google::LogMessage::~LogMessage()
[2m[33m(pid=raylet)[0m     @           0x5639a2  ray::RayLog::~RayLog()
[2m[33m(pid=raylet)[0m     @           0x47c3c0  ray::raylet::NodeManager::ClientRemoved()
[2m[33m(pid=raylet)[0m     @           0x4d1d3e  ray::gcs::ClientTable::HandleNotification()
[2m[33m(pid=raylet)[0m     @           0x4ee4ab  _ZNSt17_Function_handlerIFvPN3ray3gcs14RedisGcsClientERKNS0_8

TypeError: Attempting to call `get` on the value (array([[0.78443589, 0.17751511, 0.17801821, 0.25575877, 0.14513157,
        0.12093062, 0.78296763, 0.25403334, 0.1956431 , 0.19978634,
        0.46505386, 0.09893942, 0.19176791, 0.46112507, 0.44868458,
        0.23749606, 0.4688745 , 0.19739237, 0.19147791, 0.77971507,
        0.19870987, 0.16383376, 0.46422185, 0.20147323, 0.24425926,
        0.45844946, 0.21881809, 0.18699671, 0.18790385, 0.25027251,
        0.24791167, 0.21511257, 0.18279819, 0.18514324, 0.20920381,
        0.78892918, 0.20199895, 0.78226848, 0.15002011, 0.2603602 ,
        0.16268227, 0.15586586, 0.19070056, 0.15757902, 0.48606024,
        0.24758677, 0.23732191, 0.23073384, 0.15093899, 0.22816392,
        0.19695056, 0.20273389, 0.46748014, 0.12619959, 0.15518765,
        0.18917182, 0.78841666, 0.47251974, 0.24930064, 0.22304574,
        0.78031688, 0.78250023, 0.17623518, 0.15847355, 0.17187909,
        0.17697409, 0.18104637, 0.13210959, 0.16020213, 0.31153892,
        0.48397734, 0.25053477, 0.13645172, 0.17285587, 0.1675015 ,
        0.46953894, 0.12686443, 0.1971109 , 0.21526576, 0.16018668,
        0.17137102, 0.25351592, 0.45138808, 0.18094068, 0.78268163,
        0.18193328, 0.18974339, 0.20178101, 0.21411247, 0.24723391,
        0.18676646, 0.26851502, 0.19728728, 0.28721314, 0.1228141 ,
        0.12655686, 0.17173084, 0.44273046, 0.7875475 , 0.20858377,
        0.79720675, 0.21971115, 0.77461009, 0.22455998, 0.21137445,
        0.20298704, 0.20813034, 0.2842921 , 0.16614787, 0.19142571,
        0.15209848, 0.19385001, 0.21265757, 0.23996341, 0.2619378 ,
        0.24032804, 0.28696206, 0.78171494, 0.48173772, 0.22532409,
        0.20223871, 0.23045528, 0.79339938, 0.29344543, 0.23727131,
        0.10981789, 0.23382622, 0.14519611, 0.19518194, 0.1592459 ,
        0.15990786, 0.14049781, 0.20618317, 0.23239132, 0.25550447,
        0.15676132, 0.20089034, 0.22806422, 0.08123034, 0.78661793,
        0.8006901 , 0.26858157, 0.20902067, 0.24676812, 0.79076539,
        0.29016223, 0.22245656, 0.1632555 , 0.15121136, 0.17308817,
        0.28396557, 0.18278643, 0.2006926 , 0.23765413, 0.12004483,
        0.21640255, 0.1941986 , 0.25063204, 0.22230967, 0.18827178,
        0.21383101, 0.24031091, 0.1872234 , 0.28641039, 0.20056884,
        0.27854468, 0.8035376 , 0.22567579, 0.15653435, 0.09933673,
        0.22390709, 0.19507418, 0.1681237 , 0.18373832, 0.15909487,
        0.12912945, 0.24723527, 0.78134258, 0.20982634, 0.1665076 ,
        0.17606773, 0.13905901, 0.13440901, 0.51913231, 0.78620702,
        0.23442227, 0.20443796, 0.12477869, 0.20035304, 0.2301502 ,
        0.19095814, 0.19712629, 0.19364974, 0.21917486, 0.15588835,
        0.19555373, 0.15262104, 0.18553128, 0.17559508, 0.26455082,
        0.2573638 , 0.22523482, 0.23452183, 0.20672579, 0.2400707 ,
        0.13448474, 0.22516326, 0.46827538, 0.15635596, 0.80177448,
        0.78413621, 0.29584172, 0.21454102, 0.24173437, 0.15087799,
        0.24474503, 0.20233388, 0.21149829, 0.26874663, 0.1956289 ,
        0.18210607, 0.49402979, 0.20803577, 0.2000828 , 0.45902199,
        0.1851385 , 0.23781962, 0.26661407, 0.27959299, 0.24622844,
        0.1606549 , 0.19203241, 0.44804002, 0.1928629 , 0.14618313,
        0.45136442, 0.11273745, 0.22991435, 0.14686452, 0.17436846,
        0.26867089, 0.1589013 , 0.14763514, 0.1970054 , 0.17616023,
        0.78754555, 0.23837525, 0.17677167, 0.20625057, 0.24571001],
       [0.54487058, 0.11100379, 0.09736876, 0.14870549, 0.14361327,
        0.1434736 , 0.54097247, 0.12929922, 0.1625007 , 0.15928404,
        0.28826633, 0.10890301, 0.18935683, 0.31644484, 0.27575341,
        0.18454451, 0.25732902, 0.27191669, 0.1938956 , 0.53764182,
        0.17264684, 0.28598591, 0.26351601, 0.14564743, 0.14011624,
        0.34465967, 0.18576365, 0.12181282, 0.13207101, 0.18759026,
        0.17949468, 0.14721386, 0.11577248, 0.18212499, 0.23039647,
        0.56215353, 0.30915572, 0.54213781, 0.1523422 , 0.17428719,
        0.16555727, 0.10630942, 0.18174511, 0.15365766, 0.25619128,
        0.1881827 , 0.17187623, 0.18794942, 0.09068314, 0.16364583,
        0.29893573, 0.27051355, 0.27740735, 0.30645639, 0.09232693,
        0.16506412, 0.54467794, 0.26845519, 0.16408063, 0.16008768,
        0.55251472, 0.54051402, 0.14394572, 0.12530052, 0.29181152,
        0.18176771, 0.14406227, 0.12672599, 0.17141061, 0.38398271,
        0.29019538, 0.20550536, 0.10844617, 0.19456698, 0.14870944,
        0.26599998, 0.08719672, 0.15916521, 0.17661827, 0.15633091,
        0.28677139, 0.17551449, 0.26038389, 0.29363311, 0.53977202,
        0.12306929, 0.20386481, 0.19551565, 0.15541784, 0.17066305,
        0.12710418, 0.14559971, 0.21282204, 0.21300387, 0.29652956,
        0.17167447, 0.15046937, 0.25430409, 0.54163664, 0.16553623,
        0.55163921, 0.13525977, 0.53944552, 0.12614053, 0.11767574,
        0.18013425, 0.14537774, 0.21127162, 0.26951937, 0.18754755,
        0.13512684, 0.1573745 , 0.1913446 , 0.16341646, 0.1530599 ,
        0.15166653, 0.23268473, 0.55469261, 0.30135459, 0.15216207,
        0.13781216, 0.1348947 , 0.56093221, 0.17254858, 0.13396566,
        0.13348627, 0.16636536, 0.31323621, 0.14580329, 0.1429416 ,
        0.15116186, 0.13712804, 0.13952419, 0.14831994, 0.21610999,
        0.13130987, 0.13734594, 0.15909832, 0.10745924, 0.54266433,
        0.56259068, 0.3289794 , 0.12190675, 0.17112853, 0.54817237,
        0.18507883, 0.1916824 , 0.1803663 , 0.13722483, 0.1824564 ,
        0.23095629, 0.1847103 , 0.20751757, 0.29003791, 0.07990124,
        0.1437856 , 0.30652192, 0.20463008, 0.13523377, 0.14358394,
        0.19463749, 0.1688103 , 0.17566996, 0.18014346, 0.12192761,
        0.20117765, 0.57865467, 0.19866425, 0.12299197, 0.14395037,
        0.15866359, 0.1510237 , 0.13368804, 0.15738554, 0.1708297 ,
        0.17685335, 0.19125859, 0.53983904, 0.18143183, 0.11001096,
        0.15542835, 0.13265689, 0.14563821, 0.26652711, 0.54324609,
        0.18539643, 0.16262812, 0.08970975, 0.20595841, 0.13574101,
        0.12605622, 0.14339822, 0.31702433, 0.13796496, 0.12313902,
        0.15357282, 0.10346237, 0.1444086 , 0.13701479, 0.16203146,
        0.18909627, 0.20113225, 0.18719262, 0.17233838, 0.21132615,
        0.15434898, 0.22517201, 0.29715138, 0.17438084, 0.53446533,
        0.54381737, 0.19982672, 0.14696726, 0.15578091, 0.11504309,
        0.31774314, 0.14079464, 0.15672285, 0.21105452, 0.14013221,
        0.13237705, 0.2880714 , 0.19001276, 0.14163282, 0.25345693,
        0.20383478, 0.18623156, 0.15798177, 0.15786311, 0.19878643,
        0.12970141, 0.29932653, 0.26535009, 0.18863994, 0.09478698,
        0.26939491, 0.08863843, 0.17087716, 0.11644263, 0.13618573,
        0.17337161, 0.19935315, 0.15022092, 0.15651007, 0.10239099,
        0.56054927, 0.30899722, 0.32031271, 0.1903401 , 0.22250541],
       [0.46265854, 0.10355726, 0.07881906, 0.19328541, 0.12063015,
        0.19206115, 0.45767111, 0.07358655, 0.14804397, 0.16374742,
        0.21768979, 0.08419275, 0.26963989, 0.35189922, 0.20098167,
        0.12209946, 0.25445291, 0.3715832 , 0.23292283, 0.45318061,
        0.1284515 , 0.371815  , 0.18766152, 0.12091894, 0.10605542,
        0.37140791, 0.15398105, 0.18751286, 0.09379894, 0.14924727,
        0.14448571, 0.20088218, 0.06998712, 0.13022092, 0.315724  ,
        0.47993938, 0.35927572, 0.45923829, 0.16818766, 0.12270805,
        0.20347917, 0.11260678, 0.15088544, 0.23171141, 0.20522553,
        0.20875184, 0.13181725, 0.21083983, 0.08391302, 0.12071463,
        0.39660838, 0.35697168, 0.20737982, 0.36816367, 0.08953751,
        0.1800872 , 0.46222649, 0.23642464, 0.16145014, 0.08749579,
        0.544065  , 0.45753179, 0.09270358, 0.14383593, 0.39658932,
        0.23085533, 0.09703027, 0.15347623, 0.19349466, 0.51089157,
        0.20552494, 0.2529789 , 0.08389468, 0.2257801 , 0.14017299,
        0.19086082, 0.12012076, 0.22375483, 0.2074991 , 0.11814359,
        0.36476659, 0.0729499 , 0.17146198, 0.4050319 , 0.45758419,
        0.0811216 , 0.23752145, 0.18488193, 0.11327503, 0.13282271,
        0.10752537, 0.07799885, 0.21223646, 0.16565919, 0.3665388 ,
        0.15020086, 0.1435912 , 0.17567019, 0.4612989 , 0.09389339,
        0.47084888, 0.19673305, 0.45868258, 0.16425221, 0.09939501,
        0.13889775, 0.07109054, 0.20260969, 0.37287898, 0.13066276,
        0.10801248, 0.134362  , 0.24010216, 0.13979802, 0.06652563,
        0.20124235, 0.21612211, 0.47278894, 0.26580222, 0.13853807,
        0.1209842 , 0.17006862, 0.4780023 , 0.11012066, 0.17417698,
        0.11180496, 0.13070341, 0.38096829, 0.08333324, 0.07937414,
        0.10593594, 0.16289957, 0.11811362, 0.07646755, 0.20741969,
        0.07753697, 0.13658035, 0.1029455 , 0.08512144, 0.53207043,
        0.48106181, 0.39674882, 0.08359414, 0.15010661, 0.52681595,
        0.11251209, 0.16969239, 0.24168525, 0.09363004, 0.1503006 ,
        0.38078658, 0.23018376, 0.12076233, 0.38755219, 0.09201931,
        0.12409066, 0.39756049, 0.17332047, 0.10931003, 0.12249411,
        0.22217208, 0.17450105, 0.15464472, 0.17183844, 0.07509868,
        0.13821359, 0.49921644, 0.20697955, 0.10741144, 0.13272628,
        0.10215728, 0.13852415, 0.1166223 , 0.18725551, 0.21164149,
        0.1735194 , 0.15362182, 0.4579143 , 0.09502825, 0.09678985,
        0.12510586, 0.07066565, 0.16566692, 0.31618584, 0.46157924,
        0.12674796, 0.18849126, 0.07335376, 0.27628044, 0.17997468,
        0.08522703, 0.19294253, 0.38371361, 0.07297962, 0.13843436,
        0.10659751, 0.09305722, 0.1178909 , 0.16148369, 0.08916334,
        0.11134585, 0.13471351, 0.20714234, 0.2379797 , 0.24626194,
        0.14186683, 0.2313092 , 0.24147971, 0.16496082, 0.43690697,
        0.4653172 , 0.1382883 , 0.09047081, 0.10771946, 0.0654235 ,
        0.46581829, 0.11414567, 0.10414205, 0.3369048 , 0.12426259,
        0.10493464, 0.2480799 , 0.22169786, 0.18128395, 0.16733478,
        0.14364218, 0.12358582, 0.18090801, 0.11313638, 0.14008142,
        0.11238929, 0.3779464 , 0.20055194, 0.17270104, 0.06706748,
        0.20115388, 0.09265819, 0.18237863, 0.10000485, 0.07615271,
        0.13621972, 0.26185043, 0.273479  , 0.08513871, 0.07662335,
        0.54859712, 0.46474332, 0.40289132, 0.20109445, 0.19165656],
       [0.45639193, 0.13978158, 0.09246558, 0.18051407, 0.14018962,
        0.1597556 , 0.46275994, 0.11343115, 0.15804363, 0.22465179,
        0.19825978, 0.09030628, 0.24389388, 0.35314826, 0.20365403,
        0.14502257, 0.25220181, 0.37138394, 0.22001433, 0.43053832,
        0.15002391, 0.35677705, 0.16188824, 0.12876222, 0.11371129,
        0.36846853, 0.1644747 , 0.20646038, 0.12954415, 0.21688141,
        0.16288827, 0.21140068, 0.09008873, 0.14656528, 0.29119587,
        0.45939939, 0.36220601, 0.43892408, 0.14895082, 0.11324814,
        0.23732725, 0.15042754, 0.19140752, 0.23044179, 0.2054178 ,
        0.19125274, 0.13717545, 0.20225473, 0.14053565, 0.18069915,
        0.38064295, 0.35612227, 0.22741836, 0.35808921, 0.09765611,
        0.14665467, 0.43540583, 0.24533397, 0.14807726, 0.11576321,
        0.53258594, 0.43803762, 0.14409183, 0.12145039, 0.37463502,
        0.21515127, 0.10527194, 0.15236389, 0.17946321, 0.51843317,
        0.28073792, 0.23693116, 0.10535752, 0.20841447, 0.19171254,
        0.20115973, 0.13121708, 0.23014995, 0.20185985, 0.12275861,
        0.37883597, 0.09638407, 0.1730159 , 0.39573325, 0.46313752,
        0.09144244, 0.21665489, 0.208492  , 0.14465554, 0.12369656,
        0.13025226, 0.08943193, 0.1879717 , 0.20997   , 0.36989819,
        0.11333687, 0.15524199, 0.17875773, 0.4676811 , 0.06526253,
        0.44963877, 0.18344521, 0.43501242, 0.14340531, 0.12276817,
        0.16094478, 0.11061317, 0.18269784, 0.3675448 , 0.10850617,
        0.14583564, 0.11143861, 0.21669278, 0.16679977, 0.1128963 ,
        0.19389807, 0.20406994, 0.45688724, 0.24971936, 0.13957528,
        0.14440818, 0.15399427, 0.45371723, 0.11362313, 0.19973306,
        0.09144092, 0.13309425, 0.3736496 , 0.07639596, 0.09875693,
        0.16202885, 0.16014367, 0.0918157 , 0.08232397, 0.21296739,
        0.09637632, 0.15525771, 0.14352881, 0.08932841, 0.46362229,
        0.48407227, 0.37758997, 0.07881696, 0.15306613, 0.51965803,
        0.118337  , 0.15253101, 0.22169996, 0.13222078, 0.16814341,
        0.36249392, 0.22226853, 0.08125822, 0.37709433, 0.16636163,
        0.14656262, 0.38294108, 0.24057125, 0.15137996, 0.16127367,
        0.21275216, 0.16227927, 0.17000872, 0.15723155, 0.12321445,
        0.16948375, 0.47664632, 0.19540365, 0.10606754, 0.1379996 ,
        0.12227265, 0.13322129, 0.151265  , 0.18368003, 0.2033518 ,
        0.1703737 , 0.1470717 , 0.4357131 , 0.24206896, 0.10258531,
        0.16687136, 0.14665871, 0.15245246, 0.31157558, 0.46930297,
        0.17391145, 0.18519098, 0.08915628, 0.26773694, 0.19159217,
        0.11548061, 0.18238719, 0.36287631, 0.07771001, 0.13723184,
        0.09214196, 0.11648423, 0.13034781, 0.15864982, 0.12435788,
        0.11636478, 0.14174753, 0.22051304, 0.24333924, 0.2388572 ,
        0.16126402, 0.24362382, 0.20868292, 0.13515621, 0.46063462,
        0.5335585 , 0.14064625, 0.12127384, 0.10874649, 0.08843597,
        0.46783817, 0.10650527, 0.11333342, 0.31003229, 0.14313207,
        0.15282312, 0.24690863, 0.21986021, 0.16378941, 0.17815267,
        0.16972486, 0.15352708, 0.16866158, 0.1125586 , 0.14169046,
        0.18668575, 0.35873514, 0.22727961, 0.14795514, 0.09806729,
        0.21269006, 0.08867574, 0.15716371, 0.10637451, 0.10480805,
        0.14180623, 0.22655082, 0.24613401, 0.13397194, 0.0900087 ,
        0.54089452, 0.45831629, 0.37565014, 0.17968457, 0.15860559]]), array([[2.03157556, 0.28018635, 0.28957611, 0.40628426, 0.27466641,
        0.28046628, 2.03729762, 0.39198882, 0.28560658, 0.26943977,
        1.14778828, 0.14462799, 0.3607758 , 1.14723782, 1.15085503,
        0.40852199, 1.16476014, 0.28999257, 0.33603291, 2.03777433,
        0.35275235, 0.29274848, 1.12921062, 0.4037939 , 0.39494166,
        1.15201547, 0.34644562, 0.42253187, 0.3961698 , 0.3730438 ,
        0.42022009, 0.41755231, 0.35501767, 0.29221079, 0.36500553,
        2.04666815, 0.40651851, 2.03972674, 0.3490886 , 0.41485624,
        0.26401656, 0.26476561, 0.29381275, 0.27421573, 1.16940519,
        0.35329823, 0.37851171, 0.35988944, 0.28447895, 0.39482172,
        0.29595077, 0.36636194, 1.13678561, 0.280926  , 0.26651849,
        0.38549996, 2.03605658, 1.14135152, 0.39820721, 0.3876181 ,
        2.03250718, 2.03614382, 0.3658    , 0.24019258, 0.29365537,
        0.36010425, 0.24931371, 0.20213318, 0.27728896, 0.45353342,
        1.13297163, 0.40640645, 0.26015535, 0.35696548, 0.29296169,
        1.13957458, 0.23661157, 0.39004219, 0.36393131, 0.26136359,
        0.27213691, 0.38142827, 1.13074601, 0.36199655, 2.02977297,
        0.40282057, 0.36251035, 0.28297895, 0.35910782, 0.37806924,
        0.38724216, 0.39102259, 0.35471024, 0.38022134, 0.21849642,
        0.20132934, 0.27797689, 1.14422498, 2.02186156, 0.37030753,
        2.03436993, 0.41532168, 2.03759506, 0.40913402, 0.28782768,
        0.36043288, 0.35963074, 0.39454718, 0.2707861 , 0.40330684,
        0.36548287, 0.36199189, 0.35521957, 0.35691396, 0.42476326,
        0.41342551, 0.41184165, 2.04158377, 1.15629936, 0.34728334,
        0.3614209 , 0.41548365, 2.05424007, 0.40715876, 0.41468508,
        0.18389735, 0.36928129, 0.25686981, 0.3582554 , 0.22983822,
        0.28480572, 0.20203014, 0.38305978, 0.4319146 , 0.34741636,
        0.35088292, 0.35208872, 0.39669662, 0.11171892, 2.02083098,
        2.04093302, 0.4313773 , 0.36765537, 0.3826874 , 2.01657296,
        0.41704709, 0.34131259, 0.27387138, 0.2846775 , 0.35596558,
        0.44296751, 0.26186432, 0.36434673, 0.4063514 , 0.18790401,
        0.35895103, 0.36405428, 0.35453665, 0.37242121, 0.36331263,
        0.35822851, 0.38591893, 0.33525435, 0.40789889, 0.36517915,
        0.39140399, 2.06386804, 0.35568769, 0.20706221, 0.23095899,
        0.3494721 , 0.36572558, 0.29573796, 0.36268364, 0.35652697,
        0.22258124, 0.3859915 , 2.02741845, 0.3704504 , 0.29211127,
        0.28923872, 0.28205882, 0.22218975, 1.17748105, 2.0327613 ,
        0.29395066, 0.29695596, 0.27751579, 0.35522063, 0.4140505 ,
        0.36426906, 0.4207252 , 0.35669875, 0.34943757, 0.23784456,
        0.37073702, 0.26219375, 0.3640063 , 0.37112765, 0.39455259,
        0.39492232, 0.40540147, 0.41181923, 0.38085779, 0.35401695,
        0.27465371, 0.34689361, 1.14943794, 0.25731905, 2.05871932,
        2.03465113, 0.44668981, 0.3715267 , 0.37410383, 0.22737151,
        0.39944171, 0.35788927, 0.34848165, 0.41931641, 0.36389701,
        0.35071376, 1.14623162, 0.35041671, 0.42296885, 1.13264199,
        0.2893396 , 0.40278739, 0.40499775, 0.41654203, 0.3763236 ,
        0.29311933, 0.36481991, 1.14295448, 0.2979001 , 0.35840508,
        1.15170405, 0.18914167, 0.40712983, 0.27326228, 0.28165941,
        0.41732189, 0.28262856, 0.27803733, 0.35765483, 0.38097172,
        2.0484581 , 0.41519854, 0.3521773 , 0.35924681, 0.39954622],
       [1.40347713, 0.15284435, 0.14700394, 0.26509757, 0.29755328,
        0.2454851 , 1.41457078, 0.18330626, 0.30619361, 0.27150814,
        0.62942649, 0.17500635, 0.32817099, 0.57453296, 0.61717549,
        0.29329775, 0.59739008, 0.64360706, 0.32786107, 1.41284876,
        0.25964892, 0.66053106, 0.61771178, 0.29970541, 0.21662308,
        0.585378  , 0.30117526, 0.20006003, 0.21266642, 0.31107869,
        0.3024298 , 0.23956215, 0.18368432, 0.28859974, 0.38170665,
        1.42150688, 0.68022696, 1.40970546, 0.24250353, 0.29133605,
        0.2651771 , 0.20492835, 0.26366267, 0.31280262, 0.61041225,
        0.40000505, 0.25776242, 0.39779179, 0.15205295, 0.29314252,
        0.65880782, 0.66216303, 0.62067943, 0.67417873, 0.13914792,
        0.29738991, 1.41608428, 0.55978655, 0.256035  , 0.26502744,
        1.40372286, 1.41418095, 0.21470932, 0.23449475, 0.66313584,
        0.40450463, 0.26190362, 0.23247377, 0.30055875, 0.73482474,
        0.6215484 , 0.30240354, 0.21399643, 0.39901688, 0.26888602,
        0.62854023, 0.17280815, 0.32726011, 0.40267794, 0.28489086,
        0.67646991, 0.25081087, 0.59795107, 0.63411727, 1.40799251,
        0.28887465, 0.40540333, 0.3030109 , 0.26541445, 0.25571336,
        0.21842482, 0.19337442, 0.40079302, 0.30893815, 0.66245143,
        0.29802099, 0.30785985, 0.61516048, 1.40093656, 0.29579349,
        1.41856781, 0.26479495, 1.40315045, 0.17342039, 0.15836497,
        0.29534095, 0.23523345, 0.39856819, 0.63644575, 0.28317428,
        0.2142442 , 0.25447277, 0.40420471, 0.27113864, 0.24924096,
        0.30341528, 0.40104306, 1.40746612, 0.64289907, 0.26883362,
        0.26516914, 0.25412938, 1.43906678, 0.26116592, 0.22375958,
        0.22133896, 0.27026495, 0.67018655, 0.26197249, 0.2684835 ,
        0.26906857, 0.23426128, 0.18507211, 0.29520679, 0.39604725,
        0.24536493, 0.24105653, 0.29656006, 0.20812376, 1.38772697,
        1.34730035, 0.66849155, 0.20366085, 0.30627595, 1.39305446,
        0.3018612 , 0.25833569, 0.32839146, 0.27095107, 0.29120468,
        0.38045696, 0.34219147, 0.30163285, 0.66045085, 0.13836894,
        0.24471888, 0.6600966 , 0.3117639 , 0.22166986, 0.21425429,
        0.40586525, 0.23511108, 0.23122043, 0.30409352, 0.20000599,
        0.27011987, 1.44530341, 0.40152718, 0.19259484, 0.25087354,
        0.25381896, 0.30011974, 0.26563167, 0.25911307, 0.39863733,
        0.26215485, 0.25890409, 1.40482191, 0.27690563, 0.17902102,
        0.28606095, 0.26941395, 0.25294788, 0.56853503, 1.41141344,
        0.27733038, 0.24388288, 0.15277376, 0.40285402, 0.17206527,
        0.20438977, 0.19193154, 0.67352569, 0.23919483, 0.26353947,
        0.20660527, 0.13869435, 0.26135618, 0.24298786, 0.25526624,
        0.30181232, 0.30022054, 0.30602834, 0.31403871, 0.39744509,
        0.29532551, 0.39121714, 0.62722482, 0.29143227, 1.39066705,
        1.39535893, 0.30143271, 0.21590023, 0.2522827 , 0.17408236,
        0.70270077, 0.24311005, 0.242361  , 0.35570669, 0.30203201,
        0.2451131 , 0.58571335, 0.39802318, 0.29497413, 0.61616892,
        0.28914931, 0.30653108, 0.24736086, 0.19579926, 0.27185357,
        0.23362242, 0.66682801, 0.62318083, 0.29497422, 0.16408143,
        0.63178209, 0.13461289, 0.26434349, 0.25484687, 0.25619823,
        0.30240288, 0.32289594, 0.30923989, 0.26430606, 0.19764707,
        1.42790981, 0.69154954, 0.66339563, 0.4042748 , 0.30708959],
       [1.17842653, 0.14812873, 0.14241382, 0.4107411 , 0.30021369,
        0.40856135, 1.19470931, 0.13622432, 0.30766245, 0.26871843,
        0.42634553, 0.17620261, 0.58074724, 0.87626621, 0.41993146,
        0.18193176, 0.51117936, 0.92469877, 0.57254529, 1.1890298 ,
        0.25519269, 0.93379826, 0.42098593, 0.30135872, 0.17445663,
        0.87846927, 0.30454901, 0.40624697, 0.20014789, 0.3086404 ,
        0.20762257, 0.40612415, 0.09088048, 0.20629543, 0.5746477 ,
        1.17509613, 0.9328863 , 1.1879741 , 0.40171275, 0.19099601,
        0.38943847, 0.19346743, 0.26039491, 0.57261013, 0.35471681,
        0.51182235, 0.25362431, 0.51386018, 0.15487386, 0.21901014,
        0.93483596, 0.93159202, 0.41876241, 0.9306514 , 0.13376981,
        0.40231287, 1.20371326, 0.51829348, 0.401289  , 0.11806166,
        1.35359298, 1.20056179, 0.16510562, 0.31997333, 0.93576327,
        0.52176966, 0.17497343, 0.31076518, 0.39179974, 1.08349549,
        0.42512047, 0.41765854, 0.12654351, 0.51263163, 0.26580975,
        0.42338653, 0.18590175, 0.5764836 , 0.52050846, 0.19760543,
        0.93957729, 0.11511802, 0.41543315, 0.91388491, 1.19269142,
        0.10970939, 0.52134511, 0.30276911, 0.18810169, 0.3096425 ,
        0.2074621 , 0.13560794, 0.51283253, 0.30958205, 0.92371406,
        0.30621227, 0.31020888, 0.41387389, 1.20086649, 0.20380907,
        1.20839882, 0.41210022, 1.19876069, 0.41079245, 0.15094663,
        0.20382895, 0.16534455, 0.51098598, 0.91767775, 0.19567823,
        0.20505926, 0.30543794, 0.52635112, 0.26821242, 0.1105423 ,
        0.41127605, 0.5076735 , 1.17034834, 0.52461854, 0.26671655,
        0.26368238, 0.40550451, 1.20447389, 0.19231546, 0.40629588,
        0.20796695, 0.26465413, 0.9275957 , 0.11546342, 0.11424019,
        0.17511232, 0.313812  , 0.17693886, 0.12601808, 0.509294  ,
        0.1147848 , 0.21216433, 0.20299274, 0.19728898, 1.30541964,
        0.99468731, 0.93589445, 0.17609907, 0.30880447, 1.32155863,
        0.15048704, 0.30269938, 0.57603012, 0.14161905, 0.29403549,
        0.67156971, 0.57660217, 0.21083411, 0.9378764 , 0.12532312,
        0.21556957, 0.93736578, 0.31157863, 0.18305257, 0.20515381,
        0.5183899 , 0.38714638, 0.21967901, 0.40250411, 0.13788329,
        0.26556603, 1.20067884, 0.51047519, 0.19270055, 0.25128288,
        0.14227317, 0.3031676 , 0.26178551, 0.39294713, 0.52229532,
        0.39529794, 0.30427842, 1.19600092, 0.19597016, 0.17890467,
        0.2101619 , 0.14232893, 0.38943383, 0.63249934, 1.20196541,
        0.18374888, 0.39187577, 0.14871981, 0.533682  , 0.40471781,
        0.17650347, 0.40570991, 0.94055776, 0.12900723, 0.26041068,
        0.19302058, 0.13587199, 0.25989204, 0.39784533, 0.13624239,
        0.20789058, 0.30099599, 0.41158681, 0.57099275, 0.51214658,
        0.29672048, 0.50468683, 0.42618097, 0.30696823, 1.09741236,
        1.19813622, 0.30438722, 0.20828073, 0.17542362, 0.11932422,
        1.0509266 , 0.18618724, 0.14672857, 0.67635119, 0.30320078,
        0.1923622 , 0.51080925, 0.52168026, 0.408858  , 0.41642686,
        0.25999341, 0.20937557, 0.40454019, 0.19307836, 0.2668962 ,
        0.20339671, 0.93498313, 0.41175658, 0.39251415, 0.10501677,
        0.42391142, 0.12321966, 0.41359047, 0.17627252, 0.13408421,
        0.30471969, 0.57042945, 0.58379846, 0.13669452, 0.11335158,
        1.35929065, 1.04640319, 0.93936809, 0.5128962 , 0.41662736],
       [1.13465658, 0.24513627, 0.16680251, 0.3488342 , 0.30802795,
        0.33906124, 1.209279  , 0.2382618 , 0.32170567, 0.4148761 ,
        0.4321627 , 0.15153925, 0.51761187, 0.86665129, 0.43032511,
        0.24278834, 0.50196449, 0.91515231, 0.49813253, 1.12591282,
        0.26347352, 0.86901774, 0.27497913, 0.31658767, 0.19974241,
        0.8685186 , 0.31422426, 0.39245586, 0.23636376, 0.45480206,
        0.22294014, 0.3324431 , 0.15621702, 0.22352698, 0.51659735,
        1.12145759, 0.92641185, 1.12366748, 0.33103797, 0.18239347,
        0.39927064, 0.2031855 , 0.32882829, 0.48955354, 0.36670497,
        0.45275759, 0.26405901, 0.45055743, 0.19870464, 0.34152943,
        0.85903172, 0.92331398, 0.42999844, 0.86500446, 0.15776559,
        0.33723284, 1.1398383 , 0.51112402, 0.33658933, 0.21586723,
        1.29576008, 1.13554114, 0.25475164, 0.253172  , 0.87334813,
        0.45328855, 0.15964701, 0.24677614, 0.33167636, 1.07913278,
        0.49465037, 0.3549473 , 0.18878227, 0.45402004, 0.34151271,
        0.43460251, 0.24285063, 0.56877389, 0.45987909, 0.17697636,
        0.93343901, 0.19487993, 0.42359787, 0.90857066, 1.21652755,
        0.14029857, 0.46158106, 0.32050268, 0.31355276, 0.24964442,
        0.25289025, 0.16348113, 0.45233462, 0.33881947, 0.9157146 ,
        0.25082108, 0.32355492, 0.42336245, 1.19574235, 0.09402821,
        1.13940074, 0.35275499, 1.14177832, 0.35027558, 0.16642049,
        0.31253007, 0.24371479, 0.45126352, 0.91022355, 0.18308501,
        0.26340735, 0.24834527, 0.46382027, 0.28298559, 0.21060407,
        0.34384477, 0.44645157, 1.11236529, 0.45562752, 0.27760893,
        0.27413394, 0.34051504, 1.1411746 , 0.16380601, 0.38914243,
        0.19296515, 0.27651085, 0.92370408, 0.11489452, 0.14218743,
        0.3134406 , 0.25202226, 0.15580259, 0.13410825, 0.43588383,
        0.13859473, 0.32747043, 0.17654181, 0.18789785, 1.12504334,
        1.01030251, 0.87219582, 0.15146216, 0.32122404, 1.29480069,
        0.23878252, 0.26302296, 0.49932393, 0.30916923, 0.3044855 ,
        0.66538065, 0.501664  , 0.1510569 , 0.93170627, 0.19816457,
        0.32157558, 0.92974513, 0.4656115 , 0.20189305, 0.24883387,
        0.45998188, 0.31893796, 0.26142942, 0.34316291, 0.24805322,
        0.32870972, 1.15045992, 0.44711963, 0.15848257, 0.26226167,
        0.20050338, 0.31853008, 0.27465403, 0.32488939, 0.46660737,
        0.33510268, 0.26647567, 1.13878544, 0.48539793, 0.15753584,
        0.33752341, 0.3269055 , 0.32465782, 0.62054613, 1.2109102 ,
        0.31489471, 0.32514718, 0.15245657, 0.46985033, 0.39228543,
        0.1916281 , 0.34219903, 0.88378638, 0.11385226, 0.27292118,
        0.16055782, 0.21833584, 0.2728702 , 0.3304855 , 0.24616075,
        0.23532636, 0.31089394, 0.3405956 , 0.56336901, 0.44948475,
        0.30796524, 0.42614199, 0.28565038, 0.25251656, 1.19625041,
        1.35562674, 0.31603334, 0.2482006 , 0.14826618, 0.16267957,
        1.04172647, 0.15068712, 0.19545363, 0.59424117, 0.31502069,
        0.20836201, 0.50394055, 0.45957508, 0.33774127, 0.42469138,
        0.31711626, 0.28581082, 0.34020049, 0.16696205, 0.27735546,
        0.32086427, 0.87532547, 0.43061778, 0.3268605 , 0.15384782,
        0.43636341, 0.14709374, 0.35291384, 0.23440725, 0.23415665,
        0.31710247, 0.5088691 , 0.51838104, 0.30702904, 0.13928975,
        1.2901988 , 1.03595267, 0.891672  , 0.45606945, 0.35520406]])), which is not an ray.ObjectID.

In [23]:
import pickle
tier0 = [True, False, False, False, False, False]
tier1 = [True, True, False, False, False, False]
tier2 = [True, True, True, False, False, False]
tier3 = [True, True, True, True, False, False]
tier4 = [True, True, True, True, True, False]
tier5 = [True, True, True, True, True, True]

tiers = [tier0, tier1, tier2, tier3, tier4, tier5]

loaded = pickle.load( open( "object_id_clean.p", "rb" ) )

for i in tiers:
    print("Tier", sum(i)-1)
    print("RMSE_CV: ", np.round(np.sum(loaded[tiers.index(i)][0][0], axis = 1)/loaded[tiers.index(i)][0][0].shape[1], 2), "eV.")
    print("MaxAE_CV: ", np.round(np.sum(loaded[tiers.index(i)][0][1], axis = 1)/loaded[tiers.index(i)][0][0].shape[1], 2), "eV.")
    print()
    
print("Std of errors:")
for i in tiers:
    print("Tier", sum(i)-1)
    print(np.sqrt(np.sum((loaded[tiers.index(i)][0][0] - np.array([np.sum(loaded[tiers.index(i)][0][0], axis = 1)/loaded[tiers.index(i)][0][0].shape[1]]).T)**2, axis = 1)/loaded[tiers.index(i)][0][0].shape[1]))
    print(np.sqrt(np.sum((loaded[tiers.index(i)][0][1] - np.array([np.sum(loaded[tiers.index(i)][0][1], axis = 1)/loaded[tiers.index(i)][0][1].shape[1]]).T)**2, axis = 1)/loaded[tiers.index(i)][0][1].shape[1]))
    print()

Tier 0
RMSE_CV:  [0.28 0.22 0.2  0.21] eV.
MaxAE_CV:  [0.56 0.44 0.45 0.45] eV.

Tier 1
RMSE_CV:  [0.2  0.22 0.16 0.14] eV.
MaxAE_CV:  [0.42 0.45 0.33 0.29] eV.

Tier 2
RMSE_CV:  [0.15 0.13 0.14 0.09] eV.
MaxAE_CV:  [0.29 0.25 0.28 0.16] eV.

Tier 3
RMSE_CV:  [0.16 0.1  0.08 0.08] eV.
MaxAE_CV:  [0.31 0.17 0.14 0.16] eV.

Tier 4
RMSE_CV:  [0.17 0.13 0.11 0.11] eV.
MaxAE_CV:  [0.34 0.24 0.21 0.23] eV.

Tier 5
RMSE_CV:  [0.19 0.14 0.12 0.1 ] eV.
MaxAE_CV:  [0.38 0.29 0.25 0.21] eV.



In [39]:
# LASSO models with tiered feature spaces
dimers, AB, dE = inicializace_dat()
tier0 = [True, False, False, False, False, False]
tier1 = [True, True, False, False, False, False]
tier2 = [True, True, True, False, False, False]
tier3 = [True, True, True, True, False, False]
tier4 = [True, True, True, True, True, False]
tier5 = [True, True, True, True, True, True]

tiers = [tier0, tier1, tier2, tier3, tier4, tier5]

tier_model_output = {}

POCET = 35 # 30,35,40,45
LAMBDA_LENGTH = 100

for i in tiers:
    D, DD, A1, A2, A3, B1, B2, B3, C3, D3, E3, F1, F2, F3, G = feature_space_generation(False, 1, 0.1, dimers, AB, *i)
    print("Tier", sum(i)-1)
    print(D.shape)
    print(len(DD))
    tier_model_output[tiers.index(i)] = LASSO(POCET, LAMBDA_LENGTH, D, dE)
    
pickle.dump( tier_model_output, open( "tier_model.p", "wb" ) )

Tier 0
(82, 14)
14
Tier 1
(82, 164)
164
Tier 2
(82, 596)
596
Tier 3
(82, 1669)
1669
Tier 4
(82, 3566)
3566
Tier 5
(82, 4376)
4376


In [17]:
#pd.read_csv('RMSE_CV.csv',header=None)

# Sensitivity Analysis

In [18]:
sigma = [0.001, 0.01, 0.03, 0.05, 0.1, 0.13, 0.3] # noise
gauss = np.random.normal(1,sigma,14)
delta = [0, 0.01, 0.03, 0.10, 0.20]

dimers, AB, dE = inicializace_dat()
D_noised, DD, A1, A2, A3, B1, B2, B3, C3, D3, E3, F1, F2, F3, G = feature_space_generation(True, 1, sigma[0], dimers, AB, True, True, True, True, True, True)




dE_noised = dE + np.random.uniform(-delta, delta, dE.shape[0])

In [19]:
len(feature_space_generation(inicializace_dat()[0],inicializace_dat()[1])[1])

A1:  0 ...  3
A2:  4  ...  7
A3:  8  ...  13
B1:  14  ...  25
B2:  26  ...  37
B3:  38  ...  67
C3:  68  ...  88
D3:  89  ...  109
E3:  110  ...  130


4376

In [7]:
gauss

array([1.00056379])

In [8]:
lul = [i for i in range(21)]

In [11]:
len(lul[:6])

6