In [None]:
import json
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
from collections import defaultdict, Counter
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh
from OCC.Core.BRep import BRep_Tool
from OCC.Core.TopAbs import TopAbs_FORWARD
from OCC.Core.TopLoc import TopLoc_Location
from OCC.Extend.TopologyUtils import TopologyExplorer
from OCC.Core.STEPControl import STEPControl_Reader
from OCC.Core.GProp import GProp_GProps
from OCC.Core.BRepGProp import brepgprop_SurfaceProperties, brepgprop_LinearProperties
from OCC.Core.GeomAbs import GeomAbs_Plane, GeomAbs_Cylinder, GeomAbs_Cone, GeomAbs_Sphere, GeomAbs_Torus, GeomAbs_Circle
from OCC.Core.BRepAdaptor import BRepAdaptor_Surface, BRepAdaptor_Curve

pio.renderers.default = "browser"

# INPUT PATHS - CHANGE THESE
step_file_path = r"c:\Users\u0175912\OneDrive - KU Leuven\master's fixturing topic\Parts for Validation\Part4.stp" #change this
convexity_results_file = r"c:\Users\u0175912\OneDrive - KU Leuven\master's fixturing topic\Parts for Validation\convexity_resultsPart4.json" #change this
aag_output_file = r"c:\Users\u0175912\OneDrive - KU Leuven\master's fixturing topic\Parts for Validation\aag_resultsPart4.json" #change this


def edge_key(edge, tolerance=1e-6):
    try:
        curve_adapter = BRepAdaptor_Curve(edge)
        curve_type = curve_adapter.GetType()
        
        if curve_type == GeomAbs_Circle:
            circle = curve_adapter.Circle()
            center = circle.Location()
            radius = circle.Radius()
            center_key = (
                round(center.X()/tolerance)*tolerance,
                round(center.Y()/tolerance)*tolerance, 
                round(center.Z()/tolerance)*tolerance,
                round(radius/tolerance)*tolerance
            )
            return ("circle",) + center_key
        else:
            vertices = list(TopologyExplorer(edge).vertices())
            if len(vertices) != 2:
                return None
            p1 = BRep_Tool.Pnt(vertices[0])
            p2 = BRep_Tool.Pnt(vertices[1])
            coords = sorted([
                (round(p1.X()/tolerance)*tolerance, round(p1.Y()/tolerance)*tolerance, round(p1.Z()/tolerance)*tolerance),
                (round(p2.X()/tolerance)*tolerance, round(p2.Y()/tolerance)*tolerance, round(p2.Z()/tolerance)*tolerance)
            ])
            return ("line",) + tuple(coords[0] + coords[1])
    except:
        return None


def load_step_file(filepath):
    reader = STEPControl_Reader()
    status = reader.ReadFile(filepath)
    if status != 1:
        raise IOError(f"Failed to read STEP file: {filepath}")
    reader.TransferRoots()
    return reader.OneShape()


def mesh_shape_for_visualization(shape, linear_deflection=0.1):
    BRepMesh_IncrementalMesh(shape, linear_deflection)


def extract_mesh_data(shape):
    vertices = []
    triangles = []
    vertex_count = 0
    topo = TopologyExplorer(shape)
    for face in topo.faces():
        loc = TopLoc_Location()
        triangulation = BRep_Tool.Triangulation(face, loc)
        if triangulation is None:
            continue
        transform = loc.Transformation()
        face_vertices = []
        for i in range(1, triangulation.NbNodes() + 1):
            pnt = triangulation.Node(i)
            pnt.Transform(transform)
            vertices.append([pnt.X(), pnt.Y(), pnt.Z()])
            face_vertices.append(vertex_count)
            vertex_count += 1
        for i in range(1, triangulation.NbTriangles() + 1):
            triangle = triangulation.Triangle(i)
            n1, n2, n3 = triangle.Get()
            if face.Orientation() == TopAbs_FORWARD:
                triangles.append([face_vertices[n1 - 1], face_vertices[n2 - 1], face_vertices[n3 - 1]])
            else:
                triangles.append([face_vertices[n1 - 1], face_vertices[n3 - 1], face_vertices[n2 - 1]])
    return vertices, triangles


def get_surface_type_name(surface_type):
    type_map = {
        GeomAbs_Plane: "Plane",
        GeomAbs_Cylinder: "Cylinder",
        GeomAbs_Cone: "Cone",
        GeomAbs_Sphere: "Sphere",
        GeomAbs_Torus: "Torus"
    }
    return type_map.get(surface_type, "Other")


def save_aag_results(builder, filename):
    graph_edges, edge_classifications, edge_info = builder.get_graph_edges()
    data_to_save = {
        "faces": [
            {"index": i, "type": builder.face_types[i], "center": builder.face_centers[i]}
            for i in range(len(builder.faces))
        ],
        "edges": [
            {
                "face1": edge[0],
                "face2": edge[1],
                "edge_idx": edge_info[i]['edge_idx'],
                "classification": edge_info[i]['classification'],
                "edge_type": edge_info[i]['edge_type'],
                "edge_length": edge_info[i]['edge_length']
            }
            for i, edge in enumerate(graph_edges)
        ]
    }
    with open(filename, "w") as f:
        json.dump(data_to_save, f, indent=2)
    print(f"AAG results saved to {filename}")


