# Task Settings

In [1]:
import os
os.environ['OPENAI_API_KEY']="YOUR_API_KEY"


from utils import *
from prompts import Prompts, TASK_LABLES, TAGS

dataset_name = 'pandemic'
test_model = "experimental:gpt-4o-mini"
eval_model = "gpt-4o" 
iteration = 1
date = '0919'
total_steps=5   
epoch=1
batch_size=3

# Initialize

In [2]:
cm_labels = TASK_LABLES[dataset_name]
tags = TAGS[dataset_name]
CAUSAL_SYSTEM = Prompts[dataset_name]['CAUSAL_SYSTEM']
CAUSAL_SYSTEM_CONSTRAINT = Prompts[dataset_name]['CAUSAL_SYSTEM_CONSTRAINT']
SYSTEM = Prompts[dataset_name]['SYSTEM']

llm_api_eval = tg.get_engine(engine_name=eval_model)
llm_api_test = tg.get_engine(engine_name=test_model, cache=False)
tg.set_backward_engine(llm_api_eval, override=True)

train_set, val_set, test_set_ori, eval_fn = load_task(dataset_name, evaluation_api=llm_api_eval, prompt_col="organized_prompt")
train_loader = tg.tasks.DataLoader(train_set, batch_size=batch_size, shuffle=True)

# Add Tag
col = "organized_prompt" if dataset_name == 'swiss' else "prompt"
train_set.data[col] = train_set.data[col].apply(lambda x: f"{tags[0]}{x}{tags[1]}")
val_set.data[col] = val_set.data[col].apply(lambda x: f"{tags[0]}{x}{tags[1]}")
test_set_ori.data[col] = test_set_ori.data[col].apply(lambda x: f"{tags[0]}{x}{tags[1]}")
print("Train/Val/Test Set Lengths: ", len(train_set), len(val_set), len(test_set_ori))

Train/Val/Test Set Lengths:  100 100 100


In [3]:
system_prompt, causal_prompt, model, causal_model, optimizer, optimizer_causal = init(SYSTEM, CAUSAL_SYSTEM, llm_api_test, llm_api_eval, CAUSAL_SYSTEM_CONSTRAINT)
results = {"test_f1": [], "prompt": [], "validation_f1": [], 'system_prompt':[], 'causal_prompt': []}

In [4]:
import time
import copy
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed

NUM_WORKERS = 1

def run_one_worker(worker_id: int):

    local_test_set = copy.deepcopy(test_set)
    local_test_set_ori = copy.deepcopy(test_set_ori)
    local_val_set = copy.deepcopy(val_set)
    local_train_loader = train_loader

    val_performance = -float('inf')
    test_performance = -float('inf')
    final_results = None
    all_val_f1s = []
    all_test_f1s = []

    local_test_set.data[col] = local_test_set_ori.data[col].apply(
        lambda x: f"<!-- {time.time()} (w{worker_id}) -->, {x}"
    )

    for cur_iter in range(iteration):
        print(f"[Worker {worker_id}] [Iteration {cur_iter+1}/{iteration}] begin")
        output_json = (
            f"res/{date}_{dataset_name}_{test_model.split('/')[-1].split(':')[-1]}_"
            f"w{worker_id}_it{cur_iter+1}.json"
        )
        initialize_json_file(output_json)

        system_prompt, causal_prompt, model, causal_model, optimizer, optimizer_causal = init(
            SYSTEM, CAUSAL_SYSTEM, llm_api_test, llm_api_eval, CAUSAL_SYSTEM_CONSTRAINT
        )

        results, test_res, val_res = init_eval(
            local_val_set, local_test_set, eval_fn, model, causal_model,
            system_prompt, causal_prompt, cm_labels, iters=ITERS    
        )

        results = run_training(
            local_train_loader, local_val_set, local_test_set, eval_fn,
            model, causal_model, system_prompt, causal_prompt,
            optimizer, optimizer_causal, results, cm_labels,
            output_json=output_json, epoch=epoch, steps=total_steps, iters=ITERS
        )

        all_val_f1s.append(results['validation_f1'])
        all_test_f1s.append(results['test_f1'])

        cur_val = results['validation_f1'][-1]
        cur_test = results['test_f1'][-1]
        if cur_val > val_performance:
            val_performance = cur_val
            test_performance = cur_test
            final_results = results

        print(f"[Worker {worker_id}] [Iteration {cur_iter+1}] "
              f"val_best={val_performance:.4f}, test_at_best={test_performance:.4f}")

    return {
        'best_test_f1': test_performance,
        'val_f1s': all_val_f1s,
        'test_f1s': all_test_f1s,
        'worker_id': worker_id,
    }


