In [1]:
from torch_geometric.datasets import ZINC
import torch
import matplotlib.animation
import matplotlib.pyplot as plt
import networkx as nx
import torch_geometric
from tqdm import tqdm
from torch.nn import Linear, ReLU
from torch_geometric.nn import Sequential, GCNConv

from imitation.model.graph_diffusion import DiffusionOrderingNetwork, DenoisingNetwork
from benchmarks.GraphARM.utils import NodeMasking
from benchmarks.GraphARM.grapharm import GraphARM

from imitation.dataset.robomimic_graph_dataset import RobomimicGraphDataset

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# custom graph dataset for robomimic
dataset = RobomimicGraphDataset(dataset_path="/home/caio/workspace/GraphDiffusionImitate/data/square/ph/low_dim_v141.hdf5",
                                action_keys=['robot0_joint_vel'],
                                pred_horizon=4,
                                obs_horizon=2,
                                action_horizon=2,
                                object_state_sizes=[
                                    {"name": "nut_pos", "size": 3},
                                    {"name": "nut_quat", "size": 4},
                                    {"name": "nut_to_eef_pos", "size": 3},
                                    {"name": "nut_to_eef_quat", "size": 4}
                                ],
                                object_state_keys=["nut_pos"],
                                num_objects = 1,
                                mode="end-effector"
                                # mode = "joint-space"
                                )
dataloader = torch_geometric.data.DataLoader(dataset,
                                             batch_size=1, # does not work with batch_size > 1
                                             shuffle=True)

Processing...




100%|██████████| 200/200 [00:32<00:00,  6.18it/s]
Done!


In [23]:
# VAE model
from torch_geometric.nn import  VGAE
import torch.nn as nn
import torch.nn.functional as F

in_channels = dataset[0].x.size(-1)
out_channels = 2


model = VGAE(encoder=Sequential('x, edge_index', [
    (GCNConv(in_channels, 8), 'x, edge_index -> x'),
    ReLU(inplace=True),
    (GCNConv(8, 4), 'x, edge_index -> x'),
    ReLU(inplace=True),
    Linear(4, out_channels),
]),
    decoder = Sequential('x, edge_index', [
        (GCNConv(out_channels, 4), 'x, edge_index -> x'),
        ReLU(inplace=True),
        (GCNConv(4, 8), 'x, edge_index -> x'),
        ReLU(inplace=True),
        (GCNConv(8, in_channels), 'x, edge_index -> x'),
    ]))
model

VGAE(
  (encoder): Sequential(
    (0): GCNConv(4, 8)
    (1): ReLU(inplace=True)
    (2): GCNConv(8, 4)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=4, out_features=2, bias=True)
  )
  (decoder): Sequential(
    (0): GCNConv(2, 4)
    (1): ReLU(inplace=True)
    (2): GCNConv(4, 8)
    (3): ReLU(inplace=True)
    (4): GCNConv(8, 4)
  )
)

In [24]:

optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)
# reconcstruction loss + KL divergence
def loss_function(preds, labels, mu, logvar, n_nodes, norm):
    cost = norm * F.binary_cross_entropy_with_logits(preds, labels)
    KLD = -0.5 / n_nodes * torch.mean(torch.sum(1 + 2 * logvar - mu.pow(2) - logvar.exp().pow(2)))
    return cost + KLD

with tqdm(range(2)) as pbar:
    for epoch in pbar:
        for batch in dataloader:
            batch = batch.to('cpu')
            optimizer.zero_grad()
            # print(batch.x.float())
            # print(batch.edge_index.int())
            out = model.encoder(batch.x.float(), batch.edge_index.int())
            # print(out)
            out = model.decoder(out, batch.edge_index.int())
            loss = loss_function(out, batch.x.float(), out[0], out[1], batch.num_nodes, 1)
            loss.backward()
            nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)
            optimizer.step()
            # save model
            pbar.set_description(f"Loss: {loss.item():.4f}")
        torch.save(model.state_dict(), "model.pt")

Loss: 0.6865: 100%|██████████| 2/2 [01:43<00:00, 51.80s/it]


In [25]:
# sample from latent space
x_latent = torch.randn((8,out_channels))
edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6],
                           [1, 2, 3, 4, 5, 6, 7]])
sample = model.decode(x_latent, edge_index)
sample = sample.detach().numpy()
sample

array([[ 0.01632299, -0.31006137, -0.00354026, -0.24675733],
       [ 0.01859851, -0.28351384, -0.01827491, -0.28042677],
       [-0.11494303, -0.18433139, -0.11911406, -0.17048085],
       [-0.2286244 , -0.26207915, -0.24127063, -0.09794625],
       [-0.26696587, -0.30530635, -0.29923028, -0.11033972],
       [-0.2731024 , -0.2253134 , -0.2775925 , -0.14660078],
       [-0.2634911 , -0.12398656, -0.23169023, -0.17857623],
       [-0.23135492, -0.10476802, -0.1891623 , -0.16147882]],
      dtype=float32)

In [60]:
dataset.x = torch.tensor(sample)
dataset.edge_index = edge_index
dataset.edge_attr = torch.tensor([[1.0]]*len(edge_index[0]))

In [61]:
diff_ord_net = DiffusionOrderingNetwork(node_feature_dim=dataset[0].x.shape[1],
                                        num_edge_types=3,
                                        num_layers=3,
                                        out_channels=1)

# masker = NodeMasking(dataset)


denoising_net = DenoisingNetwork(
    node_feature_dim=dataset[0].x.shape[1],
    edge_feature_dim=dataset.num_edge_features,
    num_edge_types=3,
    num_layers=7,
    # hidden_dim=32
)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device {device}")

grapharm = GraphARM(
    dataset=dataset,
    denoising_network=denoising_net,
    diffusion_ordering_network=diff_ord_net,
    device=device
)


Using device cuda




In [62]:
graph = dataset[0]
new_node_type, new_connections = grapharm.predict_new_node(graph, sampling_method="sample", preprocess=False)

AttributeError: 'NoneType' object has no attribute 'reshape'