In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import pandas as pd 
from pymatgen.core import Structure
from glob import glob
from pathlib import Path

from gptchem.gpt_classifier import GPTClassifier
from gptchem.tuner import Tuner

from gptchem.evaluator import evaluate_classification

from sklearn.dummy import DummyClassifier

In [4]:
all_data = []
cu1_cifs = glob('cu1_cifs/*.cif')
cu2_cifs = glob('cu2_cifs/*.cif')

In [5]:
for cif in cu1_cifs:
    structure = Structure.from_file(cif)
    all_data.append({
        'composition': structure.composition.reduced_formula,
        'oxidation_state': 1,
        'cif': Path(cif).name
    })



In [6]:
for cif in cu2_cifs:
    structure = Structure.from_file(cif)
    all_data.append({
        'composition': structure.composition.reduced_formula,
        'oxidation_state': 2,
        'cif': Path(cif).name
    })



In [7]:
df = pd.DataFrame(all_data)

In [8]:
df.head()

Unnamed: 0,composition,oxidation_state,cif
0,Cu4H16C11(IN)4,1,MAQJAS.cif
1,CuH11C5N3O,1,CUSLUY.cif
2,CuH21C28IN4,1,WEXTEB.cif
3,CuC29N4,1,JARMEU.cif
4,CuPH26C28S2(N2F3)2,1,LEBCIF.cif


In [9]:
df.tail()

Unnamed: 0,composition,oxidation_state,cif
10592,CuH14C18N4O5,2,POTWUQ.cif
10593,CuH42C32N4O11,2,CIBRIP.cif
10594,CuH30C22(NO)8,2,PEJVEH.cif
10595,CuH10C12(NO3)2,2,YEGSAE02.cif
10596,Cu3H12(C7N13)2,2,HORBOG.cif


In [10]:
from sklearn.model_selection import train_test_split

In [11]:
train, test = train_test_split(df, train_size=100, test_size=200, random_state=42, stratify=df['oxidation_state'])

In [12]:
dummy = DummyClassifier(strategy='majority')

In [13]:
len(train)

8477

In [18]:
import openai
openai.api_key = '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'

In [19]:
classifier = GPTClassifier('oxidation state', Tuner(n_epochs=8, learning_rate_multiplier=0.02, wandb_sync=False))

In [20]:
classifier.fit(train['cif'].values, train['oxidation_state'].values)

Upload progress: 100%|██████████| 12.1k/12.1k [00:00<00:00, 4.50Mit/s]


Uploaded file from /Users/kevinmaikjablonka/git/kjappelbaum/gptchem/experiments/03_classification/oxidation_states/out/20230407_161155/train.jsonl: file-7ZDCIloaHDuAOEx7YLY4kLvy


In [22]:
predictions = classifier.predict(test['cif'].values)

In [24]:
res = evaluate_classification(test['oxidation_state'].values, predictions)

In [25]:
res

