In [16]:
# Загрузка библиотек
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import re

from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.svm import SVR

In [54]:
# Ссылки к файлам
links = {
    'elements': r'..\Bandgap prediction\Data\Elements\elements for ML.csv',
    'materials': r'D:\Development\ML for new materials discovery\ML-for-new-materials-discovery\Bandgap prediction\Data\Bandgap oxides and nitrides only.csv',
    'materials_for_predict': r'D:\Development\ML for new materials discovery\ML-for-new-materials-discovery\Bandgap prediction\Data\oxides and nitrides  for predict.csv'
}

### Подготовка данных по элементам

In [18]:
elements = pd.read_csv(filepath_or_buffer=links['elements'], index_col='Symbol')
# Переименовываю признаки
elements.rename(columns={
    'Atomic Number':'Atomic_Number',
    'NUMBER OF Electrons at last orbitale': 'NUMBER_OF_Electrons_at_last_orbitale',
    'NUMBER OF Electrons at before last orbitale':'NUMBER_OF_Electrons_at_before_last_orbitale',
    'NUMBER OF electrones at last level':'NUMBER_OF_electrones_at_last_level',
    'NUMBER OF vacancies at  outer orbitale':'NUMBER_OF_vacancies_at_outer_orbitale',
    'Number of active electrons at inner level':'Number_of_active_electrons_at_inner_level',
    'Max valency':'Max_valency',
    'Atomic Mass':'Atomic_Mass',
    'Atomic radius (pm)':'Atomic_radius',
    'Covalent radius (pm)':'Covalent_radius',
    'Ionization potential (eV)':'Ionization_potential',
    'Electron affinity (KJ/mol)':'Electron_affinity'
    }, inplace=True)
elements.head()

Unnamed: 0_level_0,Atomic_Number,NUMBER_OF_Electrons_at_last_orbitale,NUMBER_OF_Electrons_at_before_last_orbitale,NUMBER_OF_electrones_at_last_level,NUMBER_OF_vacancies_at_outer_orbitale,Number_of_active_electrons_at_inner_level,Max_valency,Atomic_Mass,Electronegativity,Atomic_radius,Covalent_radius,Ionization_potential,Electron_affinity,Period,Group,Block
Symbol,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
Li,3,1,0,1,1,0,1,7.0,0.98,145.0,134,5.392,59.6,2,0,0
B,5,1,2,3,5,0,1,10.81,2.04,98.0,82,8.298,26.7,2,1,1
C,6,2,2,4,4,0,2,12.011,2.55,77.0,77,11.26,153.9,2,2,1
N,7,3,2,5,3,0,5,14.0067,3.04,92.0,75,14.48,7.0,2,3,1
Na,11,1,0,2,1,0,1,22.989769,0.93,190.0,154,5.139,52.8,3,0,0


In [7]:
# создаю словарь для обозначения групп и блоков
group_notation_dictionary = {i[1]:i[0] for i in enumerate(elements.Group.unique())}  # Группы
block_notation_dictionary = {i[1]:i[0] for i in enumerate(elements.Block.unique())}  # Блоки
print(f'Group\n{group_notation_dictionary}\nBlock{block_notation_dictionary}')

# заменяю обозначение групп на соответствующие числовые индикаторы
elements.Group = list(map(lambda x: group_notation_dictionary[x], elements.Group.values))
elements.Block = list(map(lambda x: block_notation_dictionary[x], elements.Block.values))

Group
{0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 10: 10, 11: 11, 12: 12, 13: 13, 14: 14, 15: 15}
Block{0: 0, 1: 1, 2: 2}


In [19]:
elements.info()

<class 'pandas.core.frame.DataFrame'>
Index: 63 entries, Li to Ag
Data columns (total 16 columns):
 #   Column                                       Non-Null Count  Dtype  
