In [None]:

config = {
    "window": 10,
    "embedding_dim": 300,
    "n_layers_action": 1,
    "n_layers_state": 1,
    "n_layers_scorer": 1,
    "n_layers_lstm": 1,
    "hidden_dim_action": 64,
    "hidden_dim_state": 512,
    "hidden_dim_scorer": 512,
    "hidden_lstm": 128,
    "activation": "relu",
    "emb": "sum",
    "hc": None,
    "unq": False,
    "learning_rate": 0.0001,
    "env_step_limit": 20,
    "seed": 256,
    "max_steps": 100000,
    "update_freq": 1,
    "log_freq": 500,
    "eval_freq": 1000,
    "memory_size": 500000,
    "encoder_memory_size": 10,
    "save_memory": 0.5,
    "memory_path": "./encoder_memory/",
    "batch_size": 256,
    "gamma": 0.9,
    "clip": 100,
    "game_path": "./scenarios",
    "wrong_answer": True,
    "soft_reward": False,
    "reward_scale": 1,
    "wording": True,
    "evaluation": "cause",
    "document": False,
    "reduced": False,
    "encoder_type": "fasttext",
    "train_ratio": 0.8,
    "test_ratio": 0.1,
    "test_mode": "wording",
    "save_path": "./models/",
    "train_type": "normal",
    "TAU": 0.5,
    "pretrain": False,
    "llm_assisted": False,
    "use_attention": False,
    "pretrained_explore": 0.3,
    "reduce_scenarios": False,
    "patient": "baby",
    "penalty": -0.01,
}


In [None]:
import openai
import os
import json
import torch
import game
import wandb
import argparse
import random
random.seed(config["seed"])
import shutil
import fasttext
import scipy
from scenario_helper import split_single_scenario
from extract_scenarios import scenario_extractor
import numpy as np
from drrn import DRRNAgent
import llm_helper
import re
from test import summarize_ep
import pathlib
import copy

In [None]:
def check_dict_structure(results,required_keys=["train", "test", "val"],sub_keys=["score", "traj_score", "eff_score", "episode"]):
    # Check if the required keys are present in the main dictionary
    for key in required_keys:
        if key not in results:
            return False

    # Check if the nested dictionaries under the required keys have the required sub keys
    for key in required_keys:
        sub_dict = results[key]
        if not all(sub_key in sub_dict for sub_key in sub_keys):
            return False

    return True
def check_subkeys(sub_dict,sub_keys=["score", "traj_score", "eff_score", "episode"]):
    if not all(sub_key in sub_dict for sub_key in sub_keys):
        return False
    return True

