In [None]:
# Enter user inputs for data directory and API key

import ipywidgets as widgets
from IPython.display import display

data_widget = widgets.Text(
    placeholder='Type in the data directory here',
    description='Location:',
    disabled=False   
)

api_widget = widgets.Text(
    placeholder='Type in your API key here',
    description='API Key:',
    disabled=False   
)

display(data_widget, api_widget)

In [None]:
# Load data

import os
import json
from typing import Generator

data_dir = data_widget.value

def load_data(data_dir: str) -> Generator[str, None, None]:
    for filename in os.listdir(data_dir):
        if filename.endswith('.json'):
            with open(os.path.join(data_dir, filename)) as f:
                json_file = json.load(f)
                str_file = f"Title: {json_file['title']['eng']}\n - \nContent:\n{json_file['summary']}"
                yield str_file

docs = load_data(data_dir)


In [None]:
# Create LLM engine

from smolagents import LiteLLMModel
from typing import Dict

base_model = LiteLLMModel(model_id="o3-mini-2025-01-31", api_key=api_widget.value)

Message = Dict[str, str]

In [None]:
# Import agent

# from causal_world_modelling_agent.agents.causal_discovery.atomic_discovery_agent import AtomicDiscoveryAgentFactory
# from causal_world_modelling_agent.agents.causal_discovery.self_iterative_agent import SelfIterativeDiscoveryAgentFactory
from causal_world_modelling_agent.agents.causal_discovery.atomic_rag_agent import AtomicRAGDiscoveryAgentFactory

# discovery_manager = AtomicDiscoveryAgentFactory().createAgent(base_model)
# discovery_manager = SelfIterativeDiscoveryAgentFactory(num_iterations=20, graph_save_path="../data/results/causal_graph_2025-02-26_23-07-34.gml").createAgent(base_model)
discovery_manager = AtomicRAGDiscoveryAgentFactory().createAgent(base_model)

In [None]:
print(discovery_manager.system_prompt)

In [None]:
for i, doc in enumerate(docs):
    causal_graph = discovery_manager.run(doc)
    discovery_manager.tools['graph_retriever'].update_graph(causal_graph)

In [None]:
# Plot causal graph

import networkx as nx
import matplotlib.pyplot as plt
import datetime

# Extract causal graph
causal_graph = discovery_manager.tools['graph_retriever'].get_graph()

# Clear plot
plt.clf()

# Get partitions
partitions = list(nx.algorithms.community.louvain_communities(causal_graph.to_undirected()))
print(f"Number of partitions: {len(partitions)}")

# Create a color map: one color per partition
colors = [plt.cm.tab20(i % 20) for i in range(len(partitions))]

# Create clustered graph layout
supergraph = nx.cycle_graph(len(partitions))
center = list(nx.spring_layout(causal_graph, scale=5, seed=42).values())

pos = {}
for center, partition in zip(center, partitions):
    pos.update(nx.spring_layout(causal_graph.subgraph(partition), center=center))
# pos = nx.spring_layout(causal_graph, seed=42)

# Plot graph
plt.figure(figsize=(12, 12))
nx.draw_networkx_edges(causal_graph, pos, alpha=0.5)

# Draw nodes with the color of their partition
for i, partition in enumerate(partitions):
    nx.draw_networkx_nodes(causal_graph, pos, nodelist=partition, node_color=[colors[i]], node_size=100, alpha=0.7)
    centroid = nx.center(causal_graph.subgraph(partition).to_undirected())
    if len(centroid) > 0:
        plt.text(*pos[centroid[0]], s=centroid[0], 
                 bbox=dict(facecolor='white', alpha=0.5), clip_on=True)
        
# Save graph
current_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
nx.write_gml(causal_graph, f"../data/results/causal_graph_{current_time}.gml", stringizer=str)
plt.savefig(f"../data/results/causal_graph_{current_time}.png", bbox_inches='tight')