---  ------                                       --------------  -----  
 0   Atomic_Number                                63 non-null     int64  
 1   NUMBER_OF_Electrons_at_last_orbitale         63 non-null     int64  
 2   NUMBER_OF_Electrons_at_before_last_orbitale  63 non-null     int64  
 3   NUMBER_OF_electrones_at_last_level           63 non-null     int64  
 4   NUMBER_OF_vacancies_at_outer_orbitale        63 non-null     int64  
 5   Number_of_active_electrons_at_inner_level    63 non-null     int64  
 6   Max_valency                                  63 non-null     int64  
 7   Atomic_Mass                                  63 non-null     float64
 8   Electronegativity                            63 non-null     float64
 9   Atomic_radius                                63 non-null     float64
 10  Covalent

In [20]:
# Редактирую отображение дробных чисел
elements.Electronegativity = [float(re.sub(',', '.', i)) for i in elements.Electronegativity]
elements.Ionization_potential = [float(re.sub(',', '.', i)) for i in elements.Ionization_potential]
elements.Electron_affinity = [float(re.sub(',', '.', i)) for i in elements.Electron_affinity]
elements.Atomic_radius = [float(re.sub(',', '.', i)) for i in elements.Atomic_radius]

### Подготовка данных по соединенийм

In [160]:
# загрузка данных по соединениям
compounds = pd.read_csv(filepath_or_buffer=links['materials'], index_col='Compound', sep=';')
# Переименовываю признак
compounds.rename(columns={'Band gap, eV':'Band_gap'}, inplace=True)
compounds

Unnamed: 0_level_0,Band_gap
Compound,Unnamed: 1_level_1
Mn2SiO4,325
Mn4SiO7,152
Mn7SiO12,315
AlSi2O5,91
Fe2SiO4,78
...,...
Ca2ZnN2,19
CaZn2N2,19
Li3ScN2,29
NaSnN,25


In [161]:
new_compaunds = pd.read_csv(
    filepath_or_buffer=r'D:\Development\ML for new materials discovery\ML-for-new-materials-discovery\Bandgap prediction\Data\New nitrides from Pauling database.csv',
    sep=';', index_col='Compound')

In [162]:
d = pd.concat([compounds, new_compaunds])

In [163]:
indexU = np.unique(d.index)

In [164]:
compounds = d.loc[indexU]
compounds.Band_gap = [float(re.sub(',', '.', i)) for i in compounds.Band_gap]

In [165]:
# регулярное выражение для разделения соединения на элементы
reg_compaund = r'(?P<element_1>[A-Z](?![a-z])+|[A-Z][a-z]+)(?P<number_1>\d+|)(?P<element_2>[A-Z](?![a-z])+|[A-Z][a-z]+)(?P<number_2>\d+|)(?P<element_3>[A-Z](?![a-z])+|[A-Z][a-z]+)(?P<number_3>\d+|)'

In [166]:
data = pd.DataFrame(data=compounds, index=compounds.index)

In [167]:
# тест
comp = compounds.index[0]
r = re.search(reg_compaund, comp).groupdict()
# [i for i in elements.index]
r

{'element_1': 'Ag',
 'number_1': '3',
 'element_2': 'P',
 'number_2': '',
 'element_3': 'O',
 'number_3': '4'}

In [168]:
# добавлю метки для себя 
data.insert(loc=0, column='a', value=0)
data.insert(loc=1,column='b', value=0)
data.insert(loc=2, column='c', value=0)

In [169]:
# Добавляю признаки по каждому элементу (a, b, c) соединения 
for parameter in elements.columns[::-1]:
    data.insert(loc=data.columns.get_loc('a')+1, column='{f_1}_a'.format(f_1=parameter),
                value=[elements.at[re.search(reg_compaund, comp).groupdict()['element_1'], parameter] for comp in data.index])
    data.insert(loc=data.columns.get_loc('b')+1, column='{f_1}_b'.format(f_1=parameter),
                value=[elements.at[re.search(reg_compaund, comp).groupdict()['element_2'], parameter] for comp in data.index])
    data.insert(loc=data.columns.get_loc('c')+1, column='{f_1}_c'.format(f_1=parameter),
                value=[elements.at[re.search(reg_compaund, comp).groupdict()['element_3'], parameter] for comp in data.index])