In [None]:
def evaluate_patient_eval_mode(patient, eval_mode, inference, mode_,split):
    #################### LOAD ENVIRONMENT ####################
    if not os.path.exists(os.path.join('./results/',eval_mode,patient)):
        os.makedirs(os.path.join('./results/',eval_mode,patient))
    if os.path.exists(os.path.join('./results/',eval_mode,patient,f"{patient}_{eval_mode}_{inference}_{split}.json")):
        results = json.load(open(os.path.join('./results/',eval_mode,patient,f"{patient}_{eval_mode}_{inference}_{split}.json"),"r"))
        if check_dict_structure(results,required_keys=[mode_]):
            print(f"Results for {patient} {eval_mode} {inference} {split} already exist")
            return
        else:
            print(f"Results for {patient} {eval_mode} {inference} {split} is incomplete")
    else:
        results = {}
    env_train = game.Game(path=os.path.join(config["game_path"], eval_mode,patient, "train", str(split)), env_step_limit=config["env_step_limit"],
                          wrong_answer=config["wrong_answer"],
                          emb=config["emb"], hc=config["hc"],
                          embedding_dim=config["embedding_dim"],
                          wording=config["wording"], evaluation=config["evaluation"],
                          random_scenarios=True,
                          reward_scale=config["reward_scale"], reduced=config["reduced"],penalty=config["penalty"])
    env_train_eval = game.Game(path=os.path.join(config["game_path"], eval_mode,patient, "train", str(split)),
                               env_step_limit=config["env_step_limit"],
                               wrong_answer=config["wrong_answer"], emb=config["emb"],
                               hc=config["hc"],
                               embedding_dim=config["embedding_dim"],
                               wording=config["wording"], evaluation=config["evaluation"],
                               random_scenarios=False,
                               reward_scale=config["reward_scale"], reduced=config["reduced"])
    env_val = game.Game(path=os.path.join(config["game_path"], eval_mode,patient, "val", str(split)), env_step_limit=config["env_step_limit"],
                        wrong_answer=config["wrong_answer"], emb=config["emb"],
                        hc=config["hc"],
                        embedding_dim=config["embedding_dim"],
                        wording=config["wording"], evaluation=config["evaluation"],
                        random_scenarios=False,
                        reward_scale=config["reward_scale"], reduced=config["reduced"])
    env_test = game.Game(path=os.path.join(config["game_path"], eval_mode,patient, "test", str(split)), env_step_limit=config["env_step_limit"],
                         wrong_answer=config["wrong_answer"],
                         emb=config["emb"],
                         hc=config["hc"],
                         embedding_dim=config["embedding_dim"],
                         wording=config["wording"], evaluation=config["evaluation"],
                         random_scenarios=False,
                         reward_scale=config["reward_scale"], reduced=config["reduced"])
    state_dim = env_train.get_state_len()
    total_num_train = env_train.get_num_of_scenarios()
    total_num_val = env_val.get_num_of_scenarios()
    total_num_test = env_test.get_num_of_scenarios()
    total_num = {"train":total_num_train,"val":total_num_val, "test":total_num_test}
    summarizer = llm_helper.Summarizer(prompt_format=llm_helper.SummarizerPrompt2(),max_tokens=128)
    summarizer = None
    if inference == "choose":
        topk = 5
        chooser_format = llm_helper.GPTChooses(topk=topk)
        chooser = llm_helper.Chooser(prompt_format=chooser_format,model="gpt-4-1106-preview",max_tokens=1024)
    elif inference == "recommend":
        num_of_recs = 5
        recer_format = llm_helper.GPTRecs(num_of_recs=num_of_recs)
        recer = llm_helper.Recommender(prompt_format=recer_format,model="gpt-4-1106-preview",max_tokens=1024)
    elif inference == "play":
        player_format = llm_helper.GPTPlays()
        player = llm_helper.Player(prompt_format=player_format,model="gpt-4-1106-preview",max_tokens=1024)
    elif inference == "clin_play":
        clin_player_format = llm_helper.CLINPlays()
        clin_player = llm_helper.Player(prompt_format=clin_player_format,model="gpt-4-1106-preview",max_tokens=1024)
    elif inference == "clin_choose":
        topk = 5
        clin_chooser_format = llm_helper.CLINChooses(topk=topk)
        clin_chooser = llm_helper.Chooser(prompt_format=clin_chooser_format,model="gpt-4-1106-preview",max_tokens=1024)
    elif inference == "clin_recommend":
        num_of_recs = 5
        clin_recer_format = llm_helper.CLINRecs(num_of_recs=num_of_recs)
        clin_recer = llm_helper.Recommender(prompt_format=clin_recer_format,model="gpt-4-1106-preview",max_tokens=1024)
    elif inference == "cor":
        num_of_recs = 5
        corer_format = llm_helper.GPTChooses_or_Recs(topk=num_of_recs)
        corer = llm_helper.Chooser_or_Recommender(prompt_format=corer_format,model="gpt-4-1106-preview",max_tokens=1024)
    else:
        pass
    #################### LOAD MODEL ####################
    if inference not in ["play","clin_play"]:
        agent = DRRNAgent(config, state_dim)
        model = torch.load(os.path.join(f"./models/{patient}_{eval_mode}_{split}", 'best_model.pt'))
        agent.policy_network= model
        agent.target_network = model
    #################### EVALUATE ####################
    for mode in [mode_]:
        prev_exp = None
        if mode == "test":
            if inference == "play" or inference == "clin_play":
                if "train" in results.keys() and "history" in results["train"].keys():
                    prev_exp = results["train"]["history"]
            else:
                if os.path.exists(os.path.join('./results/',f"{patient}_{eval_mode}_normal.json")):
                    prev_results = json.load(open(os.path.join('./results/',f"{patient}_{eval_mode}_normal.json"),"r"))
                    if "train" in prev_results.keys() and "history" in prev_results["train"].keys():
                        prev_exp = prev_results["train"]["history"]
            print(prev_exp)
        # prev_exp = None
        #######################
        if mode not in results.keys():
            results[mode]={}
        elif check_subkeys(results[mode]):
            print(f"Results for {patient} {eval_mode} {inference}{mode} already exist")
            continue
        else:
            results[mode]={}
        env = env_val if mode == "val" else env_test if mode == "test" else env_train_eval
        total_score = 0
        total_traj_score = 0
        total_combined = 0
        total_eff_score = 0
        episodes = []
        histories = []
        for i in range(total_num[mode]):
            ep_results = {}
            if os.path.exists(os.path.join('./results/',eval_mode,patient,mode,f"{patient}_{eval_mode}_{inference}_{split}_{i}.json")):
                ep_results = json.load(open(os.path.join('./results/',eval_mode,patient,mode,f"{patient}_{eval_mode}_{inference}_{split}_{i}.json"),"r"))
                if check_subkeys(ep_results,sub_keys=["score", "traj_score", "eff_score", "episode", "history", "scenario_name"]):
                    print(f"Results for {patient} {eval_mode} {inference} {mode} {i} already exist")
                    score = ep_results["score"]
                    traj_score = ep_results["traj_score"]
                    eff_score = ep_results["eff_score"]
                    combined = ep_results["combined"]
                    episode = ep_results["episode"]
                    history = ep_results["history"]
                    scenario_name = ep_results["scenario_name"]
                    total_score += score>0
                    total_traj_score += traj_score
                    total_eff_score += eff_score
                    total_combined += combined
                    episodes.append(episode)
                    histories.append(history)
                    env.increase_episodes()
                    continue
            if not os.path.exists(os.path.join('./results/',eval_mode,patient,mode)):
                os.makedirs(os.path.join('./results/',eval_mode,patient,mode), exist_ok=True)
            if inference == "choose":
                score, episode, traj_score, eff_score, scenario_name,history = evaluate_episode_choose(chooser,topk, agent, env,prev_exp=prev_exp,summarizer=summarizer)
            elif inference == "normal":
                score, episode, traj_score, eff_score, scenario_name,history = evaluate_episode(agent, env, policy="softmax")
            elif inference == "recommend":
                score, episode, traj_score, eff_score, scenario_name,history =  evaluate_episode_rec2(recer,num_of_recs, agent, env,prev_exp=prev_exp,summarizer=summarizer)
            elif inference == "play":
                score, episode, traj_score, eff_score, scenario_name,history = evaluate_episode_play(player, env,summarizer=summarizer,prev_exp=prev_exp)
            elif inference == "clin_play":
                score, episode, traj_score, eff_score, scenario_name,history = evaluate_episode_clin_play(mode,clin_player, env, patient,prev_exp=prev_exp,summarizer=summarizer,eval_mode=eval_mode,split=split)

            elif inference == "clin_choose":
                score, episode, traj_score, eff_score, scenario_name,history = evaluate_episode_clin_choose(mode,clin_chooser,topk, agent, env,patient,prev_exp=prev_exp,summarizer=summarizer,eval_mode=eval_mode,split=split)

            elif inference == "clin_recommend":
                score, episode, traj_score, eff_score, scenario_name,history =  evaluate_episode_clin_rec(mode,clin_recer,num_of_recs, agent, env,patient,prev_exp=prev_exp,summarizer=summarizer,eval_mode=eval_mode,split=split)
            elif inference == "cor":
                score, episode, traj_score, eff_score, scenario_name,history = evaluate_episode_choose_or_rec(corer,num_of_recs, agent, env,prev_exp=prev_exp,summarizer=summarizer)
            else:
                raise ValueError("Inference mode not supported")
            ep_results["score"] = score
            ep_results["combined"] = (score>0) * traj_score
            ep_results["traj_score"] = traj_score
            ep_results["eff_score"] = eff_score
            ep_results["episode"] = episode
            ep_results["history"] = history
            ep_results["scenario_name"] = scenario_name
            json.dump(ep_results, open(os.path.join('./results/',eval_mode,patient,mode,f"{patient}_{eval_mode}_{inference}_{split}_{i}.json"),"w"), indent=4)
            total_score += score>0
            total_traj_score += traj_score
            total_eff_score += eff_score
            total_combined += ((score>0) * traj_score)
            episodes.append(episode)
            histories.append(history)
            env.increase_episodes()
        results[mode]["score"] = (total_score/total_num[mode])
        results[mode]["traj_score"] = (total_traj_score/total_num[mode])
        results[mode]["eff_score"] = (total_eff_score/total_num[mode])
        results[mode]["combined"] = (total_combined/total_num[mode])
        results[mode]["episode"] = episodes
        results[mode]["history"] = histories

        json.dump(results, open(os.path.join('./results/',eval_mode,patient,f"{patient}_{eval_mode}_{inference}_{split}.json"),"w"), indent=4)

