In [None]:
import IPython
from IPython.display import HTML, Markdown, display
import vertexai
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
app = IPython.Application.instance()
PROJECT_ID = "fit-union-449302-s5"
LOCATION = "us-central1"  # @param {type:"string"}

if not PROJECT_ID or PROJECT_ID == "[your-project-id]":
    raise ValueError("Please set your PROJECT_ID")

vertexai.init(project=PROJECT_ID, location=LOCATION)

from anthropic import AnthropicVertex
from google.auth import default, transport
import logging
import csv
from datetime import datetime
import openai
from vertexai.evaluation import (
    EvalTask,
    MetricPromptTemplateExamples,
    PairwiseMetric,
    PointwiseMetric,
    PointwiseMetricPromptTemplate,
)
from vertexai.generative_models import GenerativeModel
import vertexai

from vertexai.generative_models import GenerativeModel, Part, Image

import logging
import warnings

import pandas as pd
import sys
import os
import random
from tqdm import tqdm

from utils import evaluation

import info


logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR)
warnings.filterwarnings("ignore")

model = GenerativeModel(
    "gemini-1.5-pro-002",
    generation_config={
      # "temperature": 0.6, 
      "max_output_tokens": 256, 
      "top_k": 1},
)

In [20]:
def extract_matrix(text):
  first_index = text.find("```")
  final_index = text.rfind("```")
  # print(text[first_index+3:final_index])
  if first_index == -1 or final_index == -1:
    raise ValueError("Matrix not found in the response")
    return None
  matrix = eval(text[first_index+3:final_index])
  return matrix

def calculate_metrics(ground_truth_matrix, predicted_matrices):
    """
    Calculates average accuracy, precision, recall, F1 score, TPR, FPR, and SHD 
    across multiple predicted matrices compared to the ground truth matrix.

    Args:
        ground_truth_matrix (numpy.ndarray): Ground truth binary matrix.
        predicted_matrices (list of numpy.ndarray): List of predicted binary matrices.

    Returns:
        dict: A dictionary containing average metrics.
    """
    metrics = {
        "accuracy": [],
        "precision": [],
        "f1": [],
        "tpr": [],
        "fpr": [],
        "shd": []
    }

    ground_truth_flat = ground_truth_matrix.flatten()

    for predicted_matrix in predicted_matrices:
        predicted_matrix = np.array(predicted_matrix)
        predicted_flat = predicted_matrix.flatten()

        # Calculate confusion matrix components
        tn, fp, fn, tp = confusion_matrix(ground_truth_flat, predicted_flat, labels=[0, 1]).ravel()

        # Accuracy
        accuracy = (tp + tn) / (tp + tn + fp + fn)
        metrics["accuracy"].append(accuracy)

        # Precision
        precision = precision_score(ground_truth_flat, predicted_flat)
        metrics["precision"].append(precision)

        # Recall (True Positive Rate - TPR)
        recall = recall_score(ground_truth_flat, predicted_flat)
        metrics["tpr"].append(recall)

        # F1 Score
        f1 = f1_score(ground_truth_flat, predicted_flat)
        metrics["f1"].append(f1)

        # False Positive Rate (FPR)
        fpr = fp / (fp + tn) if (fp + tn) > 0 else 0
        metrics["fpr"].append(fpr)

        # Structural Hamming Distance (SHD)
        shd = np.sum(ground_truth_matrix != predicted_matrix)
        metrics["shd"].append(shd)

    # Average metrics
    average_metrics = {key: np.mean(value) for key, value in metrics.items()}
    return average_metrics