# Running the cell below will incur API usage charges. Refer to our paper for detailed cost

In [5]:
NUM_WORKERS = 3
ITERS = 1
total_steps=5
batch_size=3

EGO_res = []

with ThreadPoolExecutor(max_workers=NUM_WORKERS) as ex:
    futures = [ex.submit(run_one_worker, i) for i in range(NUM_WORKERS)]
    for fut in as_completed(futures):
        res = fut.result()
        EGO_res.append(res['best_test_f1'])
        print(f"[Main] Worker {res['worker_id']} done. Best test_f1={res['best_test_f1']:.4f}")

print("EGO_res (best test F1 per worker):", EGO_res)

[Worker 0] [Iteration 1/1] begin[Worker 1] [Iteration 1/1] begin

[Worker 2] [Iteration 1/1] begin


Accuracy: 0.3300: 100%|██████████| 100/100 [00:14<00:00,  6.79it/s]
Accuracy: 0.3700: 100%|██████████| 100/100 [00:15<00:00,  6.34it/s]
Accuracy: 0.3200: 100%|██████████| 100/100 [00:16<00:00,  6.10it/s]
Accuracy: 0.4700: 100%|██████████| 100/100 [00:16<00:00,  6.19it/s]


SCG_val_f1: 0.47891035970830037, SCG_test_f1:0.37069647287038593


0it [00:00, ?it/s]


Epoch 0, Step 0


Accuracy: 0.4800: 100%|██████████| 100/100 [00:16<00:00,  5.94it/s]


SCG_val_f1: 0.4683242147922999, SCG_test_f1:0.2951439283986454


0it [00:00, ?it/s]


Epoch 0, Step 0


Accuracy: 0.4500: 100%|██████████| 100/100 [00:28<00:00,  3.49it/s]


SCG_val_f1: 0.42588531347498354, SCG_test_f1:0.3052992435971159


0it [00:00, ?it/s]


Epoch 0, Step 0


Accuracy: 0.4600: 100%|██████████| 100/100 [00:32<00:00,  3.03it/s]


Skip Test
[System Validation] F1: 0.4431, Previous F1: 0.4683
[System Validation CM]:
[[ 0 11  6  1  0]
 [ 0 12  4  2  0]
 [ 0  2 12  6  1]
 [ 0  2  1  8  2]
 [ 0  1  2  5 14]]


Accuracy: 0.4900: 100%|██████████| 100/100 [00:32<00:00,  3.12it/s]
Accuracy: 0.4231:  26%|██▌       | 26/100 [00:16<00:11,  6.57it/s]

[System Validation] F1: 0.4790, Previous F1: 0.4789
[System Validation CM]:
[[ 3 11  3  1  0]
 [ 1 13  3  3  0]
 [ 0  1 15  4  1]
 [ 0  0  3  8  6]
 [ 0  0  2 11 10]]


Accuracy: 0.5000: 100%|██████████| 100/100 [00:24<00:00,  4.07it/s]


[System Validation] F1: 0.4872, Previous F1: 0.4259
[System Validation CM]:
[[ 2  6  8  1  0]
 [ 0 11  7  2  0]
 [ 0  4 12  5  1]
 [ 0  0  2 11  4]
 [ 0  0  1  6 14]]


Accuracy: 0.4000: 100%|██████████| 100/100 [00:22<00:00,  4.51it/s]
Accuracy: 0.5714:  14%|█▍        | 14/100 [00:18<00:22,  3.77it/s]

