In [1]:
import sys
import os
import random
from tqdm import tqdm
sys.path.append('/home/lds/github/Causality-informed-Generation/inference/evaluation')

from utils import info
from utils import evaluation

PROJECT_ID = "mimetic-kit-445917-d8"
LOCATION = "us-central1"  # @param {type:"string"}

if not PROJECT_ID or PROJECT_ID == "[your-project-id]":
    raise ValueError("Please set your PROJECT_ID")

 
import logging
import csv
from datetime import datetime
import vertexai

vertexai.init(project=PROJECT_ID, location=LOCATION)

from anthropic import AnthropicVertex
from google.auth import default, transport
import openai
from vertexai.evaluation import (
    EvalTask,
    MetricPromptTemplateExamples,
    PairwiseMetric,
    PointwiseMetric,
    PointwiseMetricPromptTemplate,
)
from vertexai.generative_models import GenerativeModel
import vertexai

from vertexai.generative_models import GenerativeModel, Part, Image

model = GenerativeModel(
    "gemini-1.5-pro-002",
    generation_config={
      "temperature": 1, 
                       "max_output_tokens": 500, "top_k": 1},
)

In [2]:
def extract_matrix(text):
  first_index = text.find("```")
  final_index = text.rfind("```")
  print(text[first_index+3:final_index])
  if first_index == -1 or final_index == -1:
    raise ValueError("Matrix not found in the response")
    return None
  matrix = eval(text[first_index+3:final_index])
  return matrix


def Gemini_infer(setting):
    # Set up logging
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger("Gemini_infer")

    # Initialize model with flexible configuration
    model = GenerativeModel(
        setting.get("model_name", "default-model-name"),
        generation_config=setting.get("generation_config", {"temperature": 0.1, "max_output_tokens": 1000, "top_k": 1}),
    )
    csv_file =f"gemini_inference_log_{setting['scene_name']}_{setting['strategy']}.csv"
    # Load images
    imgs = [Part.from_image(Image.load_from_file(path)) for path in setting["image_path"]]
    # logger.info("Loaded images: %s", setting["image_path"])

    # Build the prompt
    prompt = imgs + [Part.from_text(
        setting["system_info"] +
        setting["scene_info"] +
        setting["matrix"] +
        setting["matrix_info"]
    )]
    text_prompt = setting["system_info"] + setting["scene_info"] + setting["matrix"] + setting["matrix_info"]
    # Log the prompt for debugging
    # logger.info("Generated Prompt: %s", prompt)

    # Generate response
    try:
        response = model.generate_content(prompt)
    except Exception as e:
        logger.error("Model failed to generate content: %s", str(e))
        raise

    # Log the response for debugging
    # logger.info("Model Response: %s", response.text)

    # Extract matrix
    try:
        matrix = extract_matrix(response.text)
    except ValueError as e:
        logger.error("Matrix extraction failed: %s", str(e))
        raise

    # Log settings, prompt, response, and matrix
    # logger.info("Settings: %s", setting)
    # logger.info("Extracted Matrix: %s", matrix)

    # Get the current timestamp
    current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

    # Write the results to a CSV file
    headers = ["time", "strategy", "imgs", "text_prompt", "response", "matrix"]
    row = {
        "time": current_time,
        "strategy": setting["strategy"],
        "imgs": str(setting["image_path"]),
        "text_prompt": str(text_prompt),
        "response": response.text,
        "matrix": str(matrix),
    }

    # Ensure the CSV file is created with headers if it doesn't exist
    try:
        with open(csv_file, mode="a", newline="") as file:
            writer = csv.DictWriter(file, fieldnames=headers)
            if file.tell() == 0:  # Write headers if the file is empty
                writer.writeheader()
            writer.writerow(row)
    except Exception as e:
        logger.error("Failed to write to CSV: %s", str(e))
        raise

    return response.text, matrix
  
def compose_content(dict_info):
  num_of_v = len(dict_info["variables"])
  variables = dict_info["variables"]
  content = ''
  for i,v in enumerate(variables):
    content += f"{i+1}. {variables[v]}\n"
  content = f"There are {num_of_v} variables: \n{content}.\n" 
  content += "Please fill this causality adjacency matrix:\n"
  return content