def Gemini_infer(setting,):
    # Set up logging
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger("Gemini_infer")
    strategy = setting["strategy"]
    # print(strategy)
    scene_name = setting.get("scene_name", "default")
    csv_file = f"./causal_Gemini_res/{scene_name}/{strategy}.csv"

    # Initialize model with flexible configuration
    model = GenerativeModel(
        setting.get("model_name", "default-model-name"),
        generation_config=setting.get("generation_config", {
          # "temperature": 1, 
          "max_output_tokens": 1000, 
          "top_k": 1}),
    )
    imgs = [Part.from_image(Image.load_from_file(path)) for path in setting["image_path"]]
    if setting["strategy"] == "few_shot":
        # print(setting.keys())
        prompt = imgs + [Part.from_text(
            setting["system_info"] +
            setting["few_shot_examples"] +
            setting["scene_info"] +
            setting["matrix"] +
            setting["matrix_info"]
        )]
        text_prompt = setting["system_info"] + setting["few_shot_examples"] + setting["scene_info"] + setting["matrix"] + setting["matrix_info"]
    else:
      prompt = imgs + [Part.from_text(
          setting["system_info"] +
          setting["scene_info"] +
          setting["matrix"] +
          setting["matrix_info"]
      )]
      text_prompt = setting["system_info"] + setting["scene_info"] + setting["matrix"] + setting["matrix_info"]
    # print(text_prompt)
    
    fleg = True
    while fleg:
      try:
        response = model.generate_content(prompt)
        fleg = False
      except Exception as e:
        # print(e)
        fleg = True

    # try:
    #     response = model.generate_content(prompt)
    # except Exception as e:
    #     logger.error("Model failed to generate content: %s", str(e))
    #     raise
    # print(response.text)
    # print()
    try:
        matrix = extract_matrix(response.text)
    except ValueError as e:
        logger.error("Matrix extraction failed: %s", str(e))
        

    # Log settings, prompt, response, and matrix
    # logger.info("Settings: %s", setting)
    # logger.info("Extracted Matrix: %s", matrix)

    # Get the current timestamp
    current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

    # Write the results to a CSV file
    headers = ["time", "imgs", "text_prompt", "response", "matrix"]
    row = {
        "time": current_time,
        "imgs": str(setting["image_path"]),
        "text_prompt": str(text_prompt),
        "response": response.text,
        "matrix": str(matrix),
    }

    # Ensure the CSV file is created with headers if it doesn't exist
    os.makedirs(os.path.dirname(csv_file), exist_ok=True)
    # print(csv_file)
    try:
        with open(csv_file, mode="a", newline="") as file:
            writer = csv.DictWriter(file, fieldnames=headers)
            if file.tell() == 0:  # Write headers if the file is empty
                writer.writeheader()
            writer.writerow(row)
    except Exception as e:
        logger.error("Failed to write to CSV: %s", str(e))
        raise

    return response.text, matrix
  
def compose_content(dict_info):
  num_of_v = len(dict_info["variables"])
  variables = dict_info["variables"]
  content = ''
  for i,v in enumerate(variables):
    content += f"{i+1}. {variables[v]}\n"
  content = f"There are {num_of_v} variables: \n{content}.\n" 
  content += "Please fill this causality adjacency matrix:\n"
  return content

In [5]:
def get_few_shot_samples(all_scenes,scene_name):
    few_shot_samples=[]
    for key in all_scenes:
        if key!=scene_name:
            #if all_scenes[key] has attribute sample_result
            if "sample_result" in all_scenes[key].keys():

                few_shot_samples.append(all_scenes[key]["sample_result"])

    return few_shot_samples

def get_few_shot_prompt(few_shot_samples,num_samples=3):
    #randomly select num_samples from few_shot_samples, max= 4, or you can add samples into info.py
    few_shot_samples=random.sample(few_shot_samples,num_samples)

    #create prompt
    prompt=""
    i=1
    for sample in few_shot_samples:
        prompt+= f"Example {i}:"
        prompt+=sample
        prompt+="\n"
        i+=1

    prompt= prompt+ "Based on the examples above, answer the following questions:\n"
    return prompt

In [16]:
def get_causal_discovery(scene_name, folder, strategy):
  scene = info.scene()
  scene_info_dict = scene.get_scene(scene_name)

  img_path = folder

  files = os.listdir(img_path)
  files = [img_path + file for file in files]
  results = []
  few_shot_samples = None
  for i in tqdm(range(10)):
    imgs_path = random.sample(files, 10)
    matrix = scene_info_dict['adjacency_matrix']
    matrix = str(matrix).replace("1", "_,").replace("0", "_,").replace("_,]", '_]')
    matrix = matrix.replace("_]", "_],")
    matrix = "```\n" + matrix + "\n```"
    scene_info = compose_content(scene_info_dict)
    matrix_info = ".\nIn the matrix, matrix[i][j] = 1 means variable i causes variable j, matrix[i][j] = 0 means there is not direct causal relationship."
    if strategy == "explicit":
      system_info = "You are a causal discovery expert. Your objective is to analyze the provided images and identify any causal relationships between the variables. Use the identified relationships to complete the causality adjacency matrix and provide a brief explanation supporting your conclusions."
    elif strategy == "basic" or strategy == "few_shot":
      system_info = "Analyze the provided images and identify causal relationships between the variables. Complete the causality adjacency matrix based on the identified relationships and briefly explain your conclusions."
    elif strategy == "CoT":
      system_info = "Analyze the provided images and identify causal relationships between the variables. Let's think step by step and then complete the causality adjacency matrix based on the identified relationships. Based on your thoughts, give a brief explanation of the conclusions."
    
    if strategy == "few_shot":
      all_scenes = scene.get_all_scenes()
      few_shot_samples = get_few_shot_samples(all_scenes, scene_name)
      few_shot_samples = get_few_shot_prompt(few_shot_samples)
      # scene_info = scene_info + few_shot_samples
      # print(scene_info)
    setting = {
    "image_path" : imgs_path,
    "system_info" : system_info,
    "scene_info" : scene_info,
    "matrix" : matrix,
    "matrix_info" : matrix_info,
    "model_name" : "gemini-1.5-pro-002",
    "strategy" : strategy,
    "scene_name": scene_name,
    "few_shot_examples": few_shot_samples
    }

    text, matrix = Gemini_infer(setting)
    results.append(matrix)
    
  average_metrics = calculate_metrics(scene_info_dict['adjacency_matrix'], results)
  print(average_metrics)
  return average_metrics
    




