In [None]:
import os
import re
import glob
from enum import Enum
from datetime import datetime
from typing import List, Set, Any
from dataclasses import dataclass
from pprint import pprint
import pandas as pd

In [None]:
class VertexType(str, Enum):
    """Vertex type can only be `V1` or `V2`"""

    V1 = "V1"
    V2 = "V2"

    def __new__(cls, value):
        if value not in ["V1", "V2"]:
            raise ValueError("vertex type can only be `V1` or `V2`")
        return super().__new__(cls, value)

    def getother(self: "VertexType") -> "VertexType":
        """Return the other vertex type"""
        if self == VertexType.V1:
            return VertexType.V2
        elif self == VertexType.V2:
            return VertexType.V1


assert VertexType("V1") == VertexType.V1
assert VertexType("V2").getother() == VertexType.V1


@dataclass
class ClusterItem:
    """class of keeping track of cluster information"""

    community_id: int
    member: Set[Any]
    type: VertexType | None = None  # type of the cluster

    def __hash__(self) -> int:
        return hash(str(self.community_id) + self.type)

    def __repr__(self) -> str:
        if self.type == VertexType.V1:
            return str(self.member)
        else:
            return str(list(self.member)[0:5] + ["..."])


@dataclass
class CoClusterItem:
    """class of keeping track of co-cluster information"""

    cocluster_id: int
    first: ClusterItem
    second: ClusterItem

    def __hash__(self) -> int:
        return hash((self.cocluster_id, self.first, self.second, self.first, self.second))

    def similarity(self, other: "CoClusterItem") -> float:
        """Calculate the similarity between two coclusters"""
        # overlap over union of the first vertex type
        first_overlap = len(self.first.member.intersection(other.first.member))
        first_union = len(self.first.member.union(other.first.member))
        first_similarity = first_overlap / first_union
        # overlap over union of the second vertex type
        second_overlap = len(self.second.member.intersection(other.second.member))
        second_union = len(self.second.member.union(other.second.member))
        second_similarity = second_overlap / second_union
        # return the average of the two similarities
        return (first_similarity + second_similarity) / 2


def test_CoclusterItem():
    """test the similarity function"""
    cluster11 = ClusterItem(1, {"A", "B", "C"})
    cluster12 = ClusterItem(2, {"D", "E", "F"})
    cluster21 = ClusterItem(3, {1, 2, 3})
    cluster22 = ClusterItem(4, {4, 5, 6})
    cocluster1 = CoClusterItem(1, cluster11, cluster21)
    cocluster2 = CoClusterItem(2, cluster12, cluster22)
    assert cocluster1.similarity(cocluster2) == 0.0
    cocluster3 = CoClusterItem(3, cluster11, cluster22)
    assert cocluster1.similarity(cocluster3) == 0.5


test_CoclusterItem()


@dataclass
class CommunityResult:
    """class of keeping track of community detection result"""

    clusters: Set[ClusterItem]
    coclusters: Set[CoClusterItem]


@dataclass
class CommunityResultTime:
    """class of keeping track of community detection result with time"""

    community: CommunityResult
    country: str
    time_from: datetime
    time_to: datetime
    biLouvian_order: int = 1

In [None]:
matching_files = glob.glob("checkpoints/LT_bipartite_*.csv")
matching_files = sorted(matching_files)

file_regex = r"bipartite_(?P<country>\w+)_from(?P<date_from>\d{4}-\d\d-\d\d).*to(?P<date_to>\d{4}-\d\d-\d\d)"
comm_regex = r"^Community (?P<community_id>\d+)\[(?P<vertex_type>V\d+)\]: (?P<vertexes>.*)$"
clus_regex = r"^CoCluster (?P<cocluster_id>\d+):(?P<vertex_type>V\d+)\((?P<a_id>\d+)\)-(?P<b_id>\d+)$"