In [None]:
def evaluate_episode_rec(recer, num_of_recs, agent, env, policy="softmax",prev_exp=None,summarizer=None):
    episode = []
    step = 0
    score = 0
    done = False
    agent.reset_dictionaries()
    history = []
    reasons = []
    ob, valid_acts, hc = env.reset()
    history.append(ob[1])
    valid_subjects = env.scenario["subjects"]
    valid_topics = env.scenario["topics"]
    valid_causes = env.scenario["causes"]
    scenario_name = env.scenario["name"]
    subject = env.scenario["characters"][0]
    problem = find_phrase(ob[1])[0]
    state = agent.create_state(update_sentence=ob, hc=hc)
    posttest = False
    while not done:
        transition = [env.scenario["name"], step, ob[1], ]
        valid_ids = agent.encode_actions(valid_acts)
        _, action_idx, action_values, _ = agent.act(
            [state], [valid_ids], policy=policy, eval_mode=True, action_strs=valid_acts, temperature= 1)
        if len(valid_acts)>1:
            is_valid = False
            tries = 0
            while not is_valid:
                tries += 1
                response = recer.rec(history, subject, problem, valid_subjects, valid_topics, valid_causes, posttest=posttest,prev_exp=prev_exp,summarizer=summarizer)
                if posttest:
                    print(response)
                reason,recs = response.split("###")
                reason,recs = reason.strip("\n").strip(),recs.strip("\n").strip()

                parsed_responses = []
                for x in recs.split("\n"):
                    parsed_responses.append(parse_string_to_dict((x.split(". ")[-1]),valid_subjects,valid_topics,valid_causes,replace_closest=tries>3))
                recs_idxs = []
                valid_responses = []
                for x in parsed_responses:
                    if x in valid_acts:
                        recs_idxs.append(valid_acts.index(x))
                        valid_responses.append(valid_acts[recs_idxs[-1]])

                if len(recs_idxs)>0:
                    is_valid = True
                    act_probs = (softmax(action_values[0][recs_idxs],temperature=0.001 if posttest else 1))
                    chosen_act_idx = torch.multinomial(act_probs, num_samples=1).item()
                    print(reason, recs)
                    action_str = valid_responses[chosen_act_idx]
                    print(action_str)
                if not is_valid and tries>3:
                    print("OUT OF TRIES")
                    break
            reasons.append(reason)
        else:
            action_str = valid_acts[0]
            reasons.append("")
        history.append(action_str["sentence"])
        state_update, rew, done, valid_acts, hc, traj_score = env.step(ob, action_str)
        if not done:
            if state_update[0] == "interaction":
                history.append(".".join(state_update[1].split(".")[2:]))
            else:
                history.append(".".join(state_update[1].split(".")[-1:]))
        posttest = state_update[0] == "posttest"
        if not done:
            trace = env.trace
        ob = state_update
        score += rew
        step += 1
        transition += [action_str, rew, score]
        episode.append(transition)
        state = agent.create_state(
            update_sentence=ob, hc=hc, previous_state=state)

    traj_score = sum(
        a in trace for a in env.scenario["present_actions"]) / len(env.scenario["present_actions"])
    eff_score = sum(
        a in trace for a in env.scenario["present_actions"]) / len(trace)
    agent.reset_dictionaries()
    return score, episode, traj_score,eff_score, scenario_name,history

In [None]:
def evaluate_episode_rec2(recer, num_of_recs, agent, env, policy="softmax",prev_exp=None,summarizer=None):
    episode = []
    step = 0
    score = 0
    done = False
    agent.reset_dictionaries()
    history = []
    reasons = []
    ob, valid_acts, hc = env.reset()
    history.append(ob[1])
    valid_subjects = env.scenario["subjects"]
    valid_topics = env.scenario["topics"]
    valid_causes = env.scenario["causes"]
    scenario_name = env.scenario["name"]
    subject = env.scenario["characters"][0]
    problem = find_phrase(ob[1])[0]
    state = agent.create_state(update_sentence=ob, hc=hc)
    posttest = False
    while not done:
        transition = [env.scenario["name"], step, ob[1], ]
        valid_ids = agent.encode_actions(valid_acts)
        _, action_idx, action_values, _ = agent.act(
            [state], [valid_ids], policy=policy, eval_mode=True, action_strs=valid_acts, temperature= 1)
        if len(valid_acts)>1:
            is_valid = False
            tries = 0
            while not is_valid:
                tries += 1
                response = recer.rec(history, subject, problem, valid_subjects, valid_topics, valid_causes, posttest=posttest,prev_exp=prev_exp,summarizer=summarizer)
                if posttest:
                    print(response)
                reason,recs = response.split("###")
                reason,recs = reason.strip("\n").strip(),recs.strip("\n").strip()

                parsed_responses = []
                for x in recs.split("\n"):
                    parsed_responses.append(parse_string_to_dict((x.split(". ")[-1]),valid_subjects,valid_topics,valid_causes,replace_closest=tries>3))
                recs_idxs = []
                valid_responses = []
                for x in parsed_responses:
                    if x in valid_acts:
                        recs_idxs.append(valid_acts.index(x))
                        valid_responses.append(valid_acts[recs_idxs[-1]])

                if len(recs_idxs)>0:
                    if posttest:
                        recs_idxs = [recs_idxs[0]]
                    is_valid = True
                    act_probs = (softmax(action_values[0][recs_idxs],temperature=0.001 if posttest else 1))
                    chosen_act_idx = torch.multinomial(act_probs, num_samples=1).item()
                    print(reason, recs)
                    action_str = valid_responses[chosen_act_idx]
                    print(action_str)
                if not is_valid and tries>3:
                    print("OUT OF TRIES")
                    break
            reasons.append(reason)
        else:
            action_str = valid_acts[0]
            reasons.append("")
        history.append(action_str["sentence"])
        state_update, rew, done, valid_acts, hc, traj_score = env.step(ob, action_str)
        if not done:
            if state_update[0] == "interaction":
                history.append(".".join(state_update[1].split(".")[2:]))
            else:
                history.append(".".join(state_update[1].split(".")[-1:]))
        posttest = state_update[0] == "posttest"
        if not done:
            trace = env.trace
        ob = state_update
        score += rew
        step += 1
        transition += [action_str, rew, score]
        episode.append(transition)
        state = agent.create_state(
            update_sentence=ob, hc=hc, previous_state=state)

    traj_score = sum(
        a in trace for a in env.scenario["present_actions"]) / len(env.scenario["present_actions"])
    eff_score = sum(
        a in trace for a in env.scenario["present_actions"]) / len(trace)
    agent.reset_dictionaries()
    return score, episode, traj_score,eff_score, scenario_name,history