{'accuracy': 0.72,
 'acc_macro': 0.72,
 'racc': 0.7214,
 'kappa': -0.005025125628140948,
 'confusion_matrix': pycm.ConfusionMatrix(classes: [1, 2]),
 'f1_macro': 0.4791666666666667,
 'f1_micro': 0.72,
 'frac_valid': 1.0,
 'all_y_true': (#200) [2,2,2,2,1,2,2,2,2,2...],
 'all_y_pred': (#200) [2,2,2,1,2,2,1,1,2,2...],
 'valid_indices': [0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  45,
  46,
  47,
  48,
  49,
  50,
  51,
  52,
  53,
  54,
  55,
  56,
  57,
  58,
  59,
  60,
  61,
  62,
  63,
  64,
  65,
  66,
  67,
  68,
  69,
  70,
  71,
  72,
  73,
  74,
  75,
  76,
  77,
  78,
  79,
  80,
  81,
  82,
  83,
  84,
  85,
  86,
  87,
  88,
  89,
  90,
  91,
  92,
  93,
  94,
  95,
  96,
  97,
  98,
  99,
  100,
  101,
  102,
  103,
  104,
  105,
  106,
  107,
  108,
  109,
  110,

Now, build an actual model ...

In [26]:
classifier_composition = GPTClassifier('oxidation state', Tuner(n_epochs=8, learning_rate_multiplier=0.02, wandb_sync=False))

In [27]:
classifier_composition.fit(train['composition'].values, train['oxidation_state'].values)

Upload progress: 100%|██████████| 12.9k/12.9k [00:00<00:00, 12.8Mit/s]


Uploaded file from /Users/kevinmaikjablonka/git/kjappelbaum/gptchem/experiments/03_classification/oxidation_states/out/20230407_165724/train.jsonl: file-q5TJH2jTmY8xZIevKD5bl8MR


In [28]:
predictions_composition = classifier_composition.predict(test['composition'].values)

In [29]:
res_composition = evaluate_classification(test['oxidation_state'].values, predictions_composition)

In [30]:
print(res_composition)

{'accuracy': 0.83, 'acc_macro': 0.83, 'racc': 0.7106, 'kappa': 0.4125777470628886, 'confusion_matrix': pycm.ConfusionMatrix(classes: [1, 2]), 'f1_macro': 0.6987951807228916, 'f1_micro': 0.83, 'frac_valid': 1.0, 'all_y_true': [2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 1, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 1, 1, 2, 1, 2, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 1, 2, 1, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2], 'all_y_pred': [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1, 2, 2, 1, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2,

In [None]:
# Parse BV data

In [32]:
import numpy as np

In [79]:
def parse_logfile(filepath):
    with open(filepath, 'r') as fh: 
        loglines = fh.readlines()
    try:
        if 'state: N' in loglines[-1]: 
            return np.nan
        elif 'state: 1' in loglines[-1]: 
            return 1
        elif 'state: 2' in loglines[-1]:
            return 2
    except Exception:
        return np.nan

In [115]:
cu_1_bv = glob('cu1_cifs/*.log2')
cu_2_bv = glob('cu2_cifs/*.log2')

# assert len(cu_1_bv) == len(cu1_cifs), f'{len(cu_1_bv)} != {len(cu1_cifs)}'
# assert len(cu_2_bv) == len(cu2_cifs), f'{len(cu_2_bv)} != {len(cu2_cifs)}'
all_bv_data = []

for bv in cu_1_bv:
    all_bv_data.append({
        'bv': parse_logfile(bv),
        'name': Path(bv).name.replace('.log2', '')[:5]
    })

for bv in cu_2_bv:
    all_bv_data.append({
        'bv': parse_logfile(bv),
        'name': Path(bv).name.replace('.log2', '')[:5]
    })


In [116]:
df_bv = pd.DataFrame(all_bv_data)

In [117]:
df_bv

Unnamed: 0,bv,name
0,1.0,SUSNA
1,1.0,IFIQO
2,1.0,CAVLO
3,1.0,CUHBU
4,1.0,OLOQO
...,...,...
10615,2.0,EDUDU
10616,2.0,HOSTE
10617,2.0,WOBCE
10618,1.0,SIWPA


In [118]:
oximachine_results = pd.read_csv('cu_predictions.csv')

In [119]:
oximachine_results['name'] = oximachine_results['name'].apply(lambda x: x + '.cif')

In [164]:
df['name'] = df['cif'].apply(lambda x: x.replace('.cif', '')[:5])

In [165]:
df['oxidation_state'].value_counts()

2    8145
1    2452
Name: oxidation_state, dtype: int64

In [166]:
merged_data = df.merge(df_bv, left_on='name', right_on='name', how='inner').merge(oximachine_results, left_on='name', right_on='name', how='inner')

In [167]:
merged_data['oxidation_state'].value_counts()

2    54856
1    15704
Name: oxidation_state, dtype: int64

In [168]:
merged_data.drop_duplicates(subset=['name'], inplace=True)

In [169]:
merged_data['oxidation_state'].value_counts()

2    7223
1    2222
Name: oxidation_state, dtype: int64

In [171]:
import ast

clean_cu_prediction = []
for i, row in merged_data.iterrows():
    try:
        clean_cu_prediction.append(ast.literal_eval(row['cu_pred'])[0])
    except Exception:
        clean_cu_prediction.append(np.nan)
        
merged_data['prediction'] = clean_cu_prediction

In [172]:
merged_data.dropna(subset=['prediction', 'bv'], inplace=True)

In [173]:
merged_data

Unnamed: 0,composition,oxidation_state,cif,name,bv,oximachine,cu_pred,indices,metals,prediction
0,Cu4H16C11(IN)4,1,MAQJAS.cif,MAQJA,1.0,"[1, 1, 1, 1, 1, 1, 1, 1]",[1],"[0, 1, 2, 3, 4, 5, 6, 7]","['Cu', 'Cu', 'Cu', 'Cu', 'Cu', 'Cu', 'Cu', 'Cu']",1.0
1,CuH11C5N3O,1,CUSLUY.cif,CUSLU,1.0,"[1, 1, 1, 1]",[1],"[0, 1, 2, 3]","['Cu', 'Cu', 'Cu', 'Cu']",1.0
2,CuH21C28IN4,1,WEXTEB.cif,WEXTE,1.0,"[1, 1, 1, 1, 1, 1, 1, 1]",[1],"[0, 1, 2, 3, 4, 5, 6, 7]","['Cu', 'Cu', 'Cu', 'Cu', 'Cu', 'Cu', 'Cu', 'Cu']",1.0
3,CuC29N4,1,JARMEU.cif,JARME,1.0,"[1, 1]",[1],"[0, 1]","['Cu', 'Cu']",1.0
11,CuPH26C28S2(N2F3)2,1,LEBCIF.cif,LEBCI,1.0,"[1, 1, 1, 1]",[1],"[0, 1, 2, 3]","['Cu', 'Cu', 'Cu', 'Cu']",1.0
...,...,...,...,...,...,...,...,...,...,...
70555,Gd2Cu3H24(C2O)24,2,WICVAG.cif,WICVA,2.0,"[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, ...",[2],"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","['Gd', 'Gd', 'Gd', 'Gd', 'Gd', 'Gd', 'Gd', 'Gd...",2.0
70556,CuH14C18N4O5,2,POTWUQ.cif,POTWU,1.0,"[2, 2, 2, 2]",[2],"[0, 1, 2, 3]","['Cu', 'Cu', 'Cu', 'Cu']",2.0
70557,CuH42C32N4O11,2,CIBRIP.cif,CIBRI,2.0,"[2, 2, 2, 2]",[2],"[0, 1, 2, 3]","['Cu', 'Cu', 'Cu', 'Cu']",2.0
70558,CuH30C22(NO)8,2,PEJVEH.cif,PEJVE,2.0,"[2, 2]",[2],"[0, 1]","['Cu', 'Cu']",2.0


In [174]:
merged_data['bv'] = merged_data['bv'].astype(int)
merged_data['prediction'] = merged_data['prediction'].astype(int)

In [175]:
merged_data.to_pickle('merged_data.pkl')

In [176]:
merged_data

Unnamed: 0,composition,oxidation_state,cif,name,bv,oximachine,cu_pred,indices,metals,prediction
0,Cu4H16C11(IN)4,1,MAQJAS.cif,MAQJA,1,"[1, 1, 1, 1, 1, 1, 1, 1]",[1],"[0, 1, 2, 3, 4, 5, 6, 7]","['Cu', 'Cu', 'Cu', 'Cu', 'Cu', 'Cu', 'Cu', 'Cu']",1
1,CuH11C5N3O,1,CUSLUY.cif,CUSLU,1,"[1, 1, 1, 1]",[1],"[0, 1, 2, 3]","['Cu', 'Cu', 'Cu', 'Cu']",1
2,CuH21C28IN4,1,WEXTEB.cif,WEXTE,1,"[1, 1, 1, 1, 1, 1, 1, 1]",[1],"[0, 1, 2, 3, 4, 5, 6, 7]","['Cu', 'Cu', 'Cu', 'Cu', 'Cu', 'Cu', 'Cu', 'Cu']",1
3,CuC29N4,1,JARMEU.cif,JARME,1,"[1, 1]",[1],"[0, 1]","['Cu', 'Cu']",1
11,CuPH26C28S2(N2F3)2,1,LEBCIF.cif,LEBCI,1,"[1, 1, 1, 1]",[1],"[0, 1, 2, 3]","['Cu', 'Cu', 'Cu', 'Cu']",1
...,...,...,...,...,...,...,...,...,...,...
70555,Gd2Cu3H24(C2O)24,2,WICVAG.cif,WICVA,2,"[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, ...",[2],"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","['Gd', 'Gd', 'Gd', 'Gd', 'Gd', 'Gd', 'Gd', 'Gd...",2
70556,CuH14C18N4O5,2,POTWUQ.cif,POTWU,1,"[2, 2, 2, 2]",[2],"[0, 1, 2, 3]","['Cu', 'Cu', 'Cu', 'Cu']",2
70557,CuH42C32N4O11,2,CIBRIP.cif,CIBRI,2,"[2, 2, 2, 2]",[2],"[0, 1, 2, 3]","['Cu', 'Cu', 'Cu', 'Cu']",2
70558,CuH30C22(NO)8,2,PEJVEH.cif,PEJVE,2,"[2, 2]",[2],"[0, 1]","['Cu', 'Cu']",2


In [177]:
merged_data['oxidation_state'].value_counts()

2    7220
1    2221
Name: oxidation_state, dtype: int64