# Task Settings

In [1]:
import os
os.environ['OPENAI_API_KEY']="YOUR_API_KEY"

from utils import *
from prompts import Prompts, TASK_LABLES, TAGS

dataset_name = 'swiss'                      # 'trafficsafe', 'pamdemic'
test_model = "experimental:gpt-4o-mini"     # Forward Engine
eval_model = "gpt-4o"                       # Backward Engine
iteration = 1
date = '0919'
total_steps=9                               # 5-11
epoch=1                                     # 2, 1
batch_size=3                                # 1-3

# Initialize

In [2]:
cm_labels = TASK_LABLES[dataset_name]
tags = TAGS[dataset_name]
CAUSAL_SYSTEM = Prompts[dataset_name]['CAUSAL_SYSTEM']
CAUSAL_SYSTEM_CONSTRAINT = Prompts[dataset_name]['CAUSAL_SYSTEM_CONSTRAINT']
SYSTEM = Prompts[dataset_name]['SYSTEM']

llm_api_eval = tg.get_engine(engine_name=eval_model)
llm_api_test = tg.get_engine(engine_name=test_model, cache=False)
tg.set_backward_engine(llm_api_eval, override=True)

train_set, val_set, test_set_ori, eval_fn = load_task(dataset_name, evaluation_api=llm_api_eval, prompt_col="organized_prompt")
train_loader = tg.tasks.DataLoader(train_set, batch_size=batch_size, shuffle=True)
col = "organized_prompt" if dataset_name == 'swiss' else "prompt"
train_set.data[col] = train_set.data[col].apply(lambda x: f"{tags[0]}{x}{tags[1]}")
val_set.data[col] = val_set.data[col].apply(lambda x: f"{tags[0]}{x}{tags[1]}")
test_set_ori.data[col] = test_set_ori.data[col].apply(lambda x: f"{tags[0]}{x}{tags[1]}")
print("Train/Val/Test Set Lengths: ", len(train_set), len(val_set), len(test_set_ori))

Train/Val/Test Set Lengths:  168 84 85


In [3]:
import matplotlib.pyplot as plt
from copy import deepcopy

res = []

test_set = deepcopy(test_set_ori)
system_prompt, causal_prompt, model, causal_model, optimizer, optimizer_causal = init(SYSTEM, CAUSAL_SYSTEM, llm_api_test, llm_api_eval, CAUSAL_SYSTEM_CONSTRAINT)
results = {"test_f1": [], "prompt": [], "validation_f1": [], 'system_prompt':[], 'causal_prompt': []}

# Run EGO-Prompt

In [4]:
import time
import copy
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed

NUM_WORKERS = 1 

def run_one_worker(worker_id: int):


    local_test_set = copy.deepcopy(test_set)
    local_test_set_ori = copy.deepcopy(test_set_ori)
    local_val_set = copy.deepcopy(val_set)
    local_train_loader = train_loader

    val_performance = -float('inf')
    test_performance = -float('inf')
    final_results = None
    all_val_f1s = []
    all_test_f1s = []

    local_test_set.data[col] = local_test_set_ori.data[col].apply(
        lambda x: f"<!-- {time.time()} (w{worker_id}) -->, {x}"
    )

    for cur_iter in range(iteration):
        print(f"[Worker {worker_id}] [Iteration {cur_iter+1}/{iteration}] begin")

        output_json = (
            f"res/{date}_{dataset_name}_{test_model.split('/')[-1].split(':')[-1]}_"
            f"w{worker_id}_it{cur_iter+1}.json"
        )
        initialize_json_file(output_json)

        system_prompt, causal_prompt, model, causal_model, optimizer, optimizer_causal = init(
            SYSTEM, CAUSAL_SYSTEM, llm_api_test, llm_api_eval, CAUSAL_SYSTEM_CONSTRAINT
        )

        results, test_res, val_res = init_eval(
            local_val_set, local_test_set, eval_fn, model, causal_model,
            system_prompt, causal_prompt, cm_labels, iters=ITERS    
        )

        results = run_training(
            local_train_loader, local_val_set, local_test_set, eval_fn,
            model, causal_model, system_prompt, causal_prompt,
            optimizer, optimizer_causal, results, cm_labels,
            output_json=output_json, epoch=epoch, steps=total_steps, iters=ITERS
        )

        all_val_f1s.append(results['validation_f1'])
        all_test_f1s.append(results['test_f1'])

        cur_val = results['validation_f1'][-1]
        cur_test = results['test_f1'][-1]
        if cur_val > val_performance:
            val_performance = cur_val
            test_performance = cur_test
            final_results = results

        print(f"[Worker {worker_id}] [Iteration {cur_iter+1}] "
              f"val_best={val_performance:.4f}, test_at_best={test_performance:.4f}")

    return {
        'best_test_f1': test_performance,
        'val_f1s': all_val_f1s,
        'test_f1s': all_test_f1s,
        'worker_id': worker_id,
    }

