# **Splitting statistics**

In [1]:
import os
import shutil
from matminer.datasets.dataset_retrieval import load_dataset, get_all_dataset_info
from matminer.datasets import get_available_datasets

In [2]:
import pandas as pd
df_full = load_dataset('matbench_perovskites')
df_full.describe()  # type: ignore

Unnamed: 0,e_form
count,18928.0
mean,1.470932
std,0.742502
min,-0.64
25%,0.96
50%,1.36
75%,1.84
max,5.16


# **70-20-10 splits**

#### **Splitting Composition Split, Chemsys Split, sgnum Split, Pointgroup Split, Element Split, PT Groups Split**

In [3]:
# Load dataset & prepare
df = load_dataset('matbench_perovskites')
df['matbench_perovskites'] = [f"mat{i}" for i in range(len(df))]
df = df.rename(columns={'e_form': 'TARGET'})
df = df[['matbench_perovskites', 'structure', 'TARGET']]
df['composition'] = df['structure'].apply(lambda x: x.composition)
mat_struct_dict = {row['matbench_perovskites']: row['structure'].as_dict() for _, row in df.iterrows()}


### Add structure df code

In [4]:
df['structure'] = df['structure'].map(lambda x: x.as_dict())

In [5]:

from MatFold import MatFold
mf = MatFold(df, mat_struct_dict, write_data_checksums=False)

### **Composition splits**

In [6]:
# Create 70-20-10 split for composition
split_type = "composition"
train_df, val_df, test_df = mf.create_train_validation_test_splits(
    split_type_validation=split_type,
    split_type_test=split_type,
    train_fraction=0.7,
    validation_fraction=0.2,
    test_fraction=0.1,
    n_test_min=1,
    verbose=False
)

print(f"Composition split done.")


Composition split done.


In [10]:
from collections import Counter

print("Train, Validation, Test sizes:")
print(len(train_df), len(val_df), len(test_df))

print("Train unique compositions:", len(Counter(train_df['composition'])))
print("Val unique compositions:", len(Counter(val_df['composition'])))
print("Test unique compositions:", len(Counter(test_df['composition'])))

Train, Validation, Test sizes:
13247 3798 1883
Train unique compositions: 6753
Val unique compositions: 1929
Test unique compositions: 964


### **Chemsys splits**

In [7]:
# Create 70-20-10 split for chemsys
split_type = "chemsys"
train_df, val_df, test_df = mf.create_train_validation_test_splits(
    split_type_validation=split_type,
    split_type_test=split_type,
    train_fraction=0.7,
    validation_fraction=0.2,
    test_fraction=0.1,
    n_test_min=1,
    verbose=False
)

print(f"Chemsys split done.")


Chemsys split done.


### **Sgnum split**

In [12]:
train_df, val_df, test_df = mf.create_train_validation_test_splits(
    split_type_validation='sgnum',
    split_type_test='sgnum',
    train_fraction=0.7,
    validation_fraction=0.2,
    test_fraction=0.1,
    n_test_min=1,
    verbose=False
)


Exception: Error. Either train (len=18928) or test (len=0) set is empty and split cannot be created.

#### Sgnum split outright fails for a 70-20-10 split, even though (as you can see from the sgnum distribution below) that the groups are certainly larger enough to make the split

In [19]:
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

def get_sgnum_with_tolerance(structure, tol=0.1):
    try:
        sga = SpacegroupAnalyzer(structure, symprec=tol)
        return sga.get_space_group_number()
    except Exception:
        return None

df['sgnum'] = df['structure'].apply(lambda s: get_sgnum_with_tolerance(s, tol=0.1))
df = df.dropna(subset=['sgnum'])  # remove rows where sgnum couldn't be determined
df['sgnum'] = df['sgnum'].astype(int)  # convert float to int for convenience


print(df['sgnum'].value_counts())

sgnum
123    5629
25     4310
99     3639
221    3115
47     2235
Name: count, dtype: int64


#### Change split to 0.3, 0.3, 0.4 to show it works with different ratios

In [30]:
train_df, val_df, test_df = mf.create_train_validation_test_splits(
            split_type_validation='sgnum',
            split_type_test='sgnum',
            train_fraction=0.3,
            validation_fraction=0.3,
            test_fraction=0.4,
            n_test_min=1,
            verbose=False)


### **Pointgroup split**

### Pointgroup split fails outright for 70/20/10, but will succeed if the split fractions are different. Both of these categories (sgnum, and pointgroup) should be the most likely options to make a 70/20/10 split, however, just from the statistics

