# Predicting single vs. multiphase for alloys

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from sklearn.model_selection import train_test_split

from gptchem.formatter import ClassificationFormatter
from gptchem.querier import Querier
from gptchem.tuner import Tuner
from gptchem.evaluator import evaluate_classification
from gptchem.extractor import ClassificationExtractor
from gptchem.data import get_hea_phase_data

In [3]:
data = get_hea_phase_data()

In [4]:
data.head()

Unnamed: 0,Alloy,Phase,phase_encoded,phase_binary_encoded
0,Ag0.05Zr0.95,bcc,1,1
1,Al0.15Cr0.85,bcc,1,1
2,Al0.1Fe0.9,bcc,1,1
3,Al0.1Hf0.9,bcc,1,1
4,Al0.1Ti0.9,bcc,1,1


In [5]:
data['phase_encoded'].value_counts()

0    627
1    261
2    218
3    146
Name: phase_encoded, dtype: int64

In [6]:
formatter = ClassificationFormatter(
    representation_column='Alloy',
    label_column='phase_binary_encoded',
    property_name='phase',
    num_classes=None,
    qcut=None
)

In [7]:
formatted = formatter(data)

In [8]:
formatted

Unnamed: 0,prompt,completion,label,representation
0,What is the phase of Ag0.05Zr0.95?###,1@@@,1,Ag0.05Zr0.95
1,What is the phase of Al0.15Cr0.85?###,1@@@,1,Al0.15Cr0.85
2,What is the phase of Al0.1Fe0.9?###,1@@@,1,Al0.1Fe0.9
3,What is the phase of Al0.1Hf0.9?###,1@@@,1,Al0.1Hf0.9
4,What is the phase of Al0.1Ti0.9?###,1@@@,1,Al0.1Ti0.9
...,...,...,...,...
1247,What is the phase of SrY?###,0@@@,0,SrY
1248,What is the phase of TaTb?###,0@@@,0,TaTb
1249,What is the phase of TaTl?###,0@@@,0,TaTl
1250,What is the phase of TaTm?###,0@@@,0,TaTm


In [9]:
train, test = train_test_split(formatted, train_size=50, test_size=100, stratify=formatted['label'], random_state=42)

In [10]:
tuner = Tuner(base_model='ada', n_epochs=8, learning_rate_multiplier=0.02, wandb_sync=False)

In [11]:
tune_summary = tuner(train)

Upload progress: 100%|██████████| 5.57k/5.57k [00:00<00:00, 3.88Mit/s]


Uploaded file from /Users/kevinmaikjablonka/git/kjappelbaum/gptchem/experiments/03_classification/hea_single_vs_multiphase/out/20230109_151611/train.jsonl: file-YRG9zLE74de5ykSxQzKp4tZm


In [13]:
querier = Querier(tune_summary['model_name'])

In [14]:
completions = querier(test, logprobs=2)

In [15]:
extractor = ClassificationExtractor()
predictions = extractor(completions)

In [16]:
metrics = evaluate_classification(test['label'], predictions)

In [17]:
metrics

{'accuracy': 0.95,
 'acc_macro': 0.95,
 'racc': 0.5,
 'kappa': 0.8999999999999999,
 'confusion_matrix': pycm.ConfusionMatrix(classes: [0, 1]),
 'f1_macro': 0.9499549594635172,
 'f1_micro': 0.95,
 'frac_valid': 1.0,
 'all_y_true': (#100) [0,1,1,0,1,1,1,0,1,1...],
 'all_y_pred': (#100) [0,1,0,1,1,1,1,0,1,0...],
 'valid_indices': [0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  45,
  46,
  47,
  48,
  49,
  50,
  51,
  52,
  53,
  54,
  55,
  56,
  57,
  58,
  59,
  60,
  61,
  62,
  63,
  64,
  65,
  66,
  67,
  68,
  69,
  70,
  71,
  72,
  73,
  74,
  75,
  76,
  77,
  78,
  79,
  80,
  81,
  82,
  83,
  84,
  85,
  86,
  87,
  88,
  89,
  90,
  91,
  92,
  93,
  94,
  95,
  96,
  97,
  98,
  99],
 'might_have_rounded_floats': False}