In [1]:
from ase.db import connect
from icet import (ClusterSpace, StructureContainer,
                  CrossValidationEstimator, ClusterExpansion)

# step 1: Basic setup

In [2]:
db = connect('reference_data.db')
prim = db.get(id=1).toatoms()  # primitive structure

# step 2: Set up the basic structure and a cluster space

In [3]:
cs = ClusterSpace(atoms=prim,
                  cutoffs=[13.5, 6.0, 5.5],
                  chemical_symbols=['Ag', 'Pd'])
print(cs)

 chemical species: ['Ag', 'Pd'] (sublattice A)
 cutoffs: 13.5000 6.0000 5.5000
 total number of orbits: 55
 number of orbits by order: 0= 1  1= 1  2= 25  3= 12  4= 16
--------------------------------------------------------------------------------------------
index | order |  radius  | multiplicity | orbit_index | multi_component_vector | sublattices
--------------------------------------------------------------------------------------------
   0  |   0   |   0.0000 |        1     |      -1     |           .            |      .     
   1  |   1   |   0.0000 |        1     |       0     |          [0]           |      A     
   2  |   2   |   1.4460 |        6     |       1     |         [0, 0]         |     A-A    
   3  |   2   |   2.0450 |        3     |       2     |         [0, 0]         |     A-A    
   4  |   2   |   2.5046 |       12     |       3     |         [0, 0]         |     A-A    
   5  |   2   |   2.8921 |        6     |       4     |         [0, 0]         |     A-A 

# step 3: Parse the input structures and set up a structure container

In [4]:
sc = StructureContainer(cluster_space=cs)
for row in db.select('natoms<=8'):
    sc.add_structure(atoms=row.toatoms(),
                     user_tag=row.tag,
                     properties={'mixing_energy': row.mixing_energy})
print(sc)

Total number of structures: 625
-------------------------------------------------------------------------
index |       user_tag        | natoms | chemical formula | mixing_energy
-------------------------------------------------------------------------
   0  | Ag                    |     1  | Ag               |      0.000   
   1  | Pd                    |     1  | Pd               |      0.000   
   2  | AgPd_0002             |     2  | AgPd             |     -0.040   
   3  | AgPd_0003             |     3  | AgPd2            |     -0.029   
   4  | AgPd_0004             |     3  | Ag2Pd            |     -0.049   
   5  | AgPd_0005             |     3  | AgPd2            |     -0.018   
   6  | AgPd_0006             |     3  | Ag2Pd            |     -0.056   
   7  | AgPd_0007             |     3  | AgPd2            |     -0.030   
   8  | AgPd_0008             |     3  | Ag2Pd            |     -0.048   
   9  | AgPd_0009             |     4  | AgPd3            |     -0.017   
 ...
 

# step 4: Train parameters

In [5]:
opt = CrossValidationEstimator(fit_data=sc.get_fit_data(key='mixing_energy'),
                               fit_method='lasso')
opt.validate()
opt.train()
print(opt)

alpha_optimal                  : 3.162278e-05
fit_method                     : lasso
n_nonzero_parameters           : 40
n_parameters                   : 55
n_splits                       : 10
n_target_values                : 625
rmse_train                     : 0.00207139
rmse_train_final               : 0.002089748
rmse_validation                : 0.002349935
seed                           : 42
standardize                    : True
validation_method              : k-fold


# step 5: Construct cluster expansion and write it to file

In [6]:
ce = ClusterExpansion(cluster_space=cs, parameters=opt.parameters)
print(ce)
ce.write('mixing_energy.ce')


 chemical species: ['Ag', 'Pd'] (sublattice A)
 cutoffs: 13.5000 6.0000 5.5000
 total number of orbits: 55
 number of orbits by order: 0= 1  1= 1  2= 25  3= 12  4= 16
--------------------------------------------------------------------------------------------------------
index | order |  radius  | multiplicity | orbit_index | multi_component_vector | sublattices |    ECI   
--------------------------------------------------------------------------------------------------------
   0  |   0   |   0.0000 |        1     |      -1     |           .            |      .      |    -0.045
   1  |   1   |   0.0000 |        1     |       0     |          [0]           |      A      |   -0.0353
   2  |   2   |   1.4460 |        6     |       1     |         [0, 0]         |     A-A     |    0.0284
   3  |   2   |   2.0450 |        3     |       2     |         [0, 0]         |     A-A     |    0.0134
   4  |   2   |   2.5046 |       12     |       3     |         [0, 0]         |     A-A     |    