In [1]:
import pandas as pd
import torch
from tqdm import tqdm
import torch.nn as nn
from langchain_together import TogetherEmbeddings
from llm_apis import get_llm_api
from sklearn.metrics import f1_score

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
embedder = TogetherEmbeddings(model="togethercomputer/m2-bert-80M-8k-retrieval")
llm = get_llm_api("gpt-4o-mini-huit")

In [3]:
df = pd.read_csv("notebooks/export-data-for-faithfulness-experiment.csv")
df

Unnamed: 0,state,agent,episode,recommended_action,explanation,rules,state_numeric,rules_numeric,sel_rule_idx
0,- Location [FIPS code]: 06025\n- Remaining war...,base_agent,0,1,"In the current state, I observed a high heat i...",[],"[0.7964519262313843, 1.0, 1.0, 1.0763599872589...",[],-1
1,- Location [FIPS code]: 06025\n- Remaining war...,tbrl,0,1,"In the current state, I observed a heat index ...","[""1. High heat index indicates immediate risk....","[0.7964519262313843, 1.0, 1.0, 1.0763599872589...","[[0.07000600546598434, 0.038187749683856964, 0...",4
2,- Location [FIPS code]: 06025\n- Remaining war...,rbrl,0,1,"In the current state, I observed a high heat i...","[""- {\""background\"": \""High heat index today. ...","[0.7964519262313843, 1.0, 1.0, 1.0763599872589...","[[0.12648765742778778, -0.01651541143655777, 0...",1
3,- Location [FIPS code]: 48469\n- Remaining war...,base_agent,0,1,"In the current state, I observed a high heat i...",[],"[0.6988795399665833, 1.0, 0.0, 1.0009399652481...",[],-1
4,- Location [FIPS code]: 48469\n- Remaining war...,tbrl,0,1,"In the current state, I observed a heat index ...","[""1. High heat index indicates immediate risk....","[0.6988795399665833, 1.0, 0.0, 1.0009399652481...","[[-0.0538608580827713, 0.13751941919326782, -0...",0
...,...,...,...,...,...,...,...,...,...
148,- Location [FIPS code]: 48355\n- Remaining war...,tbrl,3,1,"In the current state, I observed a heat index ...","[""1) High heat index today (99 F). \n2) Warni...","[0.7878151535987854, 1.0, 1.0, 1.0020200014114...","[[0.06553667038679123, 0.050166886299848557, 0...",1
149,- Location [FIPS code]: 48355\n- Remaining war...,rbrl,3,1,"In the current state, I observed a heat index ...","[""- {\""background\"": \""High heat index is dang...","[0.7878151535987854, 1.0, 1.0, 1.0020200014114...","[[0.08157139271497726, 0.032311949878931046, 0...",2
150,- Location [FIPS code]: 48355\n- Remaining war...,base_agent,3,1,"In the current state, I observed a high heat i...",[],"[0.7546685338020325, 1.0, 1.0, 0.9964399933815...",[],-1
151,- Location [FIPS code]: 48355\n- Remaining war...,tbrl,3,1,"In the current state, I observed a heat index ...","[""1. High heat index at 102 F. \n2. Warning f...","[0.7546685338020325, 1.0, 1.0, 0.9964399933815...","[[0.09287281334400177, -0.0013680418487638235,...",2


In [None]:
df.explanation.iloc[1]



In [78]:
# preprocess all explanations. Ask chat-gpt to remove the action
prompt = """
    You will be given an explanation of the action taken by an RL agent. Your role is to
    remove the action from the explanation. 
    
    Examples:
    - In the current state, I observed a heat index of 107. I concluded that issuing a warning is the optimal action. -> In the current state, I observed a heat index of 107. I concluded that XXX. (mask action with XXX)
    - In the current state, I observed a heat index of 107 F and a forecast of high temperatures -> In the current state, I observed a heat index of 107 F and a forecast of high temperatures (no change because no action is mentioned).

    You don't need to mask a hypothetical language. For example,
    - I used a rule stating that *if* heat is > 100F, then I should issue a warning. -> I used a rule stating that if heat is > 100F, then I should issue a warning. (no change because no action is mentioned).

    However, do mask any conclusive language. For example,
    - I observed a heat index of 107, and reasoned that I should issue a warning. -> I observed a heat index of 107, and reasoned that I should XXX. (mask action with XXX)

    Sentence to mask: {}
"""

In [None]:
explanations = df.explanation.values
actions = df.recommended_action.values
agents = df.agent.values

# messages = [[{"role": "user", "content": prompt.format(explanation)}] for explanation in explanations]

processed_explanations = []
for expl, a in tqdm(zip(explanations, actions), total=len(messages)):
    # if "I concluded" in expl:
    #     new_expl = expl.split("I concluded")[0]
    #     processed_explanations.append(new_expl)
    message = [{"role": "user", "content": prompt.format(expl)}]
    response = llm.invoke(message).content
    processed_explanations.append(response)

print(len(explanations))
print(len(processed_explanations))

embeddings = embedder.embed_documents(processed_explanations)

100%|██████████| 153/153 [08:17<00:00,  3.25s/it]


153
153


In [85]:
def printmd(text):
    from IPython.display import Markdown, display
    display(Markdown(text))

for i in range(50):
    printmd(f"{i}. [{df.agent.iloc[i]}] [episode {df.episode.iloc[i]}; step {i // 3}] {processed_explanations[i]}")

0. [base_agent] [episode 0; step 0] In the current state, I observed a high heat index of 107 F and a forecast of continued high temperatures. Then, I reasoned that immediate warnings can save lives, and the budget allows for more warnings without causing fatigue. I concluded that XXX.

1. [tbrl] [episode 0; step 0] In the current state, I observed a heat index of 107 F and a forecast of high temperatures for the next several days. I reasoned that XXX would help protect public health, as the heat index is significantly above average. I concluded that XXX is the optimal action.

2. [rbrl] [episode 0; step 0] In the current state, I observed a high heat index of 107 F, no recent warnings, and a remaining budget of 10 warnings. Then, I reasoned that XXX could effectively protect public health, especially with future heatwaves forecasted. I applied the rule stating that if there have been no warnings in the last 3 days and the remaining warnings are positive, then a warning should be issued. Applying this rule to the current state, I concluded that XXX is the optimal action.

3. [base_agent] [episode 0; step 1] In the current state, I observed a high heat index of 101 F and that it is a weekend. I reasoned that XXX could prevent health issues during this critical time, especially with upcoming heat forecasts indicating more high temperatures. I concluded that XXX is the optimal action.

4. [tbrl] [episode 0; step 1] In the current state, I observed a heat index of 101 F and a weekend day. There is a remaining warning budget of 9 and a forecast of high temperatures for the next week. I concluded that XXX to protect public health.

5. [rbrl] [episode 0; step 1] In the current state, I observed a high heat index of 101 F and a weekend day, which increases exposure risk. Then, I reasoned that XXX could help protect public health, especially since the budget allows for it and there is a forecast of future heatwaves. I applied the rule stating that if warnings in the last 3 days equal 1 and the remaining budget is greater than 0, I should XXX. Applying this rule to the current state, I concluded that XXX is the

6. [base_agent] [episode 0; step 2] In the current state, I observed a high heat index of 97 F and a warning streak of 2. Then, I reasoned that warning fatigue may reduce effectiveness and future heatwaves are predicted. I concluded that XXX.

7. [tbrl] [episode 0; step 2] In the current state, I observed a heat index of 97 F and a warning streak of 2. The forecast shows high temperatures for the next week, but warning fatigue is a concern. I concluded that XXX.

8. [rbrl] [episode 0; step 2] In the current state, I observed a high heat index of 97 F and a warning streak of 2 in the last 3 days. Then, I reasoned that issuing another warning could lead to warning fatigue, reducing its effectiveness. I applied the rule stating that if warnings in the last 3 days are 2 and remaining warnings are greater than 1, I should XXX. Applying this rule to the current state, I concluded that XXX is the optimal action.

9. [base_agent] [episode 0; step 3] In the current state, I observed a high heat index of 91 F and increasing future heat indices, peaking at 101 F. Then, I reasoned that immediate action could prevent health risks, and the budget allows for more warnings. I concluded that XXX is the optimal action.

10. [tbrl] [episode 0; step 3] In the current state, I observed a heat index of 91 F and a forecast of increasing temperatures, peaking at 101 F in the coming days. I reasoned that XXX would help mitigate health risks during the upcoming heatwave, especially since the warning budget allows for it. I concluded that XXX is the optimal action.

11. [rbrl] [episode 0; step 3] In the current state, I observed a current heat index of 91 F, a warning budget of 8, and no recent warning streak. Then, I reasoned that immediate warnings can save lives, especially with worsening heat forecasts ahead. I applied the rule stating that if there is no warning streak and remaining warnings are available, I should XXX. Applying this rule to the current state, I concluded that XXX is the optimal action.

12. [base_agent] [episode 0; step 4] In the current state, I observed a high heat index of 98 F and a warning streak of 1. Then, I reasoned that immediate warnings can reduce health risks, but warning fatigue is a concern. I concluded that XXX.

13. [tbrl] [episode 0; step 4] In the current state, I observed a heat index of 98 F and a warning streak of 1. The forecast predicts high temperatures for the next several days, with values reaching up to 103 F. I concluded that XXX.

14. [rbrl] [episode 0; step 4] In the current state, I observed a high heat index of 98 F and a warning streak of 1. Then, I reasoned that immediate warnings are necessary due to the high temperatures forecasted for the next few days. I applied the rule stating that if warnings in the last 3 days are less than 2 and remaining warnings are available, then XXX. Applying this rule to the current state, I concluded that XXX.

15. [base_agent] [episode 0; step 5] In the current state, I observed a high heat index of 99 F and a warning streak of 2. Then, I reasoned that immediate warnings can prevent health issues, but warning fatigue may reduce effectiveness. I concluded that XXX.

16. [tbrl] [episode 0; step 5] In the current state, I observed a heat index of 99 F and a warning streak of 2. The forecast predicts high temperatures for the next several days, with values reaching 100 F. I concluded that XXX.

17. [rbrl] [episode 0; step 5] In the current state, I observed that there have been 2 warnings in the last 3 days and the remaining warning budget is 6. Then, I reasoned that issuing too many warnings in a row could lead to warning fatigue, reducing their effectiveness in the future. I applied the rule stating that if warnings in the last 3 days are 2 or more and remaining warnings are positive, then do not issue a warning. Applying this rule to the current state, I concluded that XXX.

18. [base_agent] [episode 0; step 6] In the current state, I observed a high heat index of 98 F and a warning streak of three warnings in the last three days. Then, I reasoned that warning fatigue could reduce the effectiveness of future warnings, especially with heatwaves forecasted for the next week. I concluded that XXX.

19. [tbrl] [episode 0; step 6] In the current state, I observed a heat index of 98 F and a warning streak of 3. I reasoned that XXX could help mitigate health risks, especially with high temperatures forecasted for the next few days. I concluded that XXX is the optimal action.

20. [rbrl] [episode 0; step 6] In the current state, I observed a high heat index of 98 F and a warning streak of 3 in the last 3 days. Then, I reasoned that immediate warnings are crucial due to the high heat and upcoming heatwaves, but warning fatigue could reduce effectiveness. I applied the rule stating that if the current heat index is above 95 F and warnings in the last 3 days are less than 3, I should XXX. Applying this rule to the current state, I concluded

21. [base_agent] [episode 0; step 7] In the current state, I observed a high heat index of 98 F and a forecast of even hotter temperatures in the coming days. Then, I reasoned that XXX could help prevent health issues, despite the risk of warning fatigue from recent warnings. I concluded that XXX is the optimal action.

22. [tbrl] [episode 0; step 7] In the current state, I observed a heat index of 98 F and a warning streak of 4. The forecast predicts high temperatures for the next several days, with a maximum of 105 F. I concluded that XXX is the optimal action to protect public health.

23. [rbrl] [episode 0; step 7] In the current state, I observed a high heat index of 98 F and a forecast of 99 F for the next day. Then, I reasoned that XXX could help protect public health effectively, especially with future heatwaves expected. I applied the rule stating that a warning should be issued if the next day's forecast exceeds 98 F and there are remaining warnings available. Applying this rule to the current state, I concluded that XXX is the optimal action.

24. [base_agent] [episode 0; step 8] In the current state, I observed a high heat index of 99 F and a warning streak of 5. Then, I reasoned that immediate action may prevent health issues, despite the risk of warning fatigue. I concluded that XXX.

25. [tbrl] [episode 0; step 8] In the current state, I observed a heat index of 99 F and a warning streak of 5. The remaining budget allows for 3 warnings, and the forecast shows high temperatures for the next several days. I concluded that XXX.

26. [rbrl] [episode 0; step 8] In the current state, I observed a high heat index of 99 F on a weekend, with a remaining warning budget of 3. Then, I reasoned that XXX today could help protect public health, especially since outdoor activities may increase during weekends. I applied the rule stating that if today is a weekend and the current heat index is above 98 F, I should XXX. Applying this rule to the current state, I concluded that XXX is the optimal action.

27. [base_agent] [episode 0; step 9] In the current state, I observed a high heat index of 100 F and a warning streak of 6. Then, I reasoned that XXX could help protect public health, despite the risk of warning fatigue. I concluded that XXX is the optimal action.

28. [tbrl] [episode 0; step 9] In the current state, I observed a heat index of 100 F and a warning streak of 6. The remaining budget allows for 2 warnings, and the forecast shows high temperatures for the next week. I concluded that XXX.

29. [rbrl] [episode 0; step 9] In the current state, I observed a high heat index of 100 F and a warning streak of 6 in the last 14 days. Then, I reasoned that too many warnings can lead to warning fatigue, reducing their effectiveness. I applied the rule stating that I should not issue a warning if warnings in the last 7 days are greater than or equal to 5 and the remaining budget is positive. Applying this rule to the current state, I concluded that XXX.

30. [base_agent] [episode 0; step 10] In the current state, I observed a high heat index of 99 F and a warning streak of 7. Then, I reasoned that warning fatigue could reduce the effectiveness of future warnings, and I have only one warning left in the budget. I concluded that XXX.

31. [tbrl] [episode 0; step 10] In the current state, I observed a heat index of 99 F and a warning streak of 7. The remaining budget allows for one warning, and the forecast shows high temperatures for the next several days. I concluded that XXX.

32. [rbrl] [episode 0; step 10] In the current state, I observed a high heat index of 99 F and a warning streak of 7 in the last 14 days. Then, I reasoned that immediate health protection is crucial despite the risk of warning fatigue. I applied the rule stating that if warnings in the last 3 days are 3 or more and the current heat index exceeds 95 F, a warning should be issued. Applying this rule to the current state, I concluded that XXX.

33. [base_agent] [episode 0; step 11] In the current state, I observed a remaining warning budget of zero and a high heat index of 99 F. Then, I reasoned that without a budget, issuing a warning is not possible, and the recent warning streak may lead to fatigue. I concluded that XXX.

34. [tbrl] [episode 0; step 11] In the current state, I observed a remaining warning budget of 0 and a warning streak of 8. Then, I reasoned that XXX is not possible due to the budget constraint, and the high number of recent warnings may lead to warning fatigue. I concluded that 0 is the optimal action.

35. [rbrl] [episode 0; step 11] In the current state, I observed a high heat index of 99 F, a warning budget of 0, and a warning streak of 8 in the last 14 days. Then, I reasoned that issuing too many warnings can lead to warning fatigue, reducing their effectiveness. I applied the rule stating that if the heat index is high and warnings in the last 7 days are excessive, then do not issue a warning. Applying this rule to the current state, I concluded that XXX.

36. [base_agent] [episode 0; step 12] In the current state, I observed that there is no remaining warning budget available. Then, I reasoned that XXX is not possible, and warning fatigue could occur if warnings are issued consecutively. I concluded that XXX is the optimal action.

37. [tbrl] [episode 0; step 12] In the current state, I observed a remaining warning budget of 0 and a current heat index of 96 F. I reasoned that without budget, issuing a warning is not possible. I concluded that XXX.

38. [rbrl] [episode 0; step 12] In the current state, I observed that the remaining warning budget is zero. Then, I reasoned that XXX is not possible today due to the budget constraint. I applied the rule stating that if warnings cannot be issued, then the optimal action is to do nothing. Applying this rule to the current state, I concluded that 0 is the optimal action.

39. [base_agent] [episode 0; step 13] In the current state, I observed a remaining warning budget of zero and a high heat index of 101 F. Then, I reasoned that no warnings can be issued now, and future heatwave forecasts are concerning. I concluded that XXX.

40. [tbrl] [episode 0; step 13] In the current state, I observed a remaining warning budget of 0 and a current heat index of 101 F. Then, I reasoned that without a budget, issuing a warning is not possible, and the current heat index is already high. I concluded that XXX is the optimal action.

41. [rbrl] [episode 0; step 13] In the current state, I observed that the remaining warning budget is 0 and there have been 5 warnings in the last 7 days. Then, I reasoned that issuing more warnings could lead to warning fatigue, reducing their effectiveness. I applied the rule stating that if warnings in the last 7 days are 5 or more, then delay issuing a warning. Applying this rule to the current state, I concluded that not issuing a warning is the optimal XXX.

42. [base_agent] [episode 0; step 14] In the current state, I observed a high heat index of 100 F and no remaining budget for warnings. Then, I reasoned that XXX without budget, and warning fatigue could worsen future responses. I concluded that XXX is the optimal action.

43. [tbrl] [episode 0; step 14] In the current state, I observed a remaining warning budget of 0 and a current heat index of 100 F. Then, I reasoned that without budget, XXX is not possible. I concluded that 0 is the optimal action.

44. [rbrl] [episode 0; step 14] In the current state, I observed that there is no remaining warning budget and that warnings in the last 7 days exceed 3. Then, I reasoned that issuing another warning could lead to warning fatigue and reduce effectiveness. I applied the rule stating that if warnings in the last 7 days exceed 3, then do not issue a warning. Applying this rule to the current state, I concluded that XXX is the optimal action.

45. [base_agent] [episode 0; step 15] In the current state, I observed that there is no remaining budget for issuing warnings. Then, I reasoned that without budget, I cannot issue a warning despite the high heat index and future forecasts. I concluded that XXX.

46. [tbrl] [episode 0; step 15] In the current state, I observed a remaining warning budget of 0 and a heat index of 98 F. I reasoned that without budget, issuing a warning is not possible. I concluded that XXX.

47. [rbrl] [episode 0; step 15] In the current state, I observed a remaining warning budget of 0 and 9 warnings issued in the last 14 days. Then, I reasoned that issuing more warnings could lead to public fatigue and is not possible due to the budget constraint. I applied the rule stating that if warnings in the last 14 days are greater than 8, do not issue a warning. Applying this rule to the current state, I concluded that XXX.

48. [base_agent] [episode 1; step 16] In the current state, I observed a high heat index of 94 F and a forecast of rising temperatures. Then, I reasoned that XXX can help reduce health impacts and that the budget allows for proactive measures. I concluded that XXX is the optimal action.

49. [tbrl] [episode 1; step 16] In the current state, I observed a heat index of 94 F and a weekend day. There have been no warnings issued recently, and the remaining budget allows for 10 warnings. I concluded that XXX.

In [None]:
explanations = df.explanation.values
actions = df.recommended_action.values
agent = df.agent.values

# messages = [[{"role": "user", "content": prompt.format(explanation)}] for explanation in explanations]

processed_explanations = []
for expl, a in tqdm(zip(explanations, actions), total=len(messages)):
    # if "I concluded" in expl:
    #     new_expl = expl.split("I concluded")[0]
    #     processed_explanations.append(new_expl)
    message = [{"role": "user", "content": prompt.format(expl)}]
    response = llm.invoke(message).content
    processed_explanations.append(response)

print(len(explanations))
print(len(processed_explanations))

embeddings = embedder.embed_documents(processed_explanations)

In [None]:
import numpy as np 

# preprocess all explanations. Ask chat-gpt to remove the action
prompt2 = """
    Below is an explanation of the action taken by an RL agent. Your role is to guess which action was taken. 
    There are only two actions available, 0=no heat warning alert, 1=heat warning alert.
    Your response should be a single number, either 0 or 1.

    Masked explanation: {}
"""

predictions = []
for i in tqdm(range(len(processed_explanations))):
    # if "I concluded" in expl:
    #     new_expl = expl.split("I concluded")[0]
    #     processed_explanations.append(new_expl)
    message = [{"role": "user", "content": prompt2.format(processed_explanations[i])}]
    response = llm.invoke(message).content
    predictions.append(int(response))

predictions = np.array(predictions)


Agent: base_agent, Accuracy: 0.8235, F1 Score: 0.8043
Agent: tbrl, Accuracy: 0.9216, F1 Score: 0.9089
Agent: rbrl, Accuracy: 0.9804, F1 Score: 0.9799


In [None]:
from scipy.stats import norm
from sklearn.utils import resample

for agent in df.agent.unique():
    y_true = df[df.agent == agent].recommended_action.values
    y_pred = predictions[df.agent == agent]
    
    # Compute accuracy and its confidence interval using bootstrapping
    accuracy_scores = []
    for _ in range(1000):  # Number of bootstrap samples
        y_true_sample, y_pred_sample = resample(y_true, y_pred)
        accuracy_scores.append(np.mean(y_true_sample == y_pred_sample))
    accuracy = np.mean(accuracy_scores)
    ci_lower_acc = np.percentile(accuracy_scores, 2.5)
    ci_upper_acc = np.percentile(accuracy_scores, 97.5)

    # Compute F1 score and its confidence interval using bootstrapping
    f1_scores = []
    for _ in range(1000):  # Number of bootstrap samples
        y_true_sample, y_pred_sample = resample(y_true, y_pred)
        f1_scores.append(f1_score(y_true_sample, y_pred_sample, average="macro"))
    f1 = np.mean(f1_scores)
    ci_lower_f1 = np.percentile(f1_scores, 2.5)
    ci_upper_f1 = np.percentile(f1_scores, 97.5)
    
    print(f"Agent: {agent}, Accuracy: {accuracy:.4f} (95% CI: [{ci_lower_acc:.4f}, {ci_upper_acc:.4f}]), "
          f"F1 Score: {f1:.4f} (95% CI: [{ci_lower_f1:.4f}, {ci_upper_f1:.4f}])")

In [None]:
from torch.utils.data import Dataset, DataLoader, TensorDataset

dev = "cuda" if torch.cuda.is_available() else "cpu"
embeddings = torch.tensor(embeddings).to(dev)
actions = torch.tensor(df.recommended_action.values).to(dev)

In [57]:
actions

tensor([1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')

In [61]:
# Train a model per agent

for agent in df.agent.unique():
    print(f"====== Training for agent {agent} ======")

    emb_agent = embeddings[agent == df.agent.values]
    actions_agent = actions[agent == df.agent.values]

    train_dataset = TensorDataset(emb_agent, actions_agent)
    train_size = int(0.8 * len(train_dataset))
    test_size = len(train_dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(train_dataset, [train_size, test_size])

    train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    # Initialize the model
    action_classifier = nn.Sequential(
        nn.Linear(768, 16),
        nn.SiLU(),
        nn.Dropout(0.5),
        nn.Linear(16, 16),
        nn.SiLU(),
        nn.Dropout(0.5),
        nn.Linear(16, 1),
    ).to(dev)

    # Define the loss function and optimizer
    criterion = nn.BCEWithLogitsLoss()

    optimizer = torch.optim.AdamW(
        action_classifier.parameters(),
        lr=1e-4,
        weight_decay=1e-4,
    )
    # Training loop
    num_epochs = 500
    for epoch in tqdm(range(num_epochs)):
        action_classifier.train()
        running_loss = 0.0
        for inputs, labels in train_dataloader:
            inputs, labels = inputs.to(dev), labels.to(dev)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = action_classifier(inputs).squeeze(-1)
            loss = criterion(outputs, labels.float())
            running_loss += loss.item()

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

        # print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_dataloader):.4f}", end="\r")

        # Validation loop with accuracy and F1 score
        action_classifier.eval()
        val_loss = 0.0

        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in test_dataloader:
                inputs, labels = inputs.to(dev), labels.to(dev)

                outputs = action_classifier(inputs).squeeze(-1)
                loss = criterion(outputs, labels.float())
                val_loss += loss.item()

                predicted = torch.round(torch.sigmoid(outputs))
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = correct / total

        # f1 score
        y_true = labels.cpu().numpy()
        y_pred = predicted.cpu().numpy()
        f1 = f1_score(y_true, y_pred)
        print(f"Validation Loss: {val_loss / len(test_dataloader):.4f}, Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}", end="\r")



  4%|▍         | 22/500 [00:00<00:02, 219.68it/s]

Validation Loss: 0.6569, Accuracy: 0.6364, F1 Score: 0.7778

  9%|▉         | 45/500 [00:00<00:02, 220.87it/s]

Validation Loss: 0.6565, Accuracy: 0.6364, F1 Score: 0.7778

 14%|█▎        | 68/500 [00:00<00:01, 221.00it/s]

Validation Loss: 0.6342, Accuracy: 0.6364, F1 Score: 0.7778

 23%|██▎       | 114/500 [00:00<00:01, 222.28it/s]

Validation Loss: 0.5941, Accuracy: 0.7273, F1 Score: 0.8235

 32%|███▏      | 160/500 [00:00<00:01, 223.07it/s]

Validation Loss: 0.5428, Accuracy: 0.7273, F1 Score: 0.8235

 41%|████      | 206/500 [00:00<00:01, 222.34it/s]

Validation Loss: 0.4896, Accuracy: 0.8182, F1 Score: 0.8571

 50%|█████     | 252/500 [00:01<00:01, 222.71it/s]

Validation Loss: 0.4477, Accuracy: 0.9091, F1 Score: 0.9231

 60%|█████▉    | 298/500 [00:01<00:00, 223.68it/s]

Validation Loss: 0.4036, Accuracy: 0.8182, F1 Score: 0.8333

 69%|██████▉   | 344/500 [00:01<00:00, 224.20it/s]

Validation Loss: 0.3847, Accuracy: 0.8182, F1 Score: 0.8333

 78%|███████▊  | 390/500 [00:01<00:00, 224.30it/s]

Validation Loss: 0.3571, Accuracy: 0.8182, F1 Score: 0.8333

 87%|████████▋ | 436/500 [00:01<00:00, 224.17it/s]

Validation Loss: 0.3407, Accuracy: 0.8182, F1 Score: 0.8333

100%|██████████| 500/500 [00:02<00:00, 222.94it/s]




  9%|▉         | 45/500 [00:00<00:02, 222.56it/s]

Validation Loss: 0.6723, Accuracy: 0.8182, F1 Score: 0.9000

 14%|█▎        | 68/500 [00:00<00:01, 222.44it/s]

Validation Loss: 0.6163, Accuracy: 0.8182, F1 Score: 0.9000

 23%|██▎       | 114/500 [00:00<00:01, 222.81it/s]

Validation Loss: 0.5271, Accuracy: 0.8182, F1 Score: 0.9000

 32%|███▏      | 160/500 [00:00<00:01, 224.01it/s]

Validation Loss: 0.4360, Accuracy: 1.0000, F1 Score: 1.0000

 41%|████      | 206/500 [00:00<00:01, 221.78it/s]

Validation Loss: 0.3763, Accuracy: 1.0000, F1 Score: 1.0000

 50%|█████     | 252/500 [00:01<00:01, 220.32it/s]

Validation Loss: 0.3160, Accuracy: 1.0000, F1 Score: 1.0000

 60%|█████▉    | 298/500 [00:01<00:00, 220.07it/s]

Validation Loss: 0.2695, Accuracy: 1.0000, F1 Score: 1.0000

 69%|██████▉   | 344/500 [00:01<00:00, 220.03it/s]

Validation Loss: 0.2310, Accuracy: 1.0000, F1 Score: 1.0000

 78%|███████▊  | 390/500 [00:01<00:00, 220.46it/s]

Validation Loss: 0.2341, Accuracy: 1.0000, F1 Score: 1.0000

 87%|████████▋ | 436/500 [00:01<00:00, 220.69it/s]

Validation Loss: 0.2115, Accuracy: 1.0000, F1 Score: 1.0000

 96%|█████████▌| 481/500 [00:02<00:00, 219.35it/s]

Validation Loss: 0.2002, Accuracy: 1.0000, F1 Score: 1.0000

100%|██████████| 500/500 [00:02<00:00, 220.72it/s]




  4%|▍         | 22/500 [00:00<00:02, 219.66it/s]

Validation Loss: 0.6415, Accuracy: 0.6364, F1 Score: 0.7778

  9%|▉         | 45/500 [00:00<00:02, 221.47it/s]

Validation Loss: 0.6366, Accuracy: 0.6364, F1 Score: 0.7778

 14%|█▎        | 68/500 [00:00<00:01, 222.89it/s]

Validation Loss: 0.5449, Accuracy: 0.9091, F1 Score: 0.9333

 23%|██▎       | 114/500 [00:00<00:01, 221.44it/s]

Validation Loss: 0.4011, Accuracy: 1.0000, F1 Score: 1.0000

 32%|███▏      | 160/500 [00:00<00:01, 221.32it/s]

Validation Loss: 0.2971, Accuracy: 1.0000, F1 Score: 1.0000

 41%|████      | 206/500 [00:00<00:01, 222.15it/s]

Validation Loss: 0.2334, Accuracy: 1.0000, F1 Score: 1.0000

 50%|█████     | 252/500 [00:01<00:01, 222.48it/s]

Validation Loss: 0.1915, Accuracy: 1.0000, F1 Score: 1.0000

 60%|█████▉    | 298/500 [00:01<00:00, 221.37it/s]

Validation Loss: 0.1790, Accuracy: 1.0000, F1 Score: 1.0000

 69%|██████▉   | 344/500 [00:01<00:00, 220.37it/s]

Validation Loss: 0.1593, Accuracy: 1.0000, F1 Score: 1.0000

 78%|███████▊  | 390/500 [00:01<00:00, 221.01it/s]

Validation Loss: 0.1424, Accuracy: 1.0000, F1 Score: 1.0000

 87%|████████▋ | 436/500 [00:01<00:00, 222.40it/s]

Validation Loss: 0.1340, Accuracy: 1.0000, F1 Score: 1.0000

 96%|█████████▋| 482/500 [00:02<00:00, 223.46it/s]

Validation Loss: 0.1201, Accuracy: 1.0000, F1 Score: 1.0000

100%|██████████| 500/500 [00:02<00:00, 221.91it/s]

Validation Loss: 0.1192, Accuracy: 1.0000, F1 Score: 1.0000




In [10]:
actions_agent

tensor([1, 1], device='cuda:0')

In [None]:
df_rbrl = df[df.agent == "rbrl"].copy()
for r, i in zip(df_rbrl.rules, df_rbrl.rule_sel_index):
    r_parsed = 