# Try to do classification and regression on the solubility data 

In [39]:
%reload_ext autoreload
%autoreload 2

In [40]:
import time

from fastcore.utils import save_pickle
from sklearn.model_selection import train_test_split

from gpt3forchem.api_wrappers import extract_prediction, fine_tune, query_gpt3, extract_regression_prediction
from gpt3forchem.data import get_solubility_data
from gpt3forchem.input import create_prompts_solubility, _SOLUBILITY_FEATURES, encode_categorical_value
from gpt3forchem.output import get_regression_metrics
from pycm import ConfusionMatrix


from tabpfn.scripts.transformer_prediction_interface import TabPFNClassifier
from xgboost import XGBClassifier

import matplotlib.pyplot as plt

plt.style.use(['science', 'nature'])

## Classification

### TabPFN baseline

In [41]:
data = get_solubility_data()
train_data_10, test_data_10 = train_test_split(data, train_size=10, test_size=1000, random_state=42, stratify=data['Solubility_cat'])  
train_data_100, test_data_100 = train_test_split(data, train_size=100, test_size=1000, random_state=42, stratify=data['Solubility_cat'])    
train_data_500, test_data_500 = train_test_split(data, train_size=500, test_size=1000, random_state=42, stratify=data['Solubility_cat'])

In [42]:
def run_tabpfn_baseline(train_data, test_data):
    X_train, y_train = train_data[_SOLUBILITY_FEATURES], train_data['Solubility_cat']
    X_test, y_test = test_data[_SOLUBILITY_FEATURES], test_data['Solubility_cat']
    classifier = TabPFNClassifier(device='cpu', N_ensemble_configurations=32)

    classifier.fit(X_train, y_train)
    y_eval, p_eval = classifier.predict(X_test, return_winning_probability=True)
    cm = ConfusionMatrix(y_test.values, y_eval)

    return cm, y_eval

In [43]:
cm_10, y_eval_10 = run_tabpfn_baseline(train_data_10, test_data_10)
cm_100, y_eval_100 = run_tabpfn_baseline(train_data_100, test_data_100)
cm_500, y_eval_500 = run_tabpfn_baseline(train_data_500, test_data_500)

Using style prior: True
Using cpu:0 device
Using a Transformer with 25.82 M parameters
Using style prior: True
Using cpu:0 device
Using a Transformer with 25.82 M parameters
Using style prior: True
Using cpu:0 device
Using a Transformer with 25.82 M parameters


In [44]:
print(cm_10)

Predict          large            medium           small            very large       very small       
Actual
large            64               8                10               49               69               

medium           43               22               26               27               82               

small            36               30               26               3                105              

very large       91               8                3                41               57               

very small       19               6                5                1                169              





Overall Statistics : 

95% CI                                                            (0.29304,0.35096)
ACC Macro                                                         0.7288
ARI                                                               0.06762
AUNP                                                              0.57625
AUNU                                  

In [45]:
print(cm_100)

Predict          large            medium           small            very large       very small       
Actual
large            58               39               19               79               5                

medium           50               71               36               31               12               

small            11               56               91               6                36               

very large       14               13               5                165              3                

very small       3                10               28               14               145              





Overall Statistics : 

95% CI                                                            (0.49907,0.56093)
ACC Macro                                                         0.812
ARI                                                               0.24894
AUNP                                                              0.70625
AUNU                                   

In [46]:
print(cm_500)

Predict          large            medium           small            very large       very small       
Actual
large            102              53               8                34               3                

medium           30               110              41               12               7                

small            6                39               129              1                25               

very large       30               15               4                148              3                

very small       1                6                44               5                144              





Overall Statistics : 

95% CI                                                            (0.60313,0.66287)
ACC Macro                                                         0.8532
ARI                                                               0.32972
AUNP                                                              0.77062
AUNU                                  

## XGBoost

In [47]:
def run_xgboost_baseline(train_data, test_data):
    X_train, y_train = train_data[_SOLUBILITY_FEATURES], train_data['Solubility_cat'].map(encode_categorical_value)
    print(train_data['Solubility_cat'].unique())
    X_test, y_test = test_data[_SOLUBILITY_FEATURES], test_data['Solubility_cat'].map(encode_categorical_value) 
    classifier = XGBClassifier(n_estimators=5000)

    classifier.fit(X_train, y_train)
    y_eval = classifier.predict(X_test)
    cm = ConfusionMatrix(y_test.values, y_eval)

    return cm, y_eval

In [48]:
#cm_100, y_eval_100 = run_xgboost_baseline(train_data_100, test_data_100)
cm_500, y_eval_500 = run_xgboost_baseline(train_data_500, test_data_500)

['large' 'small' 'very large' 'very small' 'medium']


In [50]:
print(cm_500)

