In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
import pandas as pd
import numpy as np
import torch
import torch_geometric as pyg
from tqdm.auto import *

from deepgd.model import Generator
from deepgd.data import GraphDrawingData
from deepgd.datasets import  RomeDataset

from deepgd.metrics import Stress

In [5]:
np.random.seed(721)
torch.manual_seed(721)
torch.cuda.manual_seed_all(721)

In [6]:
device = "cpu"
for backend, device_name in {
    torch.backends.mps: "mps",
    torch.cuda: "cuda",
}.items():
    if backend.is_available():
        device = device_name

  return torch._C._cuda_getDeviceCount() > 0


In [7]:
batch_size = 32
lr = 0.001
decay = 0.998

====================================================================================

##### ErdosRenyi Dataset

##### Rome Dataset

In [8]:
GraphDrawingData.set_optional_fields(["edge_pair_metaindex", "face", "rng"])
dataset = RomeDataset(
    index=pd.read_csv("assets/rome_index.txt", header=None)[0],
)
print(len(dataset))
layouts = np.load("assets/layouts/pmds.npy", allow_pickle=True)

inside init
raw_file_names
inside _parse_metadata


  if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):
  if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):
  self.data, self.slices = torch.load(self.data_path)


processed_file_names
processed_file_names
processed_file_names


Transform graphs:   0%|          | 0/347 [00:00<?, ?it/s]

data dictionary size: 347
Graph erdos_renyi_20_0.2_graph_0 not found
Graph erdos_renyi_20_0.2_graph_1 not found
Graph erdos_renyi_20_0.2_graph_6 not found
347


====================================================================================

In [9]:

params = Generator.Params(
    num_blocks=11,
    block_depth=3,
    block_width=8,
    block_output_dim=8,
    edge_net_depth=2,
    edge_net_width=16,
    edge_attr_dim=2,
    node_attr_dim=2,
)
model = Generator(
    params=params,
).to(device)
criteria = {
    Stress(): 1,
    # dgd.EdgeVar(): 0,
    # dgd.Occlusion(): 0,
    # dgd.IncidentAngle(): 0,
    # dgd.TSNEScore(): 0,
}

In [10]:
optim = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=decay)

In [11]:
datalist = list(dataset)
for i, data in enumerate(datalist):
    data.pos = torch.tensor(layouts[i]).float()

In [12]:
train_loader = pyg.loader.DataLoader(datalist[:10000], batch_size=batch_size, shuffle=True)
val_loader = pyg.loader.DataLoader(datalist[11000:], batch_size=batch_size, shuffle=False)
# test_loader = pyg.loader.DataLoader(datalist[10000:11000], batch_size=batch_size, shuffle=False)
test_loader = pyg.loader.DataLoader(datalist[:300], batch_size=batch_size, shuffle=False)

In [29]:
import networkx as nx
def generate_init_pos(batch):
    # pos = torch.rand_like(batch.pos)
    print("batch.G",batch.G)
    combined_graph = batch.G[0]
    # Combine each subsequent graph
    for g in batch.G[1:]:
        combined_graph = nx.disjoint_union(combined_graph, g)
    pos = rescale_by_stress(
        # pos = torch.tensor(list(nx.spectral_layout(nx.compose(*batch.G)).values()), dtype=torch.float32),
        pos = torch.tensor(list(nx.spectral_layout(combined_graph).values()), dtype=torch.float32),
        apsp=batch.apsp_attr,
        edge_index=batch.perm_index,
        batch_index=batch.batch,
    )
    print(" shape after rescale_by_stress:", pos.shape)
    return pos


def get_edge_features(all_pair_shortest_path):
    return torch.cat([
        all_pair_shortest_path[:, None],
        1 / all_pair_shortest_path[:, None].square()
    ], dim=-1)

# def rescale_by_stress(pos, apsp, edge_index, batch_index):  
#     src_pos, dst_pos = pos[edge_index[0]], pos[edge_index[1]]
#     dist = (dst_pos - src_pos).norm(dim=1)
#     u_over_d = dist / apsp
    
#     # Ensuring we cover all indices that pos might reference
#     max_index = pos.size(0)
#     scatterd_u_over_d_2 = pyg.utils.scatter(u_over_d ** 2, batch_index[edge_index[0]], dim_size=max_index)
#     scatterd_u_over_d = pyg.utils.scatter(u_over_d, batch_index[edge_index[0]], dim_size=max_index)

#     scale = scatterd_u_over_d_2 / scatterd_u_over_d
    
#     # print("pos shape:", pos.shape)
#     # print("scale shape:", scale.shape)
#     # print("adjusted scale shape for operation:", scale[batch_index][:, None].shape)
#     # print("Max edge index:", torch.max(edge_index).item())
#     # print("Number of positions available:", pos.size(0))