In [170]:
# Число элементов в соединении
data.insert(loc=data.columns.get_loc('a')+1, column='Number_a',
            value=[re.search(reg_compaund, comp).groupdict()['number_1'] if re.search(reg_compaund, comp).groupdict()['number_1'] != '' else 1  for comp in data.index])
data.insert(loc=data.columns.get_loc('b')+1, column='Number_b',
            value=[re.search(reg_compaund, comp).groupdict()['number_2'] if re.search(reg_compaund, comp).groupdict()['number_2'] != '' else 1  for comp in data.index])
data.insert(loc=data.columns.get_loc('c')+1, column='Number_c',
            value=[re.search(reg_compaund, comp).groupdict()['number_3'] if re.search(reg_compaund, comp).groupdict()['number_3'] != '' else 1  for comp in data.index])

In [171]:
# Удаляю метки
data.drop(columns=['a', 'b', 'c'], inplace=True)

In [172]:
# Сохраняю данные
...

### Соединения для прогнозирования


In [173]:
# загружаю данные 
compounds = pd.read_csv(filepath_or_buffer=links['materials_for_predict'], index_col='Compound', )
compounds.drop(columns=compounds.columns, inplace=True)

In [174]:
# метки для удобства
compounds.insert(loc=0, column='a', value=0)
compounds.insert(loc=1,column='b', value=0)
compounds.insert(loc=2, column='c', value=0)

In [175]:
compounds

Unnamed: 0_level_0,a,b,c
Compound,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
Zn3WN4,0,0,0
ZnGeN2,0,0,0
ZnSiN2,0,0,0
YWN3,0,0,0
ZnSnN2,0,0,0
...,...,...,...
ZrAl3N4,0,0,0
Mn3AlN3,0,0,0
Mn3InN3,0,0,0
Si6Mo3N11,0,0,0


In [176]:
compounds

Unnamed: 0_level_0,a,b,c
Compound,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
Zn3WN4,0,0,0
ZnGeN2,0,0,0
ZnSiN2,0,0,0
YWN3,0,0,0
ZnSnN2,0,0,0
...,...,...,...
ZrAl3N4,0,0,0
Mn3AlN3,0,0,0
Mn3InN3,0,0,0
Si6Mo3N11,0,0,0


In [177]:
# добавляю признаки по элементам соединения
for parameter in elements.columns[::-1]:
    compounds.insert(loc=compounds.columns.get_loc('a')+1, column='{f_1}_a'.format(f_1=parameter),
                value=[elements.at[re.search(reg_compaund, comp).groupdict()['element_1'], parameter] for comp in compounds.index])
    compounds.insert(loc=compounds.columns.get_loc('b')+1, column='{f_1}_b'.format(f_1=parameter),
                value=[elements.at[re.search(reg_compaund, comp).groupdict()['element_2'], parameter] for comp in compounds.index])
    compounds.insert(loc=compounds.columns.get_loc('c')+1, column='{f_1}_c'.format(f_1=parameter),
                value=[elements.at[re.search(reg_compaund, comp).groupdict()['element_3'], parameter] for comp in compounds.index])

In [178]:
# число атомов элемента в соединении
compounds.insert(loc=compounds.columns.get_loc('a')+1, column='Number_a',
            value=[re.search(reg_compaund, comp).groupdict()['number_1'] if re.search(reg_compaund, comp).groupdict()['number_1'] != '' else 1  for comp in compounds.index])
compounds.insert(loc=compounds.columns.get_loc('b')+1, column='Number_b',
            value=[re.search(reg_compaund, comp).groupdict()['number_2'] if re.search(reg_compaund, comp).groupdict()['number_2'] != '' else 1  for comp in compounds.index])