In [None]:
def evaluate_episode_clin_rec(mode,recer, num_of_recs, agent, env,patient,eval_mode, policy="softmax",prev_exp=None,summarizer=None,split=0):
    score = 0
    for ep in range(3):
        if score == 1:
            break
        episode = []
        step = 0
        score = 0
        done = False
        agent.reset_dictionaries()
        history = []
        reasons = []
        learning_ids = []
        saved_history = []
        history_update = {}
        ob, valid_acts, hc = env.reset()
        history.append(ob[1])
        valid_subjects = env.scenario["subjects"]
        valid_topics = env.scenario["topics"]
        valid_causes = env.scenario["causes"]
        scenario_name = env.scenario["name"]
        subject = env.scenario["characters"][0]
        problem = find_phrase(ob[1])[0]
        state = agent.create_state(update_sentence=ob, hc=hc)
        posttest = False
        TaskDescription = f"Find the cause behind the {subject}'s {problem}"
        task, sub_task = patient, scenario_name
        save_path = f"./results/memory/{eval_mode}/{mode}/{split}/rec/{task}/{sub_task}"
        if not os.path.exists(save_path):
            pathlib.Path(save_path).mkdir(parents=True, exist_ok=True)
            episodeIdx = 0
        else:
            episodeIdx = len(os.listdir(save_path))
        episodeIdx = max(episodeIdx, ep)
        if episodeIdx > 0:
            summary = json.load(open(f"{save_path}/{episodeIdx - 1}.json", "r"))["summary"]
        else:
            summary = ""
        file_name = f"{save_path}/{episodeIdx}.json"
        while not done:
            transition = [env.scenario["name"], step, ob[1], ]
            valid_ids = agent.encode_actions(valid_acts)
            _, action_idx, action_values, _ = agent.act(
                [state], [valid_ids], policy=policy, eval_mode=True, action_strs=valid_acts, temperature= 1)
            if len(valid_acts)>1:
                is_valid = False
                tries = 0
                while not is_valid:
                    tries += 1
                    response = recer.rec(history, subject, problem, valid_subjects, valid_topics, valid_causes, summary, posttest=posttest,prev_exp=prev_exp,summarizer=summarizer)
                    print(response)

                    if posttest:
                        print(response)
                    try:
                        ### split and remove empty strings and spaces
                        response_split = []
                        for x in response.split("$$$"):
                            x = x.strip("\n").strip()
                            if len(x)>0:
                                response_split.append(x)
                        if len(response_split) == 2:
                            learning_id,response = response_split
                        else:
                            learning_id,response = "",response_split[0]
                    except Exception as e:
                        print(e)
                        print(response)
                        continue
                    learning_id = learning_id.strip("\n").strip()
                    try:
                        reason,recs = response.split("###")
                    except Exception as e:
                        print(e)
                        print(response)
                        continue
                    reason,recs = reason.strip("\n").strip(),recs.strip("\n").strip()

                    parsed_responses = []
                    for x in recs.split("\n"):
                        parsed_responses.append(parse_string_to_dict((x.split(". ")[-1]),valid_subjects,valid_topics,valid_causes,replace_closest=tries>3))
                    recs_idxs = []
                    valid_responses = []
                    for x in parsed_responses:
                        if x in valid_acts:
                            recs_idxs.append(valid_acts.index(x))
                            valid_responses.append(valid_acts[recs_idxs[-1]])

                    if len(recs_idxs)>0:
                        is_valid = True
                        act_probs = (softmax(action_values[0][recs_idxs],temperature=0.001 if posttest else 1))
                        chosen_act_idx = torch.multinomial(act_probs, num_samples=1).item()
                        print(reason, recs)
                        action_str = valid_responses[chosen_act_idx]
                        print(action_str)
                    if not is_valid and tries>3:
                        print("OUT OF TRIES")
                        break
            else:
                action_str = valid_acts[0]
            history_update["observation"] = history[-1]
            history_update["rationale"] = reason
            history_update["action"] = action_str["sentence"]
            learning_ids.append(learning_id)
            reasons.append(reason)
            history.append(action_str["sentence"])
            saved_history.append(copy.deepcopy(history_update))
            state_update, rew, done, valid_acts, hc, traj_score = env.step(ob, action_str)
            if not done:
                if state_update[0] == "interaction":
                    history.append(".".join(state_update[1].split(".")[2:]))
                else:
                    history.append(".".join(state_update[1].split(".")[-1:]))
            posttest = state_update[0] == "posttest"
            if not done:
                trace = env.trace
            ob = state_update
            score += rew
            step += 1
            transition += [action_str, rew, score]
            episode.append(transition)
            state = agent.create_state(
                update_sentence=ob, hc=hc, previous_state=state)

        traj_score = sum(
            a in trace for a in env.scenario["present_actions"]) / len(env.scenario["present_actions"])
        eff_score = sum(
            a in trace for a in env.scenario["present_actions"]) / len(trace)
        agent.reset_dictionaries()
        env.reset()
        data = dict()
        data["taskDescription"] = TaskDescription
        data["episodeIdx"] = episodeIdx
        data["history"] = saved_history
        data["finalScore"] = score
        data["finalTrajScore"] = traj_score
        data["finalEffScore"] = eff_score
        print(data)
        json.dump(data, open(file_name, "w"))
        if score == 1:
            break
        o= summarize_ep(task, sub_task,inference="rec",mode=mode,eval_mode=eval_mode,split=split)
    return score, episode, traj_score,eff_score, scenario_name,history

In [None]:
def evaluate_episode_clin_choose(mode,chooser, topk, agent, env,patient,eval_mode, policy="softmax",prev_exp=None,summarizer=None,split=0):
    score = 0
    for ep in range(3):
        if score == 1:
            break
        episode = []
        step = 0
        score = 0
        done = False
        agent.reset_dictionaries()
        history = []
        reasons = []
        learning_ids = []
        saved_history = []
        history_update = {}
        ob, valid_acts, hc = env.reset()
        history.append(ob[1])
        scenario_name = env.scenario["name"]
        subject = env.scenario["characters"][0]
        problem = find_phrase(ob[1])[0]
        state = agent.create_state(update_sentence=ob, hc=hc)
        posttest = False
        TaskDescription = f"Find the cause behind the {subject}'s {problem}"
        task, sub_task = patient, scenario_name
        save_path = f"./results/memory/{eval_mode}/{mode}/{split}/choose/{task}/{sub_task}"
        if not os.path.exists(save_path):
            pathlib.Path(save_path).mkdir(parents=True, exist_ok=True)
            episodeIdx = 0
        else:
            episodeIdx = len(os.listdir(save_path))
        episodeIdx = max(episodeIdx, ep)
        if episodeIdx > 0:
            summary = json.load(open(f"{save_path}/{episodeIdx - 1}.json", "r"))["summary"]
        else:
            summary = ""
        file_name = f"{save_path}/{episodeIdx}.json"
        while not done:
            transition = [env.scenario["name"], step, ob[1], ]
            valid_ids = agent.encode_actions(valid_acts)
            _, action_idx, action_values, _ = agent.act(
                [state], [valid_ids], policy=policy, eval_mode=True, action_strs=valid_acts, temperature= 1)
            sorted_idxs = np.argsort(action_values[0].detach().cpu().numpy())
            if not posttest:
                choices_idxs = sorted_idxs[-1:-topk-1:-1]
            else:
                choices_idxs = sorted_idxs[-1:-3:-1]
            choices = [valid_acts[i]["sentence"] for i in choices_idxs]
            print(choices)
            if len(valid_acts)>1:
                is_valid = False
                tries = 0
                while not is_valid:
                    tries += 1
                    if not is_valid and tries>3:
                        print(choices)
                        print(chosen_action)
                        print(reason)
                        print('Out of tries')
                        chosen_action = 1
                        break
                    response = chooser.choose(history, subject, problem, choices,summary, posttest=posttest,prev_exp=prev_exp,summarizer=summarizer)
                    response_split = []
                    for x in response.split("$$$"):
                        x = x.strip("\n").strip()
                        if len(x)>0:
                            response_split.append(x)
                    if len(response_split) == 2:
                        learning_id,response = response_split
                    else:
                        learning_id,response = "",response_split[0]
                    learning_id = learning_id.strip("\n").strip()
                    try:
                        reason,chosen_action = response.split("###")
                    except Exception as e:
                        print(e)
                        print(response)
                        continue
                    reason,chosen_action = reason.strip("\n").strip(),chosen_action.strip("\n").strip()
                    if chosen_action.isnumeric():
                        chosen_action = int(chosen_action)
                        if chosen_action<=len(choices_idxs) and chosen_action>0:
                            chosen_action = chosen_action
                            reason = reason
                            is_valid = True
                reasons.append(reason)
                print(chosen_action)
                print(reason)
            else:
                chosen_action = 1
                reason = ""
            action_str = (valid_acts[choices_idxs[chosen_action-1]])
            history_update["observation"] = history[-1]
            history_update["rationale"] = reason
            history_update["action"] = action_str["sentence"]
            learning_ids.append(learning_id)
            reasons.append(reason)
            history.append(action_str["sentence"])
            saved_history.append(copy.deepcopy(history_update))
            state_update, rew, done, valid_acts, hc, traj_score = env.step(ob, action_str)
            if not done:
                if state_update[0] == "interaction":
                    history.append(".".join(state_update[1].split(".")[2:]))
                else:
                    history.append(".".join(state_update[1].split(".")[-1:]))
            posttest = state_update[0] == "posttest"
            if not done:
                trace = env.trace
            ob = state_update
            score += rew
            step += 1
            transition += [action_str, rew, score]
            episode.append(transition)
            state = agent.create_state(
                update_sentence=ob, hc=hc, previous_state=state)
        traj_score = sum(
            a in trace for a in env.scenario["present_actions"]) / len(env.scenario["present_actions"])
        eff_score = sum(
            a in trace for a in env.scenario["present_actions"]) / len(trace)
        agent.reset_dictionaries()
        env.reset()
        data = dict()
        data["taskDescription"] = TaskDescription
        data["episodeIdx"] = episodeIdx
        data["history"] = saved_history
        data["finalScore"] = score
        data["finalTrajScore"] = traj_score
        data["finalEffScore"] = eff_score
        print(data)
        json.dump(data, open(file_name, "w"))
        if score == 1:
            break
        o= summarize_ep(task, sub_task,inference="choose",mode=mode,eval_mode=eval_mode,split=split)
    return score, episode, traj_score, eff_score, scenario_name,history