#     scaled_pos = pos / scale[:, None]
#     # print("Shape of scaled positions:", scaled_pos.shape)

#     return scaled_pos

def rescale_by_stress(pos, apsp, edge_index):  
    src_pos, dst_pos = pos[edge_index[0]], pos[edge_index[1]]
    dist = (dst_pos - src_pos).norm(dim=1)
    
    # 避免除以零
    mask = apsp != 0
    if mask.sum() == 0:
        return pos  # 如果所有 apsp 都为零，直接返回原始位置
    u_over_d = dist[mask] / apsp[mask]
    
    # 计算缩放因子
    numerator = pyg.utils.scatter(u_over_d ** 2, edge_index[0][mask], dim_size=pos.size(0))
    denominator = pyg.utils.scatter(u_over_d, edge_index[0][mask], dim_size=pos.size(0))
    scale = numerator / denominator
    
    # 避免缩放因子为零
    scale[scale == 0] = 1.0
    
    scaled_pos = pos / scale[:, None]
    return scaled_pos


In [24]:
model.load_state_dict(torch.load("model_359.pt", map_location=device))

  model.load_state_dict(torch.load("model_359.pt", map_location=device))


<All keys matched successfully>

In [25]:
# for epoch in range(1000):
for epoch in range(0):
    model.train()
    losses = []
    for batch in tqdm(train_loader):
        batch = batch.to(device)
        model.zero_grad()
        loss = 0
        for c, w in criteria.items():
            pred = model(
                init_pos=generate_init_pos(batch),
                edge_index=batch.perm_index,
                edge_attr=get_edge_features(batch.apsp_attr),
                batch_index=batch.batch,
            )
            pos = rescale_by_stress(pred, batch.apsp_attr, batch.perm_index, batch.batch)
            loss += w * c(pos, batch.perm_index, batch.apsp_attr, batch.batch)
        loss.backward()
        optim.step()
        losses.append(loss.item())
    scheduler.step()
    print(f'[Epoch {epoch}] Train Loss: {np.mean(losses)}')
    with torch.no_grad():
        model.eval()
        losses = []
        for batch in tqdm(val_loader, disable=True):
            batch = batch.to(device)
            loss = 0
            for c, w in criteria.items():
                pred = model(
                    init_pos=generate_init_pos(batch),
                    edge_index=batch.perm_index,
                    edge_attr=get_edge_features(batch.apsp_attr),
                    batch_index=batch.batch,
                )
                pos = rescale_by_stress(pred, batch.apsp_attr, batch.perm_index, batch.batch)
                loss += w * c(pos, batch.perm_index, batch.apsp_attr, batch.batch)
            losses.append(loss.item())
        print(f'[Epoch {epoch}] Val Loss: {np.mean(losses)}')

================================ DEBUG ================================

In [26]:
with torch.no_grad():
    model.eval()
    losses = []
    for batch in tqdm(test_loader, disable=True):
        print(batch)
        batch = batch.to(device)
        
        # print("Batch size:", batch.batch.size())
        # print("perm index shape:", batch.perm_index.shape)
        # print("apsp_attr shape:", batch.apsp_attr.shape)


        loss = 0
        for c, w in criteria.items():

            pred = model(
                init_pos=generate_init_pos(batch),
                edge_index=batch.perm_index,
                edge_attr=get_edge_features(batch.apsp_attr),
                batch_index=batch.batch,
            )
            pos = rescale_by_stress(pred, batch.apsp_attr, batch.perm_index, batch.batch)
            loss += w * c(pos, batch.perm_index, batch.apsp_attr, batch.batch)
        losses.append(loss.item())
    print(f'Test Loss: {np.mean(losses)}')