Skip Test
[Causal Validation] F1: 0.3999, Previous F1: 0.4683
[Causal Validation CM]:
[[ 4  5  5  3  0]
 [ 0  4 13  2  0]
 [ 0  3  9  7  3]
 [ 0  0  1 10  6]
 [ 0  0  1  7 13]]
Skip Test

Epoch 0, Step 1


Accuracy: 0.4600: 100%|██████████| 100/100 [00:33<00:00,  3.02it/s]


Skip Test
[Causal Validation] F1: 0.4277, Previous F1: 0.4790
[Causal Validation CM]:
[[ 0 10  5  2  0]
 [ 0 10  8  2  0]
 [ 0  4 11  7  0]
 [ 0  1  2 13  1]
 [ 0  0  1 10 12]]


Accuracy: 0.4100: 100%|██████████| 100/100 [00:26<00:00,  3.84it/s]


Skip Test
[Causal Validation] F1: 0.3800, Previous F1: 0.4872
[Causal Validation CM]:
[[ 1  6  9  2  0]
 [ 0  6  9  4  0]
 [ 0  1 13  6  2]
 [ 0  0  1 12  4]
 [ 0  1  0 13  9]]


Accuracy: 0.3300: 100%|██████████| 100/100 [00:23<00:00,  4.18it/s]
1it [02:46, 166.72s/it]

[Test Result] F1: 0.2964

Epoch 0, Step 1


Accuracy: 0.3100: 100%|██████████| 100/100 [00:26<00:00,  3.77it/s]
1it [03:00, 180.21s/it]

[Test Result] F1: 0.3141

Epoch 0, Step 1


Accuracy: 0.4600: 100%|██████████| 100/100 [00:20<00:00,  4.89it/s]


Skip Test
[System Validation] F1: 0.4319, Previous F1: 0.4683
[System Validation CM]:
[[ 1 11  5  1  0]
 [ 0 11  7  2  0]
 [ 0  2 11  7  2]
 [ 0  1  2 12  2]
 [ 0  0  2 10 11]]


Accuracy: 0.3400: 100%|██████████| 100/100 [00:21<00:00,  4.55it/s]
2it [04:20, 129.27s/it]

Skip Test
[Causal Validation] F1: 0.3170, Previous F1: 0.4683
[Causal Validation CM]:
[[ 0 12  5  1  0]
 [ 1  6 10  3  0]
 [ 0  5  5 11  0]
 [ 0  0  1 12  3]
 [ 0  0  2 10 11]]
Skip Test

Epoch 0, Step 2


Accuracy: 0.3000: 100%|██████████| 100/100 [00:22<00:00,  4.39it/s]


Skip Test
[System Validation] F1: 0.3517, Previous F1: 0.4790
[System Validation CM]:
[[ 1  4  7  0  0]
 [ 0  5  3  2  0]
 [ 0  1  4  5  0]
 [ 0  0  2  6  3]
 [ 0  0  1  4 14]]


Accuracy: 0.0000: 100%|██████████| 100/100 [00:26<00:00,  3.83it/s]


Skip Test
[System Validation] F1: 0.0000, Previous F1: 0.4872
[System Validation CM]:
[[0 0 0 0 0]
 [0 0 1 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 1 0 0]]


Accuracy: 0.4700: 100%|██████████| 100/100 [00:24<00:00,  4.03it/s]
2it [05:30, 162.45s/it]

Skip Test
[Causal Validation] F1: 0.4479, Previous F1: 0.4790
[Causal Validation CM]:
[[ 1 11  4  2  0]
 [ 1 10  7  2  0]
 [ 0  4  9  5  3]
 [ 0  1  2 12  1]
 [ 0  0  1  7 15]]
Skip Test

Epoch 0, Step 2


Accuracy: 0.4300: 100%|██████████| 100/100 [00:30<00:00,  3.32it/s]
Accuracy: 0.4516:  62%|██████▏   | 62/100 [00:24<00:06,  5.79it/s]