# Running the cell below will incur API usage charges. Refer to our paper for detailed cost

In [5]:
NUM_WORKERS = 3
ITERS = 1
total_steps=5

EGO_res = []


with ThreadPoolExecutor(max_workers=NUM_WORKERS) as ex:
    futures = [ex.submit(run_one_worker, i) for i in range(NUM_WORKERS)]
    for fut in as_completed(futures):
        res = fut.result()
        EGO_res.append(res['best_test_f1'])
        print(f"[Main] Worker {res['worker_id']} done. Best test_f1={res['best_test_f1']:.4f}")

print("EGO_res (best test F1 per worker):", EGO_res)

[Worker 0] [Iteration 1/1] begin[Worker 1] [Iteration 1/1] begin

[Worker 2] [Iteration 1/1] begin


Accuracy: 0.1647: 100%|██████████| 85/85 [00:17<00:00,  4.89it/s]
Accuracy: 0.2235: 100%|██████████| 85/85 [00:19<00:00,  4.41it/s]
Accuracy: 0.2471: 100%|██████████| 85/85 [00:30<00:00,  2.83it/s]
Accuracy: 0.2381: 100%|██████████| 84/84 [00:20<00:00,  4.10it/s]
Accuracy: 0.1364:  24%|██▍       | 20/84 [00:07<00:07,  8.33it/s]

SCG_val_f1: 0.3182539682539683, SCG_test_f1:0.20487520070062765


Accuracy: 0.1304:  27%|██▋       | 23/84 [00:07<00:06,  9.95it/s]


Epoch 0, Step 0


Accuracy: 0.2381: 100%|██████████| 84/84 [00:22<00:00,  3.80it/s]


SCG_val_f1: 0.31224489795918364, SCG_test_f1:0.3009764229578471


0it [00:00, ?it/s]


Epoch 0, Step 0


Accuracy: 0.1786: 100%|██████████| 84/84 [00:15<00:00,  5.52it/s]


SCG_val_f1: 0.24574175824175828, SCG_test_f1:0.2765328353563648


0it [00:00, ?it/s]


Epoch 0, Step 0


Accuracy: 0.3690: 100%|██████████| 84/84 [00:16<00:00,  5.21it/s]


[System Validation] F1: 0.3978, Previous F1: 0.2457
[System Validation CM]:
[[15  5  7]
 [15 11  2]
 [ 6  5  5]]


Accuracy: 0.3810: 100%|██████████| 84/84 [00:14<00:00,  5.92it/s]


[Causal Validation] F1: 0.4046, Previous F1: 0.3978
[Causal Validation CM]:
[[11  6 11]
 [13 15  3]
 [ 6  5  6]]


Accuracy: 0.4524: 100%|██████████| 84/84 [00:37<00:00,  2.26it/s]


[System Validation] F1: 0.4528, Previous F1: 0.3183
[System Validation CM]:
[[18  8  7]
 [16 15  2]
 [ 9  3  5]]


Accuracy: 0.3929: 100%|██████████| 84/84 [00:23<00:00,  3.54it/s]


[System Validation] F1: 0.4419, Previous F1: 0.3122
[System Validation CM]:
[[13  5 10]
 [ 8 13  3]
 [ 6  2  7]]


Accuracy: 0.4471: 100%|██████████| 85/85 [00:44<00:00,  1.91it/s]
Accuracy: 0.8000:   6%|▌         | 5/84 [00:10<01:35,  1.20s/it]

[Test Result] F1: 0.4484

Epoch 0, Step 1


Accuracy: 0.4286: 100%|██████████| 84/84 [00:21<00:00,  3.85it/s]


Skip Test
[Causal Validation] F1: 0.4217, Previous F1: 0.4528
[Causal Validation CM]:
[[19  6  8]
 [18 13  2]
 [ 9  5  4]]


