In [3]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
import pandas as pd
import numpy as np
import torch
import torch_geometric as pyg
from tqdm.auto import *

from deepgd.model import Generator
from deepgd.data import GraphDrawingData
from deepgd.datasets import  RomeDataset

from deepgd.metrics import Stress

In [5]:
device = "cpu"
for backend, device_name in {
    torch.backends.mps: "mps",
    torch.cuda: "cuda",
}.items():
    if backend.is_available():
        device = device_name

  return torch._C._cuda_getDeviceCount() > 0


In [6]:
batch_size = 32
lr = 0.001
decay = 0.998

====================================================================================

##### ErdosRenyi Dataset

In [9]:
from deepgd.datasets import  ErdosRenyiDataset


dataset = ErdosRenyiDataset(
    root="path/to/dataset",
    node_sizes=[20, 40, 80],
    probabilities=[0.2, 0.4, 0.6, 0.8],
    num_graphs_per_combination=10,
    datatype=GraphDrawingData
)

# 为每个图生成随机的初始位置
for data in dataset:
    data.pos = torch.rand((data.num_nodes, 2))  # 例如，为2D空间生成随机坐标

# 使用 DataLoader 加载数据
loader = DataLoader(dataset, batch_size=32, shuffle=True)


Processing...
Done!
Transform graphs:   0%|          | 0/1 [00:00<?, ?it/s]


AttributeError: 'NoneType' object has no attribute '_transforms'

##### Rome Dataset

In [None]:
GraphDrawingData.set_optional_fields(["edge_pair_metaindex", "face", "rng"])
dataset = RomeDataset(
    index=pd.read_csv("assets/rome_index.txt", header=None)[0],
)
layouts = np.load("assets/layouts/pmds.npy", allow_pickle=True)

====================================================================================

In [5]:

params = Generator.Params(
    num_blocks=11,
    block_depth=3,
    block_width=8,
    block_output_dim=8,
    edge_net_depth=2,
    edge_net_width=16,
    edge_attr_dim=2,
    node_attr_dim=2,
)
model = Generator(
    params=params,
).to(device)
criteria = {
    Stress(): 1,
    # dgd.EdgeVar(): 0,
    # dgd.Occlusion(): 0,
    # dgd.IncidentAngle(): 0,
    # dgd.TSNEScore(): 0,
}

Downloading https://www.graphdrawing.org/download/rome-graphml.tgz
Extracting datasets/Rome/raw/rome-graphml.tgz
Processing...


Loading graphs:   0%|          | 0/11534 [00:00<?, ?it/s]

Done!


Transform graphs:   0%|          | 0/11531 [00:00<?, ?it/s]

In [7]:
optim = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=decay)

In [8]:
datalist = list(dataset)
for i, data in enumerate(datalist):
    data.pos = torch.tensor(layouts[i]).float()

In [9]:
train_loader = pyg.loader.DataLoader(datalist[:10000], batch_size=batch_size, shuffle=True)
val_loader = pyg.loader.DataLoader(datalist[11000:], batch_size=batch_size, shuffle=False)
test_loader = pyg.loader.DataLoader(datalist[10000:11000], batch_size=batch_size, shuffle=False)

In [10]:
def generate_init_pos(batch):
    # pos = torch.rand_like(batch.pos)
    pos = rescale_by_stress(
        pos=batch.pos,
        apsp=batch.apsp_attr,
        edge_index=batch.perm_index,
        batch_index=batch.batch,
    )
    return pos

def get_edge_features(all_pair_shortest_path):
    return torch.cat([
        all_pair_shortest_path[:, None],
        1 / all_pair_shortest_path[:, None].square()
    ], dim=-1)

def rescale_by_stress(pos, apsp, edge_index, batch_index):
    src_pos, dst_pos = pos[edge_index[0]], pos[edge_index[1]]
    dist = (dst_pos - src_pos).norm(dim=1)
    u_over_d = dist / apsp
    scatterd_u_over_d_2 = pyg.utils.scatter(u_over_d ** 2, batch_index[edge_index[0]])
    scatterd_u_over_d = pyg.utils.scatter(u_over_d, batch_index[edge_index[0]])
    scale = scatterd_u_over_d_2 / scatterd_u_over_d
    return pos / scale[batch_index][:, None]

In [12]:
model.load_state_dict(torch.load("model_359.pt", map_location=device))

<All keys matched successfully>

In [None]:
for epoch in range(1000):
    model.train()
    losses = []
    for batch in tqdm(train_loader):
        batch = batch.to(device)
        model.zero_grad()
        loss = 0
        for c, w in criteria.items():
            pred = model(
                init_pos=generate_init_pos(batch),
                edge_index=batch.perm_index,
                edge_attr=get_edge_features(batch.apsp_attr),
                batch_index=batch.batch,
            )
            pos = rescale_by_stress(pred, batch.apsp_attr, batch.perm_index, batch.batch)
            loss += w * c(pos, batch.perm_index, batch.apsp_attr, batch.batch)
        loss.backward()
        optim.step()
        losses.append(loss.item())
    scheduler.step()
    print(f'[Epoch {epoch}] Train Loss: {np.mean(losses)}')
    with torch.no_grad():
        model.eval()
        losses = []
        for batch in tqdm(val_loader, disable=True):
            batch = batch.to(device)
            loss = 0
            for c, w in criteria.items():
                pred = model(
                    init_pos=generate_init_pos(batch),
                    edge_index=batch.perm_index,
                    edge_attr=get_edge_features(batch.apsp_attr),
                    batch_index=batch.batch,
                )
                pos = rescale_by_stress(pred, batch.apsp_attr, batch.perm_index, batch.batch)
                loss += w * c(pos, batch.perm_index, batch.apsp_attr, batch.batch)
            losses.append(loss.item())
        print(f'[Epoch {epoch}] Val Loss: {np.mean(losses)}')

In [15]:
with torch.no_grad():
    model.eval()
    losses = []
    for batch in tqdm(test_loader, disable=True):
        batch = batch.to(device)
        loss = 0
        for c, w in criteria.items():
            pred = model(
                init_pos=generate_init_pos(batch),
                edge_index=batch.perm_index,
                edge_attr=get_edge_features(batch.apsp_attr),
                batch_index=batch.batch,
            )
            pos = rescale_by_stress(pred, batch.apsp_attr, batch.perm_index, batch.batch)
            loss += w * c(pos, batch.perm_index, batch.apsp_attr, batch.batch)
        losses.append(loss.item())
    print(f'Test Loss: {np.mean(losses)}')

Test Loss: 234.92233180999756
