In [16]:
import gc
from collections import OrderedDict
from pathlib import Path
from typing import Callable, Dict

import fire
import numpy as np
import pandas as pd
from crabnet.crabnet_ import CrabNet
from fastcore.xtras import save_pickle
from loguru import logger
from numpy.typing import ArrayLike
from pycm import ConfusionMatrix
from pymatgen.core import Composition
from sklearn.model_selection import train_test_split

from gptchem.data import get_hea_phase_data
from gptchem.evaluator import evaluate_classification

logger.enable("gptchem")

NUM_REPEATS = 10
LEARNING_CURVE_POINTS = [
    20,
    50,
    100,
    200,
    10,
]
TEST_SIZE = 250

In [7]:
class RuleBaseline:
    def predict(self, X: ArrayLike) -> ArrayLike:
        """Predict property values for a set of molecular representations.

        Args:
            X (ArrayLike): Input data (typically array of molecular representations)

        Returns:
            ArrayLike: Predicted property values
        """
        return [1 if "." in x else 0 for x in X]

In [3]:
data = get_hea_phase_data()

In [5]:
data["dot_in_comp"] = data["Alloy"].apply(lambda x: "." in x)

In [13]:
results = []


classifier = RuleBaseline()

for train_size in LEARNING_CURVE_POINTS:
    train, test = train_test_split(
        data,
        train_size=train_size,
        test_size=TEST_SIZE,
        stratify=data["phase_binary_encoded"],
        random_state=3245,
    )

    predictions = classifier.predict(test["Alloy"])
    res = evaluate_classification(test["phase_binary_encoded"], predictions)

    results.append(res)

In [17]:
results = pd.DataFrame(results)

In [18]:
results

Unnamed: 0,accuracy,acc_macro,racc,kappa,confusion_matrix,f1_macro,f1_micro,frac_valid,all_y_true,all_y_pred,valid_indices,might_have_rounded_floats
0,0.848,0.848,0.5,0.696,"((0, {0: 94, 1: 31}), (1, {0: 7, 1: 118}))",0.846586,0.848,1.0,"[1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, ...","[1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, ...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",False
1,0.844,0.844,0.5,0.688,"((0, {0: 93, 1: 32}), (1, {0: 7, 1: 118}))",0.842424,0.844,1.0,"[1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, ...","[0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, ...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",False
2,0.844,0.844,0.5,0.688,"((0, {0: 93, 1: 32}), (1, {0: 7, 1: 118}))",0.842424,0.844,1.0,"[0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, ...","[0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, ...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",False
3,0.856,0.856,0.5,0.712,"((0, {0: 94, 1: 31}), (1, {0: 5, 1: 120}))",0.854425,0.856,1.0,"[1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, ...","[1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, ...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",False
4,0.852,0.852,0.5,0.704,"((0, {0: 95, 1: 30}), (1, {0: 7, 1: 118}))",0.850737,0.852,1.0,"[0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, ...","[1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, ...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",False


In [19]:
results.agg("mean")

accuracy                     0.848800
acc_macro                    0.848800
racc                         0.500000
kappa                        0.697600
f1_macro                     0.847319
f1_micro                     0.848800
frac_valid                   1.000000
might_have_rounded_floats    0.000000
dtype: float64