In [1]:
from typing import Iterable
from collections import defaultdict
import pandas as pd
from pymatgen.core import Structure
from functools import partial
from tqdm.auto import tqdm
from pymatgen.analysis.structure_matcher import StructureMatcher
tqdm().pandas()
data = pd.read_csv('../generated/mp_20/WyckoffTransformer/DiffCSP++10k/CHGNet_free/DFT/WyFormer-1-MP GGA static.csv.gz',
                   index_col="material_id", converters={"structure": partial(Structure.from_str, fmt="json")})

0it [00:00, ?it/s]

In [2]:
data.head()

Unnamed: 0_level_0,e_above_hull_corrected,e_uncorrected,e_corrected,structure,entry
material_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
69,0.285988,-34.561289,-34.561289,"[[3.42928824 3.42928833 3.42928825] Tb, [5.361...","{""@module"": ""pymatgen.entries.computed_entries..."
67,0.458247,-9.982895,-11.584896,[[4.05009323e+00 1.04009090e-06 1.07315501e+00...,"{""@module"": ""pymatgen.entries.computed_entries..."
59,-0.033989,-52.869647,-52.869647,"[[0. 0. 0.] Tb, [-5.03120460e-09 3.61987122e+...","{""@module"": ""pymatgen.entries.computed_entries..."
75,0.380581,-26.114099,-25.830099,"[[3.1503182 3.15031867 3.15031783] Mg, [4.753...","{""@module"": ""pymatgen.entries.computed_entries..."
23,0.40029,-55.886481,-58.378481,"[[3.01833352 3.85863256 2.82988216] B, [0.3353...","{""@module"": ""pymatgen.entries.computed_entries..."


In [3]:
def filter_by_unique_structure_chem_sys_index(
    data: pd.DataFrame,
    attempt_supercell: bool = False,
    symmetric: bool = False) -> pd.DataFrame:

    present = defaultdict(list)
    unique_indices = []
    for index, structure in data.structure.items():
        # Strutures consisiting of different sets of elements
        # can't match in any way
        chem_system = frozenset(structure.composition)
        if chem_system not in present:
            unique_indices.append(index)
        else:
            for present_structure in present[chem_system]:
                if StructureMatcher(attempt_supercell=attempt_supercell).fit(
                    structure, present_structure, symmetric=symmetric):
                    break
            else:
                unique_indices.append(index)
        present[chem_system].append(structure)
    return data.loc[unique_indices]

In [4]:
unique = filter_by_unique_structure_chem_sys_index(data)

In [5]:
len(unique) / len(data)

0.9880494648238595

In [6]:
mp_20_train = pd.read_csv('../cdvae/data/mp_20/train.csv', index_col=0)
mp_20_train["structure"] = mp_20_train.cif.apply(Structure.from_str, fmt="cif")



In [7]:
class NoveltyFilter:
    def __init__(self,
        reference_structures: Iterable[Structure],
        attempt_supercell: bool = False,
        symmetric: bool = False):

        self.attempt_supercell = attempt_supercell
        self.symmetric = symmetric
        self.reference = defaultdict(list)
        for structure in reference_structures:
            chem_system = frozenset(structure.composition)
            self.reference[chem_system].append(structure)
    
    def __call__(self, structure: Structure) -> bool:
        """
        Returns True if the structure is novel, i.e. not matching any of the reference structures
        """
        chem_system = frozenset(structure.composition)
        if chem_system not in self.reference:
            return True
        for reference_structure in self.reference[chem_system]:
            if StructureMatcher(attempt_supercell=self.attempt_supercell).fit(
                structure, reference_structure, symmetric=self.symmetric):                
                return False
        return True

In [8]:
novelty_filter = NoveltyFilter(reference_structures=mp_20_train.structure, attempt_supercell=False, symmetric=False)

In [9]:
novel  = data[data.structure.apply(novelty_filter)]

In [10]:
len(novel) / len(data)

0.9178011015275902

In [11]:
novel_unique = filter_by_unique_structure_chem_sys_index(novel)
len(novel_unique) / len(data)

0.9114621219993765

In [12]:
def is_sun(record: pd.Series) -> bool:
    return (record.e_above_hull_corrected < 0) and \
        (len(set(record.structure.composition)) >= 2)

In [13]:
sun = novel_unique[novel_unique.apply(is_sun, axis=1)]
print(len(sun) / len(data))
print(len(sun) / 10000)

0.039280889535487896
0.0378
