In [None]:
#Output Format: Can be csv or parquet
inputFormat = "csv"

# Input path of user vertices
usersInputPath = "abfss://output@onastoremzj5lge2pebi4.dfs.core.windows.net/users_2022-06-01_to_2022-07-01.csv"

# Input path of user to user edges
interactionsInputPath = "abfss://output@onastoremzj5lge2pebi4.dfs.core.windows.net/interactions_2022-06-01_to_2022-07-01.csv"

# Output path for HTML Visualizer
onaVisualizerOutputPath = "abfss://output@onastoremzj5lge2pebi4.dfs.core.windows.net/onaVisualizer_2022-06-01_to_2022-07-01"

In [None]:
# Load data
# Read the user data into a DataFrame based on the input format
try:
    if inputFormat == "csv":
        usersDF = spark.read.csv(usersInputPath, header=True, inferSchema=True)
    elif inputFormat == "parquet":
        df = spark.read.parquet(usersInputPath)
    else:
        raise ValueError(f"Unsupported input format: {inputFormat}")
except (Exception) as error:
    print(error)
    raise Exception("Users data not loaded")

# Read the interactions data into a DataFrame based on the input format
try:
    if inputFormat == "csv":
        graphDF = spark.read.csv(interactionsInputPath, header=True, inferSchema=True)
    elif inputFormat == "parquet":
        graphDF = spark.read.parquet(interactionsInputPath)
    else:
        raise ValueError(f"Unsupported input format: {inputFormat}")
except (Exception) as error:
    print(error)
    raise Exception("Interactions data not loaded")

In [None]:
from pyspark.sql.functions import col, when
import json

In [None]:
graphDF = graphDF.select(col("Source"), col("Target"), col("Interactions"))

usersDF = usersDF.select(col("EmailAddress"), col("Country"), col("Department"), col("Title"), col("x"), col("y"))
usersDF = usersDF.withColumn("Country", when(col("Country").isNull(), "Unknown").otherwise(col("Country")))
usersDF = usersDF.withColumn("Department", when(col("Department").isNull(), "Unknown").otherwise(col("Department")))
usersDF = usersDF.withColumn("Title", when(col("Title").isNull(), "Unknown").otherwise(col("Title")))

users_json_data = json.dumps(usersDF.toJSON().collect())
graph_json_data = json.dumps(graphDF.toJSON().collect())

