### Import libraries and define node class

In [19]:
import networkx as nx
import random
import statistics

class ZollmanNode:
    def __init__(self, beliefProb, pEH=None):
        self.beliefProb = beliefProb
        self.pEH = pEH

### Define Function to Display Graph

In [53]:
import plotly.graph_objects as go

def plot_network(graph):
    # Convert NetworkX graph to Plotly graph object
    pos = nx.spring_layout(graph)  # positions for all nodes

    # Create edges
    edge_x = []
    edge_y = []
    for edge in graph.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.append(x0)
        edge_x.append(x1)
        edge_x.append(None)
        edge_y.append(y0)
        edge_y.append(y1)
        edge_y.append(None)

    # Create nodes
    node_x = []
    node_y = []
    node_text = []
    for node in graph.nodes(data=True):
        x, y = pos[node[0]]
        node_x.append(x)
        node_y.append(y)
        node_text.append(node[0])

    # Create node trace
    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers+text',
        text=node_text,
        marker=dict(
            size=10,
            color=[],
        ),
        textposition="top center"
    )

    # Create edge trace
    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        mode='lines',
        line=dict(width=0.5, color='#888'),
        hoverinfo='none'
    )

    # Create figure
    fig = go.Figure(data=[edge_trace, node_trace],
                    layout=go.Layout(
                        title='Interactive Network Visualization',
                        titlefont_size=16,
                        showlegend=False,
                        hovermode='closest',
                        margin=dict(b=20, l=5, r=5, t=40),
                        annotations=[dict(
                            text="",
                            showarrow=False,
                            xref="paper", yref="paper",
                            x=0.005, y=-0.002)],
                        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                    )

    # Display the interactive graph directly in the notebook
    fig.show()

### Define graph initialization function

In [20]:
def initialize_graph(type: str, numberNodes: int = 10):
    graph = nx.Graph()
    if type == 'circle':
        for index in range(numberNodes):
            initial_data = {'group': index % 4, 'beliefProb': random.random(), 'pEH': None}
            graph.add_node(f"scientist-{index}", **initial_data)
            if index > 0:
                graph.add_edge(f"scientist-{index - 1}", f"scientist-{index}")
        graph.add_edge(f"scientist-{numberNodes - 1}", "scientist-0")
    elif type == 'wheel':
        for index in range(numberNodes - 1):
            initial_data = {'group': index % 4, 'beliefProb': random.random(), 'pEH': None}
            graph.add_node(f"scientist-{index}", **initial_data)
            if index > 0:
                graph.add_edge(f"scientist-{index - 1}", f"scientist-{index}")
        graph.add_edge(f"scientist-{numberNodes - 2}", "scientist-0")
        graph.add_node(f"scientist-{numberNodes - 1}", group=1, beliefProb=random.random(), pEH=None)
        for node in graph.nodes:
            if node != f"scientist-{numberNodes - 1}":
                graph.add_edge(node, f"scientist-{numberNodes - 1}")
    else:
        for index in range(numberNodes):
            initial_data = {'group': index % 4, 'beliefProb': random.random(), 'pEH': None}
            graph.add_node(f"scientist-{index}", **initial_data)
        for i in range(numberNodes):
            for j in range(i + 1, numberNodes):
                if i != j:
                    graph.add_edge(f"scientist-{i}", f"scientist-{j}")
    return graph

### Define timestep function and helper functions

In [21]:
def run_experiment(world):
    return random.uniform(0.3, 0.8) if world == 'higher' else random.uniform(0.2, 0.7)

def calculate_posterior(pH, pEH):
    pE = pEH * pH + (1 - pEH) * (1 - pH)
    return (pEH * pH) / pE

def timestep(graph):
    nodes = graph.nodes(data=True)
    for node, data in nodes:
        data['pEH'] = run_experiment('higher') if data['beliefProb'] > 0.5 else None
        data['beliefProb'] = calculate_posterior(data['beliefProb'], data['pEH']) if data['pEH'] is not None else data['beliefProb']

    posteriors = []
    for node, data in nodes:
        neighbors = graph.neighbors(node)
        for neighbor in neighbors:
            neighbor_data = graph.nodes[neighbor]
            if neighbor_data['pEH'] is not None:
                data['beliefProb'] = calculate_posterior(data['beliefProb'], neighbor_data['pEH'])
        data['group'] = sum(data['beliefProb'] > x for x in [0, 0.4, 0.6, 1])
        posteriors.append(data['beliefProb'])
    return posteriors

### Define Function to Run Simulation

In [22]:
def run_simulation(graph, num_timesteps=1):
    median_posteriors = []
    for _ in range(num_timesteps):
        posteriors = timestep(graph)
        median_posterior = statistics.median(posteriors)
        median_posteriors.append(median_posterior if median_posterior is not None else 0)
    return median_posteriors

### Run Simulation and Display Results

In [54]:
# Example usage:
graph = initialize_graph('circle', numberNodes=10)
run_simulation(graph, num_timesteps=5)
plot_network(graph)