In [9]:
from torch_geometric.datasets import Planetoid, CitationFull, NELL
from Proposed.proposed_dataset import ProposedDataset
import torch_geometric.transforms as T
from torch_geometric.data import InMemoryDataset, Data
import torch
from torch_sparse import SparseTensor


DATASET_ROOT_FOLDER = "Datasets"
DATASET_ROOT_FOLDER_FOR_NELL = "Datasets/NELL"

def get_citeseer_dataset():
    dataset = Planetoid(root = DATASET_ROOT_FOLDER,
                        name= "CiteSeer",
                        split='random')
    
    return random_split(dataset)

def get_cora_dataset():
    dataset = Planetoid(root = DATASET_ROOT_FOLDER,
                        name= "Cora",
                        split='random')
    
    return random_split(dataset)


def get_pubmed_dataset():
    dataset = Planetoid(root = DATASET_ROOT_FOLDER,
                        name= "PubMed",
                        split='random')
    
    return random_split(dataset)


def get_nell_dataset():

    transform = T.RandomNodeSplit(split='random')

    dataset = NELL(root = DATASET_ROOT_FOLDER, transform=transform)
    return dataset

    data = Data(x=dataset[0].x.to_dense(),
                edge_index=dataset[0].edge_index,
                y=dataset[0].y)
    
    dataset = InMemoryNellDataset(root = DATASET_ROOT_FOLDER_FOR_NELL, data=data)
    
    return random_split(dataset)

def get_in_memeory_nell_dataset():
    dataset = InMemoryNellDataset(root = DATASET_ROOT_FOLDER_FOR_NELL)
    
    return random_split(dataset, is_nell=True)


def get_proposed_dataset():
    dataset = ProposedDataset(root = DATASET_ROOT_FOLDER)
    
    return dataset


def random_split(data, num_train_per_class: int = 20, num_val: int = 500, is_nell=False):
    data.train_mask.fill_(False)
    for c in range(data.num_classes):
        num_train_per_class = 2 if is_nell else num_train_per_class
        idx = (data.y == c).nonzero(as_tuple=False).view(-1)
        idx = idx[torch.randperm(idx.size(0))[:num_train_per_class]]
        data.train_mask[idx] = True

    remaining = (~data.train_mask).nonzero(as_tuple=False).view(-1)
    remaining = remaining[torch.randperm(remaining.size(0))]

    data.val_mask.fill_(False)
    data.val_mask[remaining[:num_val]] = True

    data.test_mask.fill_(False)
    data.test_mask[remaining[num_val:]] = True

    return data




import torch
from torch_geometric.data import InMemoryDataset, Data


class InMemoryNellDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        #self.load(self.processed_paths[0])
        
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['file.edges', 'file.x']

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        pass

    def process(self):
        
        dataset = get_nell_dataset()
        
        y  = dataset[0].y
        y_unique = y.unique()
        to_be_replaced_with = torch.arange(0, y_unique.shape[0])
        for i in range(y_unique.shape[0]):
            y[y==y_unique[i]] = to_be_replaced_with[i]
        
        
        data = Data(x=dataset[0].x.to_dense(),
                edge_index=dataset[0].edge_index,
                y=y)
        data.train_mask = torch.zeros(dataset[0].num_nodes, dtype=torch.bool)
        data.test_mask = torch.zeros(dataset[0].num_nodes, dtype=torch.bool)
        data.val_mask = torch.zeros(dataset[0].num_nodes, dtype=torch.bool)
        
        
        data_list = [data]

        #self.save(data, self.processed_paths[0])
        torch.save(self.collate(data_list), self.processed_paths[0])

In [10]:
in_memory_nell_dataset = get_in_memeory_nell_dataset()

Processing...
Done!


In [11]:
in_memory_nell_dataset[0]

Data(x=[65755, 61278], edge_index=[2, 251550], y=[65755], train_mask=[65755], test_mask=[65755], val_mask=[65755])

In [12]:
in_memory_nell_dataset[0].y.unique()

tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104])