Skip Test
[Causal Validation] F1: 0.4187, Previous F1: 0.4872
[Causal Validation CM]:
[[ 2  7  8  1  0]
 [ 1  7 10  2  0]
 [ 0  3 11  8  0]
 [ 0  0  3 12  2]
 [ 0  0  3  9 11]]
Skip Test

Epoch 0, Step 2


Accuracy: 0.4400: 100%|██████████| 100/100 [00:50<00:00,  2.00it/s]


Skip Test
[System Validation] F1: 0.4131, Previous F1: 0.4683
[System Validation CM]:
[[ 0  7  6  4  0]
 [ 0  8  9  3  0]
 [ 0  3  8  8  1]
 [ 0  0  3 12  2]
 [ 0  0  2  5 16]]


Accuracy: 0.4100: 100%|██████████| 100/100 [00:27<00:00,  3.68it/s]
Accuracy: 0.4271:  96%|█████████▌| 96/100 [00:29<00:07,  1.89s/it]

Skip Test
[System Validation] F1: 0.3814, Previous F1: 0.4790
[System Validation CM]:
[[ 0  9  9  0  0]
 [ 1  5 13  1  0]
 [ 0  4 11  7  0]
 [ 0  1  2 12  2]
 [ 0  0  2  8 13]]


Accuracy: 0.4300: 100%|██████████| 100/100 [00:37<00:00,  2.63it/s]
3it [07:13, 149.12s/it]

Skip Test
[Causal Validation] F1: 0.4049, Previous F1: 0.4683
[Causal Validation CM]:
[[ 2  5  8  2  0]
 [ 0  3 15  2  0]
 [ 0  2 15  4  1]
 [ 0  0  2 12  2]
 [ 0  1  1  9 11]]
Skip Test

Epoch 0, Step 3


Accuracy: 0.5000: 100%|██████████| 100/100 [00:24<00:00,  4.07it/s]


[System Validation] F1: 0.4926, Previous F1: 0.4872
[System Validation CM]:
[[ 4  9  4  0  1]
 [ 1 11  8  0  0]
 [ 0  3 13  4  2]
 [ 0  0  2 11  3]
 [ 0  0  2 10 11]]


Accuracy: 0.4900: 100%|██████████| 100/100 [00:21<00:00,  4.58it/s]
3it [07:49, 151.88s/it]

Skip Test
[Causal Validation] F1: 0.4720, Previous F1: 0.4790
[Causal Validation CM]:
[[ 1  9  7  1  0]
 [ 0 10  8  2  0]
 [ 0  2 12  8  0]
 [ 0  1  2 12  1]
 [ 0  0  2  6 14]]
Skip Test

Epoch 0, Step 3


Accuracy: 0.4700: 100%|██████████| 100/100 [00:25<00:00,  3.88it/s]


Skip Test
[Causal Validation] F1: 0.4676, Previous F1: 0.4926
[Causal Validation CM]:
[[ 6  7  5  0  0]
 [ 7  6  5  2  0]
 [ 1  3  9  6  3]
 [ 0  1  3 10  3]
 [ 0  2  0  5 16]]


Accuracy: 0.3800: 100%|██████████| 100/100 [00:30<00:00,  3.33it/s]
3it [08:43, 177.68s/it]

[Test Result] F1: 0.3585

Epoch 0, Step 3


Accuracy: 0.4700: 100%|██████████| 100/100 [00:26<00:00,  3.77it/s]
Accuracy: 0.5104:  96%|█████████▌| 96/100 [00:20<00:00,  4.33it/s]

[System Validation] F1: 0.4712, Previous F1: 0.4683
[System Validation CM]:
[[ 3  6  7  1  0]
 [ 0  9  7  3  0]
 [ 0  2 11  9  0]
 [ 0  0  3 10  4]
 [ 0  0  2  7 14]]


Accuracy: 0.5000: 100%|██████████| 100/100 [00:22<00:00,  4.40it/s]


Skip Test
[System Validation] F1: 0.4628, Previous F1: 0.4790
[System Validation CM]:
[[ 0 10  6  2  0]
 [ 2 12  2  4  0]
 [ 0  2 13  5  2]
 [ 0  0  3 12  2]
 [ 0  0  2  8 13]]


