In [None]:
import sys
from loguru import logger
from bs4 import BeautifulSoup
import pandas as pd
import regex as re
import networkx as nx
import matplotlib.pyplot as plt

df = pd.read_csv("relationship_data.csv")
types_df = pd.read_csv("entity_types.csv")

types_df.head()

In [None]:
# Initialize the graph
G = nx.DiGraph()  # Directed graph since relationships are directional

# # Add edges to the graph
# for _, row in df.iterrows():
    
#     G.add_edge(row['entity_1'], 
#                row['entity_2'], 
#                relationship = row['relationship'])

# Add nodes and assign the entity_type property
for entity in set(df['entity_1']).union(set(df['entity_2'])):
    entity_type = types_df.loc[types_df['entity_name'] == entity, 'entity_type'].values
    if len(entity_type) > 0:
        G.add_node(entity, entity_type=entity_type[0])
    else:
        G.add_node(entity, entity_type='Unknown')
        
# Add edges to the graph
for _, row in df.iterrows():
    G.add_edge(row['entity_1'], row['entity_2'], relationship=row['relationship'])
                                                                  
# Initialize visualization
plt.figure(figsize = (12, 8))
pos = nx.kamada_kawai_layout(G)

# Define edge colors based on relationship
edge_colors = {'Acting for or on behalf of': 'blue', 
               'Leader or official of': 'green', 
               'Owned or Controlled By': 'red', 
               'Providing support to': 'purple'}

edge_colors_list = [edge_colors[data['relationship']] for _, _, data in G.edges(data = True)]

# Separate nodes by shape
entity_nodes = [node for node, attr in G.nodes(data=True) if attr['entity_type'] == 'Entity']
individual_nodes = [node for node, attr in G.nodes(data=True) if attr['entity_type'] == 'Individual']

# Draw nodes with specific shapes
nx.draw_networkx_nodes(G, pos, nodelist=entity_nodes, node_shape='d', node_size=500, node_color='skyblue')
nx.draw_networkx_nodes(G, pos, nodelist=individual_nodes, node_shape='o', node_size=500, node_color='skyblue')

nx.draw_networkx_edges(G, 
                       pos, 
                       edge_color = edge_colors_list)

nx.draw_networkx_labels(G, 
                        pos, 
                        font_size = 2, 
                        font_color = 'black')

# Draw edge labels
# edge_labels = nx.get_edge_attributes(G, 'relationship')

# nx.draw_networkx_edge_labels(G, 
                            #  pos, 
                            #  edge_labels = edge_labels)

plt.title('Network Graph')
plt.show()