import numpy as np
def cal_TPR_between_matrix(ground_truth_matrix1, matrix2):
    """
    Calculate the True Positive Rate between two adjacency matrix
    """
    # Check if the matrices are the same size
    if ground_truth_matrix1.shape != matrix2.shape:
        raise ValueError("Matrices must have the same shape")
      
    # Calculate the True Positive Rate
    TP = np.sum((ground_truth_matrix1 == 1) & (matrix2 == 1))
    FN = np.sum((ground_truth_matrix1 == 1) & (matrix2 == 0))
    TPR = TP / (TP + FN)
    
    return TPR
  
def cal_FPR_between_matrix(ground_truth_matrix1, matrix2):
    """
    Calculate the False Positive Rate between two adjacency matrix
    """
    # Check if the matrices are the same size
    if ground_truth_matrix1.shape != matrix2.shape:
        raise ValueError("Matrices must have the same shape")
      
    # Calculate the True Positive Rate
    FP = np.sum((ground_truth_matrix1 == 0) & (matrix2 == 1))
    TN = np.sum((ground_truth_matrix1 == 0) & (matrix2 == 0))
    FPR = FP / (FP + TN)
    
    return FPR
  
def cal_SHD_between_matrix(ground_truth_matrix1, matrix2):
    """
    Calculate the Structural Hamming Distance between two adjacency matrix
    """
    # Check if the matrices are the same size
    if ground_truth_matrix1.shape != matrix2.shape:
        raise ValueError("Matrices must have the same shape")
      
    # Calculate the Structural Hamming Distance
    SHD = np.sum(ground_truth_matrix1 != matrix2)
    
    return SHD
  
def cal_Accuarcy_between_matrix(ground_truth_matrix1, matrix2):
    """
    Calculate the Accuarcy between two adjacency matrix
    """
    # Check if the matrices are the same size
    if ground_truth_matrix1.shape != matrix2.shape:
        raise ValueError("Matrices must have the same shape")
      
    # Calculate the Structural Hamming Distance
    Accuarcy = np.sum(ground_truth_matrix1 == matrix2) / ground_truth_matrix1.size
    
    return Accuarcy
  
def cal_Precision_between_matrix(ground_truth_matrix1, matrix2):
    """
    Calculate the Precision between two adjacency matrix
    """
    # Check if the matrices are the same size
    if ground_truth_matrix1.shape != matrix2.shape:
        raise ValueError("Matrices must have the same shape")
      
    # Calculate the Structural Hamming Distance
    TP = np.sum((ground_truth_matrix1 == 1) & (matrix2 == 1))
    FP = np.sum((ground_truth_matrix1 == 0) & (matrix2 == 1))
    if (TP + FP) == 0:
      return 0
    Precision = TP / (TP + FP)
    
    return Precision
  
def cal_Recall_between_matrix(ground_truth_matrix1, matrix2):
    """
    Calculate the Recall between two adjacency matrix
    """
    # Check if the matrices are the same size
    if ground_truth_matrix1.shape != matrix2.shape:
        raise ValueError("Matrices must have the same shape")
      
    # Calculate the Structural Hamming Distance
    TP = np.sum((ground_truth_matrix1 == 1) & (matrix2 == 1))
    FN = np.sum((ground_truth_matrix1 == 1) & (matrix2 == 0))
    Recall = TP / (TP + FN)
    
    return Recall
  
def cal_F1_between_matrix(ground_truth_matrix1, matrix2):
    """
    Calculate the F1 between two adjacency matrix
    """
    # Check if the matrices are the same size
    if ground_truth_matrix1.shape != matrix2.shape:
        raise ValueError("Matrices must have the same shape")
      
    # Calculate the Structural Hamming Distance
    TP = np.sum((ground_truth_matrix1 == 1) & (matrix2 == 1))
    FP = np.sum((ground_truth_matrix1 == 0) & (matrix2 == 1))
    FN = np.sum((ground_truth_matrix1 == 1) & (matrix2 == 0))

    if (TP + FP) == 0:
      return 0
    Precision = TP / (TP + FP)    

    if (TP + FN) == 0:
      return 0
    Recall = TP / (TP + FN)
    F1 = 2 * Precision * Recall / (Precision + Recall)
    if (Precision + Recall) == 0:
      return 0
    
    return F1
  