### evaluation

In [56]:
# get_causal_discovery(scene_name = "Magnets", 
#                      folder = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_magnet_v3_256P/Real_magnet_v3/",
#                      strategy = "explicit")

get_causal_discovery(scene_name = "Magnets", 
                     folder = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_magnet_v3_256P/Real_magnet_v3/",
                     strategy = "basic")

Magnets


100%|██████████| 10/10 [00:58<00:00,  5.87s/it]

{'accuracy': 0.875, 'precision': 1.0, 'f1': 0.5, 'tpr': 0.33333333333333337, 'fpr': 0.0, 'shd': 2.0}





{'accuracy': 0.875,
 'precision': 1.0,
 'f1': 0.5,
 'tpr': 0.33333333333333337,
 'fpr': 0.0,
 'shd': 2.0}

In [None]:
path = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_reflection_v2__256P/real_rendered_reflection_256P/"
# get_causal_discovery(scene_name = "Reflection", 
#                      folder = path,
#                      strategy = "basic")

get_causal_discovery(scene_name = "Reflection", 
                     folder = path,
                     strategy = "explicit")

In [129]:
path = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_spring_v3_256P/Real_spring_v3_256P/"
# get_causal_discovery(scene_name = "Spring", 
#                      folder = path,
#                      strategy = "basic")

get_causal_discovery(scene_name = "Spring", 
                     folder = path,
                     strategy = "explicit")

Spring


100%|██████████| 10/10 [01:10<00:00,  7.03s/it]

{'accuracy': 0.7555555555555555, 'precision': 0.45, 'f1': 0.45, 'tpr': 0.45, 'fpr': 0.1571428571428571, 'shd': 2.2}





{'accuracy': 0.7555555555555555,
 'precision': 0.45,
 'f1': 0.45,
 'tpr': 0.45,
 'fpr': 0.1571428571428571,
 'shd': 2.2}

In [123]:
path = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_seesaw_v3_256P/Real_seesaw_v3_256P/"
# get_causal_discovery(scene_name = "Seesaw", 
#                      folder = path,
#                      strategy = "basic")

get_causal_discovery(scene_name = "Seesaw", 
                     folder = path,
                     strategy = "explicit")

Seesaw
['/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_seesaw_v3_256P/Real_seesaw_v3_256P/3902.png', '/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_seesaw_v3_256P/Real_seesaw_v3_256P/7785.png', '/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_seesaw_v3_256P/Real_seesaw_v3_256P/4942.png', '/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_seesaw_v3_256P/Real_seesaw_v3_256P/8945.png', '/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_seesaw_v3_256P/Real_seesaw_v3_256P/6024.png', '/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_seesaw_v3_256P/Real_seesaw_v3_256P/3580.png', '/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_seesaw_v3_256P/Real_seesaw_v3_256P/1532.png', '/home/lds/github/Causality-informed-Generation/code1/databas

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:53<00:00,  5.39s/it]

{'accuracy': 0.7259259259259259, 'precision': 0.4, 'f1': 0.4, 'tpr': 0.4, 'fpr': 0.1809523809523809, 'shd': 2.466666666666667}





{'accuracy': 0.7259259259259259,
 'precision': 0.4,
 'f1': 0.4,
 'tpr': 0.4,
 'fpr': 0.1809523809523809,
 'shd': 2.466666666666667}

In [136]:
path = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_water_flow_v5_256/Water_flow_scene_render/"
get_causal_discovery(scene_name = "Waterflow", 
                     folder = path,
                     strategy = "basic")

get_causal_discovery(scene_name = "Waterflow", 
                     folder = path,
                     strategy = "explicit")

Waterflow


100%|██████████| 10/10 [01:56<00:00, 11.67s/it]


{'accuracy': 0.9, 'precision': 0.75, 'f1': 0.6499999999999999, 'tpr': 0.575, 'fpr': 0.03809523809523809, 'shd': 2.5}
Waterflow


100%|██████████| 10/10 [01:56<00:00, 11.65s/it]

{'accuracy': 0.8880000000000002, 'precision': 0.8166666666666668, 'f1': 0.5495238095238095, 'tpr': 0.425, 'fpr': 0.023809523809523808, 'shd': 2.8}





{'accuracy': 0.8880000000000002,
 'precision': 0.8166666666666668,
 'f1': 0.5495238095238095,
 'tpr': 0.425,
 'fpr': 0.023809523809523808,
 'shd': 2.8}

In [137]:
path = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_parabola_v4_512x256/generated_images/"
get_causal_discovery(scene_name = "Parabola", 
                     folder = path,
                     strategy = "basic")

get_causal_discovery(scene_name = "Parabola", 
                     folder = path,
                     strategy = "explicit")

Parabola


100%|██████████| 10/10 [01:37<00:00,  9.71s/it]


