In [1]:
import os
import sys

# Ensure Python finds the `src` directory
sys.path.append(os.path.abspath("../../src"))

# Import everything defined in `__all__`
from causalalign.dataset_creation import (
    expand_df_by_task_queries,
    # generate_prompt_dataframe,
    expand_domain_to_dataframe,
    graph_structures,
    inference_tasks_rw17,
    # create_domain_dict,
    verbalize_domain_intro,
    # verbalize_causal_mechanism,
    verbalize_inference_task,
    # append_dfs,
)

print("Dataset creation module imported successfully!")


Dataset creation module imported successfully!


In [2]:
def create_domain_dict(
    domain_name, introduction, variables_config, graph_type="collider"
):
    """
    Create a domain dictionary with full support for explanations and counterbalance conditions.

    Parameters:
    -----------
    domain_name : str
        Domain name (e.g., "economy", "sociology")

    introduction : str
        Domain introduction text

    variables_config : dict
        Complete configuration for all variables (X, Y, Z):
        {
            "X": {
                "name": "interest rates",
                "detailed": "Interest rates are the rates banks charge...",
                "p_value": {"1": "low", "0": "normal"},
                "m_value": {"1": "high", "0": "normal"},
                "explanations": {
                    "p_p": "Low interest rates stimulate economic growth...",
                    "p_m": "The good economic times produced by...",
                    "m_p": "The high interest rates result in high yields...",
                    "m_m": "A lot of people are making large monthly interest..."
                }
            },
            "Y": {...},
            "Z": {...}
        }

    graph_type : str
        Type of causal graph (collider, fork, chain)

    Returns:
    --------
    dict
        Domain dictionary in the required format
    """
    # Start with basic structure
    domain_dict = {
        "domain_name": domain_name,
        "introduction": introduction,
        "variables": {},
        "graph_type": graph_type,
    }

    # Process each variable
    for var_key, config in variables_config.items():
        # Validate required fields
        required_fields = ["name", "detailed", "p_value"]
        for field in required_fields:
            if field not in config:
                raise ValueError(
                    f"Missing required field '{field}' for variable {var_key}"
                )

        # Create variable entry
        var_entry = {
            f"{var_key}_name": config["name"],
            f"{var_key}_detailed": config["detailed"],
            "p_value": config["p_value"].copy(),
        }

        # Add m_value if provided, otherwise use opposite of p_value
        if "m_value" in config:
            var_entry["m_value"] = config["m_value"].copy()
        else:
            # Default behavior: swap 0/1 values from p_value
            var_entry["m_value"] = {
                "1": config["p_value"]["0"],
                "0": config["p_value"]["1"],
            }

        # Add explanations if provided
        if "explanations" in config:
            var_entry["explanations"] = config["explanations"].copy()

        # Add to domain_dict
        domain_dict["variables"][var_key] = var_entry

    return domain_dict

In [3]:
def verbalize_variables_section(domain_dict, row):
    """
    Create a comprehensive description of all variables based on counterbalance conditions.

    Parameters:
    -----------
    domain_dict : dict
        The domain dictionary containing variable information
    row : pd.Series
        Row from the expanded dataframe with counterbalance information

    Returns:
    --------
    str
        Formatted text describing all variables with their counterbalanced values
    """
    domain_name = domain_dict["domain_name"].upper()
    variables_text = " "  # f"\n{domain_name}\n********** VARIABLES **********\n"

    for var_key, var_details in domain_dict["variables"].items():
        name = var_details[f"{var_key}_name"]
        detailed = var_details[f"{var_key}_detailed"]

        # Get appropriate value based on counterbalance
        cntbl = row[f"{var_key}_cntbl"]
        value_dict = var_details["p_value"] if cntbl == "p" else var_details["m_value"]

        # Create description with on/off values
        value_1 = value_dict["1"]
        value_0 = value_dict["0"]

        # Format the description
        variables_text += f"{detailed} Some economies have {value_1} {name}. Others have {value_0} {name}. "

    return variables_text

In [4]:
def verbalize_causal_mechanism(domain_dict, row, graph_type, graph_structures):
    """
    Create a detailed description of causal relationships with explanations.

    Parameters:
    -----------
    domain_dict : dict
        The domain dictionary
    row : pd.Series
        Row from the expanded dataframe with counterbalance information
    graph_type : str
        Type of causal graph (collider, fork, chain)
    graph_structures : dict
        Graph structure templates

    Returns:
    --------
    str
        Formatted text describing causal relationships with explanations
    """
    if graph_type not in graph_structures:
        return ""

    domain_name = domain_dict["domain_name"].upper()
    causal_text = (
        " "  # f"\n{domain_name}\n********** CAUSAL RELATIONSHIPS **********\n"
    )

    # Get variables and their counterbalance conditions
    x_cntbl = row["X_cntbl"]
    y_cntbl = row["Y_cntbl"]
    z_cntbl = row["Z_cntbl"]

    # Get variable senses
    x_sense = row["X_sense"]
    y_sense = row["Y_sense"]
    z_sense = row["Z_sense"]

    # Get variable names
    x_name = row["X"]
    y_name = row["Y"]
    z_name = row["Z"]

    # Build causal relationships based on graph type
    if graph_type == "collider":
        # X → Z relationship with explanation
        x_z_relation = f"{x_sense} {x_name} causes {z_sense} {z_name}. "

        # Get X → Z explanation
        x_z_key = f"{x_cntbl}_{z_cntbl}"
        x_z_explanation = ""
        if (
            "explanations" in domain_dict["variables"]["X"]
            and x_z_key in domain_dict["variables"]["X"]["explanations"]
        ):
            x_z_explanation = (
                " " + domain_dict["variables"]["X"]["explanations"][x_z_key]
            )

        # Y → Z relationship with explanation
        y_z_relation = f"{y_sense} {y_name} causes {z_sense} {z_name}."

        # Get Y → Z explanation
        y_z_key = f"{y_cntbl}_{z_cntbl}"
        y_z_explanation = ""
        if (
            "explanations" in domain_dict["variables"]["Y"]
            and y_z_key in domain_dict["variables"]["Y"]["explanations"]
        ):
            y_z_explanation = (
                " " + domain_dict["variables"]["Y"]["explanations"][y_z_key]
            )

        # Combine relationships and explanations
        causal_text += f"{x_z_relation}{x_z_explanation} "
        causal_text += f"{y_z_relation}{y_z_explanation} "

    # Add similar handling for fork and chain graphs if needed

    return causal_text

