# Try to do classification and regression on the solubility data 

In [1]:
%reload_ext autoreload
%autoreload 2

In [3]:
import time

from fastcore.utils import save_pickle
from sklearn.model_selection import train_test_split

from gpt3forchem.api_wrappers import extract_prediction, fine_tune, query_gpt3, extract_regression_prediction
from gpt3forchem.data import get_solubility_data
from gpt3forchem.input import create_prompts_solubility, _SOLUBILITY_FEATURES, encode_categorical_value
from gpt3forchem.output import get_regression_metrics
from pycm import ConfusionMatrix


from tabpfn.scripts.transformer_prediction_interface import TabPFNClassifier
from xgboost import XGBClassifier

import matplotlib.pyplot as plt

plt.style.use(['science', 'nature'])

## Classification

### TabPFN baseline

In [4]:
data = get_solubility_data()
train_data_100, test_data_100 = train_test_split(data, train_size=100, random_state=42, stratify=data['Solubility_cat'])    
train_data_500, test_data_500 = train_test_split(data, train_size=500, random_state=42, stratify=data['Solubility_cat'])

In [5]:
def run_tabpfn_baseline(train_data, test_data):
    X_train, y_train = train_data[_SOLUBILITY_FEATURES], train_data['Solubility_cat']
    X_test, y_test = test_data[_SOLUBILITY_FEATURES], test_data['Solubility_cat']
    classifier = TabPFNClassifier(device='cpu', N_ensemble_configurations=32)

    classifier.fit(X_train, y_train)
    y_eval, p_eval = classifier.predict(X_test, return_winning_probability=True)
    cm = ConfusionMatrix(y_test.values, y_eval)

    return cm, y_eval

In [6]:
cm_100, y_eval_100 = run_tabpfn_baseline(train_data_100, test_data_100)
cm_500, y_eval_500 = run_tabpfn_baseline(train_data_500, test_data_500)

Using style prior: True
Using cpu:0 device
Using a Transformer with 25.82 M parameters


In [None]:
print(cm_100)

Predict          large            medium           small            very large       very small       
Actual
large            3656             528              14               714              0                

medium           639              1165             208              163              0                

small            45               184              357              67               0                

very large       530              16               0                1521             0                

very small       0                1                12               1                0                





Overall Statistics : 

95% CI                                                            (0.6729,0.69132)
ACC Macro                                                         0.87284
ARI                                                               0.32011
AUNP                                                              0.75498
AUNU                                  

In [None]:
print(cm_500)

Predict          large            medium           small            very large       very small       
Actual
large            3821             527              12               352              0                

medium           518              1346             163              59               0                

small            37               237              348              5                0                

very large       668              18               3                1294             0                

very small       0                2                11               0                0                





Overall Statistics : 

95% CI                                                            (0.71371,0.73179)
ACC Macro                                                         0.8891
ARI                                                               0.38371
AUNP                                                              0.77799
AUNU                                  

## XGBoost

In [None]:
def run_xgboost_baseline(train_data, test_data):
    X_train, y_train = train_data[_SOLUBILITY_FEATURES], train_data['Solubility_cat'].map(encode_categorical_value)
    print(train_data['Solubility_cat'].unique())
    X_test, y_test = test_data[_SOLUBILITY_FEATURES], test_data['Solubility_cat'].map(encode_categorical_value) 
    classifier = XGBClassifier(n_estimators=5000)

    classifier.fit(X_train, y_train)
    y_eval = classifier.predict(X_test)
    cm = ConfusionMatrix(y_test.values, y_eval)

    return cm, y_eval

In [None]:
#cm_100, y_eval_100 = run_xgboost_baseline(train_data_100, test_data_100)
cm_500, y_eval_500 = run_xgboost_baseline(train_data_500, test_data_500)

['large' 'very large' 'medium' 'small' 'very small']


In [None]:
print(cm_500)

Predict    0          1          2          3          4          
Actual
0          1          9          2          1          0          

1          2          345        218        50         12         

2          0          195        1249       554        88         

3          0          19         591        3675       427        

4          0          3          32         658        1290       





Overall Statistics : 

95% CI                                                            (0.68703,0.7056)
ACC Macro                                                         0.87853
ARI                                                               0.3411
AUNP                                                              0.76005
AUNU                                                              0.71765
Bangdiwala B                                                      0.53331
Bennett S                                                         0.6204
CBA                               

## GPT-3

In [None]:
def train_test_gpts(train_data, test_data, repr, regression=False, subsample:int=50): 
    train_prompts = create_prompts_solubility(train_data, representation=repr, regression=regression)
    test_prompts = create_prompts_solubility(test_data, representation=repr, regression=regression)

    train_size  = len(train_prompts)
    test_size = len(test_prompts)

    filename_base = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
    train_filename = f"run_files/{filename_base}_train_prompts_solubility_classification_{train_size}.jsonl"
    valid_filename = f"run_files/{filename_base}_valid_prompts_solubility_classification_{test_size}.jsonl"

    train_prompts.to_json(train_filename, orient="records", lines=True)
    test_prompts.to_json(valid_filename, orient="records", lines=True)

    modelname = fine_tune(train_filename, valid_filename)

    completions = query_gpt3(modelname, test_prompts) 
    predictions = [extract_prediction(completions, i) for i, completion in enumerate(completions["choices"])]
    true = [test_prompts.iloc[i]['completion'].split('@')[0] for i in range(len(test_prompts))]
    return modelname, predictions, completions, true

In [59]:
model_500, predictions_500, completions_500, true_500 = train_test_gpts(train_data_500, test_data_500, "SMILES", regression=False)

Traceback (most recent call last):
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/bin/openai", line 8, in <module>
    sys.exit(main())
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/_openai_scripts.py", line 63, in main
    args.func(args)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/cli.py", line 545, in sync
    resp = openai.wandb_logger.WandbLogger.sync(
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 74, in sync
    fine_tune_logged = [
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 75, in <listcomp>
    cls._log_fine_tune(
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 125, in _log_fine_tune
    wandb_run = cls._get_wandb_run(run_path)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/pyth

APIConnectionError: Error communicating with OpenAI

In [58]:
print(ConfusionMatrix(true_500, predictions_500))

Predict          large            medium           small            very large       
Actual
large            20               5                0                5                

medium           5                8                0                0                

small            0                0                2                0                

very large       1                0                0                4                





Overall Statistics : 

95% CI                                                            (0.5507,0.8093)
ACC Macro                                                         0.84
ARI                                                               0.22733
AUNP                                                              0.72688
AUNU                                                              0.81698
Bangdiwala B                                                      0.48497
Bennett S                                                         0.57333
CBA      