In [None]:

def evaluate_episode_clin_play(mode,clin_player, env, patient,eval_mode,prev_exp=None,summarizer=None,split=0):
    score = 0
    for ep in range(3):
        if score == 1:
            break
        episode = []
        step = 0
        score = 0
        done = False
        history = []
        reasons = []
        learning_ids = []
        saved_history = []
        history_update = {}
        ob, valid_acts, hc = env.reset()
        history.append(ob[1])
        valid_subjects = env.scenario["subjects"]
        valid_topics = env.scenario["topics"]
        valid_causes = env.scenario["causes"]
        scenario_name = env.scenario["name"]
        subject = env.scenario["characters"][0]
        problem = find_phrase(ob[1])[0]
        posttest = False
        TaskDescription = f"Find the cause behind the {subject}'s {problem}"
        task, sub_task = patient, scenario_name
        save_path = f"./results/memory/{eval_mode}/{mode}/{split}/play/{task}/{sub_task}"
        if not os.path.exists(save_path):
            pathlib.Path(save_path).mkdir(parents=True, exist_ok=True)
            episodeIdx = 0
        else:
            episodeIdx = len(os.listdir(save_path))
        episodeIdx = max(episodeIdx, ep)
        if episodeIdx > 0:
            summary = json.load(open(f"{save_path}/{episodeIdx - 1}.json", "r"))["summary"]
        else:
            summary = ""
        file_name = f"{save_path}/{episodeIdx}.json"
        while not done:
            transition = [env.scenario["name"], step, ob[1], ]
            if len(valid_acts)>1:
                is_valid = False
                tries = 0
                while not is_valid:
                    tries += 1
                    response = clin_player.play(history, subject, problem, valid_subjects, valid_topics, valid_causes,summary, posttest=posttest,prev_exp=prev_exp,summarizer=summarizer)
                    response_split = []
                    for x in response.split("$$$"):
                        x = x.strip("\n").strip()
                        if len(x)>0:
                            response_split.append(x)
                    if len(response_split) == 2:
                        learning_id,response = response_split
                    else:
                        learning_id,response = "",response_split[0]
                    learning_id = learning_id.strip("\n").strip()
                    try:
                        reason,action = response.split("###")
                    except Exception as e:
                        print(e)
                        print(response)
                        reason = ""
                        action = response
                    reason,action = reason.strip("\n").strip(),action.strip("\n").strip()
                    parsed_responses = []
                    for x in action.split("\n"):
                        parsed_responses.append(parse_string_to_dict((x.split(". ")[-1]),valid_subjects,valid_topics,valid_causes,replace_closest=tries>3))
                    valid_responses = []
                    for x in parsed_responses:
                        if x in valid_acts:
                            valid_responses.append(x)

                    if len(valid_responses)>0:
                        is_valid = True
                        action_str = valid_responses[0]
                        print(action_str)
                    else:
                        print("No valid responses")
                        print(action)
                        print(parsed_responses)
                        print(len(valid_acts))
                    if not is_valid and tries>3:
                        print("OUT OF TRIES")
                        break
                history_update["observation"] = history[-1]
                history_update["rationale"] = reason
                history_update["action"] = action_str["sentence"]
                learning_ids.append(learning_id)
                reasons.append(reason)
                history.append(action_str["sentence"])
                saved_history.append(copy.deepcopy(history_update))
            else:
                action_str = valid_acts[0]
                history.append(action_str["sentence"])
            state_update, rew, done, valid_acts, hc, traj_score = env.step(ob, action_str)
            if not done:
                if state_update[0] == "interaction":
                    history.append(".".join(state_update[1].split(".")[2:]))
                else:
                    history.append(".".join(state_update[1].split(".")[-1:]))
            posttest = state_update[0] == "posttest"
            if not done:
                trace = env.trace
            ob = state_update
            score += rew
            step += 1
            transition += [action_str, rew, score]
            episode.append(transition)
        traj_score = sum(
        a in trace for a in env.scenario["present_actions"]) / len(env.scenario["present_actions"])
        eff_score = sum(
        a in trace for a in env.scenario["present_actions"]) / len(trace)
        env.reset()
        data = dict()
        data["taskDescription"] = TaskDescription
        data["episodeIdx"] = episodeIdx
        data["history"] = saved_history
        data["finalScore"] = score
        data["finalTrajScore"] = traj_score
        data["finalEffScore"] = eff_score
        print(data)
        json.dump(data, open(file_name, "w"))
        if score == 1:
            break
        o= summarize_ep(task, sub_task,inference="play",mode=mode,eval_mode=eval_mode,split=split)
    return score, episode, traj_score, eff_score, scenario_name, history