In [None]:
from causal_world_modelling_agent.agents.causal_inference.causal_inference_agent import CausalInferenceAgentFactory

inference_agent = CausalInferenceAgentFactory().createAgent(base_model)

In [None]:
import networkx as nx

node_data = {'COVID-19 pandemic severity': {'description': 'Severity of the COVID-19 outbreak affecting market conditions.', 'type': 'string', 'values': ['low', 'moderate', 'high']}, 'Oil price war impact': {'description': 'Impact of the oil price war on economic conditions and market sentiment.', 'type': 'string', 'values': ['low', 'moderate', 'high']}, 'Investor sentiment': {'description': 'Overall market sentiment of investors, which reacts to economic and financial uncertainties.', 'type': 'string', 'values': ['optimistic', 'neutral', 'pessimistic']}, 'FBM KLCI index level': {'description': 'The current level of the FBM KLCI stock index as observed during trading.', 'type': 'float', 'values': 'Real numbers representing index levels'}, 'Government intervention': {'description': 'Government action (or inaction) in terms of fiscal stimulus and liquidity injection.', 'type': 'string', 'values': ['active', 'inactive']}, 'Credit conditions': {'description': 'The state of credit lines and broader credit conditions affecting bank stocks and SMEs.', 'type': 'string', 'values': ['normal', 'stressed']}}
G = nx.DiGraph()
for key, value in node_data.items():
    G.add_node(key, **{"name": key, **value})

edge_data = [('COVID-19 pandemic severity', 'Investor sentiment', {'description': 'Increased COVID-19 severity deteriorates investor sentiment.', 'contextual_information': 'High outbreak risk reduces investor confidence.', 'type': 'direct', 'strength': 'strong', 'confidence': 'high', 'function': None}), ('Oil price war impact', 'Investor sentiment', {'description': 'The oil price war increases economic uncertainty, negatively affecting investor sentiment.', 'contextual_information': 'Heightened oil price conflicts add to the risk environment.', 'type': 'direct', 'strength': 'strong', 'confidence': 'high', 'function': None}), ('Investor sentiment', 'FBM KLCI index level', {'description': 'Negative investor sentiment leads to selling pressure and market decline.', 'contextual_information': 'Continued pessimism in the market has led to lower index levels.', 'type': 'direct', 'strength': 'moderate', 'confidence': 'high', 'function': None}), ('Government intervention', 'Credit conditions', {'description': 'Lack of active government intervention worsens credit conditions.', 'contextual_information': 'Insufficient fiscal measures leave credit lines stressed.', 'type': 'indirect', 'strength': 'moderate', 'confidence': 'medium', 'function': None}), ('Credit conditions', 'Investor sentiment', {'description': 'Stressed credit conditions further dampen investor sentiment.', 'contextual_information': 'Liquidity stress from credit issues heightens market fear.', 'type': 'indirect', 'strength': 'moderate', 'confidence': 'medium', 'function': None}), ('Credit conditions', 'FBM KLCI index level', {'description': 'Poor credit conditions contribute to market declines.', 'contextual_information': 'Deterioration in credit may lead to defaults and further index drops.', 'type': 'indirect', 'strength': 'moderate', 'confidence': 'medium', 'function': None})]
for source, target, attrs in edge_data:
    G.add_edge(source, target, **attrs)


nx.draw(G, with_labels=True, font_weight='bold')

In [None]:
causal_effect, inferred_graph = inference_agent.run(
    task="Compute the causal effect of the variables in the context of the COVID-19 pandemic",
    additional_args={
            "causal_graph": G,
            "target_variable": "FBM KLCI index level",
            "observations": [
                {
                    "name": "COVID-19 pandemic severity",
                    "current_value": "low"
                }
            ]
        }
    )

In [None]:
for key, value in inferred_graph.nodes(data=True):
    print(key)
    for k, v in value.items():
        print(f"\t{k}: {v}")