# Testing Out-of-Distribution Performance of Materials ML Models with MatFold

We will use the same Matminer/Matbench dataset as during lecture 11, but this time we will create an OOD test set using `MatFold`.

In [None]:
from matminer.datasets.dataset_retrieval import load_dataset
import pandas as pd
from MatFold import MatFold

df_full = load_dataset('matbench_perovskites')
df_full.head()  # type: ignore

In [None]:
n_sample = 2000
df = df_full.sample(n=2000, random_state=17)  # subsample for speed   # type: ignore
df.describe()

For MatFold, we will need a dictionary of the bulk structures (as a Pymatgen object) and as keys we will create an index that refers to each of them ("mat0", "mat1",...).

In [None]:
from pymatgen.core import Structure

df['mat_index'] = ['mat{}'.format(i) for i in range(len(df))]
df = df[['mat_index', 'structure', 'e_form']]
mat_struct_dict = {
    row['mat_index']: row['structure'].as_dict() for _, row in df.iterrows()
}
print(df.head())

We will use the `EwaldSumMatrix` featurizer again.

In [None]:
from dscribe.descriptors import EwaldSumMatrix
from pymatgen.io.ase import AseAtomsAdaptor
import numpy as np

# Determine the maximum number of atoms across all structures in the dataset
n_max = 0
for mat in df['structure']:
    if len(mat) > n_max :
        n_max = len(mat)
print(n_max)

ews = EwaldSumMatrix(n_atoms_max=n_max, permutation="eigenspectrum")

ase_structures = [AseAtomsAdaptor.get_atoms(struc) for struc in df['structure']]
ews_matrices = np.array(ews.create(ase_structures))

In [None]:
ews_columns = [f'ews_{i}' for i in range(ews_matrices.shape[1])]
df_featurized_ews = df.copy()
df_featurized_ews[ews_columns] = pd.DataFrame(ews_matrices, index=df_featurized_ews.index)
df_featurized_ews = df_featurized_ews.drop(columns=['structure'])
print(df_featurized_ews.head())

MatFold can create out-of-distribution splits based on these categories:

- `index` (or `random`)
- `structureid` (or `structure`)
- `composition` (or `comp`)
- `chemsys` (or `chemicalsystem`)
- `sgnum` (or `spacegroup`, `spacegroupnumber`)
- `pointgroup` (or `pg`, `pointgroupsymbol`, `pgsymbol`)
- `crystalsys` (or `crystalsystem`)
- `elements` (or `elems`)
- `periodictablerows` (or `ptrows`)
- `periodictablegroups` (or `ptgroups`)

First, we create an instance of the `MatFold` class by supplying the full Pandas dataframe (that contains the features) and the dictionary we created earlier that contains the bulk Pymatgen structures. Then we can use the function `split_statistics` to look at the distributions of different categories in the dataset.

In [None]:
mf = MatFold(df_featurized_ews, mat_struct_dict)
mf.split_statistics("elements")
mf.split_statistics("crystalsys")

In [None]:
train_df, val_df, test_df = mf.create_train_validation_test_splits("random", "elements", 
                                                                   train_fraction=0.7, 
                                                                   validation_fraction=0.2, 
                                                                   test_fraction=0.1,
                                                                   n_test_min=4)

Now, we scale the training dataset features with the `StandardScaler` and then apply that scaling to the validation and test sets. In lecture 11 we scaled the entire dataset at once, which is not ideal because there may be data leakage (information from validation and test set "leaking" through the scaling information). For a entirely random split, this effect is minor, but for an OOD test set, this leakage may be more severe.

In [None]:
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()

y_train = train_df['e_form'].to_numpy()
y_val = val_df['e_form'].to_numpy()  # type: ignore
y_test = test_df['e_form'].to_numpy()

X_train = scaler.fit_transform(train_df.loc[:, 'ews_0':'ews_4'])
X_val = scaler.transform(val_df.loc[:, 'ews_0':'ews_4'])  # type: ignore
X_test = scaler.transform(test_df.loc[:, 'ews_0':'ews_4'])

print(y_train.shape)
print(X_train.shape)

Let's calculate the baseline metrics for both the validation and test set.

In [None]:
from sklearn.metrics import mean_absolute_error