class DirectAAGBuilder:
    def __init__(self, step_file, convexity_results_file):
        self.step_file = step_file
        self.convexity_results_file = convexity_results_file
        self.load_step_shape()
        self.load_convexity_results()
        self.build_consistent_edge_mapping()
        self.build_face_edge_topology()
        self.extract_geometry()

    def load_step_shape(self):
        self.shape = load_step_file(self.step_file)
        mesh_shape_for_visualization(self.shape, linear_deflection=0.1)
        self.vertices, self.triangles = extract_mesh_data(self.shape)

    def load_convexity_results(self):
        with open(self.convexity_results_file, 'r') as f:
            self.convexity_results = json.load(f)
        self.edge_classification = {}
        for result in self.convexity_results:
            edge_id = result['edge_id']
            self.edge_classification[edge_id] = {
                'classification': result['classification'],
                'edge_type': result.get('edge_type', 'Unknown'),
                'edge_length': result.get('edge_length', 0)
            }
        print(f"Loaded convexity results for {len(self.convexity_results)} edges")

    def build_consistent_edge_mapping(self):
        topo = TopologyExplorer(self.shape)
        self.edges = list(topo.edges())
        self.edge_id_map = {}
        self.convexity_to_current_map = {}
        
        for conv_id in self.edge_classification.keys():
            if conv_id < len(self.edges):
                edge = self.edges[conv_id]
                key = edge_key(edge)
                if key is not None:
                    self.edge_id_map[key] = conv_id
                    self.convexity_to_current_map[conv_id] = conv_id
        
        print(f"Direct mapping: {len(self.edge_id_map)} edges mapped")
        
        unmapped_conv = set(self.edge_classification.keys()) - set(self.convexity_to_current_map.keys())
        for conv_id in unmapped_conv:
            conv_data = self.edge_classification[conv_id]
            conv_length = conv_data['edge_length']
            for curr_id, edge in enumerate(self.edges):
                if curr_id in self.convexity_to_current_map.values():
                    continue
                try:
                    props = GProp_GProps()
                    brepgprop_LinearProperties(edge, props)
                    curr_length = props.Mass()
                    if abs(curr_length - conv_length) < 0.1:
                        key = edge_key(edge)
                        if key is not None and key not in self.edge_id_map:
                            self.edge_id_map[key] = conv_id
                            self.convexity_to_current_map[conv_id] = curr_id
                            break
                except:
                    continue
        
        final_unmapped = set(self.edge_classification.keys()) - set(self.convexity_to_current_map.keys())
        print(f"Final: {len(self.edge_id_map)} mapped, {len(final_unmapped)} unmapped")

    def build_face_edge_topology(self):
        topo = TopologyExplorer(self.shape)
        self.faces = list(topo.faces())
        self.face_adjacency = defaultdict(set)
        self.edge_to_faces = defaultdict(list)
        
        for face_idx, face in enumerate(self.faces):
            face_topo = TopologyExplorer(face)
            for edge in face_topo.edges():
                key = edge_key(edge)
                if key is None:
                    continue
                conv_edge_id = self.edge_id_map.get(key)
                if conv_edge_id is None:
                    continue
                self.edge_to_faces[conv_edge_id].append(face_idx)
        
        for edge_id, face_list in self.edge_to_faces.items():
            if len(face_list) < 2:
                continue
            for i in range(len(face_list)):
                for j in range(i + 1, len(face_list)):
                    f1, f2 = face_list[i], face_list[j]
                    self.face_adjacency[f1].add((f2, edge_id))
                    self.face_adjacency[f2].add((f1, edge_id))
        
        print(f"Built adjacency graph with {len(self.face_adjacency)} connected faces")

    def extract_geometry(self):
        self.face_centers = []
        self.face_types = []
        for face in self.faces:
            props = GProp_GProps()
            brepgprop_SurfaceProperties(face, props)
            center = props.CentreOfMass()
            self.face_centers.append([center.X(), center.Y(), center.Z()])
            try:
                surface = BRepAdaptor_Surface(face)
                surf_type = surface.GetType()
                type_name = get_surface_type_name(surf_type)
                self.face_types.append(type_name)
            except:
                self.face_types.append("Unknown")

    def get_graph_edges(self):
        graph_edges = []
        edge_classifications = []
        edge_info = []
        processed_pairs = set()
        
        for f1, adjacents in self.face_adjacency.items():
            for f2, edge_id in adjacents:
                pair = tuple(sorted([f1, f2]))
                if pair in processed_pairs:
                    continue
                processed_pairs.add(pair)
                graph_edges.append((f1, f2))
                cls_data = self.edge_classification.get(edge_id, {})
                classification = cls_data.get('classification', 'unknown')
                edge_type = cls_data.get('edge_type', 'Unknown')
                edge_length = cls_data.get('edge_length', 0)
                edge_classifications.append(classification)
                edge_info.append({
                    'edge_idx': edge_id,
                    'classification': classification,
                    'edge_type': edge_type,
                    'edge_length': edge_length
                })
        return graph_edges, edge_classifications, edge_info

    def visualize_3d_aag(self, show_mesh=True, mesh_opacity=0.2, node_size=10):
        fig = go.Figure()
        
        if show_mesh:
            fig.add_trace(go.Mesh3d(
                x=[v[0] for v in self.vertices],
                y=[v[1] for v in self.vertices],
                z=[v[2] for v in self.vertices],
                i=[t[0] for t in self.triangles],
                j=[t[1] for t in self.triangles],
                k=[t[2] for t in self.triangles],
                color='lightblue',
                opacity=mesh_opacity,
                name='3D Model',
                flatshading=True,
            ))
        
        face_colors = {
            'Plane': 'blue', 'Cylinder': 'red', 'Cone': 'orange',
            'Sphere': 'green', 'Torus': 'purple', 'Other': 'brown', 'Unknown': 'gray'
        }
        
        unique_types = set(self.face_types)
        for ftype in unique_types:
            indices = [i for i, ft in enumerate(self.face_types) if ft == ftype]
            if not indices:
                continue
            xs = [self.face_centers[i][0] for i in indices]
            ys = [self.face_centers[i][1] for i in indices]
            zs = [self.face_centers[i][2] for i in indices]
            fig.add_trace(go.Scatter3d(
                x=xs, y=ys, z=zs, mode='markers+text',
                marker=dict(size=node_size, color=face_colors.get(ftype, 'gray'), line=dict(width=2, color='black')),
                text=[str(i) for i in indices], textposition='middle center',
                textfont=dict(size=8, color='white'), name=f'{ftype} faces ({len(indices)})',
                hovertemplate="Face %{text}<br>Type: "+ftype+"<br>Center: (%{x:.2f}, %{y:.2f}, %{z:.2f})<extra></extra>"
            ))
        
        graph_edges, edge_classifications, edge_info = self.get_graph_edges()
        edge_groups = {
            'convex': {'edges': [], 'color': 'blue', 'width': 4, 'name': 'Convex edges'},
            'concave': {'edges': [], 'color': 'red', 'width': 4, 'name': 'Concave edges'},
            'tangent': {'edges': [], 'color': 'green', 'width': 3, 'name': 'Tangent edges'},
            'unknown': {'edges': [], 'color': 'gray', 'width': 2, 'name': 'Unknown edges'}
        }
        
        for i, (edge, cls) in enumerate(zip(graph_edges, edge_classifications)):
            group_key = cls if cls in edge_groups else 'unknown'
            edge_groups[group_key]['edges'].append((edge, edge_info[i]))
        
        for group in edge_groups.values():
            if not group['edges']:
                continue
            x_line, y_line, z_line = [], [], []
            hover_text = []
            for (f1, f2), info in group['edges']:
                x1, y1, z1 = self.face_centers[f1]
                x2, y2, z2 = self.face_centers[f2]
                x_line.extend([x1, x2, None])
                y_line.extend([y1, y2, None])
                z_line.extend([z1, z2, None])
                hover_label = (f"Edge {info['edge_idx']}<br>Faces: {f1} â†” {f2}<br>"
                               f"Classification: {info['classification']}<br>Type: {info['edge_type']}<br>"
                               f"Length: {info['edge_length']:.2f}")
                hover_text.extend([hover_label, hover_label, None])
            fig.add_trace(go.Scatter3d(
                x=x_line, y=y_line, z=z_line, mode='lines',
                line=dict(color=group['color'], width=group['width']),
                name=f"{group['name']} ({len(group['edges'])})",
                hovertemplate="%{text}<extra></extra>", text=hover_text
            ))
        
        fig.update_layout(
            title="AAG Graph from STEP File",
            scene=dict(xaxis_title="X", yaxis_title="Y", zaxis_title="Z", aspectmode="data",
                      camera=dict(eye=dict(x=1.5, y=1.5, z=1.5))),
            legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
            width=1200, height=900
        )
        fig.show()
        self.print_stats(graph_edges, edge_classifications)

    def print_stats(self, graph_edges, edge_classifications):
        print(f"\nTotal faces: {len(self.faces)}")
        print(f"Total edges in STEP: {len(self.edges)}")
        print(f"Successfully mapped edges: {len(self.edge_id_map)}")
        print(f"Graph edges (face adjacencies): {len(graph_edges)}")
        print("\nFace type distribution:")
        for ftype, count in Counter(self.face_types).items():
            print(f"  {ftype}: {count}")
        print("\nEdge classification distribution:")
        for cls, count in Counter(edge_classifications).items():
            print(f"  {cls}: {count}")
        covered = sum(1 for c in edge_classifications if c != 'unknown')
        coverage = (covered / len(edge_classifications) * 100) if edge_classifications else 0
        print(f"\nConvexity coverage: {covered}/{len(edge_classifications)} ({coverage:.1f}%)")