In [None]:
from typing import Callable, List, Optional
import torch
import pickle
from torch_geometric.data import Data, InMemoryDataset, Dataset, download_url
from torch_geometric.utils import index_sort
import os

In [None]:
class IMCAG(InMemoryDataset):
    url = 'https://github.com/username/repos_name/raw/branchname'
  
    def __init__(self, root, transform=None, pre_transform=None):
          super().__init__(root, transform, pre_transform)
          self.data, self.slices = torch.load(self.processed_paths[0])
    
    @property
    def raw_file_names(self):
        # The name of the files in the self.raw_dir folder that must be present in order to skip downloading.
        return ['data_1.pickle', 'data_2.pickle']

    @property
    def processed_file_names(self):
        # The name of the files in the self.processed_dir folder that must be present in order to skip processing.
        return ['data.pt']

    def download(self):
        for f in self.raw_file_names:
            download_url(os.path.join(self.url, f), self.raw_dir)
    
    def load_pickle(self, path: str):
        with open(path, 'rb') as f:
            data = pickle.load(f)
        return data

    def process(self):
        data_list = []
        files = [f for f in os.listdir(self.raw_dir) if not os.path.isdir(f)]
        for f in files:
            data_list.extend(self.load_pickle(os.path.join(self.raw_dir, f)))

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [None]:
class CAG(Dataset):
    url = 'https://github.com/username/repos_name/raw/branchname'
  
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
    
    @property
    def raw_file_names(self):
        # The name of the files in the self.raw_dir folder that must be present in order to skip downloading.
        return ['data_1.pickle', 'data_2.pickle']

    @property
    def processed_file_names(self):
         # The name of the files in the self.processed_dir folder that must be present in order to skip processing.
        return ['data.pt', ...]

    def download(self):
        for f in self.raw_file_names:
            download_url(os.path.join(self.url, f), self.raw_dir)
    
    def load_pickle(self, path: str):
        with open(path, 'rb') as f:
            data = pickle.load(f)
        return data

    def process(self):
        idx = 0
        for raw_path in self.raw_paths:
            data = self.load_pickle(raw_path)

            if self.pre_filter is not None and not self.pre_filter(data):
                continue

            if self.pre_transform is not None:
                data = self.pre_transform(data)
            
            for graph in data:
                torch.save(graph, os.path.join(self.processed_dir, f'data_{idx}.pt'))
                idx += 1

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(os.path.join(self.processed_dir, f'data_{idx}.pt'))
        return data