# Social Network Simulation Demo

This notebook demonstrates the basic functionality of the social network simulation framework using reinforcement learning.

In [None]:
import sys
sys.path.append('..')

import numpy as np
import networkx as nx
from src.environment.social_network_env import SocialNetworkEnv
from src.agents.dqn_agent import DQNAgent
from src.training.train import train_network
from src.visualization.network_visualizer import NetworkVisualizer
from src.visualization.propagation_graphs import PropagationVisualizer

## 1. Initialize Environment and Agent

In [None]:
# Create environment
env_config = {
    'num_agents': 50,
    'initial_connections': 3,
    'max_connections': 10,
    'personality_dim': 5
}
env = SocialNetworkEnv(**env_config)

# Create DQN agent
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQNAgent(state_dim, action_dim)

# Initialize visualizers
network_vis = NetworkVisualizer()
prop_vis = PropagationVisualizer()

## 2. Visualize Initial Network Structure

In [None]:
# Get initial network state
initial_network = env.network

# Visualize initial network
network_vis.visualize_network(
    initial_network,
    title="Initial Network Structure"
)

## 3. Train the Agent

In [None]:
# Training parameters
training_params = {
    'num_episodes': 500,
    'max_steps_per_episode': 100,
    'eval_frequency': 50
}

# Train the network
training_stats = train_network(env, agent, **training_params)

## 4. Analyze Training Results

In [None]:
import matplotlib.pyplot as plt

# Plot training rewards
plt.figure(figsize=(10, 5))
plt.plot(training_stats['episode_rewards'])
plt.title('Training Rewards over Episodes')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.grid(True)
plt.show()

# Plot network metrics
metrics = training_stats['network_metrics']
plt.figure(figsize=(12, 6))
for key in metrics[0].keys():
    values = [m[key] for m in metrics]
    plt.plot(values, label=key)
plt.title('Network Metrics Evolution')
plt.xlabel('Evaluation Step')
plt.legend()
plt.grid(True)
plt.show()