In [None]:
def evaluate_episode_play(player, env,summarizer=None, prev_exp=None):
    episode = []
    step = 0
    score = 0
    done = False
    history = []
    reasons = []
    ob, valid_acts, hc = env.reset()
    history.append(ob[1])
    valid_subjects = env.scenario["subjects"]
    valid_topics = env.scenario["topics"]
    valid_causes = env.scenario["causes"]
    scenario_name = env.scenario["name"]
    subject = env.scenario["characters"][0]
    problem = find_phrase(ob[1])[0]
    posttest = False
    while not done:
        transition = [env.scenario["name"], step, ob[1], ]
        if len(valid_acts)>1:
            is_valid = False
            tries = 0
            while not is_valid:
                tries += 1
                response = player.play(history, subject, problem, valid_subjects, valid_topics, valid_causes, posttest=posttest,prev_exp=prev_exp,summarizer=summarizer)
                try:
                    reason,action = response.split("###")
                except Exception as e:
                    print(e)
                    print(response)
                    reason = ""
                    action = response
                reason,action = reason.strip("\n").strip(),action.strip("\n").strip()

                parsed_responses = []
                for x in action.split("\n"):
                    parsed_responses.append(parse_string_to_dict((x.split(". ")[-1]),valid_subjects,valid_topics,valid_causes,replace_closest=tries>3))
                valid_responses = []
                for x in parsed_responses:
                    if x in valid_acts:
                        valid_responses.append(x)

                if len(valid_responses)>0:
                    is_valid = True
                    action_str = valid_responses[0]
                    print(action_str)
                else:
                    print("No valid responses")
                    print(action)
                    print(parsed_responses)
                    print(len(valid_acts))
                if not is_valid and tries>3:
                    print("OUT OF TRIES")
                    break
            reasons.append(reason)
            history.append(action_str["sentence"])
        else:
            action_str = valid_acts[0]
            history.append(action_str["sentence"])
        state_update, rew, done, valid_acts, hc, traj_score = env.step(ob, action_str)
        if not done:
            if state_update[0] == "interaction":
                history.append(".".join(state_update[1].split(".")[2:]))
            else:
                history.append(".".join(state_update[1].split(".")[-1:]))
        posttest = state_update[0] == "posttest"
        if not done:
            trace = env.trace
        ob = state_update
        score += rew
        step += 1
        transition += [action_str, rew, score]
        episode.append(transition)

    traj_score = sum(
        a in trace for a in env.scenario["present_actions"]) / len(env.scenario["present_actions"])
    eff_score = sum(
        a in trace for a in env.scenario["present_actions"]) / len(trace)
    return score, episode, traj_score,eff_score, scenario_name,history

In [None]:
def evaluate_episode(agent, env, policy):
    episode = []
    history = []
    step = 0
    score = 0
    done = False
    agent.reset_dictionaries()
    ob, valid_acts, hc = env.reset()
    history.append(ob[1])
    scenario_name = env.scenario["name"]
    state = agent.create_state(update_sentence=ob, hc=hc)
    while not done:
        transition = [env.scenario["name"], step, ob[1], ]
        valid_ids = agent.encode_actions(valid_acts)
        _, action_idx, action_values, _ = agent.act(
            [state], [valid_ids], policy=policy, eval_mode=True, action_strs=valid_acts)
        action_idx = action_idx[0]
        action_values = action_values[0]
        action_str = valid_acts[action_idx]
        state_update, rew, done, valid_acts, hc, traj_score = env.step(ob, action_str)
        if not done:
            trace = env.trace
        history.append(action_str["sentence"])
        if not done:
            if state_update[0] == "interaction":
                history.append(".".join(state_update[1].split(".")[2:]))
            else:
                history.append(".".join(state_update[1].split(".")[-1:]))
        ob = state_update
        score += rew
        step += 1
        transition += [action_str, rew, score]
        episode.append(transition)
        state = agent.create_state(
            update_sentence=ob, hc=hc, previous_state=state)
    traj_score = sum(
        a in trace for a in env.scenario["present_actions"]) / len(env.scenario["present_actions"])
    eff_score = sum(
        a in trace for a in env.scenario["present_actions"]) / len(trace)
    agent.reset_dictionaries()
    return score, episode, traj_score, eff_score, scenario_name,history

In [None]:
a = [1,2,3,4,5]
a[-1:-6:-1]

In [None]:
def evaluate_episode_choose(chooser, topk, agent, env, policy="softmax",prev_exp=None,summarizer=None):
    episode = []
    step = 0
    score = 0
    done = False
    agent.reset_dictionaries()
    history = []
    reasons = []
    ob, valid_acts, hc = env.reset()

    history.append(ob[1])
    scenario_name = env.scenario["name"]
    subject = env.scenario["characters"][0]
    problem = find_phrase(ob[1])[0]
    state = agent.create_state(update_sentence=ob, hc=hc)
    posttest = False
    while not done:
        transition = [env.scenario["name"], step, ob[1], ]
        valid_ids = agent.encode_actions(valid_acts)
        _, action_idx, action_values, _ = agent.act(
            [state], [valid_ids], policy=policy, eval_mode=True, action_strs=valid_acts, temperature= 1)
        sorted_idxs = np.argsort(action_values[0].detach().cpu().numpy())
        if not posttest:
            choices_idxs = sorted_idxs[-1:-topk-1:-1]
        else:
            choices_idxs = sorted_idxs[-1:-3:-1]
        choices = [valid_acts[i]["sentence"] for i in choices_idxs]
        print(choices)
        if len(valid_acts)>1:
            is_valid = False
            tries = 0
            while not is_valid:
                tries += 1
                if not is_valid and tries>3:
                    print(choices)
                    print(chosen_action)
                    print(reason)
                    print('Out of tries')
                    chosen_action = 1
                    break

                response = chooser.choose(history, subject, problem, choices, posttest=posttest,prev_exp=prev_exp)
                try:
                    reason,chosen_action = response.split("###")
                except Exception as e:

                    print(e)
                    print(response)
                    continue
                reason,chosen_action = reason.strip("\n").strip(),chosen_action.strip("\n").strip()
                if chosen_action.isnumeric():
                    chosen_action = int(chosen_action)
                    reason = reason.split(": ")[-1]
                    if chosen_action<=len(choices_idxs) and chosen_action>0:
                        chosen_action = chosen_action
                        reason = reason
                        is_valid = True


            reasons.append(reason)
            print(chosen_action)
            print(reason)
        else:
            chosen_action = 1
        action_str = (valid_acts[choices_idxs[chosen_action-1]])
        history.append(action_str["sentence"])
        state_update, rew, done, valid_acts, hc, traj_score = env.step(ob, action_str)
        if not done:
            if state_update[0] == "interaction":
                history.append(".".join(state_update[1].split(".")[2:]))
            else:
                history.append(".".join(state_update[1].split(".")[-1:]))
        posttest = state_update[0] == "posttest"
        if not done:
            trace = env.trace
        ob = state_update
        score += rew
        step += 1
        transition += [action_str, rew, score]
        episode.append(transition)
        state = agent.create_state(
            update_sentence=ob, hc=hc, previous_state=state)
    traj_score = sum(
        a in trace for a in env.scenario["present_actions"]) / len(env.scenario["present_actions"])
    eff_score = sum(
        a in trace for a in env.scenario["present_actions"]) / len(trace)
    agent.reset_dictionaries()
    return score, episode, traj_score, eff_score, scenario_name,history

