In [58]:
import os
import torch
from torch_geometric.utils import dense_to_sparse
from torch_geometric.data import Data, InMemoryDataset
from pathlib import Path
from src.datasets.spectre_dataset import SpectreGraphDataModule, SpectreDatasetInfos
from src.datasets.abstract_dataset import AbstractDataModule, AbstractDatasetInfos
from torch_geometric.loader import DataLoader

# Load training dataset
file_path='/home/bakirkhon/DiGress/training_dataset/training_dataset.pt'
raw_path='/home/bakirkhon/DiGress/training_dataset/raw'
processed_path='/home/bakirkhon/DiGress/training_dataset/processed'
all_graphs=torch.load(file_path)

# Split
num_graphs=len(all_graphs)
test_len=int(round(num_graphs*0.2))
train_len=int(round(num_graphs-test_len)*0.8)
val_len=num_graphs-test_len-train_len

g_cpu=torch.Generator()
g_cpu.manual_seed(0)

indices=torch.randperm(num_graphs,generator=g_cpu)
train_indices=indices[:train_len]
val_indices=indices[train_len:train_len+val_len]
test_indices=indices[train_len+val_len:]

train_data=[]
val_data=[]
test_data=[]

for i, graph in enumerate(all_graphs):
    graph['X'] = torch.tensor(graph['X'], dtype=torch.float)
    graph['E'] = torch.tensor(graph['E'], dtype=torch.float)
    if i in train_indices:
        train_data.append(graph)
    elif i in val_indices:
        val_data.append(graph)
    elif i in test_indices:
        test_data.append(graph)
    else:
        raise ValueError(f'Index {i} not in any split')

torch.save(train_data, os.path.join(raw_path, 'train.pt'))
torch.save(val_data, os.path.join(raw_path, 'val.pt'))
torch.save(test_data, os.path.join(raw_path,'test.pt'))

# Process
def process(dataset: str):
    raw_dataset=torch.load(os.path.join(raw_path,f'{dataset}.pt'))
    
    data_list=[]
    for graph in raw_dataset:
        X=graph['X']
        E=graph['E']
        n=X.shape[0]
        # first row=source nodes, second row=destination rows
        edge_index, _=dense_to_sparse((E.sum(-1)>0).float())
        edge_attr=E[edge_index[0],edge_index[1],:]
        num_nodes=n*torch.ones(1,dtype=torch.long)
        data=Data(x=X,edge_index=edge_index,edge_attr=edge_attr,n_nodes=num_nodes)
        data_list.append(data)
    torch.save(InMemoryDataset.collate(data_list),os.path.join(processed_path,f'{dataset}.pt'))

train_processed=process('train')
val_processed=process('val')
test_processed=process('test')

# Computes the empirical distribution of graph sizes (number of nodes) 
def node_counts(dataset: str, max_nodes_possible=50):
    processed_data,slices=torch.load(os.path.join(processed_path,f'{dataset}.pt'))
    # Wrap into a dataset
    class DummyDataset(InMemoryDataset):
        def __init__(self, data, slices):
            super().__init__()
            self.data=data
            self.slices = slices

    dataset=DummyDataset(processed_data, slices)
    loader=DataLoader(dataset, batch_size=512)
    all_counts=torch.zeros(max_nodes_possible)
    for data in loader:
        unique, counts=torch.unique(data.batch,return_counts=True)
        for count in counts:
            all_counts[count]+=1
    max_index=max(all_counts.nonzero())
    all_counts=all_counts[:max_index+1]
    all_counts=all_counts/all_counts.sum()
    return all_counts

train_nodes=node_counts('train')
val_nodes=node_counts('val')

# def node_types(dataset: str):
#     num_classes=None
#     processed_data,slices=torch.load(os.path.join(processed_path,f'{dataset}.pt'))
#     # Wrap into a dataset
#     class DummyDataset(InMemoryDataset):
#         def __init__(self, data, slices):
#             super().__init__()
#             self.data=data
#             self.slices = slices

#     dataset=DummyDataset(processed_data, slices)
#     loader=DataLoader(dataset, batch_size=512)
#     for data in loader:
#         num_classes=data.x.shape[1]
#         break

#     counts=torch.zeros(num_classes)

#     for i, data in enumerate(loader):
#         counts+=data.x.sum(dim=0)
    
#     counts=counts/counts.sum()
#     return counts

# train_node_types=node_types('train')

def edge_counts(dataset: str):
    num_classes=None
    processed_data,slices=torch.load(os.path.join(processed_path,f'{dataset}.pt'))
    # Wrap into a dataset
    class DummyDataset(InMemoryDataset):
        def __init__(self, data, slices):
            super().__init__()
            self.data=data
            self.slices = slices
    dataset=DummyDataset(processed_data, slices)
    loader=DataLoader(dataset, batch_size=512)
    for data in loader:
        num_classes=data.edge_attr.shape[1]
        break
    
    d=torch.zeros(num_classes,dtype=torch.float)

    for i,data in enumerate(loader):
        uniqie, counts=torch.unique(data.batch,return_counts=True)

        all_pairs=0
        for count in counts:
            all_pairs+=count*(count-1)

        num_edges=data.edge_index.shape[1]
        num_non_edges=all_pairs-num_edges
        
        edge_types=data.edge_attr.sum(dim=0)
        assert num_non_edges>=0
        d[0]+=num_non_edges
        d[1:]+=edge_types[1:]
    
    d=d/d.sum()
    print(d)

train_edges=edge_counts('train')

    

  graph['X'] = torch.tensor(graph['X'], dtype=torch.float)


tensor([0.9300, 0.0404, 0.0295])