In [3]:
def inference(image_path = "/home/lds/github/Causality-informed-Generation/code1/database/Real_magnet_v3/",
              scene_name = "Magnets",
              strategy = "basic"):
  magnet_img_path = image_path
  scene = info.scene()
  scene_info_dict = scene.get_scene(scene_name)
  # print(scene_info_dict)

  files = os.listdir(magnet_img_path)
  files = [magnet_img_path + file for file in files]
  scene_info = compose_content(scene_info_dict)
  # print(scene_info)
  
  
  matrix_info = ".\nMatrix[i][j] = 1: Variable i directly causes variable j. Matrix[i][j] = 0: There is no direct causal relationship between variable i and variable j."
  if strategy == "explicit":
    system_info = "You are a causal discovery expert. Your objective is to analyze the provided images and identify any causal relationships between the variables. Use the identified relationships to complete the causality adjacency matrix and provide a brief explanation supporting your conclusions."
  elif strategy == "basic":
    system_info = "Analyze the provided images and identify causal relationships between the variables. Complete the causality adjacency matrix based on the identified relationships and briefly explain your conclusions."
  for i in tqdm(range(10)):
    imgs_path = random.sample(files, 10)
    matrix = scene_info_dict['adjacency_matrix']
    matrix = str(matrix).replace("1", "_,").replace("0", "_,").replace("_,]", '_]')
    
    matrix = (matrix).replace("_]", "_],")
    
    
    setting = {
    "image_path" : imgs_path,
    "system_info" : system_info,
    "scene_info" : scene_info,
    "matrix" : matrix,
    "matrix_info" : matrix_info,
    "model_name" : "gemini-1.5-pro-002",
    "scene_name" : scene_name,
    "strategy" : strategy
    
  }

    text, matrix = Gemini_infer(setting)


# 
stragegies = ["explicit"]

for strategy in stragegies:
  
  inference(image_path="/home/lds/github/Causality-informed-Generation/code1/database/final_dataset/real/Real_magnet_v3_256P/Real_magnet_v3/",
          strategy=strategy,)



  0%|          | 0/10 [00:00<?, ?it/s]

 10%|█         | 1/10 [00:08<01:12,  8.07s/it]


[[0, 0, 0, 0],
 [0, 0, 0, 0],
 [0, 0, 0, 1],
 [0, 0, 0, 0]]



 20%|██        | 2/10 [00:15<01:03,  7.97s/it]


[[0, 0, 0, 0],
 [0, 0, 0, 0],
 [0, 0, 0, 1],
 [0, 0, 0, 0]]



 30%|███       | 3/10 [00:23<00:53,  7.60s/it]


[[0, 0, 0, 0],
 [0, 0, 0, 0],
 [0, 0, 0, 1],
 [0, 0, 0, 0]]



 40%|████      | 4/10 [00:30<00:45,  7.54s/it]


[[0, 0, 0, 0],
 [0, 0, 0, 0],
 [0, 0, 0, 1],
 [0, 0, 0, 0]]



 50%|█████     | 5/10 [00:39<00:40,  8.19s/it]


[[0, 0, 0, 0],
 [0, 0, 0, 0],
 [0, 0, 0, 1],
 [0, 0, 0, 0]]



 60%|██████    | 6/10 [00:47<00:32,  8.07s/it]


[[0, 0, 0, 0],
 [0, 0, 0, 0],
 [0, 0, 0, 1],
 [0, 0, 0, 0]]



 70%|███████   | 7/10 [00:54<00:22,  7.56s/it]


[[0, 0, 0, 0],
 [0, 0, 0, 0],
 [0, 0, 0, 1],
 [0, 0, 0, 0]]



 80%|████████  | 8/10 [01:00<00:14,  7.17s/it]


[[0, 0, 0, 0],
 [0, 0, 0, 0],
 [0, 0, 0, 1],
 [0, 0, 0, 0]]



 90%|█████████ | 9/10 [01:08<00:07,  7.48s/it]


[[0, 0, 0, 0],
 [0, 0, 0, 0],
 [0, 0, 0, 1],
 [0, 0, 0, 0]]



100%|██████████| 10/10 [01:16<00:00,  7.62s/it]


[[0, 0, 0, 0],
 [0, 0, 0, 0],
 [0, 0, 0, 1],
 [0, 0, 0, 0]]