In [13]:
train_df, val_df, test_df = mf.create_train_validation_test_splits(
    split_type_validation='pointgroup',
    split_type_test='pointgroup',
    train_fraction=0.7,
    validation_fraction=0.2,
    test_fraction=0.1,
    n_test_min=1,
    verbose=False)

Exception: Error. Either train (len=18928) or test (len=0) set is empty and split cannot be created.

In [72]:
import pandas as pd

# Check for missing pointgroup entries
missing_count = df['pointgroup'].isna().sum()
if missing_count > 0:
    print(f"Warning: There are {missing_count} missing pointgroup entries in the dataset.")
else:
    print("No missing pointgroup entries detected.")

# Print distribution of pointgroups
pg_counts = df['pointgroup'].value_counts()
print("\nPointgroup distribution:")
print(pg_counts)


No missing pointgroup entries detected.

Pointgroup distribution:
pointgroup
4/mmm    5629
mm2      4310
4mm      3639
m-3m     3115
mmm      2235
Name: count, dtype: int64


### Change split to 0.3, 0.3, 0.4 to check

In [28]:
for n in range(1, 10):
    try:
        train_df, val_df, test_df = mf.create_train_validation_test_splits(
            split_type_validation='pointgroup',
            split_type_test='pointgroup',
            train_fraction=0.3,
            validation_fraction=0.3,
            test_fraction=0.4,
            n_test_min=n,
            verbose=False
        )
        print(f"Split successful with n_test_min={n}")
        break
    except Exception as e:
        print(f"Failed with n_test_min={n}: {e}")

Split successful with n_test_min=1


### **Element split**

Doesn't explicitly fail, but ran for >20mins even with n_test_min=1

In [68]:
train_df, val_df, test_df = mf.create_train_validation_test_splits(
    split_type_validation='elements',
    split_type_test='elements',
    train_fraction=0.7,
    validation_fraction=0.2,
    test_fraction=0.1,
    n_test_min=1,
    verbose=False
)

print("Element split done.")

KeyboardInterrupt: 

### **Periodic Table Groups**

n_test_min=1

In [8]:
train_df, val_df, test_df = mf.create_train_validation_test_splits(
    split_type_validation='periodictablegroups',
    split_type_test='periodictablegroups',
    train_fraction=0.7,
    validation_fraction=0.2,
    test_fraction=0.1,
    n_test_min=1,
    verbose=False
)

print("PT groups split done.")

PT groups split done.


# **80-0-20 split**

### **Composition Split**

In [9]:

split_type = "composition"
train_df, val_df, test_df = mf.create_train_validation_test_splits(
    split_type_validation=split_type,
    split_type_test=split_type,
    train_fraction=0.8,
    validation_fraction=0.0,
    test_fraction=0.2,
    n_test_min=1,
    verbose=False
)

print(f"Composition split done.")

Composition split done.


### **Chemsys Split**

In [10]:

split_type = "chemsys"
train_df, val_df, test_df = mf.create_train_validation_test_splits(
    split_type_validation=split_type,
    split_type_test=split_type,
    train_fraction=0.8,
    validation_fraction=0.0,
    test_fraction=0.2,
    n_test_min=1,
    verbose=False
)

print(f"Chemsys split done.")


Chemsys split done.


### **Sgnum split**

n_test_min=10

In [11]:
train_df, val_df, test_df = mf.create_train_validation_test_splits(
    split_type_validation='sgnum',
    split_type_test='sgnum',
    train_fraction=0.8,
    validation_fraction=0.0,
    test_fraction=0.2,
    n_test_min=10,
    verbose=False)

### **Pointgroup split**

n_test_min=10

In [12]:
train_df, val_df, test_df = mf.create_train_validation_test_splits(
    split_type_validation='pointgroup',
    split_type_test='pointgroup',
    train_fraction=0.8,
    validation_fraction=0.0,
    test_fraction=0.2,
    n_test_min=10,
    verbose=False)

### **Elements split**

n_test_min=6

In [13]:
train_df, val_df, test_df = mf.create_train_validation_test_splits(
    split_type_validation='elements',
    split_type_test='elements',
    train_fraction=0.8,
    validation_fraction=0.0,
    test_fraction=0.2,
    n_test_min=6,
    verbose=False
)

print("Element split done.")

Element split done.


### **Periodic Table Groups**

n_test_min=2

In [14]:
train_df, val_df, test_df = mf.create_train_validation_test_splits(
    split_type_validation='periodictablegroups',
    split_type_test='periodictablegroups',
    train_fraction=0.8,
    validation_fraction=0.0,
    test_fraction=0.2,
    n_test_min=2,
    verbose=False
)

print("PT groups split done.")

PT groups split done.