In [5]:
def generate_prompt_dataframe(
    domain_dict,
    inference_tasks,
    graph_type,
    graph_structures,
    prompt_type="Please provide only a numeric response and no additional information",
    prompt_category="single_numeric_response",
    counterbalance_enabled=True,
):
    """
    Expand the DataFrame to include full prompt verbalization and graph structure.
    """
    # Expand the domain dictionary into a DataFrame
    df = expand_df_by_task_queries(
        expand_domain_to_dataframe(domain_dict), inference_tasks
    )

    # Extract the domain introduction text
    domain_intro = verbalize_domain_intro(domain_dict)

    # Add graph type to the DataFrame
    df["graph"] = graph_type

    # Generate the full prompt by combining all sections
    df["prompt"] = df.apply(
        lambda row: domain_intro
        + verbalize_variables_section(domain_dict, row)
        + verbalize_causal_mechanism(domain_dict, row, graph_type, graph_structures)
        + verbalize_inference_task(
            row, nested_dict=domain_dict, prompt_type=prompt_type
        ),
        axis=1,
    )

    df["prompt_category"] = prompt_category
    return df

In [6]:
# Define the economy domain configuration
economy_config = {
    "X": {
        "name": "interest rates",
        "detailed": "Interest rates are the rates banks charge to loan money.",
        "p_value": {"1": "low", "0": "normal"},
        "m_value": {"1": "high", "0": "normal"},
        "explanations": {
            "p_p": "Low interest rates stimulate economic growth, leading to greater prosperity overall, and allowing more money to be saved for retirement in particular.",
            "p_m": "The good economic times produced by the low interest rates leads to greater confidence and less worry about the future, so people are less concerned about retirement.",
            "m_p": "The high interest rates result in high yields on government bonds, which are especially attractive for retirement savings because they are such a safe investment.",
            "m_m": "A lot of people are making large monthly interest payments on credit card debt, and they have no money left to save for retirement.",
        },
    },
    "Y": {
        "name": "trade deficits",
        "detailed": "A country's trade deficit is the difference between the value of the goods that a country imports and the value of the goods that a country exports.",
        "p_value": {"1": "small", "0": "normal"},
        "m_value": {"1": "large", "0": "normal"},
        "explanations": {
            "p_p": "When the economy is good, people can cover their basic expenses and so have enough money left over to contribute to their retirement accounts.",
            "p_m": "When the economy is good, people are optimistic and so spend rather than save.",
            "m_p": "People become nervous when their economy is no longer competitive enough in the world economy to export products, and begin saving for retirement as a result.",
            "m_m": "The loss of local manufacturing jobs means that there are people out of work, and contributions to retirement accounts decreases.",
        },
    },
    "Z": {
        "name": "retirement savings",
        "detailed": "Retirement savings is the money people save for their retirement.",
        "p_value": {"1": "high", "0": "normal"},
        "m_value": {"1": "low", "0": "normal"},
    },
}

# Create the domain dictionary
economy_domain = create_domain_dict(
    domain_name="economy",
    introduction="Economists seek to describe and predict the regular patterns of economic fluctuation. To do this, they study some important variables or attributes of economies. They also study how these attributes are responsible for producing or causing one another.",
    variables_config=economy_config,
    graph_type="collider",
)

# Generate prompts
economy_prompts_df = generate_prompt_dataframe(
    economy_domain, inference_tasks_rw17, "collider", graph_structures
)

In [8]:
economy_prompts_df.shape

(160, 24)

In [9]:
# now subset the dataframe for all rows where cntbl_cond is ppp
economy_prompts_df[["cntbl_cond"]].value_counts().sort_index()

cntbl_cond
mmm           20
mmp           20
mpm           20
mpp           20
pmm           20
pmp           20
ppm           20
ppp           20
Name: count, dtype: int64

In [19]:
# get the first row of the datafrmae where cntbl_cond is ppp
# print the entire cell value
economy_prompts_df[economy_prompts_df["cntbl_cond"] == "ppp"][["prompt"]].iloc[0, 0]


"Economists seek to describe and predict the regular patterns of economic fluctuation. To do this, they study some important variables or attributes of economies. They also study how these attributes are responsible for producing or causing one another. Interest rates are the rates banks charge to loan money. Some economies have low interest rates. Others have normal interest rates. A country's trade deficit is the difference between the value of the goods that a country imports and the value of the goods that a country exports. Some economies have small trade deficits. Others have normal trade deficits. Retirement savings is the money people save for their retirement. Some economies have high retirement savings. Others have normal retirement savings.  low interest rates causes high retirement savings.  Low interest rates stimulate economic growth, leading to greater prosperity overall, and allowing more money to be saved for retirement in particular. small trade deficits causes high r