In [None]:
def evaluate_episode_choose_or_rec(cor, topk, agent, env, policy="softmax",prev_exp=None,summarizer=None):
    episode = []
    step = 0
    score = 0
    done = False
    agent.reset_dictionaries()
    valid_subjects = env.scenario["subjects"]
    valid_topics = env.scenario["topics"]
    valid_causes = env.scenario["causes"]
    scenario_name = env.scenario["name"]
    history = []
    reasons = []
    ob, valid_acts, hc = env.reset()

    history.append(ob[1])
    scenario_name = env.scenario["name"]
    subject = env.scenario["characters"][0]
    problem = find_phrase(ob[1])[0]
    state = agent.create_state(update_sentence=ob, hc=hc)
    posttest = False
    while not done:
        transition = [env.scenario["name"], step, ob[1], ]
        valid_ids = agent.encode_actions(valid_acts)
        _, action_idx, action_values, _ = agent.act(
            [state], [valid_ids], policy=policy, eval_mode=True, action_strs=valid_acts, temperature= 1)
        sorted_idxs = np.argsort(action_values[0].detach().cpu().numpy())
        choices_idxs = sorted_idxs[-1:-topk-1:-1]
        choices = [valid_acts[i]["sentence"] for i in choices_idxs]
        print(choices)
        if len(valid_acts)>1:
            is_valid = False
            tries = 0
            while not is_valid:
                tries += 1
                response = cor.cor(history, subject, problem, valid_subjects, valid_topics, valid_causes, choices, posttest=posttest,prev_exp=prev_exp)
                try:
                    mode,response = response.split("$$$")
                except Exception as e:
                    print(e)
                    print(response)
                    continue
                mode = mode.strip("\n").strip()
                if mode == "choose":
                    try:
                        reason,chosen_action = response.split("###")
                    except Exception as e:
                        print(e)
                        print(response)
                        continue
                    reason,chosen_action = reason.strip("\n").strip(),chosen_action.strip("\n").strip()
                    if chosen_action.isnumeric():
                        chosen_action = int(chosen_action)
                        reason = reason.split(": ")[-1]
                        is_valid = True
                    if not is_valid and tries>3:
                        print(choices)
                        print(chosen_action)
                        print(reason)
                        print('Out of tries')
                        chosen_action = choices_idxs[0]
                else:
                    reason,recs = response.split("###")
                    reason,recs = reason.strip("\n").strip(),recs.strip("\n").strip()

                    parsed_responses = []
                    for x in recs.split("\n"):
                        parsed_responses.append(parse_string_to_dict((x.split(". ")[-1]),valid_subjects,valid_topics,valid_causes,replace_closest=tries>3))
                    recs_idxs = []
                    valid_responses = []
                    for x in parsed_responses:
                        if x in valid_acts:
                            recs_idxs.append(valid_acts.index(x))
                            valid_responses.append(valid_acts[recs_idxs[-1]])

                    if len(recs_idxs)>0:
                        is_valid = True
                        act_probs = (softmax(action_values[0][recs_idxs],temperature=0.001 if posttest else 1))
                        chosen_act_idx = torch.multinomial(act_probs, num_samples=1).item()
                        print(reason, recs)
                        action_str = valid_responses[chosen_act_idx]
                        print(action_str)
                    if not is_valid and tries>3:
                        print("OUT OF TRIES")
                        break

            reasons.append(reason)
            print(chosen_action)
            print(reason)
        else:
            chosen_action = 1
        action_str = (valid_acts[choices_idxs[chosen_action-1]])
        history.append(action_str["sentence"])
        state_update, rew, done, valid_acts, hc, traj_score = env.step(ob, action_str)
        if not done:
            if state_update[0] == "interaction":
                history.append(".".join(state_update[1].split(".")[2:]))
            else:
                history.append(".".join(state_update[1].split(".")[-1:]))
        posttest = state_update[0] == "posttest"
        if not done:
            trace = env.trace
        ob = state_update
        score += rew
        step += 1
        transition += [action_str, rew, score]
        episode.append(transition)
        state = agent.create_state(
            update_sentence=ob, hc=hc, previous_state=state)
    traj_score = sum(
        a in trace for a in env.scenario["present_actions"]) / len(env.scenario["present_actions"])
    eff_score = sum(
        a in trace for a in env.scenario["present_actions"]) / len(trace)
    agent.reset_dictionaries()
    return score, episode, traj_score, eff_score, scenario_name,history

In [None]:
def rename_folder(old_name, new_name):
    try:
        os.rename(old_name, new_name)
        print(f"Folder '{old_name}' renamed to '{new_name}' successfully.")
    except FileNotFoundError:
        print(f"Folder '{old_name}' not found.")
    except FileExistsError:
        print(f"Folder '{new_name}' already exists.")
def remove_folder(folder_path):
    try:
        shutil.rmtree(folder_path)
        print(f"Folder '{folder_path}' removed successfully.")
    except FileNotFoundError:
        print(f"Folder '{folder_path}' not found.")
    except PermissionError:
        print(f"Permission denied to remove folder '{folder_path}'.")
def copy_folder(source_folder, destination_folder):
    try:
        # Check if the destination folder exists, and if not, create it
        if os.path.exists(destination_folder):
            shutil.rmtree(destination_folder)
            print(f"Deleted folder '{destination_folder}'.")

        # Copy the source folder and its contents to the destination folder
        shutil.copytree(source_folder, destination_folder)
        print(f"Folder '{source_folder}' copied to '{destination_folder}' successfully.")
    except FileNotFoundError:
        print(f"Folder '{source_folder}' not found.")
    except FileExistsError:
        print(f"Folder '{destination_folder}' already exists.")

In [None]:
def print_table(scores):
    # Headers for the table
    print(f"{'Mode/Type':<15}{'Choose':<20}{'Recommend':<20} {'Normal':<20}")

    for mode in ['train', 'val', 'test']:
        print(f"{mode:<15}", end='')

        for typ in ['choose','recomend', 'normal']:
            score, traj_score = scores[typ][mode]["score"], scores[typ][mode]["traj_score"]
            print(f"{score}/{traj_score:<20}", end='')

        print()

In [None]:
def softmax(q_values, temperature):
    """
    Apply softmax function with temperature to a set of Q-values.

    :param q_values: A tensor of Q-values for each action.
    :param temperature: The temperature parameter for softmax.
                        Higher values increase exploration.
    :return: The probabilities for each action.
    """
    q_values_temp = q_values / temperature
    exp_q_values = torch.exp(q_values_temp - torch.max(q_values_temp))
    probabilities = exp_q_values / torch.sum(exp_q_values)

    return probabilities

In [None]:
fasttext_model = fasttext.load_model(
                "./lms/cc.en.300.bin"
            )
def find_all_occurences(list, value):
    return [i for i, x in enumerate(list) if x == value]
def match(sentence, valid_sentences,replace_closest=False):
    for t in valid_sentences:
        if sentence == t.lower():
            return t
    indicator = [(sentence in t.lower()) or (t.lower() in sentence) for t in valid_sentences]
    if any(indicator):
        idx = find_all_occurences(indicator,True)
        values = [valid_sentences[i] for i in idx]
        len_values = [len(x) for x in values]
        if len(idx)>1:
            idx = idx[len_values.index(max(len_values))]
        if isinstance(idx,list):
            idx = idx[0]
        return valid_sentences[idx]
    else:
        if replace_closest:
            ### replace the closest sentence
            valid_sentences_embeddings = [fasttext_model.get_sentence_vector(x) for x in valid_sentences]
            sentence_embedding = fasttext_model.get_sentence_vector(sentence)
            distances = [1-scipy.spatial.distance.cosine(x,sentence_embedding) for x in valid_sentences_embeddings]
            idx = distances.index(max(distances))
            return valid_sentences[idx]
        else:
            return sentence