{'accuracy': 0.78125, 'precision': 0.5566666666666666, 'f1': 0.5738095238095238, 'tpr': 0.6, 'fpr': 0.15833333333333335, 'shd': 3.5}
Parabola


100%|██████████| 10/10 [01:34<00:00,  9.47s/it]

{'accuracy': 0.7625, 'precision': 0.5333333333333333, 'f1': 0.5, 'tpr': 0.475, 'fpr': 0.14166666666666666, 'shd': 3.8}





{'accuracy': 0.7625,
 'precision': 0.5333333333333333,
 'f1': 0.5,
 'tpr': 0.475,
 'fpr': 0.14166666666666666,
 'shd': 3.8}

In [139]:
path = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_pendulum_v5_256P/Real_pendulum/"
get_causal_discovery(scene_name = "Pendulum", 
                     folder = path,
                     strategy = "basic")

get_causal_discovery(scene_name = "Pendulum", 
                     folder = path,
                     strategy = "explicit")

Pendulum


100%|██████████| 10/10 [01:36<00:00,  9.68s/it]


{'accuracy': 0.76, 'precision': 0.5572727272727273, 'f1': 0.5164304812834224, 'tpr': 0.5, 'fpr': 0.15789473684210525, 'shd': 6.0}
Pendulum


100%|██████████| 10/10 [01:47<00:00, 10.73s/it]

{'accuracy': 0.752, 'precision': 0.5666666666666667, 'f1': 0.39222222222222214, 'tpr': 0.3166666666666667, 'fpr': 0.11052631578947367, 'shd': 6.2}





{'accuracy': 0.752,
 'precision': 0.5666666666666667,
 'f1': 0.39222222222222214,
 'tpr': 0.3166666666666667,
 'fpr': 0.11052631578947367,
 'shd': 6.2}

In [141]:
path = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_convex_len_v3_512x256/convex_len_render_images/"
# get_causal_discovery(scene_name = "Convex", 
#                      folder = path,
#                      strategy = "basic")

get_causal_discovery(scene_name = "Convex", 
                     folder = path,
                     strategy = "explicit")

Convex


100%|██████████| 10/10 [01:32<00:00,  9.28s/it]

{'accuracy': 0.7777777777777779, 'precision': 0.7666666666666666, 'f1': 0.6266666666666667, 'tpr': 0.5333333333333334, 'fpr': 0.09999999999999999, 'shd': 2.0}





{'accuracy': 0.7777777777777779,
 'precision': 0.7666666666666666,
 'f1': 0.6266666666666667,
 'tpr': 0.5333333333333334,
 'fpr': 0.09999999999999999,
 'shd': 2.0}

---
### hypothetic


In [4]:
path = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v2_linear/"
print("linear")
# get_causal_discovery(scene_name = "V2", 
#                      folder = path,
#                      strategy = "basic")

# get_causal_discovery(scene_name = "V2", 
#                      folder = path,
#                      strategy = "explicit")


print("Nonlinear")
path = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v2_nonlinear/"
# get_causal_discovery(scene_name = "V2", 
#                      folder = path,
#                      strategy = "basic")

get_causal_discovery(scene_name = "V2", 
                     folder = path,
                     strategy = "explicit")

linear
Nonlinear


NameError: name 'get_causal_discovery' is not defined

In [68]:
path = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypo_v3_v_structure_256/"
print("linear")
# get_causal_discovery(scene_name = "V3_V", 
#                      folder = path,
#                      strategy = "basic")

# get_causal_discovery(scene_name = "V3_V", 
#                      folder = path,
#                      strategy = "explicit")


print("Nonlinear")
path = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_V3_nonlinear_vstructure/"
get_causal_discovery(scene_name = "V3_V", 
                     folder = path,
                     strategy = "basic")

# get_causal_discovery(scene_name = "V3_V", 
#                      folder = path,
#                      strategy = "explicit")

linear
Nonlinear
V3_V


100%|██████████| 10/10 [00:37<00:00,  3.79s/it]

{'accuracy': 0.7777777777777778, 'precision': 0.0, 'f1': 0.0, 'tpr': 0.0, 'fpr': 0.0, 'shd': 2.0}





{'accuracy': 0.7777777777777778,
 'precision': 0.0,
 'f1': 0.0,
 'tpr': 0.0,
 'fpr': 0.0,
 'shd': 2.0}

In [None]:
path = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v4_linear_v/"
print("linear")
# get_causal_discovery(scene_name = "V4_V", 
#                      folder = path,
#                      strategy = "basic")

# get_causal_discovery(scene_name = "V4_V", 
#                      folder = path,
#                      strategy = "explicit")


print("Nonlinear")
path = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v4_nonlinear_v/"
# get_causal_discovery(scene_name = "V4_V", 
#                      folder = path,
#                      strategy = "basic")

get_causal_discovery(scene_name = "V4_V", 
                     folder = path,
                     strategy = "explicit")

In [140]:
path = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v5_linear/"
print("linear")
# get_causal_discovery(scene_name = "V5", 
#                      folder = path,
#                      strategy = "basic")