Accuracy: 0.4824: 100%|██████████| 85/85 [00:20<00:00,  4.21it/s]
Accuracy: 0.2911:  94%|█████████▍| 79/84 [00:20<00:02,  2.49it/s]

[Test Result] F1: 0.4634

Epoch 0, Step 1


Accuracy: 0.2976: 100%|██████████| 84/84 [00:25<00:00,  3.31it/s]


Skip Test
[Causal Validation] F1: 0.3529, Previous F1: 0.4419
[Causal Validation CM]:
[[ 9  5  6]
 [ 9 12  2]
 [ 9  3  4]]


Accuracy: 0.3412: 100%|██████████| 85/85 [00:25<00:00,  3.34it/s]
1it [05:15, 315.73s/it]

[Test Result] F1: 0.3636

Epoch 0, Step 1


Accuracy: 0.3810: 100%|██████████| 84/84 [00:26<00:00,  3.13it/s]


Skip Test
[System Validation] F1: 0.3884, Previous F1: 0.4046
[System Validation CM]:
[[15  9  7]
 [16 12  3]
 [ 8  5  5]]


Accuracy: 0.4048: 100%|██████████| 84/84 [00:23<00:00,  3.57it/s]


Skip Test
[System Validation] F1: 0.3934, Previous F1: 0.4528
[System Validation CM]:
[[19  7  7]
 [22  9  2]
 [ 8  4  6]]


Accuracy: 0.3929: 100%|██████████| 84/84 [00:24<00:00,  3.44it/s]


[Causal Validation] F1: 0.4078, Previous F1: 0.4046
[Causal Validation CM]:
[[14  7 11]
 [15 13  1]
 [ 8  4  6]]


Accuracy: 0.4286: 100%|██████████| 84/84 [00:17<00:00,  4.70it/s]


Skip Test
[System Validation] F1: 0.4394, Previous F1: 0.4419
[System Validation CM]:
[[16  4 13]
 [16 14  2]
 [ 6  5  6]]


Accuracy: 0.4235: 100%|██████████| 85/85 [00:24<00:00,  3.54it/s]
2it [08:47, 266.33s/it]

[Test Result] F1: 0.4325

Epoch 0, Step 2


Accuracy: 0.3333: 100%|██████████| 84/84 [00:17<00:00,  4.90it/s]
2it [09:13, 270.10s/it]

Skip Test
[Causal Validation] F1: 0.3907, Previous F1: 0.4419
[Causal Validation CM]:
[[13  2  7]
 [13  8  2]
 [ 5  1  7]]
Skip Test

Epoch 0, Step 2


Accuracy: 0.3929: 100%|██████████| 84/84 [00:32<00:00,  2.57it/s]
2it [09:19, 278.26s/it]

Skip Test
[Causal Validation] F1: 0.3900, Previous F1: 0.4528
[Causal Validation CM]:
[[17  5 11]
 [20 10  3]
 [ 8  4  6]]
Skip Test

Epoch 0, Step 2


Accuracy: 0.3929: 100%|██████████| 84/84 [00:32<00:00,  2.61it/s]


[System Validation] F1: 0.4198, Previous F1: 0.4078
[System Validation CM]:
[[13  9  8]
 [12 12  2]
 [ 4  5  8]]


Accuracy: 0.4167: 100%|██████████| 84/84 [00:23<00:00,  3.51it/s]


[Causal Validation] F1: 0.4320, Previous F1: 0.4198
[Causal Validation CM]:
[[15  8  8]
 [16 11  3]
 [ 4  3  9]]


Accuracy: 0.3214: 100%|██████████| 84/84 [01:34<00:00,  1.13s/it]


Skip Test
[System Validation] F1: 0.3764, Previous F1: 0.4419
[System Validation CM]:
[[13  1 14]
 [12 10  5]
 [ 3  3  4]]


Accuracy: 0.4286: 100%|██████████| 84/84 [00:27<00:00,  3.09it/s]


[Causal Validation] F1: 0.4620, Previous F1: 0.4419
[Causal Validation CM]:
[[13  9  6]
 [ 8 16  4]
 [ 6  3  7]]


Accuracy: 0.4762: 100%|██████████| 84/84 [02:01<00:00,  1.44s/it]


[System Validation] F1: 0.4679, Previous F1: 0.4528
[System Validation CM]:
[[21  8  4]
 [17 13  3]
 [ 8  4  6]]