Predict   0         1         2         3         4         
Actual
0         140       38        8         6         8         

1         40        108       39        11        2         

2         10        40        76        54        20        

3         6         12        59        91        32        

4         10        5         15        38        132       





Overall Statistics : 

95% CI                                                            (0.51615,0.57785)
ACC Macro                                                         0.8188
ARI                                                               0.24115
AUNP                                                              0.71688
AUNU                                                              0.71688
Bangdiwala B                                                      0.31372
Bennett S                                                         0.43375
CBA                                                               0.

## GPT-3

In [51]:
def train_test_gpts(train_data, test_data, repr, regression=False, subsample:int=50): 
    train_prompts = create_prompts_solubility(train_data, representation=repr, regression=regression)
    test_prompts = create_prompts_solubility(test_data, representation=repr, regression=regression)

    train_size  = len(train_prompts)
    test_size = len(test_prompts)

    filename_base = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
    train_filename = f"run_files/{filename_base}_train_prompts_solubility_classification_{train_size}.jsonl"
    valid_filename = f"run_files/{filename_base}_valid_prompts_solubility_classification_{test_size}.jsonl"

    train_prompts.to_json(train_filename, orient="records", lines=True)
    test_prompts.to_json(valid_filename, orient="records", lines=True)

    modelname = fine_tune(train_filename, valid_filename)

    completions = query_gpt3(modelname, test_prompts) 
    predictions = [extract_prediction(completions, i) for i, completion in enumerate(completions["choices"])]
    true = [test_prompts.iloc[i]['completion'].split('@')[0] for i in range(len(test_prompts))]
    return modelname, predictions, completions, true

In [52]:
model_500, predictions_500, completions_500, true_500 = train_test_gpts(train_data_500, test_data_500, "SMILES", regression=False)