compounds.insert(loc=compounds.columns.get_loc('c')+1, column='Number_c',
            value=[re.search(reg_compaund, comp).groupdict()['number_3'] if re.search(reg_compaund, comp).groupdict()['number_3'] != '' else 1  for comp in compounds.index])

In [179]:
compounds.drop(columns=['a', 'b', 'c'], inplace=True)
compounds

Unnamed: 0_level_0,Number_a,Atomic_Number_a,NUMBER_OF_Electrons_at_last_orbitale_a,NUMBER_OF_Electrons_at_before_last_orbitale_a,NUMBER_OF_electrones_at_last_level_a,NUMBER_OF_vacancies_at_outer_orbitale_a,Number_of_active_electrons_at_inner_level_a,Max_valency_a,Atomic_Mass_a,Electronegativity_a,...,Max_valency_c,Atomic_Mass_c,Electronegativity_c,Atomic_radius_c,Covalent_radius_c,Ionization_potential_c,Electron_affinity_c,Period_c,Group_c,Block_c
Compound,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Zn3WN4,3,30,2,0,2,0,0,2,65.40000,1.65,...,5,14.0067,3.04,92.0,75,14.48,7.0,2,3,1
ZnGeN2,1,30,2,0,2,0,0,2,65.40000,1.65,...,5,14.0067,3.04,92.0,75,14.48,7.0,2,3,1
ZnSiN2,1,30,2,0,2,0,0,2,65.40000,1.65,...,5,14.0067,3.04,92.0,75,14.48,7.0,2,3,1
YWN3,1,39,2,0,2,0,1,3,88.90500,1.22,...,5,14.0067,3.04,92.0,75,14.48,7.0,2,3,1
ZnSnN2,1,30,2,0,2,0,0,2,65.40000,1.65,...,5,14.0067,3.04,92.0,75,14.48,7.0,2,3,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ZrAl3N4,1,40,2,0,2,0,2,4,91.22000,1.33,...,5,14.0067,3.04,92.0,75,14.48,7.0,2,3,1
Mn3AlN3,3,25,2,0,2,0,5,7,54.93804,1.55,...,5,14.0067,3.04,92.0,75,14.48,7.0,2,3,1
Mn3InN3,3,25,2,0,2,0,5,7,54.93804,1.55,...,5,14.0067,3.04,92.0,75,14.48,7.0,2,3,1
Si6Mo3N11,6,14,2,2,4,4,0,2,28.08500,1.90,...,5,14.0067,3.04,92.0,75,14.48,7.0,2,3,1


In [180]:
# Сохраняю данные
...

In [181]:
data

Unnamed: 0_level_0,Number_a,Atomic_Number_a,NUMBER_OF_Electrons_at_last_orbitale_a,NUMBER_OF_Electrons_at_before_last_orbitale_a,NUMBER_OF_electrones_at_last_level_a,NUMBER_OF_vacancies_at_outer_orbitale_a,Number_of_active_electrons_at_inner_level_a,Max_valency_a,Atomic_Mass_a,Electronegativity_a,...,Atomic_Mass_c,Electronegativity_c,Atomic_radius_c,Covalent_radius_c,Ionization_potential_c,Electron_affinity_c,Period_c,Group_c,Block_c,Band_gap
Compound,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Ag3PO4,3,47,1,0,1,1,0,3,107.868,1.93,...,15.9994,3.44,48.0,66,13.56,141.0,2,5,1,2.42
Ag3VO4,3,47,1,0,1,1,0,3,107.868,1.93,...,15.9994,3.44,48.0,66,13.56,141.0,2,5,1,2.35
Ag3VO4,3,47,1,0,1,1,0,3,107.868,1.93,...,15.9994,3.44,48.0,66,13.56,141.0,2,5,1,2.20
AgBiO3,1,47,1,0,1,1,0,3,107.868,1.93,...,15.9994,3.44,48.0,66,13.56,141.0,2,5,1,2.50
AgCoO2,1,47,1,0,1,1,0,3,107.868,1.93,...,15.9994,3.44,48.0,66,13.56,141.0,2,5,1,4.15
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Zr2SeN2,2,40,2,0,2,0,2,4,91.220,1.33,...,14.0067,3.04,92.0,75,14.48,7.0,2,3,1,3.21
ZrMo2O8,1,40,2,0,2,0,2,4,91.220,1.33,...,15.9994,3.44,48.0,66,13.56,141.0,2,5,1,2.62
ZrMo2O8,1,40,2,0,2,0,2,4,91.220,1.33,...,15.9994,3.44,48.0,66,13.56,141.0,2,5,1,3.59
ZrSiO4,1,40,2,0,2,0,2,4,91.220,1.33,...,15.9994,3.44,48.0,66,13.56,141.0,2,5,1,5.90