# get_causal_discovery(scene_name = "V5", 
#                      folder = path,
#                      strategy = "explicit")


print("Nonlinear")
path = "//home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_V5_nonlinear/"
# get_causal_discovery(scene_name = "V5", 
#                      folder = path,
#                      strategy = "basic")

get_causal_discovery(scene_name = "V5", 
                     folder = path,
                     strategy = "explicit")

linear
Nonlinear
V5


100%|██████████| 10/10 [00:57<00:00,  5.73s/it]

{'accuracy': 0.8, 'precision': 0.0, 'f1': 0.0, 'tpr': 0.0, 'fpr': 0.0, 'shd': 5.0}





{'accuracy': 0.8,
 'precision': 0.0,
 'f1': 0.0,
 'tpr': 0.0,
 'fpr': 0.0,
 'shd': 5.0}

---


In [None]:
# path = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v3_fully_connected_linear/"
# print("linear")
# get_causal_discovery(scene_name = "V3_F", 
#                      folder = path,
#                      strategy = "basic")

# get_causal_discovery(scene_name = "V3_F", 
#                      folder = path,
#                      strategy = "explicit")


path = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_V4_linear_full_connected/"
print("linear")
# get_causal_discovery(scene_name = "V4_F", 
#                      folder = path,
#                      strategy = "basic")

# get_causal_discovery(scene_name = "V4_F", 
#                      folder = path,
#                      strategy = "explicit")



path = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v5_linear_full_connected/"
print("linear")
# get_causal_discovery(scene_name = "V5_F", 
#                      folder = path,
#                      strategy = "basic")

get_causal_discovery(scene_name = "V5_F", 
                     folder = path,
                     strategy = "explicit")

---

In [None]:
import pandas as pd
path = "/home/lds/github/Causality-informed-Generation/experiment_gemini_api/causal_Gemini_res/V5_F/basic.csv"
data = pd.read_csv(path)
matrix = data['matrix']
matrix = [eval(i) for i in matrix.tolist()]

g = np.array([
            # Rows and columns correspond to the indices in "variables"
          [0, 1, 1, 0, 1], [0, 0, 1, 0,0], [0, 0, 0, 1,1], 
          [0, 0, 0, 0, 1], [0, 0, 0, 0,0]
          ])
average_metrics = calculate_metrics(g, matrix)
average_metrics


---

### CoT

In [None]:
###
get_causal_discovery(scene_name = "Magnets", 
                     folder = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_magnet_v3_256P/Real_magnet_v3/",
                     strategy = "CoT")

In [None]:
get_causal_discovery(scene_name = "Spring", 
                     folder = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_spring_v3_256P/Real_spring_v3_256P/",
                     strategy = "CoT")

In [None]:
get_causal_discovery(scene_name = "Convex", 
                     folder = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_convex_len_v3_512x256/convex_len_render_images/",
                     strategy = "CoT")

In [5]:
get_causal_discovery(scene_name = "Parabola", 
                     folder = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_parabola_v4_512x256/generated_images/",
                     strategy = "CoT")

Parabola


100%|██████████| 10/10 [02:16<00:00, 13.61s/it]

{'accuracy': 0.86875, 'precision': 0.7, 'f1': 0.7158730158730158, 'tpr': 0.75, 'fpr': 0.09166666666666667, 'shd': 2.1}





{'accuracy': 0.86875,
 'precision': 0.7,
 'f1': 0.7158730158730158,
 'tpr': 0.75,
 'fpr': 0.09166666666666667,
 'shd': 2.1}

In [8]:
get_causal_discovery(scene_name = "Seesaw", 
                     folder = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_seesaw_v3_256P/Real_seesaw_v3_256P/",
                     strategy = "CoT")

Seesaw


100%|██████████| 10/10 [01:41<00:00, 10.13s/it]

{'accuracy': 1.0, 'precision': 1.0, 'f1': 1.0, 'tpr': 1.0, 'fpr': 0.0, 'shd': 0.0}





{'accuracy': 1.0,
 'precision': 1.0,
 'f1': 1.0,
 'tpr': 1.0,
 'fpr': 0.0,
 'shd': 0.0}

In [15]:
get_causal_discovery(scene_name = "Pendulum", 
                     folder = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_pendulum_v5_256P/Real_pendulum/",
                     strategy = "CoT")

Pendulum


 70%|███████   | 7/10 [01:22<00:36, 12.07s/it]

429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.


 90%|█████████ | 9/10 [01:46<00:12, 12.05s/it]

429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.


100%|██████████| 10/10 [01:58<00:00, 11.88s/it]

{'accuracy': 0.9120000000000001, 'precision': 0.818095238095238, 'f1': 0.8163170163170163, 'tpr': 0.8166666666666667, 'fpr': 0.05789473684210526, 'shd': 2.2}





{'accuracy': 0.9120000000000001,
 'precision': 0.818095238095238,
 'f1': 0.8163170163170163,
 'tpr': 0.8166666666666667,
 'fpr': 0.05789473684210526,
 'shd': 2.2}

