In [None]:
from manim import *
import pandas as pd
import networkx as nx
import numpy as np

class MentalHealthNetworkAnimation(Scene):
    def construct(self):
        df_net = pd.read_csv("Mental_Health_Full_Network_Evolution_Expanded.csv")
        df_katz = pd.read_csv("katz_prestige_by_act_version.csv")

        df_net.columns = df_net.columns.str.strip()
        df_katz.columns = df_katz.columns.str.strip()

        df_net["Date"] = pd.to_datetime(df_net["Date"])
        df_net["Month"] = df_net["Date"].dt.to_period("M").astype(str)
        df_katz["Date"] = pd.to_datetime(df_katz["Date"])

        core_act = "Mental Health (Compulsory Assessment and Treatment) Act 1992"
        months = sorted(df_net["Month"].unique())
        all_nodes = pd.concat([df_net["Source"], df_net["Target"]]).unique()

       
        layout_graph = nx.Graph()
        layout_graph.add_nodes_from(all_nodes)
        raw_pos = nx.spring_layout(layout_graph, seed=42, k=10.0, dim=2)
        positions = np.array(list(raw_pos.values()))
        min_pos, max_pos = positions.min(axis=0), positions.max(axis=0)
        center = (min_pos + max_pos) / 2
        scale = 7.0 / (max_pos - min_pos).max()
        fixed_pos = {node: (raw_pos[node] - center) * scale for node in raw_pos}

        if core_act in fixed_pos:
            fixed_pos[core_act] = np.array([0.0, 0.0])

        
        nodes, labels, prestige_labels, edges = {}, {}, {}, {}
        time_label = Text("", font_size=1).to_corner(UR)
        katz_label = Text("", font_size=1).next_to(time_label, DOWN, aligned_edge=RIGHT)
        self.add(time_label, katz_label)

        def get_color(relation):
            return {
                "R_S": YELLOW,
                "AMD": RED,
                "AMD_S": WHITE,
                "PR": BLUE,
                "PR_S": PURPLE,
                "CIT": GREEN
            }.get(relation, GRAY)

        def get_edge_offset(src, tgt, rel_index, total_rels):
            vec = nodes[tgt].get_center() - nodes[src].get_center()
            perpendicular = np.array([-vec[1], vec[0]])
            norm = np.linalg.norm(perpendicular)
            if norm == 0:
                return np.array([0.0, 0.0, 0.0])
            perpendicular /= norm
            offset = (rel_index - (total_rels - 1) / 2) * 0.04
            return np.array([*perpendicular * offset, 0.0])

        for i, month in enumerate(months):
            current_df = df_net[df_net["Month"] == month]
            current_date = current_df["Date"].iloc[0]
            next_df = df_net[df_net["Month"] == months[i + 1]] if i + 1 < len(months) else pd.DataFrame(columns=df_net.columns)
            katz_df = df_katz[df_katz["Date"] == current_date]

            
            year_text = f"Year: {current_date.year}"
            katz_score = katz_df[katz_df["Act_Version"] == core_act]["Katz_Prestige"].values

            katz_text = f"Katz Prestige: {katz_score[0]:.2f}" if len(katz_score) > 0 else "Katz Prestige: N/A"

            new_time_label = Text(year_text, font_size=24).to_corner(UR)
            new_katz_label = Text(katz_text, font_size=24).next_to(new_time_label, DOWN, aligned_edge=RIGHT)
            self.play(FadeOut(time_label), FadeOut(katz_label), FadeIn(new_time_label), FadeIn(new_katz_label), run_time=0.3)
            time_label = new_time_label
            katz_label = new_katz_label

            
            expected_edges = set(zip(next_df["Source"], next_df["Target"], next_df["RelationType"]))
            to_remove = [k for k in edges if k not in expected_edges]
            for edge_key in to_remove:
                self.play(Uncreate(edges[edge_key]), run_time=0.2)
                self.remove(edges[edge_key])
                del edges[edge_key]

            
            for _, row in current_df.iterrows():
                src, tgt, rel = row["Source"], row["Target"], row["RelationType"]

                if "R_S" in rel and tgt in nodes:
                    self.play(Uncreate(nodes[tgt]), run_time=0.2)
                    self.remove(nodes[tgt], labels[tgt], prestige_labels.get(tgt, VGroup()))
                    del nodes[tgt], labels[tgt]
                    if tgt in prestige_labels:
                        del prestige_labels[tgt]
                    for ek in [k for k in edges if tgt in k[:2]]:
                        self.play(Uncreate(edges[ek]), run_time=0.2)
                        self.remove(edges[ek])
                        del edges[ek]
                    continue

                for node in [src, tgt]:
                    if node not in nodes and node in fixed_pos:
                        pos = np.append(fixed_pos[node], 0.0)
                        color = RED if node == core_act else PINK
                        circle = Circle(radius=0.05, color=color, fill_opacity=1).move_to(pos)
                        label = Text(node, font_size=9).next_to(circle, UP)
                        nodes[node] = circle
                        labels[node] = label
                        self.add(circle, label)

                if src in nodes and tgt in nodes:
                    edge_key = (src, tgt, rel)
                    if edge_key not in edges:
                        existing_rels = [k[2] for k in edges if (k[0], k[1]) == (src, tgt) or (k[1], k[0]) == (src, tgt)]
                        rel_index = len(existing_rels)
                        total_expected = len(current_df[((current_df["Source"] == src) & (current_df["Target"] == tgt)) |
                                                        ((current_df["Source"] == tgt) & (current_df["Target"] == src))])
                        offset = get_edge_offset(src, tgt, rel_index, total_expected)
                        start = nodes[src].get_center() + offset
                        end = nodes[tgt].get_center() + offset
                        edge = Line(start, end, color=get_color(rel), stroke_width=0.5)
                        edges[edge_key] = edge
                        self.add(edge)
                        self.play(Create(edge), run_time=0.2)

            
            if not next_df.empty:
                next_nodes = set(next_df["Source"]) | set(next_df["Target"])
                for node in list(nodes):
                    if node not in next_nodes and node != core_act:
                        self.play(Uncreate(nodes[node]), run_time=0.2)
                        self.remove(nodes[node], labels[node], prestige_labels.get(node, VGroup()))
                        del nodes[node], labels[node]
                        if node in prestige_labels:
                            del prestige_labels[node]
                        for ek in [k for k in edges if node in k[:2]]:
                            self.play(Uncreate(edges[ek]), run_time=0.2)
                            self.remove(edges[ek])
                            del edges[ek]

            self.wait(0.3)

        self.wait(0.5)

if __name__ == "__main__":
    from manim import config
    config.renderer = "opengl"
    config.quality = "high_quality"
    config.disable_caching = True
    config.output_file = "MentalHealthAct_Network_Animation.mp4"

    scene = MentalHealthNetworkAnimation()
    scene.render()