results: List[CommunityResultTime] = []
for edgelist in matching_files:
    edgelist = os.path.splitext(edgelist)[0]
    print(edgelist)

    matches = re.finditer(file_regex, edgelist, re.MULTILINE)
    for matchNum, match in enumerate(matches, start=1):
        country = match.group("country")
        date_from = match.group("date_from")
        date_to = match.group("date_to")

    community_file = edgelist + "_ResultsCommunities.txt"
    if not os.path.exists(community_file):
        continue
    with open(community_file, "r") as f:
        text = f.read()

    clusters: List[ClusterItem] = []
    matches = re.finditer(comm_regex, text, re.MULTILINE)
    for matchNum, match in enumerate(matches, start=1):
        community_id = match.group("community_id")
        vertex_type = match.group("vertex_type")
        vertexes = match.group("vertexes").split(", ")[0]
        vertexes = vertexes.split(",")
        # if vertex_type == "V1":
        # print(community_id, vertex_type, vertexes)
        cluster = ClusterItem(community_id=int(community_id), member=set(vertexes), type=VertexType(vertex_type))
        clusters.append(cluster)

    with open(f"{edgelist}_ResultsCoClusterCommunities.txt", "r") as f:
        text = f.read()

    coclusters: List[CoClusterItem] = []
    matches = re.finditer(clus_regex, text, re.MULTILINE)
    for matchNum, match in enumerate(matches, start=1):
        cocluster_id = match.group("cocluster_id")
        vertex_type = match.group("vertex_type")
        a_id = match.group("a_id")
        b_id = match.group("b_id")

        # find the cluster with id a_id in clusters
        for cluster in clusters:
            if cluster.community_id == int(a_id):
                a = cluster
                break
        else:
            raise ValueError(f"cannot find cluster with id {a_id}")

        # do the same for b_id
        for cluster in clusters:
            if cluster.community_id == int(b_id):
                b = cluster
                break
        else:
            raise ValueError(f"cannot find cluster with id {b_id}")

        if VertexType(vertex_type) != "V1":
            # swap a and b
            a, b = b, a

        cocluster = CoClusterItem(
            cocluster_id=int(cocluster_id),
            first=a,
            second=b,
        )
        coclusters.append(cocluster)

    comm_result = CommunityResult(
        clusters=set(clusters),
        coclusters=set(coclusters),
    )

    comm_time_result = CommunityResultTime(
        community=comm_result,
        country=country,
        time_from=datetime.strptime(date_from, "%Y-%m-%d"),
        time_to=datetime.strptime(date_to, "%Y-%m-%d"),
    )
    results.append(comm_time_result)

    coclusters_df = pd.DataFrame(coclusters)
    coclusters_df["first_id"] = coclusters_df["first"].apply(lambda x: x["community_id"])
    coclusters_df["second_id"] = coclusters_df["second"].apply(lambda x: x["community_id"])
    coclusters_df = coclusters_df.drop(columns=["first", "second"])
    assert coclusters_df.drop(columns="cocluster_id").duplicated().sum() == len(coclusters_df) / 2
    # print(coclusters)

In [None]:
# drop the results with time_from < 2013-01-01
results = [x for x in results if x.time_from >= datetime(2013, 1, 1)]

In [None]:
pprint(results)

# Calculate similarity

In [None]:
import networkx as nx
import numpy as np

G = nx.DiGraph()

# calculate the similarity between pairs of coclusters in a and b
pos = {}

results = sorted(results, key=lambda x: x.time_from)[-3:-1]
for t, pair in enumerate(zip(results[:-1], results[1:])):
    a, b = pair
    for i, cocluster_a in enumerate(a.community.coclusters):
        G.add_node(cocluster_a, time_from=a.time_from)
        pos[cocluster_a] = np.array([t, i])
        for j, cocluster_b in enumerate(b.community.coclusters):
            G.add_node(cocluster_b, time_from=b.time_from)
            pos[cocluster_b] = np.array([t + 1, j])
            sim = cocluster_a.similarity(cocluster_b)
            if sim > 0.0:
                G.add_edge(cocluster_a, cocluster_b, weight=sim)

In [None]:
import matplotlib.pyplot as plt

# Set the figure size
plt.figure(figsize=(10, 6))

# Draw the graph
nx.draw(G, pos, with_labels=False, node_size=100, node_color="lightblue", edge_color="gray", width=0.5, arrowsize=5)

# Draw the edge labels with larger font size
edge_labels = nx.get_edge_attributes(G, "weight")
edge_labels = {k: round(v, 2) for k, v in edge_labels.items()}
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8)

# Add time_from on top of each column
for t, result in enumerate(results):
    plt.text(t, len(results), result.time_from.strftime("%Y-%m-%d"), ha="center", va="bottom")

# Show the plot
plt.show()

In [None]:
# Remove the attributes 'color' from all edges
for u, v, data in G.edges(data=True):
    if "color" in data:
        del data["color"]

# Iterate over the nodes in the graph
for node in G.nodes:
    # get the incomming edges of the current node
    incoming_edges = G.in_edges(node, data=True)

    # Find the edge with the highest weight
    max_weight = 0
    max_weight_edge = None
    for edge in incoming_edges:
        weight = edge[2]["weight"]
        if weight > max_weight:
            max_weight = weight
            max_weight_edge = edge

    # Highlight the edge with the highest weight
    if max_weight_edge is not None:
        G[max_weight_edge[0]][max_weight_edge[1]]["color"] = "red"

# Set the edge colors based on the 'color' attribute
edge_colors = [G[u][v].get("color", "gray") for u, v in G.edges]

# Set the figure size
plt.figure(figsize=(10, 6))

# Draw the graph with highlighted edges
nx.draw(
    G, pos, with_labels=False, node_size=100, node_color="lightblue", edge_color=edge_colors, width=0.5, arrowsize=5
)

# Draw the edge labels with larger font size
edge_labels = nx.get_edge_attributes(G, "weight")
edge_labels = {k: round(v, 2) for k, v in edge_labels.items()}
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8)

# Add time_from on top of each column
for t, result in enumerate(results):
    plt.text(t, len(results), result.time_from.strftime("%Y-%m-%d"), ha="center", va="bottom")

# Show the plot
plt.show()

In [None]:
# we should find maximum differently