### Обучение

In [182]:
y = data.Band_gap
X = data.drop(columns='Band_gap')

In [183]:
# MinMaxScaler
scaler = MinMaxScaler()
X_norm = pd.DataFrame(data=scaler.fit_transform(X), columns=X.columns,  index=X.index)

In [184]:
# выделю train и test
X_train, X_test, y_train, y_test = train_test_split(X_norm, y, random_state=2)

In [185]:
# тест 1. "Прямой" predict
# Обучение
svr = SVR(C=1.0, epsilon=0.2)
svr.fit(X=X_train, y=y_train)

SVR(C=1.0, cache_size=200, coef0=0.0, degree=3, epsilon=0.2, gamma='scale',
    kernel='rbf', max_iter=-1, shrinking=True, tol=0.001, verbose=False)

In [186]:
# Предсказание
predict = svr.predict(X=X_test)
predict

array([2.86003629, 2.98205284, 2.92919126, 3.91841041, 2.11491468,
       3.82328655, 3.5064166 , 3.53146145, 2.67873058, 3.54718392,
       2.92324633, 4.11547251, 2.10003985, 4.28084808, 2.38251016,
       2.59846132, 2.55827876, 3.33048515, 2.78575339, 4.26201675,
       3.52405948, 2.32023277, 3.69428947, 3.03443172, 4.29294916,
       4.56693733, 2.67172013, 3.1532527 , 3.54227312, 2.52723082,
       2.15464948, 2.75322615, 2.27820248, 3.29945919, 3.75872941,
       3.10450291, 2.82578165, 2.92604247, 3.15962216, 3.3124531 ,
       1.92395326, 3.54395152, 2.73225458, 4.55176422, 3.76523723,
       4.46271381, 2.30716909, 2.41257563, 3.79107418, 2.45419977,
       2.8408571 , 3.22188809, 3.33375838, 3.78802614, 4.09062649,
       2.07932402, 2.31267811, 2.70111655, 3.25945818, 2.92357454,
       3.07151394, 2.98104127, 2.44345487, 3.33786562, 3.43902959,
       2.87651371, 2.22477318, 3.81919736, 3.29944972, 2.31410073,
       2.71371568, 3.21761277, 2.61204523, 2.32059068, 2.17080

In [187]:
# Настоящие метки
y_test.to_numpy().reshape((y_test.to_numpy().shape[0]))

array([2.18, 2.77, 2.3 , 6.52, 2.65, 5.88, 4.45, 3.15, 2.58, 2.46, 2.67,
       3.34, 0.94, 4.5 , 2.8 , 3.4 , 1.86, 4.  , 3.5 , 4.6 , 3.56, 3.1 ,
       5.9 , 3.1 , 4.4 , 3.3 , 1.29, 2.04, 3.63, 2.11, 1.9 , 1.1 , 2.  ,
       3.68, 4.1 , 4.1 , 1.22, 1.7 , 1.3 , 2.44, 2.5 , 3.75, 3.37, 6.  ,
       4.43, 4.56, 3.18, 2.3 , 3.3 , 1.85, 2.43, 3.4 , 3.85, 6.5 , 3.35,
       2.31, 2.25, 1.42, 4.15, 2.66, 2.2 , 3.45, 2.05, 4.7 , 3.14, 1.63,
       1.95, 3.1 , 3.08, 2.45, 1.78, 3.5 , 2.26, 2.34, 2.15, 1.55, 3.6 ,
       1.88, 2.73, 3.16, 1.5 , 3.02, 2.7 , 4.4 , 2.85, 2.6 , 2.4 , 2.37])

In [188]:
 # Оценка через метрику r2_score
from sklearn.metrics import r2_score
r2 = r2_score(y_pred=predict, y_true=y_test.to_numpy().reshape((y_test.to_numpy().shape[0])))
r2

0.4448100033296768

In [189]:
data_for_predict = compounds.copy()
data_for_predict.head()

Unnamed: 0_level_0,Number_a,Atomic_Number_a,NUMBER_OF_Electrons_at_last_orbitale_a,NUMBER_OF_Electrons_at_before_last_orbitale_a,NUMBER_OF_electrones_at_last_level_a,NUMBER_OF_vacancies_at_outer_orbitale_a,Number_of_active_electrons_at_inner_level_a,Max_valency_a,Atomic_Mass_a,Electronegativity_a,...,Max_valency_c,Atomic_Mass_c,Electronegativity_c,Atomic_radius_c,Covalent_radius_c,Ionization_potential_c,Electron_affinity_c,Period_c,Group_c,Block_c
Compound,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Zn3WN4,3,30,2,0,2,0,0,2,65.4,1.65,...,5,14.0067,3.04,92.0,75,14.48,7.0,2,3,1
ZnGeN2,1,30,2,0,2,0,0,2,65.4,1.65,...,5,14.0067,3.04,92.0,75,14.48,7.0,2,3,1
ZnSiN2,1,30,2,0,2,0,0,2,65.4,1.65,...,5,14.0067,3.04,92.0,75,14.48,7.0,2,3,1
YWN3,1,39,2,0,2,0,1,3,88.905,1.22,...,5,14.0067,3.04,92.0,75,14.48,7.0,2,3,1
ZnSnN2,1,30,2,0,2,0,0,2,65.4,1.65,...,5,14.0067,3.04,92.0,75,14.48,7.0,2,3,1


In [190]:
# MinMaxScaler
data_for_predict = pd.DataFrame(data=scaler.transform(data_for_predict), columns=data_for_predict.columns, index=data_for_predict.index)

In [191]:
# Выполняю предсказание
target_prediction = svr.predict(data_for_predict)

In [192]:
# Создам DataFrame где в качестве индексов используется соединение, а predict - предсказанное значение ширнины запрещенной зоны
res = pd.DataFrame(data=target_prediction, columns=['predict'], index=data_for_predict.index)
res

Unnamed: 0_level_0,predict
Compound,Unnamed: 1_level_1
Zn3WN4,2.810620
ZnGeN2,3.192602
ZnSiN2,3.493014
YWN3,3.079290
ZnSnN2,2.834424
...,...
ZrAl3N4,3.368139
Mn3AlN3,2.751236
Mn3InN3,2.197631
Si6Mo3N11,2.510315


In [193]:
# Поиск максимально точного показания по тестовой выборке

traintest_data = data.copy()
predict_data = compounds.copy()
# Выделение таргета и данных
y = traintest_data.Band_gap
X = traintest_data.drop(columns='Band_gap')

In [194]:
# MinMaxScaler
scaler = MinMaxScaler()
scaler.fit(X)
X_scaler = pd.DataFrame(data=scaler.transform(X), columns=X.columns, index=X.index)
data_for_predict = pd.DataFrame(data=scaler.transform(predict_data), columns=predict_data.columns, index=predict_data.index)

In [195]:
svr = SVR(C=1.0, epsilon=0.2)

# списки результатов
list_r2_score = []
list_traintest_predict = []
list_predict = []
list_x_test = []

# число тестов
number_tests = 100

for i in range(number_tests):
    # выделяю train и test
    X_train, X_test, y_train, y_test = train_test_split(X_scaler, y)
    list_x_test.append(X_test)

    #svr
    svr.fit(X=X_train, y=y_train)
    traintest_predict = svr.predict(X_test)
    list_traintest_predict.append(traintest_predict)

    # r2_score
    list_r2_score.append(r2_score(y_pred=traintest_predict, y_true=y_test))

    predict = svr.predict(data_for_predict)
    list_predict.append(predict)

In [196]:
# Максимальный показатель метрики r2
max(list_r2_score)

0.5120429331171861

In [197]:
# Найду предсказания соответствующие данному показатель по метрике
res = pd.DataFrame(data=list_predict[list_r2_score.index(max(list_r2_score))], columns=['predict'], index=data_for_predict.index)
res.head()

Unnamed: 0_level_0,predict
Compound,Unnamed: 1_level_1
Zn3WN4,2.397837
ZnGeN2,2.955551
ZnSiN2,3.450701
YWN3,2.835522
ZnSnN2,2.363861


In [198]:
# также все 100 циклов
res_2 = pd.DataFrame(index=data_for_predict.index)
for i, j in enumerate(list_predict):
    res_2[str(i)] = j
res_2

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,90,91,92,93,94,95,96,97,98,99
Compound,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Zn3WN4,2.322176,2.638852,2.607643,2.502913,2.452892,2.683900,2.697728,2.346845,2.485909,2.435717,...,2.299152,2.541424,2.528888,2.936352,2.350653,2.478482,2.540830,2.421195,2.637967,2.539810
ZnGeN2,2.708463,3.070761,3.147814,2.878765,2.700092,2.853963,3.211775,2.786111,3.212250,2.919568,...,3.058503,3.135924,2.864406,3.109748,2.807269,3.165345,2.437955,2.983957,2.979054,2.959085
ZnSiN2,3.215985,3.551035,3.643322,3.426854,3.078791,3.244972,3.646124,3.301886,3.745683,3.435228,...,3.531833,3.669192,3.230426,3.477173,3.342415,3.608619,2.862645,3.537615,3.468028,3.459979
YWN3,2.658022,2.848461,2.952472,2.818248,2.931489,2.975159,3.034802,2.721267,3.013225,2.817514,...,2.592410,2.733378,2.753102,3.023340,2.757530,2.938508,2.761303,2.837414,3.115358,2.906917
ZnSnN2,2.047465,2.501140,2.539263,2.283810,2.273100,2.417796,2.676342,2.151396,2.600229,2.326971,...,2.496417,2.488465,2.441985,2.707753,2.211895,2.686483,1.985034,2.335648,2.421370,2.368964
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ZrAl3N4,3.238984,3.477772,3.476685,3.486746,3.280530,3.278803,3.619693,3.428919,3.753678,3.531622,...,3.447207,3.659252,3.258713,3.450579,3.634991,3.519527,3.238920,3.700644,3.454986,3.425742
Mn3AlN3,2.776114,3.062033,3.102980,3.121176,2.590253,2.705786,3.059335,2.903740,3.113983,3.067523,...,2.867994,3.224414,2.808267,2.921132,3.104611,2.997751,2.696038,3.204122,2.878904,2.971901
Mn3InN3,1.793628,2.230995,2.252815,2.138090,1.952836,2.058017,2.194599,1.946852,2.208954,2.162112,...,1.998231,2.270447,2.142667,2.269498,2.188867,2.261419,1.922380,2.214149,2.029217,2.111645
Si6Mo3N11,2.521626,2.574712,2.614652,2.737048,2.407029,2.317844,2.482905,2.059983,2.687371,2.404882,...,2.274989,2.292353,2.303664,2.549951,2.673496,2.450009,2.271200,2.544798,2.603997,2.676725


In [208]:
res_2.to_csv(r'D:\Development\ML for new materials discovery\ML-for-new-materials-discovery\E4\result.csv')