Traceback (most recent call last):
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/bin/openai", line 8, in <module>
    sys.exit(main())
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/_openai_scripts.py", line 63, in main
    args.func(args)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/cli.py", line 545, in sync
    resp = openai.wandb_logger.WandbLogger.sync(
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 74, in sync
    fine_tune_logged = [
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 75, in <listcomp>
    cls._log_fine_tune(
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 125, in _log_fine_tune
    wandb_run = cls._get_wandb_run(run_path)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/pyth

In [55]:
print(ConfusionMatrix(true_500, predictions_500))

Predict          large            medium           small            very large       very small       
Actual
large            73               51               23               46               7                

medium           50               80               38               20               12               

small            12               50               98               10               30               

very large       44               13               5                128              10               

very small       2                12               45               6                135              





Overall Statistics : 

95% CI                                                            (0.48302,0.54498)
ACC Macro                                                         0.8056
ARI                                                               0.20712
AUNP                                                              0.69625
AUNU                                  

In [56]:
model_500, predictions_500, completions_500, true_500 = train_test_gpts(train_data_500, test_data_500, "Name", regression=False)

Traceback (most recent call last):
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/bin/openai", line 8, in <module>
    sys.exit(main())
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/_openai_scripts.py", line 63, in main
    args.func(args)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/cli.py", line 545, in sync
    resp = openai.wandb_logger.WandbLogger.sync(
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 74, in sync
    fine_tune_logged = [
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 75, in <listcomp>
    cls._log_fine_tune(
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 125, in _log_fine_tune
    wandb_run = cls._get_wandb_run(run_path)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/pyth

In [57]:
print(ConfusionMatrix(true_500, predictions_500))

Predict          large            medium           small            very large       very small       
Actual
large            63               41               23               58               15               

medium           47               44               36               49               24               

small            26               53               71               22               28               

very large       52               31               6                95               16               

very small       10               20               47               21               102              





Overall Statistics : 

95% CI                                                            (0.34499,0.40501)
ACC Macro                                                         0.75
ARI                                                               0.08332
AUNP                                                              0.60938
AUNU                                    

In [58]:
model_500, predictions_500, completions_500, true_500 = train_test_gpts(train_data_500, test_data_500, "selfies", regression=False)

Traceback (most recent call last):
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/bin/openai", line 8, in <module>
    sys.exit(main())
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/_openai_scripts.py", line 63, in main
    args.func(args)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/cli.py", line 545, in sync
    resp = openai.wandb_logger.WandbLogger.sync(
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 74, in sync
    fine_tune_logged = [
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 75, in <listcomp>
    cls._log_fine_tune(
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 125, in _log_fine_tune
    wandb_run = cls._get_wandb_run(run_path)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/pyth

In [59]:
print(ConfusionMatrix(true_500, predictions_500))

Predict          large            medium           small            very large       very small       
Actual
large            71               57               11               46               15               

medium           41               87               40               16               16               

small            10               61               83               7                39               

very large       35               23               4                123              15               

very small       1                15               33               18               133              





Overall Statistics : 

95% CI                                                            (0.46601,0.52799)
ACC Macro                                                         0.7988
ARI                                                               0.18177
AUNP                                                              0.68562
AUNU                                  

In [60]:
model_500, predictions_500, completions_500, true_500 = train_test_gpts(train_data_500, test_data_500, "InChI", regression=False)

Traceback (most recent call last):
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/bin/openai", line 8, in <module>
    sys.exit(main())
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/_openai_scripts.py", line 63, in main
    args.func(args)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/cli.py", line 545, in sync
    resp = openai.wandb_logger.WandbLogger.sync(
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 74, in sync
    fine_tune_logged = [
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 75, in <listcomp>
    cls._log_fine_tune(
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 125, in _log_fine_tune
    wandb_run = cls._get_wandb_run(run_path)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/pyth

In [61]:
print(ConfusionMatrix(true_500, predictions_500))

Predict          large            medium           small            very large       very small       
Actual
large            81               24               28               55               12               

medium           45               60               46               31               18               

small            21               29               77               16               57               

very large       41               19               12               120              8                

very small       3                7                36               14               140              





Overall Statistics : 

95% CI                                                            (0.44704,0.50896)
ACC Macro                                                         0.7912
ARI                                                               0.16577
AUNP                                                              0.67375
AUNU                                  

In [62]:
model_500, predictions_500, completions_500, true_500 = train_test_gpts(train_data_500, test_data_500, "iupac_names", regression=False)

Traceback (most recent call last):
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/bin/openai", line 8, in <module>
    sys.exit(main())
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/_openai_scripts.py", line 63, in main
    args.func(args)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/cli.py", line 545, in sync
    resp = openai.wandb_logger.WandbLogger.sync(
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 74, in sync
    fine_tune_logged = [
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 75, in <listcomp>
    cls._log_fine_tune(
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 125, in _log_fine_tune
    wandb_run = cls._get_wandb_run(run_path)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/pyth

In [63]:
print(ConfusionMatrix(true_500, predictions_500))

Predict          large            medium           small            very large       very small       
Actual
large            68               43               25               50               14               

medium           28               60               55               29               28               

small            16               51               80               10               43               

very large       56               17               7                108              12               

very small       3                19               60               7                111              





Overall Statistics : 

95% CI                                                            (0.39634,0.45766)
ACC Macro                                                         0.7708
ARI                                                               0.12926
AUNP                                                              0.64188
AUNU                                  

In [64]:
model_100, predictions_100, completions_100, true_100 = train_test_gpts(train_data_100, test_data_100, "iupac_names", regression=False)

Traceback (most recent call last):
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/bin/openai", line 8, in <module>
    sys.exit(main())
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/_openai_scripts.py", line 63, in main
    args.func(args)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/cli.py", line 545, in sync
    resp = openai.wandb_logger.WandbLogger.sync(
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 74, in sync
    fine_tune_logged = [
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 75, in <listcomp>
    cls._log_fine_tune(
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 125, in _log_fine_tune
    wandb_run = cls._get_wandb_run(run_path)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/pyth

In [65]:
print(ConfusionMatrix(true_100, predictions_100))

Predict          large            medium           small            very large       very small       
Actual
large            54               24               13               96               13               

medium           22               40               12               110              16               

small            17               26               17               102              38               

very large       58               28               11               94               9                

very small       9                20               13               69               89               





Overall Statistics : 

95% CI                                                            (0.26576,0.32224)
ACC Macro                                                         0.7176
ARI                                                               0.03746
AUNP                                                              0.55875
AUNU                                  

In [66]:
model_10, predictions_10, completions_10, true_10 = train_test_gpts(train_data_10, test_data_10, "iupac_names", regression=False)

Traceback (most recent call last):
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/bin/openai", line 8, in <module>
    sys.exit(main())
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/_openai_scripts.py", line 63, in main
    args.func(args)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/cli.py", line 545, in sync
    resp = openai.wandb_logger.WandbLogger.sync(
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 74, in sync
    fine_tune_logged = [
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 75, in <listcomp>
    cls._log_fine_tune(
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 125, in _log_fine_tune
    wandb_run = cls._get_wandb_run(run_path)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/pyth

In [67]:
print(ConfusionMatrix(true_10, predictions_10))

Predict          large            medium           small            very large       very small       
Actual
large            0                0                44               129              27               

medium           0                0                49               126              25               

small            0                0                54               120              26               

very large       0                0                38               136              26               

very small       0                0                74               84               42               





Overall Statistics : 

95% CI                                                            (0.20584,0.25816)
ACC Macro                                                         0.6928
ARI                                                               0.00939
AUNP                                                              0.52
AUNU                                     