In [2]:
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import BatchNorm1d, ReLU, Linear, Sequential
import numpy as np
from torch_geometric.data import Data, InMemoryDataset, download_url
from torch_geometric.nn import GCNConv, global_mean_pool, GATConv, GINConv, global_add_pool
import os
from torch.utils.data import DataLoader
import networkx as nx
from sklearn.preprocessing import OrdinalEncoder, OneHotEncoder
from torch_geometric.transforms import NormalizeFeatures

In [6]:
class SuperPixDGL(torch.utils.data.Dataset):
    def __init__(self,
                 data_dir,
                 dataset,
                 split,
                 graph_format='edge_wt_only_coord'):
        assert graph_format in ['edge_wt_only_coord', 'edge_wt_coord_feat', 'edge_wt_region_boundary']
        self.split = split
        self.graph_lists = []

        with open(os.path.join(data_dir, 'sample/COCO_500sp_%s_superpixels.pkl' % split), 'rb') as f:
            self.superpixels = pickle.load(f)
        
        with open(os.path.join(data_dir, 'sample/COCO_500sp_%s.pkl' % split), 'rb') as f:
            self.labels, self.sp_data = pickle.load(f)
            self.sp_node_labels = self.labels
        
        if graph_format == 'edge_wt_region_boundary':
            with open(os.path.join(data_dir, 'sample/COCO_500sp_%s_rag_boundary_graphs.pkl' % split), 'rb') as f:
                self.region_boundary_graphs = pickle.load(f)

        self.graph_format = graph_format 
        self.n_samples = len(self.labels)
        
        self._prepare()
    
    def _prepare(self):
        print("preparing %d graphs for the %s set..." % (self.n_samples, self.split.upper()))
        self.Adj_matrices, self.node_features, self.edges_lists, self.edge_features = [], [], [], []
        for index, sample in enumerate(self.sp_data):
            mean_px, coord = sample[:2]
            
            try:
                coord = coord / self.img_size
            except AttributeError:
                VOC_has_variable_image_sizes = True
                
            if self.graph_format == 'edge_wt_coord_feat':
                A = compute_adjacency_matrix_images(coord, mean_px) # using super-pixel locations + features
                edges_list, edge_values_list = compute_edges_list(A) 
            elif self.graph_format == 'edge_wt_only_coord':
                A = compute_adjacency_matrix_images(coord, mean_px, False) # using only super-pixel locations
                edges_list, edge_values_list = compute_edges_list(A) 
            elif self.graph_format == 'edge_wt_region_boundary':
                A, edges_list, edge_values_list = None, None, None

            N_nodes = mean_px.shape[0]
            
            mean_px = mean_px.reshape(N_nodes, -1)
            coord = coord.reshape(N_nodes, 2)
            x = np.concatenate((mean_px, coord), axis=1)

            if edge_values_list is not None:
                edge_values_list = edge_values_list.reshape(-1) 
            
            self.node_features.append(x)
            self.edge_features.append(edge_values_list) 
            self.Adj_matrices.append(A)
            self.edges_lists.append(edges_list)
        
    def __len__(self):
        """Return the number of graphs in the dataset."""
        return self.n_samples

    def __getitem__(self, idx):
        
        if self.graph_format == 'edge_wt_region_boundary':
            if self.node_features[idx].shape[0] == 1:
                # handling for 1 node where the self loop would be the only edge
                # since, VOC Superpixels has few samples (5 samples) with only 1 node
                g = dgl.DGLGraph()
                g.add_nodes(self.node_features[idx].shape[0]) 
                g = dgl.add_self_loop(g)
                # dummy edge feat since no actual edge present
                g.edata['feat'] = torch.zeros(1, 2) # 1 edge and 2 feat dim
                self.Adj_matrices[idx] = g.adjacency_matrix().to_dense().numpy()
            else:
                g = dgl.from_networkx(self.region_boundary_graphs[idx].to_directed(),
                                  edge_attrs=['weight', 'count'])
                g.edata['feat'] = torch.stack((g.edata['weight'], g.edata['count']),-1)
                del g.edata['weight'], g.edata['count']
                self.Adj_matrices[idx] = g.adjacency_matrix().to_dense().numpy()
        else:
            g = dgl.DGLGraph()
            g.add_nodes(self.node_features[idx].shape[0])
            for src, dsts in enumerate(self.edges_lists[idx]):
                g.add_edges(src, dsts[dsts!=src])
                
        g.ndata['feat'] = torch.Tensor(self.node_features[idx])
        
        return g, self.sp_node_labels[idx]

In [9]:
data, val_dict = torch.load("/home/zluo/nn/lrgb/datasets/peptides-functional/processed/geometric_data_processed.pt")

In [12]:
data

Data(edge_index=[2, 4773974], edge_attr=[4773974, 3], x=[2344859, 9], y=[15535, 10])