In [None]:
html_template = f"""<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Network Graph</title>
    <script src="https://unpkg.com/vis-network@9.1.1/dist/vis-network.min.js"></script>
    <script type="text/javascript" src="https://code.jquery.com/jquery-3.6.0.min.js"></script>
    <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/select2/4.0.13/js/select2.min.js"></script>
    <link rel="stylesheet" type="text/css" href="https://cdnjs.cloudflare.com/ajax/libs/select2/4.0.13/css/select2.min.css">
    <style>
        .container {{
            display: flex;
            justify-content: space-between;
            width: 100%;
        }}
        .left-col {{
			flex: 1.25;
			margin: 0 10px;
		}}
		.center-col {{
			flex: 7.75;
			margin: 0 10px;
		}}
		.right-col {{
			flex: 1;
			margin: 0 10px;
		}}
        table {{
            border-collapse: collapse;
            width: 100%;
        }}
        th, td {{
            border: 1px solid black;
            padding: 8px;
            text-align: left;
        }}
        th {{
            background-color: #f2f2f2;
        }}
        #network {{
            width: 100%;
            height: 800px;
            border: 1px solid lightgray;
        }}
    </style>
</head>
<body>
    <div class="container">
        <div class="left-col">
            <div>
				<select id="nodeSelect" style="width: 100%;">
					<option value="">Search for nodes...</option>
				</select>
				<button onclick="displayNodeGraph()">Filter</button>
				<button onclick="clearFilter()">Clear Filter</button>
			</div>
			<div>
				<div id="nodeInfo"></div>
			</div>
        </div>
        <div class="center-col">
            <div id="network"></div>
            <div>
                <button onclick="groupBy('Country')">Group by Country</button>
                <button onclick="groupBy('Department')">Group by Department</button>
                <button onclick="groupBy('Title')">Group by Job Title</button>
            </div>
        </div>
        <div class="right-col">
            <div id="legend"></div>
        </div>
    </div>
    <script>
        const graph_json_data_raw = {graph_json_data};
        const users_json_data_raw = {users_json_data};
        const graph_json_data = graph_json_data_raw.map(JSON.parse);
        const users_json_data = users_json_data_raw.map(JSON.parse);

 		// Create nodes and edges arrays
        const nodes = [];
        const edges = [];
		
		// Function to get or generate a group color
        const groupColors = new Map();
		
		let network;
		let nodesDataset;
		let edgesDataset;
       
        // Function to get the user data by node id
        function getUserData(nodeId) {{
            return users_json_data.find((user) => user.EmailAddress === nodeId);
        }}

        // Function to generate random color
        function getRandomColor() {{
            const letters = '0123456789ABCDEF';
            let color = '#';
            for (let i = 0; i < 6; i++) {{
                color += letters[Math.floor(Math.random() * 16)];
            }}
            return color;
        }}

        // Function to get or generate a group color
        function getGroupColor(groupName) {{
            if (!groupColors.has(groupName)) {{
                groupColors.set(groupName, getRandomColor());
            }}
            return groupColors.get(groupName);
        }}
        
        // Function to update the legend
        function updateLegend(groupType) {{
            const legendDiv = document.getElementById("legend");
            legendDiv.innerHTML = "";

            const groupSet = new Set(nodes.map((node) => getUserData(node.id)[groupType]));

            groupSet.forEach((groupName) => {{
                const legendItem = document.createElement("div");
                legendItem.style.display = "flex";
                legendItem.style.alignItems = "center";
                legendItem.style.marginBottom = "4px";

                const colorBox = document.createElement("div");
                colorBox.style.width = "16px";
                colorBox.style.height = "16px";
                colorBox.style.backgroundColor = getGroupColor(groupName);
                colorBox.style.marginRight = "8px";

                const groupNameText = document.createTextNode(groupName);

                legendItem.appendChild(colorBox);
                legendItem.appendChild(groupNameText);
                legendDiv.appendChild(legendItem);
            }});
        }}

        //display node Information
        function displayNodeInfo(nodeId) {{
			const nodeData = users_json_data.find((data) => data.EmailAddress === nodeId);

			if (!nodeData) {{
				return;
			}}

			const nodeInfoDiv = document.getElementById("nodeInfo");
			const connectedEdges = edges.filter(
				(edge) => edge.from === nodeId || edge.to === nodeId
			);

			const connectedNodesData = connectedEdges.reduce((acc, edge) => {{
				const connectedNodeId = edge.from === nodeId ? edge.to : edge.from;
				const connectedNode = nodes.find((node) => node.id === connectedNodeId);
				const existingData = acc.find((data) => data.nodeId === connectedNode.id);

				if (existingData) {{
					if (edge.from === nodeId) {{
						existingData.messagesSent += parseInt(edge.label);
					}} else {{
						existingData.messagesReceived += parseInt(edge.label);
					}}
				}} else {{
					acc.push({{
						nodeId: connectedNode.id,
						nodeLabel: connectedNode.label,
						messagesSent: edge.from === nodeId ? parseInt(edge.label) : 0,
						messagesReceived: edge.to === nodeId ? parseInt(edge.label) : 0,
					}});
				}}

				return acc;
			}}, []);

			const connectedNodesTable = connectedNodesData.map((entry) => `<tr><td>${{entry.nodeLabel}}</td><td>${{entry.messagesReceived}}</td><td>${{entry.messagesSent}}</td></tr>`).join("");

			nodeInfoDiv.innerHTML = `<h3>User Information</h3><p>Id: ${{nodeId}}</p><p>Department: ${{nodeData.Department || "Unknown"}}</p><p>Country: ${{nodeData.Country || "Unknown"}}</p><p>JobTitle: ${{nodeData.Title || "Unknown"}}</p><h3>Connected Users</h3><table><thead><tr><th>Connected User</th><th>Messages Received</th><th>Messages Sent</th></tr></thead><tbody>${{connectedNodesTable}}</tbody></table>`;
		}}

        // Initialize Select2 and populate it with nodes
        function initializeSelect2() {{
            const nodeSelect = document.getElementById("nodeSelect");
            nodes.forEach((node) => {{
                const option = document.createElement("option");
                option.value = node.id;
                option.textContent = node.label;
                nodeSelect.appendChild(option);
            }});
            $("#nodeSelect").select2({{
				placeholder: "Select a node",
				allowClear: true,
				width: "100%",
			}});
        }}

        // Function to get connected node ids
        function getConnectedNodeIds(nodeId) {{
            const connectedNodeIds = [];
            edges.forEach((edge) => {{
                if (edge.from === nodeId) {{
                    connectedNodeIds.push(edge.to);
                }} else if (edge.to === nodeId) {{
                connectedNodeIds.push(edge.from);
                }}
            }});
            return connectedNodeIds;
        }}

        // Function to display the graph with respect to the selected node
        function displayNodeGraph() {{
            const nodeSelect = document.getElementById("nodeSelect");
            const nodeId = nodeSelect.value;
            if (!nodeId) {{
                return;
            }}
            const connectedNodeIds = getConnectedNodeIds(nodeId);
            connectedNodeIds.push(nodeId);

            nodes.forEach((node) => {{
                if (connectedNodeIds.includes(node.id)) {{
                    node.hidden = false;
                }} else {{
                    node.hidden = true;
                }}
            }});

            nodesDataset.update(nodes);
            displayNodeInfo(nodeId);
        }}

        // Function to clear the filter and display the original graph
        function clearFilter() {{
            nodes.forEach((node) => {{
                node.hidden = false;
            }});
            nodesDataset.update(nodes);

            const nodeSelect = document.getElementById("nodeSelect");
            nodeSelect.value = "";
            $("#nodeSelect").trigger("change"); // Update Select2

            const nodeInfoDiv = document.getElementById("nodeInfo");
            nodeInfoDiv.innerHTML = "";
        }}

        // Function to update node grouping
        function groupBy(groupType) {{
            nodes.forEach((node) => {{
                const userData = getUserData(node.id);
                node.group = userData[groupType];
                node.color = getGroupColor(userData[groupType]);
            }});
            nodesDataset.update(nodes);
            updateLegend(groupType);
        }}

        function createNetworkGraph() {{
            
            // Fill nodes and edges arrays from the dataset
            graph_json_data.forEach((entry) => {{
                if (!nodes.some((node) => node.id === entry.Source)) {{
                    const userData = getUserData(entry.Source);
                    nodes.push({{
                        id: entry.Source,
                        label: entry.Source,
                        group: userData.Country,
                        color: getGroupColor(userData.Country),
                        x: parseFloat(userData.x),
                        y: parseFloat(userData.y),
                        fixed: {{
                            x: true,
                            y: true
                        }}
                    }});
                }}
                if (!nodes.some((node) => node.id === entry.Target)) {{
                    const userData = getUserData(entry.Target);
                    nodes.push({{
                        id: entry.Target,
                        label: entry.Target,
                        group: userData.country,
                        color: getGroupColor(userData.country),
                        x: parseFloat(userData.x),
                        y: parseFloat(userData.y),
                        fixed: {{
                            x: true,
                            y: true
                        }}
                    }});
                }}
                edges.push({{
                    from: entry.Source,
                    to: entry.Target,
                    label: String(entry.Interactions),
                    arrows: 'to',
                }});
            }});

            // Create a vis.DataSet for nodes and edges
            nodesDataset = new vis.DataSet(nodes);
            edgesDataset = new vis.DataSet(edges);

            // Set up options for the network graph
            const options = {{
                nodes: {{
                    shape: 'dot',
                    scaling: {{
                        min: 10,
                        max: 30,
                    }},
                    font: {{
                        size: 12,
                        face: "Tahoma",
                    }}
                }},
                edges: {{
                    width: 0.15,
                    color: {{
                        inherit: "from"
                    }},
                    font: {{
                        size: 12,
                        align: 'middle'
                    }},
                    arrows: {{
                        to: {{ enabled: true, scaleFactor: 0.7 }},
                    }},
                    smooth: {{ enabled: true, type: 'continuous' }}
                }},
                physics: false,
                interaction: {{
                    hover: true,
                    hideEdgesOnDrag: true,
                    tooltipDelay: 200
                }}
            }};

            // Initialize the network graph
            const container = document.getElementById('network');
            const data = {{
                nodes: nodesDataset,
                edges: edgesDataset
            }};
            const network = new vis.Network(container, data, options);
            initializeSelect2();
            groupBy("Country");
        }}
        createNetworkGraph();
 
    </script>
</body>
</html>
"""

html_vis_df = spark.createDataFrame([(html_template,)], ["html_content"])

# Save the dataframe as a Text file
html_vis_df.coalesce(1).write.option("header", False).mode("overwrite").text(onaVisualizerOutputPath)

Path = sc._gateway.jvm.org.apache.hadoop.fs.Path
# get the part file generated by spark write
fs = Path(onaVisualizerOutputPath).getFileSystem(sc._jsc.hadoopConfiguration())
part_file = fs.globStatus(Path(onaVisualizerOutputPath + "/part*"))[0].getPath()
#set final target path
target_path_ona_visualizer = onaVisualizerOutputPath + ".html"
# move and rename the file
fs.delete(Path(target_path_ona_visualizer), True)
fs.rename(part_file, Path(target_path_ona_visualizer))
fs.delete(Path(onaVisualizerOutputPath), True)