GraphDrawingDataBatch(G=[32], perm_index=[2, 12160], edge_metaindex=[4178], apsp_attr=[12160], perm_weight=[12160], aggr_metaindex=[12160], pos=[1535, 2], name=[32], n=[32], m=[32], edge_pair_metaindex=[2, 73450], num_nodes=640, batch=[640], ptr=[33])
batch.G [<networkx.classes.digraph.DiGraph object at 0x7f7c8f054280>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f055030>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f055e40>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f055ed0>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f055f60>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f055ff0>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f056080>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f056110>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f0561a0>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f056230>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f0562c0>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f056350>, <networ

TypeError: rescale_by_stress() got an unexpected keyword argument 'batch_index'

In [27]:
import numpy as np
import matplotlib.pyplot as plt
def visualize_graph_v2(batch, perm_index, apsp_attr, pos=None):
    print("inside visualize_graph_v2", batch.G)
    total = 0
    for G in batch.G:
        total += G.number_of_nodes()
    print("total",total)
    print("dimension of shape", pos.shape)

    x = 0
    for i in range(len(batch.G)):
        print(i, pos[x:x+batch.G[i].number_of_nodes()])
        nx.draw_networkx(batch.G[i], pos=pos[x:x+batch.G[i].number_of_nodes()], with_labels=False, cmap="Set2")
        # print("print(pos): ----",pos.shape)
        # pos = pos[batch.G[i].number_of_nodes():]
        # print("print(pos): ----",pos.shape)
        x += batch.G[i].number_of_nodes()
        break

    plt.show()
    # return g2
    return batch.G[0]

with torch.no_grad():
    model.eval()
    losses = []
    
    for batch in tqdm(test_loader, disable=True):
        # print(batch)
        batch = batch.to(device)
        
        # print("Batch size:", batch.batch.size())
        # print("perm index shape:", batch.perm_index.shape)
        # print("apsp_attr shape:", batch.apsp_attr.shape)

        loss = 0
        for c, w in criteria.items():
            pred = model(
                init_pos=generate_init_pos(batch),
                edge_index=batch.perm_index,
                edge_attr=get_edge_features(batch.apsp_attr),
                batch_index=batch.batch,
            )
            # print("pred: ", pred)
            pos = rescale_by_stress(pred, batch.apsp_attr, batch.perm_index, batch.batch)
            
            loss += w * c(pos, batch.perm_index, batch.apsp_attr, batch.batch)

        losses.append(loss.item())
        
        visualize_graph_v2(batch, batch.perm_index.cpu().numpy(), batch.apsp_attr.cpu().numpy(), pred)

    print(f'Test Loss: {np.mean(losses)}')


batch.G [<networkx.classes.digraph.DiGraph object at 0x7f7c8f054280>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f055030>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f055e40>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f055ed0>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f055f60>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f055ff0>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f056080>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f056110>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f0561a0>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f056230>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f0562c0>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f056350>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f0563e0>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f056470>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f056500>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f056590>, <networkx.class

TypeError: rescale_by_stress() got an unexpected keyword argument 'batch_index'

In [28]:
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import math

def visualize_graph_v2(batch, perm_index, apsp_attr, pos=None):
    print("inside visualize_graph_v2", batch.G)
    total = 0
    for G in batch.G:
        total += G.number_of_nodes()
    print("total", total)
    print("dimension of shape", pos.shape)

    x = 0
    num_graphs = len(batch.G)
    num_cols = 4
    num_rows = math.ceil(num_graphs / num_cols)
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols*5, num_rows*5))
    axes = axes.flatten() 

    for i in range(num_graphs):
        ax = axes[i]
        graph_pos = pos[x:x + batch.G[i].number_of_nodes()]

        if torch.isnan(graph_pos).any():
            print(f"Warning: NaN values found in graph {i} positions! Replacing NaN values with (0,0).")
            graph_pos = torch.nan_to_num(graph_pos, nan=0.0)

        print(i, graph_pos)
        nx.draw_networkx(batch.G[i], pos=graph_pos.numpy(), with_labels=False, cmap="Set2", ax=ax)
        ax.set_axis_off()
        x += batch.G[i].number_of_nodes()

    for j in range(i+1, len(axes)):
        axes[j].set_visible(False)

    plt.tight_layout()
    plt.show()

    return batch.G[0]



# Testing the modificatio n inside the main loop
with torch.no_grad():
    model.eval()
    losses = []

    for batch in tqdm(test_loader, disable=True):
        batch = batch.to(device)

        loss = 0
        for c, w in criteria.items():
            pred = model(
                init_pos=generate_init_pos(batch),
                edge_index=batch.perm_index,
                edge_attr=get_edge_features(batch.apsp_attr),
                batch_index=batch.batch,
            )
            pos = rescale_by_stress(pred, batch.apsp_attr, batch.perm_index, batch.batch)
            loss += w * c(pos, batch.perm_index, batch.apsp_attr, batch.batch)

        break
        losses.append(loss.item())

        visualize_graph_v2(batch, batch.perm_index.cpu().numpy(), batch.apsp_attr.cpu().numpy(), pred)

    print(f'Test Loss: {np.mean(losses)}')



batch.G [<networkx.classes.digraph.DiGraph object at 0x7f7c8f054280>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f055030>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f055e40>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f055ed0>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f055f60>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f055ff0>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f056080>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f056110>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f0561a0>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f056230>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f0562c0>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f056350>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f0563e0>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f056470>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f056500>, <networkx.classes.digraph.DiGraph object at 0x7f7c8f056590>, <networkx.class

TypeError: rescale_by_stress() got an unexpected keyword argument 'batch_index'