In [21]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [22]:
import pandas as pd
import numpy as np
import torch
import torch_geometric as pyg
from tqdm.auto import *

from deepgd.model import Generator
from deepgd.data import GraphDrawingData
from deepgd.datasets import  RomeDataset

from deepgd.metrics import Stress

In [23]:
np.random.seed(721)
torch.manual_seed(721)
torch.cuda.manual_seed_all(721)

In [24]:
device = "cpu"
for backend, device_name in {
    torch.backends.mps: "mps",
    torch.cuda: "cuda",
}.items():
    if backend.is_available():
        device = device_name

In [25]:
batch_size = 32
lr = 0.001
decay = 0.998

====================================================================================

##### ErdosRenyi Dataset

In [7]:
from deepgd.datasets import  ErdosRenyiDataset
from torch_geometric.data import DataLoader
import torch
import networkx as nx
import numpy as np

dataset = ErdosRenyiDataset(
    root="assets/erdorenyi",
    node_sizes=[20, 40, 80],
    probabilities=[0.2, 0.4, 0.6, 0.8],
    num_graphs_per_combination=10,
    datatype=GraphDrawingData
)

for data in dataset:
    data.pos = torch.rand((data.num_nodes, 2)) 

loader = DataLoader(dataset, batch_size=32, shuffle=True)


Transform graphs:   0%|          | 0/1 [00:00<?, ?it/s]

AttributeError: 'NoneType' object has no attribute '_transforms'

##### Rome Dataset

In [26]:
GraphDrawingData.set_optional_fields(["edge_pair_metaindex", "face", "rng"])
dataset = RomeDataset(
    index=pd.read_csv("assets/rome_index.txt", header=None)[0],
)
print(len(dataset))
layouts = np.load("assets/layouts/pmds.npy", allow_pickle=True)

inside init
raw_file_names
inside _parse_metadata


  if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):
  if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):
Processing...


processed_file_names
inside process
inside generate
processed_file_names
processed_file_names
processed_file_names
processed_file_names


Done!
  self.data, self.slices = torch.load(self.data_path)


Transform graphs:   0%|          | 0/118 [00:00<?, ?it/s]

Graph erdos_renyi_20_0.2_graph_4 not found in data_dict
Graph erdos_renyi_20_0.2_graph_7 not found in data_dict
118


====================================================================================

In [27]:

params = Generator.Params(
    num_blocks=11,
    block_depth=3,
    block_width=8,
    block_output_dim=8,
    edge_net_depth=2,
    edge_net_width=16,
    edge_attr_dim=2,
    node_attr_dim=2,
)
model = Generator(
    params=params,
).to(device)
criteria = {
    Stress(): 1,
    # dgd.EdgeVar(): 0,
    # dgd.Occlusion(): 0,
    # dgd.IncidentAngle(): 0,
    # dgd.TSNEScore(): 0,
}

In [28]:
optim = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=decay)

In [30]:
datalist = list(dataset)
for i, data in enumerate(datalist):
    data.pos = torch.tensor(layouts[i]).float()

In [31]:
train_loader = pyg.loader.DataLoader(datalist[:10000], batch_size=batch_size, shuffle=True)
val_loader = pyg.loader.DataLoader(datalist[11000:], batch_size=batch_size, shuffle=False)
# test_loader = pyg.loader.DataLoader(datalist[10000:11000], batch_size=batch_size, shuffle=False)
test_loader = pyg.loader.DataLoader(datalist, batch_size=batch_size, shuffle=False)

In [36]:
def generate_init_pos(batch):
    # pos = torch.rand_like(batch.pos)
    pos = rescale_by_stress(
        pos=batch.pos,
        apsp=batch.apsp_attr,
        edge_index=batch.perm_index,
        batch_index=batch.batch,
    )
    return pos

def get_edge_features(all_pair_shortest_path):
    return torch.cat([
        all_pair_shortest_path[:, None],
        1 / all_pair_shortest_path[:, None].square()
    ], dim=-1)

def rescale_by_stress(pos, apsp, edge_index, batch_index):
    src_pos, dst_pos = pos[edge_index[0]], pos[edge_index[1]]
    dist = (dst_pos - src_pos).norm(dim=1)
    u_over_d = dist / apsp
    scatterd_u_over_d_2 = pyg.utils.scatter(u_over_d ** 2, batch_index[edge_index[0]])
    scatterd_u_over_d = pyg.utils.scatter(u_over_d, batch_index[edge_index[0]])
    scale = scatterd_u_over_d_2 / scatterd_u_over_d
    return pos / scale[batch_index][:, None]

In [37]:
model.load_state_dict(torch.load("model_359.pt", map_location=device))

  model.load_state_dict(torch.load("model_359.pt", map_location=device))


<All keys matched successfully>

In [38]:
# for epoch in range(1000):
for epoch in range(0):
    model.train()
    losses = []
    for batch in tqdm(train_loader):
        batch = batch.to(device)
        model.zero_grad()
        loss = 0
        for c, w in criteria.items():
            pred = model(
                init_pos=generate_init_pos(batch),
                edge_index=batch.perm_index,
                edge_attr=get_edge_features(batch.apsp_attr),
                batch_index=batch.batch,
            )
            pos = rescale_by_stress(pred, batch.apsp_attr, batch.perm_index, batch.batch)
            loss += w * c(pos, batch.perm_index, batch.apsp_attr, batch.batch)
        loss.backward()
        optim.step()
        losses.append(loss.item())
    scheduler.step()
    print(f'[Epoch {epoch}] Train Loss: {np.mean(losses)}')
    with torch.no_grad():
        model.eval()
        losses = []
        for batch in tqdm(val_loader, disable=True):
            batch = batch.to(device)
            loss = 0
            for c, w in criteria.items():
                pred = model(
                    init_pos=generate_init_pos(batch),
                    edge_index=batch.perm_index,
                    edge_attr=get_edge_features(batch.apsp_attr),
                    batch_index=batch.batch,
                )
                pos = rescale_by_stress(pred, batch.apsp_attr, batch.perm_index, batch.batch)
                loss += w * c(pos, batch.perm_index, batch.apsp_attr, batch.batch)
            losses.append(loss.item())
        print(f'[Epoch {epoch}] Val Loss: {np.mean(losses)}')

================================ DEBUG ================================

NameError: name 'pos' is not defined

In [39]:
with torch.no_grad():
    model.eval()
    losses = []
    for batch in tqdm(test_loader, disable=True):
        batch = batch.to(device)
        loss = 0
        for c, w in criteria.items():
            print(f"pos.shape: {pos.shape}")
            print(f"scale.shape: {scale.shape}")
            print(f"batch_index.shape: {batch_index.shape}")
            print(f"Unique batch_index: {batch_index.unique().numel()}")

            pred = model(
                init_pos=generate_init_pos(batch),
                edge_index=batch.perm_index,
                edge_attr=get_edge_features(batch.apsp_attr),
                batch_index=batch.batch,
            )
            pos = rescale_by_stress(pred, batch.apsp_attr, batch.perm_index, batch.batch)
            loss += w * c(pos, batch.perm_index, batch.apsp_attr, batch.batch)
        losses.append(loss.item())
    print(f'Test Loss: {np.mean(losses)}')

RuntimeError: The size of tensor a (1535) must match the size of tensor b (640) at non-singleton dimension 0