# Initial experiments with bootstrapped ensemble uncertainty

In [17]:
import pandas as pd 
from gpt3forchem.api_wrappers import fine_tune, query_gpt3, extract_prediction
from gpt3forchem.data import get_waterstability_data
from sklearn.model_selection import train_test_split
from pycm import ConfusionMatrix
import time

import numpy as np

from scipy.stats import mode

In [3]:
data = get_waterstability_data()

In [4]:
train_data, test_data = train_test_split(data, train_size=0.8, random_state=42, stratify=data['stability'])

In [35]:
true = test_data['stability'].apply(lambda x: 0 if x == 'low' else 1.).values

In [5]:
PROMPT_TEMPLATE_water_stability= "How is the water stability of {}"###"
COMPLETION_TEMPLATE_water_stability = "{}@@@"


def generate_water_stability_prompts(
    data: pd.DataFrame
) -> pd.DataFrame:
    prompts = []
    completions = []
    for i, row in data.iterrows():

        prompt = PROMPT_TEMPLATE_water_stability.format(
            row['normalized_names']
        )

        stability = 0 if row['stability'] == 'low' else 1
        completion = COMPLETION_TEMPLATE_water_stability.format(stability)
        prompts.append(prompt)
        completions.append(completion)

    prompts = pd.DataFrame(
        {"prompt": prompts, "completion": completions,}
    )

    return prompts

In [6]:
models = []

for i in range(10):
    # resample the training set with replacement
    train_data_resampled = train_data.sample(n=len(train_data), replace=True)
    prompts = generate_water_stability_prompts(train_data_resampled)
    filename_base = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
    train_filename = f"run_files/{filename_base}_train_prompts_water_stability.jsonl"
    valid_filename = f"run_files/{filename_base}_valid_prompts_water_stability.jsonl"
    test_prompts = generate_water_stability_prompts(test_data)

    prompts.to_json(train_filename, orient="records", lines=True)
    test_prompts.to_json(valid_filename, orient="records", lines=True)

    model = fine_tune(train_filename, valid_filename)
    models.append(model)

Traceback (most recent call last):
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/bin/openai", line 8, in <module>
    sys.exit(main())
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/_openai_scripts.py", line 63, in main
    args.func(args)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/cli.py", line 545, in sync
    resp = openai.wandb_logger.WandbLogger.sync(
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 74, in sync
    fine_tune_logged = [
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 75, in <listcomp>
    cls._log_fine_tune(
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/python3.9/site-packages/openai/wandb_logger.py", line 125, in _log_fine_tune
    wandb_run = cls._get_wandb_run(run_path)
  File "/Users/kevinmaikjablonka/miniconda3/envs/gpt3/lib/pyth

In [7]:
models

['ada:ft-lsmoepfl-2022-12-13-10-16-48',
 'ada:ft-lsmoepfl-2022-12-13-10-19-46',
 'ada:ft-lsmoepfl-2022-12-13-10-22-38',
 'ada:ft-lsmoepfl-2022-12-13-10-25-33',
 'ada:ft-lsmoepfl-2022-12-13-10-28-25',
 'ada:ft-lsmoepfl-2022-12-13-10-31-14',
 'ada:ft-lsmoepfl-2022-12-13-10-34-08',
 'ada:ft-lsmoepfl-2022-12-13-10-36-58',
 'ada:ft-lsmoepfl-2022-12-13-10-39-51',
 'ada:ft-lsmoepfl-2022-12-13-10-42-42']

In [10]:
overall_predictions = []
overall_completions = []

for model in models:
    completions = query_gpt3(model, test_prompts)
    predictions = []

    for i in range(len(completions['choices'])):
        try:
            pred = int(extract_prediction(completions, i))
        except:
            pred = np.nan
        predictions.append(pred)
    overall_predictions.append(predictions)
    overall_completions.append(completions)

In [12]:
overall_predictions = np.array(overall_predictions)

In [13]:
overall_predictions.shape

(10, 38)

In [28]:
mode_pred = mode(overall_predictions, axis=0).mode.flatten()

  mode_pred = mode(overall_predictions, axis=0).mode.flatten()


In [67]:
std = np.std(overall_predictions, axis=0)

In [70]:
high_conf_mask = std <0.5 

In [71]:
std

array([ 0.        ,  6.        ,         nan,  0.        ,  6.04069532,
        0.45825757,  0.        ,  0.        ,  0.3       , 12.03702621,
        6.        ,  0.48989795,  0.        ,  0.3       ,  0.48989795,
        0.3       ,  0.3       ,  0.45825757,  0.        ,  0.        ,
        0.        ,  0.        ,  6.        ,  0.4       ,  0.        ,
        0.        ,  0.        ,  0.45825757,  0.3       ,  0.45825757,
        0.        ,  0.48989795,  0.3       ,  0.48989795,  0.        ,
        0.        ,  0.        ,  0.        ])

In [72]:
high_conf_mask

array([ True, False, False,  True, False,  True,  True,  True,  True,
       False, False,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True, False,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True])

In [73]:
mode_pred

array([1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1., 0., 0., 0.,
       1., 1., 1., 1.])

In [74]:
high_conf_pred= mode_pred[high_conf_mask]

In [75]:
print(ConfusionMatrix(true, mode_pred))

Predict   0.0       1.0       
Actual
0.0       5         3         

1.0       2         28        





Overall Statistics : 

95% CI                                                            (0.76094,0.9759)
ACC Macro                                                         0.86842
ARI                                                               0.46573
AUNP                                                              0.77917
AUNU                                                              0.77917
Bangdiwala B                                                      0.82049
Bennett S                                                         0.73684
CBA                                                               0.76411
CSI                                                               0.58792
Chi-Squared                                                       13.10154
Chi-Squared DF                                                    1
Conditional Entropy                                  

In [76]:
print(ConfusionMatrix(true[high_conf_mask], high_conf_pred))

Predict   0.0       1.0       
Actual
0.0       5         3         

1.0       2         22        





Overall Statistics : 

95% CI                                                            (0.71794,0.96956)
ACC Macro                                                         0.84375
ARI                                                               0.41689
AUNP                                                              0.77083
AUNU                                                              0.77083
Bangdiwala B                                                      0.77591
Bennett S                                                         0.6875
CBA                                                               0.7525
CSI                                                               0.56798
Chi-Squared                                                       10.30095
Chi-Squared DF                                                    1
Conditional Entropy                                   