In [16]:
get_causal_discovery(scene_name = "Reflection", 
                     folder = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_reflection_v2__256P/real_rendered_reflection_256P/",
                     strategy = "CoT")

Reflection


 20%|██        | 2/10 [00:14<00:57,  7.14s/it]

429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.


 30%|███       | 3/10 [00:21<00:49,  7.04s/it]

429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again 

 40%|████      | 4/10 [00:33<00:53,  8.93s/it]

429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.


 50%|█████     | 5/10 [00:40<00:42,  8.51s/it]

500 Internal error encountered.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.


 60%|██████    | 6/10 [00:54<00:40, 10.16s/it]

429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again 

 70%|███████   | 7/10 [01:31<00:57, 19.07s/it]

429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again 

 80%|████████  | 8/10 [01:41<00:32, 16.25s/it]

429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.


 90%|█████████ | 9/10 [02:00<00:17, 17.08s/it]

429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.


100%|██████████| 10/10 [02:09<00:00, 12.93s/it]

{'accuracy': 0.925, 'precision': 0.85, 'f1': 0.8999999999999998, 'tpr': 1.0, 'fpr': 0.1, 'shd': 0.3}





{'accuracy': 0.925,
 'precision': 0.85,
 'f1': 0.8999999999999998,
 'tpr': 1.0,
 'fpr': 0.1,
 'shd': 0.3}

In [19]:
get_causal_discovery(scene_name = "Waterflow", 
                     folder = "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_water_flow_v5_256/Water_flow_scene_render/",
                     strategy = "CoT")

Waterflow


  0%|          | 0/10 [00:00<?, ?it/s]

429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.
429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.


 50%|█████     | 5/10 [01:17<01:16, 15.37s/it]

500 Internal error encountered.


 70%|███████   | 7/10 [01:48<00:45, 15.29s/it]

429 Online prediction request quota exceeded for gemini-1.5-pro. Please try again later with backoff.


100%|██████████| 10/10 [02:32<00:00, 15.26s/it]

{'accuracy': 0.96, 'precision': 0.95, 'f1': 0.8595238095238095, 'tpr': 0.8, 'fpr': 0.009523809523809523, 'shd': 1.0}





{'accuracy': 0.96,
 'precision': 0.95,
 'f1': 0.8595238095238095,
 'tpr': 0.8,
 'fpr': 0.009523809523809523,
 'shd': 1.0}

In [4]:
d = {
  # "V3_F":"/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v3_fully_connected_linear/",
  # 'V4_F': "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_V4_linear_full_connected/",
  # 'V5_F': "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v5_linear_full_connected/",
  "V2_nonlinear": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v2_nonlinear/",
  "V3_V_nonlinear": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_V3_nonlinear_vstructure/",
  "V4_V_nonlinear": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v4_nonlinear_v/",
  "V5_nonlinear": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_V5_nonlinear/",
  "V2": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v2_linear/",
  "V3_V": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypo_v3_v_structure_256/",
  "V4_v": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v4_linear_v/",
  "V5": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v5_linear/",
  }

for scene in d:
  # print(scene)
  # if scene == "V2_nonlinear":
  #   folder = d[scene]
  #   scene_name = "V2"
  # else:
    
  get_causal_discovery(scene_name = scene, 
                     folder = d[scene],
                     strategy = "CoT")
  print("----")

berer: V2_nonlinear


100%|██████████| 10/10 [01:15<00:00,  7.57s/it]


{'accuracy': 0.75, 'precision': 0.0, 'f1': 0.0, 'tpr': 0.0, 'fpr': 0.0, 'shd': 1.0}
----
berer: V3_V_nonlinear


100%|██████████| 10/10 [02:55<00:00, 17.55s/it]


{'accuracy': 0.7555555555555555, 'precision': 0.0, 'f1': 0.0, 'tpr': 0.0, 'fpr': 0.02857142857142857, 'shd': 2.2}
----
berer: V4_V_nonlinear


100%|██████████| 10/10 [02:55<00:00, 17.57s/it]


{'accuracy': 0.825, 'precision': 0.2, 'f1': 0.1, 'tpr': 0.06666666666666667, 'fpr': 0.0, 'shd': 2.8}
----
berer: V5_nonlinear


100%|██████████| 10/10 [02:52<00:00, 17.28s/it]


{'accuracy': 0.8, 'precision': 0.0, 'f1': 0.0, 'tpr': 0.0, 'fpr': 0.0, 'shd': 5.0}
----
berer: V2


100%|██████████| 10/10 [01:56<00:00, 11.68s/it]


{'accuracy': 0.75, 'precision': 0.0, 'f1': 0.0, 'tpr': 0.0, 'fpr': 0.0, 'shd': 1.0}
----
berer: V3_V


 30%|███       | 3/10 [00:30<01:11, 10.18s/it]


UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x7af9199db9f0>