Accuracy: 0.3882: 100%|██████████| 85/85 [00:23<00:00,  3.61it/s]
3it [14:35, 293.71s/it]

[Test Result] F1: 0.4019

Epoch 0, Step 3


Accuracy: 0.4762: 100%|██████████| 84/84 [00:31<00:00,  2.67it/s]


Skip Test
[Causal Validation] F1: 0.4311, Previous F1: 0.4679
[Causal Validation CM]:
[[28  2  3]
 [23  9  1]
 [12  3  3]]


Accuracy: 0.5059: 100%|██████████| 85/85 [00:31<00:00,  2.71it/s]
3it [15:52, 330.64s/it]

[Test Result] F1: 0.4604

Epoch 0, Step 3


Accuracy: 0.4235: 100%|██████████| 85/85 [03:48<00:00,  2.69s/it]
3it [16:55, 367.75s/it]

[Test Result] F1: 0.4343

Epoch 0, Step 3


Accuracy: 0.4286: 100%|██████████| 84/84 [00:23<00:00,  3.61it/s]


[System Validation] F1: 0.4666, Previous F1: 0.4620
[System Validation CM]:
[[16  7  2]
 [ 7 18  0]
 [ 7  4  2]]


Accuracy: 0.3810: 100%|██████████| 84/84 [00:21<00:00,  3.84it/s]


Skip Test
[Causal Validation] F1: 0.4202, Previous F1: 0.4666
[Causal Validation CM]:
[[17  6  1]
 [11 14  0]
 [ 5  3  1]]


Accuracy: 0.4405: 100%|██████████| 84/84 [00:39<00:00,  2.12it/s]
Accuracy: 0.2273:  26%|██▌       | 22/85 [00:12<00:15,  3.98it/s]

Skip Test
[System Validation] F1: 0.4230, Previous F1: 0.4679
[System Validation CM]:
[[22  7  4]
 [22 10  1]
 [ 9  4  5]]


Accuracy: 0.3647: 100%|██████████| 85/85 [00:48<00:00,  1.74it/s]
Accuracy: 0.4737:  23%|██▎       | 19/84 [00:15<00:09,  6.75it/s]

[Test Result] F1: 0.3896

Epoch 0, Step 4


Accuracy: 0.4286: 100%|██████████| 84/84 [00:26<00:00,  3.16it/s]
4it [19:56, 296.59s/it]

Skip Test
[Causal Validation] F1: 0.3890, Previous F1: 0.4679
[Causal Validation CM]:
[[25  5  2]
 [22  9  2]
 [14  2  2]]
Skip Test

Epoch 0, Step 4


Accuracy: 0.4048: 100%|██████████| 84/84 [00:36<00:00,  2.28it/s]


Skip Test
[System Validation] F1: 0.4120, Previous F1: 0.4320
[System Validation CM]:
[[15  8  9]
 [18 12  1]
 [ 6  5  7]]


Accuracy: 0.3810: 100%|██████████| 84/84 [00:18<00:00,  4.50it/s]
4it [20:33, 308.55s/it]

Skip Test
[Causal Validation] F1: 0.4082, Previous F1: 0.4320
[Causal Validation CM]:
[[14  8  8]
 [12 12  2]
 [ 8  3  6]]
Skip Test

Epoch 0, Step 4


Accuracy: 0.5238: 100%|██████████| 84/84 [00:35<00:00,  2.40it/s]
Accuracy: 0.3594:  76%|███████▌  | 64/84 [00:09<00:02,  9.99it/s]

[System Validation] F1: 0.5177, Previous F1: 0.4666
[System Validation CM]:
[[18 12  3]
 [11 21  1]
 [ 9  3  5]]


Accuracy: 0.3810: 100%|██████████| 84/84 [00:17<00:00,  4.85it/s]


Skip Test
[System Validation] F1: 0.3907, Previous F1: 0.4320
[System Validation CM]:
[[12  7 12]
 [18 12  2]
 [ 6  4  8]]


Accuracy: 0.4286: 100%|██████████| 84/84 [00:23<00:00,  3.58it/s]


Skip Test
[System Validation] F1: 0.4284, Previous F1: 0.4679
[System Validation CM]:
[[18  7  7]
 [20 13  0]
 [10  3  5]]


Accuracy: 0.5476: 100%|██████████| 84/84 [00:23<00:00,  3.57it/s]