Accuracy: 0.4900: 100%|██████████| 100/100 [00:22<00:00,  4.50it/s]
4it [09:57, 142.28s/it]

Skip Test
[Causal Validation] F1: 0.4725, Previous F1: 0.4790
[Causal Validation CM]:
[[ 2 11  4  1  0]
 [ 1 12  4  3  0]
 [ 1  3 10  6  1]
 [ 0  0  2 13  2]
 [ 0  0  2  9 12]]
Skip Test

Epoch 0, Step 4


Accuracy: 0.4700: 100%|██████████| 100/100 [00:26<00:00,  3.77it/s]


[Causal Validation] F1: 0.4723, Previous F1: 0.4712
[Causal Validation CM]:
[[ 4  5  6  3  0]
 [ 1  7  8  4  0]
 [ 1  1 11  9  0]
 [ 0  0  3 12  2]
 [ 0  0  2  8 13]]


Accuracy: 0.4300: 100%|██████████| 100/100 [00:19<00:00,  5.05it/s]
4it [10:17, 162.93s/it]

[Test Result] F1: 0.3948

Epoch 0, Step 4


Accuracy: 0.4000: 100%|██████████| 100/100 [00:25<00:00,  3.89it/s]


Skip Test
[System Validation] F1: 0.3922, Previous F1: 0.4926
[System Validation CM]:
[[ 2  7  8  1  0]
 [ 3  7  7  3  0]
 [ 0  4 10  6  2]
 [ 0  0  3  9  4]
 [ 0  0  3  8 12]]


Accuracy: 0.4600: 100%|██████████| 100/100 [00:25<00:00,  3.97it/s]
4it [10:57, 160.64s/it]

Skip Test
[Causal Validation] F1: 0.4588, Previous F1: 0.4926
[Causal Validation CM]:
[[ 9  7  1  1  0]
 [ 5  8  4  3  0]
 [ 3  5  7  5  2]
 [ 0  0  5  6  6]
 [ 0  1  3  3 16]]
Skip Test

Epoch 0, Step 4


Accuracy: 0.4500: 100%|██████████| 100/100 [00:23<00:00,  4.30it/s]
Accuracy: 0.0000:  98%|█████████▊| 98/100 [00:28<00:00,  3.77it/s]

Skip Test
[System Validation] F1: 0.4319, Previous F1: 0.4723
[System Validation CM]:
[[ 2  7  8  1  0]
 [ 0 10  6  4  0]
 [ 0  4  6 11  1]
 [ 0  0  1 14  2]
 [ 0  0  1  9 13]]


Accuracy: 0.0000: 100%|██████████| 100/100 [00:29<00:00,  3.41it/s]


Skip Test
[System Validation] F1: 0.0000, Previous F1: 0.4790
[System Validation CM]:
[[0 0 1 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]]


Accuracy: 0.4700: 100%|██████████| 100/100 [00:20<00:00,  4.79it/s]


Skip Test
[System Validation] F1: 0.4525, Previous F1: 0.4926
[System Validation CM]:
[[ 1 12  4  1  0]
 [ 1 11  5  2  1]
 [ 0  3 11  6  1]
 [ 0  0  5 10  1]
 [ 0  1  1  7 14]]


Accuracy: 0.4500: 100%|██████████| 100/100 [00:22<00:00,  4.42it/s]
5it [12:44, 157.03s/it]

Skip Test
[Causal Validation] F1: 0.4281, Previous F1: 0.4723
[Causal Validation CM]:
[[ 3  7  6  2  0]
 [ 3  7  6  3  1]
 [ 1  2  8  8  3]
 [ 0  2  1  7  7]
 [ 0  0  2  1 20]]
Skip Test

Epoch 0, Step 5


Accuracy: 0.4900: 100%|██████████| 100/100 [00:25<00:00,  3.98it/s]
5it [12:55, 155.25s/it]

Skip Test
[Causal Validation] F1: 0.4662, Previous F1: 0.4790
[Causal Validation CM]:
[[ 2 10  4  2  0]
 [ 0 10  6  4  0]
 [ 0  2  8  9  3]
 [ 0  0  2 12  3]
 [ 0  0  2  4 17]]