In [5]:
d = {
  # "V3_F":"/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v3_fully_connected_linear/",
  # 'V4_F': "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_V4_linear_full_connected/",
  # 'V5_F': "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v5_linear_full_connected/",
  # "V2_nonlinear": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v2_nonlinear/",
  # "V3_V_nonlinear": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_V3_nonlinear_vstructure/",
  # "V4_V_nonlinear": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v4_nonlinear_v/",
  # "V5_nonlinear": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_V5_nonlinear/",
  # "V2": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v2_linear/",
  "V3_V": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypo_v3_v_structure_256/",
  "V4_V": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v4_linear_v/",
  "V5": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v5_linear/",
  }

for scene in d:
  # print(scene)
  # if scene == "V2_nonlinear":
  #   folder = d[scene]
  #   scene_name = "V2"
  # else:
    
  get_causal_discovery(scene_name = scene, 
                     folder = d[scene],
                     strategy = "CoT")
  print("----")

berer: V3_V


100%|██████████| 10/10 [01:30<00:00,  9.00s/it]

{'accuracy': 0.7777777777777778, 'precision': 0.0, 'f1': 0.0, 'tpr': 0.0, 'fpr': 0.0, 'shd': 2.0}
----
berer: V4_v





KeyError: 'V4_v'

In [6]:
d = {
  # "V3_F":"/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v3_fully_connected_linear/",
  # 'V4_F': "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_V4_linear_full_connected/",
  # 'V5_F': "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v5_linear_full_connected/",
  # "V2_nonlinear": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v2_nonlinear/",
  # "V3_V_nonlinear": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_V3_nonlinear_vstructure/",
  # "V4_V_nonlinear": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v4_nonlinear_v/",
  # "V5_nonlinear": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_V5_nonlinear/",
  # "V2": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v2_linear/",
  # "V3_V": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypo_v3_v_structure_256/",
  "V4_V": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v4_linear_v/",
  "V5": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v5_linear/",
  }

for scene in d:
  # print(scene)
  # if scene == "V2_nonlinear":
  #   folder = d[scene]
  #   scene_name = "V2"
  # else:
    
  get_causal_discovery(scene_name = scene, 
                     folder = d[scene],
                     strategy = "CoT")
  print("----")

berer: V4_V


100%|██████████| 10/10 [01:35<00:00,  9.55s/it]


{'accuracy': 0.80625, 'precision': 0.05, 'f1': 0.04, 'tpr': 0.03333333333333333, 'fpr': 0.015384615384615385, 'shd': 3.1}
----
berer: V5


100%|██████████| 10/10 [02:05<00:00, 12.59s/it]

{'accuracy': 0.8, 'precision': 0.0, 'f1': 0.0, 'tpr': 0.0, 'fpr': 0.0, 'shd': 5.0}
----





#### few_shot

In [21]:
d = {
  "V3_F":"/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v3_fully_connected_linear/",
  'V4_F': "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_V4_linear_full_connected/",
  'V5_F': "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v5_linear_full_connected/",
  "V2_nonlinear": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v2_nonlinear/",
  "V3_V_nonlinear": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_V3_nonlinear_vstructure/",
  "V4_V_nonlinear": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v4_nonlinear_v/",
  "V5_nonlinear": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_V5_nonlinear/",
  "V2": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v2_linear/",
  "V3_V": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypo_v3_v_structure_256/",
  "V4_V": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v4_linear_v/",
  "V5": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/hypothetic/data/Hypothetic_v5_linear/",
  }

for scene in d:
  # print(scene)
  # if scene == "V2_nonlinear":
  #   folder = d[scene]
  #   scene_name = "V2"
  # else:
    
  get_causal_discovery(scene_name = scene, 
                     folder = d[scene],
                     strategy = "few_shot")
  print("----")

V3_F


100%|██████████| 10/10 [01:32<00:00,  9.28s/it]


{'accuracy': 0.6666666666666667, 'precision': 0.0, 'f1': 0.0, 'tpr': 0.0, 'fpr': 0.0, 'shd': 3.0}
----
V4_F


100%|██████████| 10/10 [01:24<00:00,  8.43s/it]


{'accuracy': 0.6875, 'precision': 0.0, 'f1': 0.0, 'tpr': 0.0, 'fpr': 0.0, 'shd': 5.0}
----
V5_F


100%|██████████| 10/10 [01:57<00:00, 11.70s/it]


{'accuracy': 0.72, 'precision': 0.0, 'f1': 0.0, 'tpr': 0.0, 'fpr': 0.0, 'shd': 7.0}
----
V2_nonlinear


100%|██████████| 10/10 [02:47<00:00, 16.73s/it]


{'accuracy': 0.75, 'precision': 0.0, 'f1': 0.0, 'tpr': 0.0, 'fpr': 0.0, 'shd': 1.0}
----
V3_V_nonlinear


100%|██████████| 10/10 [02:56<00:00, 17.62s/it]


