In [1]:
import json
import networkx as nx
import torch
from torch_geometric.utils import from_networkx
from sklearn.preprocessing import StandardScaler
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

In [4]:
class Building:
    def __init__(self, name='', label='', position=None, bbox=None):
        self.node_type = 'Building'
        self.name = name
        self.label = label
        self.position = position if position else [0.0, 0.0, 0.0]
        self.bbox = bbox if bbox else [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

class Level:
    def __init__(self, level_index=0, num_regions=0, label='', position=None, bbox=None):
        self.node_type = 'Level'
        self.level_index = level_index
        self.num_regions = num_regions
        self.label = label
        self.position = position if position else [0.0, 0.0, 0.0]
        self.bbox = bbox if bbox else [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

class Room:
    def __init__(self, region_index=0, level_index=0, label='', position=None, bbox=None, height=0.0, objects=None, portals=None):
        self.node_type = 'Room'
        self.region_index = region_index
        self.level_index = level_index
        self.label = label
        self.position = position if position else [0.0, 0.0, 0.0]
        self.bbox = bbox if bbox else [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
        self.height = height
        self.objects = objects if objects else []
        self.portals = portals if portals else []

class Object:
    def __init__(self, object_type='', object_index=0, region_index=0, category_index=0, position=None, axis_directions_a0=None, axis_directions_a1=None, radii_r=None, additional_values=None, category_details=None, segments=None):
        self.node_type = 'Object'
        self.object_type = object_type
        self.object_index = object_index
        self.region_index = region_index
        self.category_index = category_index
        self.position = position if position else [0.0, 0.0, 0.0]
        self.axis_directions_a0 = axis_directions_a0 if axis_directions_a0 else [0.0, 0.0, 0.0]
        self.axis_directions_a1 = axis_directions_a1 if axis_directions_a1 else [0.0, 0.0, 0.0]
        self.radii_r = radii_r if radii_r else [0.0, 0.0, 0.0]
        self.additional_values = additional_values if additional_values else [0.0] * 8
        self.category_details = category_details if category_details else {}
        self.segments = segments if segments else []

class Portal:
    def __init__(self, portal_index=0, regions=None, label='', bbox=None):
        self.node_type = 'Portal'
        self.portal_index = portal_index
        self.regions = regions if regions else [0, 0]
        self.label = label
        self.bbox = bbox if bbox else [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

class Panorama:
    def __init__(self, name='', panorama_index=0, region_index=0, position=None, images=None):
        self.node_type = 'Panorama'
        self.name = name
        self.panorama_index = panorama_index
        self.region_index = region_index
        self.position = position if position else [0.0, 0.0, 0.0]
        self.images = images if images else []

class Image:
    def __init__(self, image_index=0, panorama_index=0, name='', camera_index=0, yaw_index=0, extrinsics=None, intrinsics=None, width=0, height=0, position=None):
        self.node_type = 'Image'
        self.image_index = image_index
        self.panorama_index = panorama_index
        self.name = name
        self.camera_index = camera_index
        self.yaw_index = yaw_index
        self.extrinsics = extrinsics if extrinsics else [[0.0]*4]*4
        self.intrinsics = intrinsics if intrinsics else [[0.0]*3]*3
        self.width = width
        self.height = height
        self.position = position if position else [0.0, 0.0, 0.0]

def load_json_to_classes(json_file):
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    building = Building(**data['building'])

    levels = [Level(**level) for level in data['levels']]

    rooms = []
    for room_data in data['rooms']:
        objects = [Object(
            object_type=obj['object_type'],
            object_index=obj['object_index'],
            region_index=obj['region_index'],
            category_index=obj['category_index'],
            position=obj['position'],
            axis_directions_a0=obj['axis_directions(a0)'],
            axis_directions_a1=obj['axis_directions(a1)'],
            radii_r=obj['radii(r)'],
            additional_values=obj['additional_values'],
            category_details=obj['category_details'],
            segments=obj['segments']
        ) for obj in room_data['objects']]
        portals = [Portal(**portal) for portal in room_data['portals']]
        room = Room(
            objects=objects,
            portals=portals,
            **{k: v for k, v in room_data.items() if k not in ['objects', 'portals']}
        )
        rooms.append(room)

    panoramas = []
    for panorama_data in data['panoramas']:
        images = [Image(**image) for image in panorama_data['images']]
        panorama = Panorama(
            images=images,
            **{k: v for k, v in panorama_data.items() if k != 'images'}
        )
        panoramas.append(panorama)
    
    return building, levels, rooms, panoramas

def normalize_positions(building, rooms):
    for room in rooms:
        room.position = [coord / (building.position[i] if building.position[i] != 0 else 1) for i, coord in enumerate(room.position)]
        for obj in room.objects:
            obj.position = [coord / (room.position[i] if room.position[i] != 0 else 1) for i, coord in enumerate(obj.position)]
    return building, rooms

In [5]:
def build_graph(building, levels, rooms, panoramas):
    G = nx.Graph()
    
    def add_node_with_attributes(G, node_id, attributes):
        G.add_node(node_id, **attributes.__dict__)
    
    add_node_with_attributes(G, "building", building)

    for level in levels:
        level_id = f"level_{level.level_index}"
        add_node_with_attributes(G, level_id, level)
        G.add_edge("building", level_id, type='has_level')

    for room in rooms:
        room_id = f"room_{room.region_index}"
        add_node_with_attributes(G, room_id, room)
        
        level_id = f"level_{room.level_index}"
        G.add_edge(level_id, room_id, type='has_room')
        
        for obj in room.objects:
            obj_id = f"object_{obj.object_index}"
            add_node_with_attributes(G, obj_id, obj)
            G.add_edge(room_id, obj_id, type='contains')
        
        for portal in room.portals:
            portal_id = f"portal_{portal.portal_index}"
            add_node_with_attributes(G, portal_id, portal)
            region0 = f"room_{portal.regions[0]}"
            region1 = f"room_{portal.regions[1]}"
            G.add_edge(region0, region1, type='connected_to')
    
    for panorama in panoramas:
        panorama_id = f"panorama_{panorama.panorama_index}"
        add_node_with_attributes(G, panorama_id, panorama)
        
        room_id = f"room_{panorama.region_index}"
        G.add_edge(room_id, panorama_id, type='has_panorama')
        
        for image in panorama.images:
            image_id = f"image_{image.image_index}"
            add_node_with_attributes(G, image_id, image)
            G.add_edge(panorama_id, image_id, type='contains')

    return G

# 转换为PyG数据
def graph_to_pyg_data(G):
    node_attributes_set = set(G.nodes[next(iter(G.nodes))].keys())
    for node in G.nodes:
        if set(G.nodes[node].keys()) != node_attributes_set:
            raise ValueError(f"Node {node} attributes {set(G.nodes[node].keys())} do not match {node_attributes_set}")
    
    data = from_networkx(G)
    data.x = torch.tensor([G.nodes[node]['position'] + G.nodes[node].get('bbox', [0, 0, 0, 0, 0, 0]) for node in G.nodes], dtype=torch.float)
    data.y = torch.tensor([G.nodes[node]['category_index'] if 'category_index' in G.nodes[node] else 0 for node in G.nodes], dtype=torch.long)
    
    return data

def normalize_features(data):
    scaler = StandardScaler()
    data.x = torch.tensor(scaler.fit_transform(data.x), dtype=torch.float)
    return data

def normalize_edge_features(data):
    if data.edge_attr is not None:
        scaler = StandardScaler()
        data.edge_attr = torch.tensor(scaler.fit_transform(data.edge_attr), dtype=torch.float)
    return data

def add_train_test_masks(data, train_ratio=0.8):
    num_nodes = data.num_nodes
    indices = torch.randperm(num_nodes)
    train_size = int(train_ratio * num_nodes)
    train_indices = indices[:train_size]
    test_indices = indices[train_size:]

    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool)
    train_mask[train_indices] = True
    test_mask[test_indices] = True

    data.train_mask = train_mask
    data.test_mask = test_mask

json_file = '17DRP5sb8fy.json'
building, levels, rooms, panoramas = load_json_to_classes(json_file)
building, rooms = normalize_positions(building, rooms)
G = build_graph(building, levels, rooms, panoramas)
data = graph_to_pyg_data(G)
data = normalize_features(data)
data = normalize_edge_features(data)
add_train_test_masks(data, train_ratio=0.8)

ValueError: Node room_0 attributes {'region_index', 'level_index', 'node_type', 'position', 'objects', 'height', 'bbox', 'portals', 'label'} do not match {'node_type', 'name', 'position', 'bbox', 'label'}

In [None]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

def train_and_evaluate(model, data, optimizer, epochs=20):
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()

        model.eval()
        _, pred = model(data.x, data.edge_index).max(dim=1)
        correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
        acc = correct / int(data.test_mask.sum())
        print(f'Epoch {epoch+1}, Loss: {loss.item()}, Test Accuracy: {acc}')

In [None]:
torch.manual_seed(42)

# 创建模型和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(in_channels=data.x.size(1), hidden_channels=16, out_channels=2).to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

# 训练和评估模型
train_and_evaluate(model, data, optimizer)

# 检查标签分布
train_labels = data.y[data.train_mask]
test_labels = data.y[data.test_mask]
print(f"Training labels distribution: {torch.bincount(train_labels)}")
print(f"Testing labels distribution: {torch.bincount(test_labels)}")