Skip Test

Epoch 0, Step 5


Accuracy: 0.2800: 100%|██████████| 100/100 [00:24<00:00,  4.13it/s]
5it [13:31, 158.17s/it]

Skip Test
[Causal Validation] F1: 0.2168, Previous F1: 0.4926
[Causal Validation CM]:
[[ 1 17  0  0  0]
 [ 0 19  0  0  0]
 [ 0 21  0  1  0]
 [ 0 15  0  2  0]
 [ 0 12  0  3  6]]
Skip Test

Epoch 0, Step 5


Accuracy: 0.4500: 100%|██████████| 100/100 [00:25<00:00,  3.97it/s]


Skip Test
[System Validation] F1: 0.4471, Previous F1: 0.4723
[System Validation CM]:
[[ 3 12  3  0  0]
 [ 0 10  8  2  0]
 [ 0  3  9  9  1]
 [ 0  0  1 12  2]
 [ 0  0  1 11 11]]


Accuracy: 0.3700: 100%|██████████| 100/100 [00:30<00:00,  3.28it/s]


Skip Test
[System Validation] F1: 0.3519, Previous F1: 0.4790
[System Validation CM]:
[[ 1  9  6  2  0]
 [ 0  7 10  3  0]
 [ 0  4  6 10  1]
 [ 0  0  1 11  5]
 [ 0  0  3  8 12]]


Accuracy: 0.5200: 100%|██████████| 100/100 [00:31<00:00,  3.18it/s]
Accuracy: 0.1964:  56%|█████▌    | 56/100 [00:18<00:04,  9.16it/s]

[System Validation] F1: 0.5121, Previous F1: 0.4926
[System Validation CM]:
[[ 3 11  3  1  0]
 [ 1  9  9  1  0]
 [ 0  3 12  7  0]
 [ 0  1  1 14  1]
 [ 0  1  2  5 14]]


Accuracy: 0.2000: 100%|██████████| 100/100 [00:24<00:00,  4.00it/s]
5it [15:23, 184.64s/it]


Skip Test
[Causal Validation] F1: 0.1035, Previous F1: 0.4723
[Causal Validation CM]:
[[ 0  0  3 15  0]
 [ 0  1  1 18  0]
 [ 0  0  2 19  0]
 [ 0  0  0 17  0]
 [ 0  0  0 23  0]]
Skip Test
[Worker 0] [Iteration 1] val_best=0.4723, test_at_best=0.3948
[Main] Worker 0 done. Best test_f1=0.3948


Accuracy: 0.3800: 100%|██████████| 100/100 [00:28<00:00,  3.51it/s]
5it [15:48, 189.70s/it]


Skip Test
[Causal Validation] F1: 0.3797, Previous F1: 0.4790
[Causal Validation CM]:
[[ 2  5  8  3  0]
 [ 1  6  6  6  0]
 [ 0  0 10 11  0]
 [ 0  0  1 11  4]
 [ 0  0  1 13  9]]
Skip Test
[Worker 2] [Iteration 1] val_best=0.4790, test_at_best=0.3141
[Main] Worker 2 done. Best test_f1=0.3141


Accuracy: 0.4500: 100%|██████████| 100/100 [00:29<00:00,  3.39it/s]


Skip Test
[Causal Validation] F1: 0.4472, Previous F1: 0.5121
[Causal Validation CM]:
[[ 3 13  0  1  0]
 [ 2 10  4  4  0]
 [ 0  5  8  9  0]
 [ 0  0  2 11  4]
 [ 0  0  2  8 13]]


Accuracy: 0.4100: 100%|██████████| 100/100 [00:32<00:00,  3.08it/s]
5it [16:33, 198.76s/it]

[Test Result] F1: 0.3880
[Worker 1] [Iteration 1] val_best=0.5121, test_at_best=0.3880
[Main] Worker 1 done. Best test_f1=0.3880
EGO_res (best test F1 per worker): [0.394791520401759, 0.3141095077981786, 0.3879919437062294]