def parse_string_to_dict(input_str,valid_subjects,valid_topics,valid_causes,replace_closest=False):
    # Splitting the input string into the command and the arguments
    input_str = input_str.lower()
    print(input_str)
    parts = input_str.split('(',1)
    command = parts[0]
    args = parts[1].rsplit(')',1)[0] if len(parts) > 1 else ""
    args = args.split("),")[0]
    # Initializing the dictionary with default values
    result_dict = {
        "type": "",
        "part": "",
        "detail": "",
        "sentence": ""
    }

    # Mapping based on the command
    if command == "ask":
        result_dict["type"] = "interaction"
        result_dict["part"] = "discuss"
        if len(args.split(',')) != 2:
            print(args.split(','))
            raise
        subject, topic = args.split(',') if args else ("", "")
        subject, topic = subject.strip(), topic.strip()
        subject = match(subject,valid_subjects,replace_closest=replace_closest)
        topic = match(topic,valid_topics,replace_closest=replace_closest)
        print(subject,topic)
        result_dict["detail"] = ",".join([subject,topic])
        result_dict["sentence"] = f"i want to know about the {subject} 's {topic}."

    elif command == "answer":
        result_dict["type"] = "interaction"
        result_dict["part"] = "solution"
        result_dict["sentence"] = "i want to suggest a solution."

    elif command == "choose":
        result_dict["type"] = "posttest"
        args = match(args,valid_causes,replace_closest=replace_closest)
        result_dict["sentence"] = args

    return result_dict


In [None]:
def find_phrase(text):
    # Pattern to find phrases between "have"/"has" and a dot
    pattern = r'\b(have|has)\b(.*?)(?=\.)'

    # Find all matches in the text
    matches = re.findall(pattern, text, re.IGNORECASE | re.DOTALL)

    # Extracting the phrases
    phrases = [''.join(match[1]).strip() for match in matches]
    return phrases

In [None]:
eval_modes = ["wording"]
# eval_modes = ["wording",]
inferences = ["choose","recommend","play","normal","clin_choose","clin_recommend","clin_play"]
inferences = ["normal","recommend","clin_recommend","choose","clin_choose","normal"]
# inferences = ["normal"]
# inferences = ["clin_play"]
inferences = ["clin_play","clin_recommend","clin_choose"]
# inferences = ["normal"]
patients = ["baby","mother","gm","skin","eye","gyno","joint","stomachache","throat"]
# patients = ["mother","eye","gm"]
# patients = ["baby","skin","gyno","joint","stomachache","throat"]
# patients = ["throat"]
# patients = ["baby"]
modes = ["train","val","test"]
modes = ["test"]
splits = [0,1,2]
splits = [0]

In [None]:
for e in eval_modes:
    for p in patients:
        for inference in inferences:
            for mode in modes:
                for split in splits:
                    print(f"Patient: {p}, Eval mode: {e}, Inference: {inference}, Mode: {mode}, Split: {split}")
                    evaluate_patient_eval_mode(p, e, inference, mode_=mode,split=split)


In [None]:
################### DOWNLOAD MODELS ####################
run = wandb.init()
for split in splits:
    for p in patients:
        for e in eval_modes:
            if os.path.exists(f"./models/{p}_{e}_{split}"):
                print(f"Folder './models/{p}_{e}_{split}' already exists.")
            else:
                print(f"Downloading {p}_{e} models")
                artifact = run.use_artifact(f'xxxx/{p}_{e}_not_pretrained_fasttext_cause_sum/best-model:latest', type='model')
                artifact_dir = artifact.download()
                rename_folder(artifact_dir, f"./models/{p}_{e}_{split}")



In [None]:
################### SPLIT SCENARIOS ####################
for e in eval_modes:
    remove_folder(os.path.join(config["game_path"],e))
    for p in patients:
        dirs = ["train", "val", "test"]
        for d in dirs:
            for f in os.listdir(os.path.join(config["game_path"], "patients",p,e,d)):
                json_file = os.path.join(config["game_path"], "patients",p,e,d,f)
                if os.path.isfile(json_file):
                    with open(json_file, "r") as file:
                        scenario = json.load(file)
                        if len(scenario["subjects"])<3:
                            scenario["subjects"] = list(scenario["question_answers"].keys())
                            json.dump(scenario, open(json_file, "w"), indent=4, sort_keys=True)
            copy_folder(os.path.join(config["game_path"], "patients",p,e,d), os.path.join(config["game_path"],e,p,d,"0"))
            # copy_folder(os.path.join(config["game_path"], "patients",p,e,d), os.path.join(config["game_path"],e,p,d))

In [None]:
################### SPLIT SCENARIOS ####################
for e in eval_modes:
    remove_folder(os.path.join(config["game_path"],e))
    for p in patients:
        dirs = ["train", "val", "test"]
        for d in dirs:
            for f in os.listdir(os.path.join(config["game_path"], "patients",p,e,d)):
                json_file = os.path.join(config["game_path"], "patients",p,e,d,f)
                if os.path.isfile(json_file):
                    with open(json_file, "r") as file:
                        scenario = json.load(file)
                        if len(scenario["subjects"])<3:
                            scenario["subjects"] = list(scenario["question_answers"].keys())
                            json.dump(scenario, open(json_file, "w"), indent=4, sort_keys=True)
            copy_folder(os.path.join(config["game_path"], "patients",p,e,d), os.path.join(config["game_path"],e,p,d))

In [None]:
for e in eval_modes:
    for p in patients:
        for inference in inferences:
            for mode in modes:
                for split in splits:
                    print(f"Patient: {p}, Eval mode: {e}, Inference: {inference}, Mode: {mode}, Split: {split}")
                    evaluate_patient_eval_mode(p, e, inference, mode_=mode,split=split)


In [None]:
import concurrent.futures

# Assuming eval_modes, patients, and inferences are defined lists
# and evaluate_patient_eval_mode is a function you've defined

# Define a wrapper function for your task
def process_task(e, p, inference,mode, split):
    print(f"Patient: {p}, Eval mode: {e}, Inference: {inference}, Mode: {mode}, Split: {split}")
    evaluate_patient_eval_mode(p, e, inference, mode_=mode,split=split)  # Assuming 'modes' is defined elsewhere

# Create a list of all task arguments
tasks = [(e, p, inference,mode,split) for e in eval_modes for p in patients for inference in inferences for mode in modes for split in splits]

# Use ThreadPoolExecutor or ProcessPoolExecutor to run tasks in parallel
# Adjust max_workers based on your system's capabilities and the nature of your tasks
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
    # Submit all tasks to the executor
    future_to_task = {executor.submit(process_task, *task): task for task in tasks}

    # Process the results as they complete (optional)
    for future in concurrent.futures.as_completed(future_to_task):
        task = future_to_task[future]
        try:
            result = future.result()  # You can use the result if your function returns something
        except Exception as exc:
            print(f"Task {task} generated an exception: {exc}")
        else:
            print(f"Task {task} completed successfully.")