mean_train = y_train.mean()
baseline_mae_test = mean_absolute_error(y_test, [mean_train] * len(y_test))
print(f"Baseline MAE for test-set (predicting mean formation energy): {baseline_mae_test:.4f} eV")
baseline_mae_val = mean_absolute_error(y_val, [mean_train] * len(y_val))  # type: ignore
print(f"Baseline MAE for validation-set (predicting mean formation energy): {baseline_mae_val:.4f} eV")

Hyerparameter tuning for the RF model using the validation set metrics.

In [None]:
from sklearn.ensemble import RandomForestRegressor
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

n_estimators_list = [10, 50, 100, 500]
train_maes = []
val_maes = []

for n in tqdm(n_estimators_list, desc='Training RF models'):
    rf = RandomForestRegressor(n_estimators=n, random_state=17, n_jobs=1)
    rf.fit(X_train, y_train)
    y_train_pred = rf.predict(X_train)
    y_val_pred = rf.predict(X_val)
    train_maes.append(mean_absolute_error(y_train, y_train_pred))
    val_maes.append(mean_absolute_error(y_val, y_val_pred))  # type: ignore

plt.figure(figsize=(8, 5))
plt.plot(n_estimators_list, train_maes, marker='o', label='Train MAE')
plt.plot(n_estimators_list, val_maes, marker='o', label='Validation MAE')
plt.axhline(baseline_mae, color='gray', linestyle='--', label='Mean Baseline')
plt.xlabel('Number of Estimators')
plt.ylabel('Mean Absolute Error')
plt.title('Random Forest: MAE vs. Number of Estimators')
plt.legend()
plt.tight_layout()
plt.show()

Final model training and metric evaluation.

In [None]:
from sklearn.metrics import r2_score

rf_final = RandomForestRegressor(n_estimators=100, random_state=17, n_jobs=1)
rf_final.fit(X_train, y_train)

y_train_pred = rf_final.predict(X_train)
y_val_pred = rf_final.predict(X_val)
y_test_pred = rf_final.predict(X_test)

# Calculate metrics
mae_train = mean_absolute_error(y_train, y_train_pred)
mae_val = mean_absolute_error(y_val, y_val_pred)
mae_test = mean_absolute_error(y_test, y_test_pred)

r2_train = r2_score(y_train, y_train_pred)
r2_val = r2_score(y_val, y_val_pred)
r2_test = r2_score(y_test, y_test_pred)

fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharex=True, sharey=True)
min_val = 0.0
max_val = 5.0

# Train parity plot
axes[0].scatter(y_train, y_train_pred, alpha=0.5, color='blue')
axes[0].plot([min_val, max_val], [min_val, max_val], 'k--', lw=2)
axes[0].set_title(f'Train (MAE={mae_train:.3f}, R2={r2_train:.3f})')
axes[0].set_xlabel('True Formation Energy')
axes[0].set_ylabel('Predicted Formation Energy')
axes[0].set_aspect('equal', adjustable='box')

# Validation parity plot
axes[1].scatter(y_val, y_val_pred, alpha=0.5, color='orange')
axes[1].plot([min_val, max_val], [min_val, max_val], 'k--', lw=2)
axes[1].set_title(f'Validation (MAE={mae_val:.3f}, R2={r2_val:.3f})')
axes[1].set_xlabel('True Formation Energy')
axes[1].set_ylabel('Predicted Formation Energy')
axes[1].set_aspect('equal', adjustable='box')

# Test parity plot
axes[2].scatter(y_test, y_test_pred, alpha=0.5, color='green')
axes[2].plot([min_val, max_val], [min_val, max_val], 'k--', lw=2)
axes[2].set_title(f'Test (MAE={mae_test:.3f}, R2={r2_test:.3f})')
axes[2].set_xlabel('True Formation Energy')
axes[2].set_ylabel('Predicted Formation Energy')
axes[2].set_aspect('equal', adjustable='box')

plt.tight_layout()
plt.show()

### Conclusion
The OOD test set metrics are worse than the ID (random) split:

ID: MAE=0.377 eV, R2=50.0%

OOD: MAE=0.382 eV, R2=30.3%

## Exercise 12.1

Now, try this pipeline again but for a different OOD test set category (e.g, `sgnum`).