[Causal Validation] F1: 0.5435, Previous F1: 0.5177
[Causal Validation CM]:
[[21  7  3]
 [13 20  0]
 [ 7  6  5]]


Accuracy: 0.4286: 100%|██████████| 84/84 [00:18<00:00,  4.43it/s]


[Causal Validation] F1: 0.4427, Previous F1: 0.4320
[Causal Validation CM]:
[[ 9 10  9]
 [ 8 19  3]
 [ 6  4  8]]


Accuracy: 0.4643: 100%|██████████| 84/84 [00:21<00:00,  3.84it/s]
Accuracy: 0.4746:  68%|██████▊   | 58/85 [00:16<00:02, 11.90it/s]

Skip Test
[Causal Validation] F1: 0.4355, Previous F1: 0.4679
[Causal Validation CM]:
[[25  6  2]
 [22 10  1]
 [11  3  4]]
Skip Test

Epoch 0, Step 5


Accuracy: 0.5176: 100%|██████████| 85/85 [00:25<00:00,  3.29it/s]
5it [24:15, 289.40s/it]

[Test Result] F1: 0.4879

Epoch 0, Step 5


Accuracy: 0.3765: 100%|██████████| 85/85 [00:25<00:00,  3.34it/s]
5it [24:13, 276.58s/it]

[Test Result] F1: 0.3736

Epoch 0, Step 5


Accuracy: 0.4405: 100%|██████████| 84/84 [00:19<00:00,  4.25it/s]


Skip Test
[System Validation] F1: 0.4397, Previous F1: 0.4427
[System Validation CM]:
[[13 10 10]
 [13 18  2]
 [ 6  6  6]]


Accuracy: 0.4524: 100%|██████████| 84/84 [00:21<00:00,  3.87it/s]


Skip Test
[System Validation] F1: 0.4585, Previous F1: 0.4679
[System Validation CM]:
[[19  8  6]
 [15 14  2]
 [ 9  2  5]]


Accuracy: 0.5476: 100%|██████████| 84/84 [00:14<00:00,  5.86it/s]
Accuracy: 0.0000:   1%|          | 1/84 [00:06<08:27,  6.12s/it]

Skip Test
[System Validation] F1: 0.5365, Previous F1: 0.5435
[System Validation CM]:
[[21  9  3]
 [13 20  0]
 [ 7  6  5]]


Accuracy: 0.4405: 100%|██████████| 84/84 [00:18<00:00,  4.44it/s]


[Causal Validation] F1: 0.4690, Previous F1: 0.4427
[Causal Validation CM]:
[[13  8  9]
 [10 16  2]
 [ 4  4  8]]


Accuracy: 0.4286: 100%|██████████| 84/84 [00:22<00:00,  3.80it/s]
5it [28:21, 340.34s/it]
Accuracy: 0.5181:  96%|█████████▋| 82/85 [00:17<00:00,  3.41it/s]

Skip Test
[Causal Validation] F1: 0.4336, Previous F1: 0.4679
[Causal Validation CM]:
[[15  8 10]
 [16 15  1]
 [ 8  4  6]]
Skip Test
[Worker 1] [Iteration 1] val_best=0.4679, test_at_best=0.4604
[Main] Worker 1 done. Best test_f1=0.4604


Accuracy: 0.5176: 100%|██████████| 85/85 [00:19<00:00,  4.26it/s]
5it [28:16, 339.35s/it]
Accuracy: 0.3478:  26%|██▌       | 22/84 [00:14<00:08,  7.29it/s]

[Test Result] F1: 0.5262
[Worker 2] [Iteration 1] val_best=0.4690, test_at_best=0.5262
[Main] Worker 2 done. Best test_f1=0.5262


Accuracy: 0.5119: 100%|██████████| 84/84 [00:22<00:00,  3.70it/s]
5it [28:29, 341.85s/it]

Skip Test
[Causal Validation] F1: 0.4907, Previous F1: 0.5435
[Causal Validation CM]:
[[18 11  4]
 [11 22  0]
 [ 9  6  3]]
Skip Test
[Worker 0] [Iteration 1] val_best=0.5435, test_at_best=0.4879
[Main] Worker 0 done. Best test_f1=0.4879
EGO_res (best test F1 per worker): [0.46038415366146457, 0.5261764705882354, 0.4878954248366013]