{'accuracy': 0.7777777777777778, 'precision': 0.0, 'f1': 0.0, 'tpr': 0.0, 'fpr': 0.0, 'shd': 2.0}
----
V4_V_nonlinear


100%|██████████| 10/10 [02:16<00:00, 13.63s/it]


{'accuracy': 0.8125, 'precision': 0.0, 'f1': 0.0, 'tpr': 0.0, 'fpr': 0.0, 'shd': 3.0}
----
V5_nonlinear


100%|██████████| 10/10 [01:35<00:00,  9.51s/it]


{'accuracy': 0.8, 'precision': 0.0, 'f1': 0.0, 'tpr': 0.0, 'fpr': 0.0, 'shd': 5.0}
----
V2


100%|██████████| 10/10 [02:59<00:00, 17.92s/it]


{'accuracy': 0.75, 'precision': 0.0, 'f1': 0.0, 'tpr': 0.0, 'fpr': 0.0, 'shd': 1.0}
----
V3_V


100%|██████████| 10/10 [02:15<00:00, 13.56s/it]


{'accuracy': 0.7777777777777778, 'precision': 0.0, 'f1': 0.0, 'tpr': 0.0, 'fpr': 0.0, 'shd': 2.0}
----
V4_V


100%|██████████| 10/10 [02:12<00:00, 13.24s/it]


{'accuracy': 0.8125, 'precision': 0.0, 'f1': 0.0, 'tpr': 0.0, 'fpr': 0.0, 'shd': 3.0}
----
V5


100%|██████████| 10/10 [01:41<00:00, 10.17s/it]

{'accuracy': 0.8, 'precision': 0.0, 'f1': 0.0, 'tpr': 0.0, 'fpr': 0.0, 'shd': 5.0}
----





In [24]:
real = {
  "Reflection":"/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_reflection_v2__256P/real_rendered_reflection_256P/",
  "Spring":"/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_spring_v3_256P/Real_spring_v3_256P/",
  "Seesaw":"/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_seesaw_v3_256P/Real_seesaw_v3_256P/",
  "Convex":"/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_convex_len_v3_512x256/convex_len_render_images/",
  "Magnets":"/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_magnet_v3_256P/Real_magnet_v3/",
  "Parabola": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_parabola_v4_512x256/generated_images/",
  "Pendulum": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_pendulum_v5_256P/Real_pendulum/",
  "Waterflow": "/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_water_flow_v5_256/Water_flow_scene_render/",
}

for scene in real:
  # print(scene)
  # if scene == "V2_nonlinear":
  #   folder = d[scene]
  #   scene_name = "V2"
  # else:
    
  get_causal_discovery(scene_name = scene, 
                     folder = real[scene],
                     strategy = "few_shot")
  print("----")

Reflection


100%|██████████| 10/10 [01:55<00:00, 11.51s/it]


{'accuracy': 1.0, 'precision': 1.0, 'f1': 1.0, 'tpr': 1.0, 'fpr': 0.0, 'shd': 0.0}
----
Spring


100%|██████████| 10/10 [02:24<00:00, 14.46s/it]


{'accuracy': 0.8444444444444443, 'precision': 0.6666666666666666, 'f1': 0.6799999999999999, 'tpr': 0.7, 'fpr': 0.11428571428571428, 'shd': 1.4}
----
Seesaw


100%|██████████| 10/10 [01:56<00:00, 11.67s/it]


{'accuracy': 1.0, 'precision': 1.0, 'f1': 1.0, 'tpr': 1.0, 'fpr': 0.0, 'shd': 0.0}
----
Convex


100%|██████████| 10/10 [04:00<00:00, 24.04s/it]


{'accuracy': 0.9111111111111111, 'precision': 0.95, 'f1': 0.8514285714285714, 'tpr': 0.8, 'fpr': 0.03333333333333333, 'shd': 0.8}
----
Magnets


100%|██████████| 10/10 [02:41<00:00, 16.17s/it]


{'accuracy': 0.8375, 'precision': 0.7999999999999999, 'f1': 0.45, 'tpr': 0.33333333333333337, 'fpr': 0.046153846153846156, 'shd': 2.6}
----
Parabola


100%|██████████| 10/10 [02:30<00:00, 15.08s/it]


{'accuracy': 0.8, 'precision': 0.6333333333333334, 'f1': 0.5428571428571428, 'tpr': 0.475, 'fpr': 0.09166666666666667, 'shd': 3.2}
----
Pendulum


100%|██████████| 10/10 [02:40<00:00, 16.08s/it]


{'accuracy': 0.8320000000000001, 'precision': 0.7200000000000001, 'f1': 0.5575757575757576, 'tpr': 0.4666666666666666, 'fpr': 0.05263157894736842, 'shd': 4.2}
----
Waterflow


100%|██████████| 10/10 [02:50<00:00, 17.09s/it]

{'accuracy': 0.9559999999999998, 'precision': 1.0, 'f1': 0.838095238095238, 'tpr': 0.725, 'fpr': 0.0, 'shd': 1.1}
----



