In [6]:
### Dataloader for MoleculeNET dataset

import copy
import logging
import os
import pickle
import random
from math import log2

# from .create_dataset import MoleculesDataset
# from .utils import *

import matplotlib.pyplot as plt
import seaborn as sns
import torch.utils.data as data

def get_data(path):
    with open(path + '/adjacency_matrices.pkl', 'rb') as f:
        adj_matrices = pickle.load(f)

    with open(path + '/feature_matrices.pkl', 'rb') as f:
        feature_matrices = pickle.load(f)

    labels = np.load(path + '/labels.npy')

    return adj_matrices, feature_matrices, labels

def create_random_split(path):
    adj_matrices, feature_matrices, labels = get_data(path)

    # Random 80/10/10 split in the MoleculeNet dataset
    train_range = (0, int(0.8 * len(adj_matrices)))
    val_range = (int(0.8 * len(adj_matrices)), int(0.8 * len(adj_matrices)) + int(0.1 * len(adj_matrices)))
    test_range = (int(0.8 * len(adj_matrices)) + int(0.1 * len(adj_matrices)), len(adj_matrices))

    all_idxs = list(range(len(adj_matrices)))
    random.shuffle(all_idxs)

    train_adj_matrices = [adj_matrices[all_idxs[i]] for i in range(train_range[0], train_range[1])]
    train_feature_matrices = [feature_matrices[all_idxs[i]] for i in range(train_range[0], train_range[1])]
    train_labels = [labels[all_idxs[i]] for i in range(train_range[0], train_range[1])]

    val_adj_matrices = [adj_matrices[all_idxs[i]] for i in range(val_range[0], val_range[1])]
    val_feature_matrices = [feature_matrices[all_idxs[i]] for i in range(val_range[0], val_range[1])]
    val_labels = [labels[all_idxs[i]] for i in range(val_range[0], val_range[1])]

    test_adj_matrices = [adj_matrices[all_idxs[i]] for i in range(test_range[0], test_range[1])]
    test_feature_matrices = [feature_matrices[all_idxs[i]] for i in range(test_range[0], test_range[1])]
    test_labels = [labels[all_idxs[i]] for i in range(test_range[0], test_range[1])]

    return train_adj_matrices, train_feature_matrices, train_labels, \
           val_adj_matrices, val_feature_matrices, val_labels, test_adj_matrices, test_feature_matrices, test_labels


# Centralized training
def get_dataloader(path, compact=True, normalize_features=False, normalize_adj=False):
    train_adj_matrices, train_feature_matrices, train_labels, \
    val_adj_matrices, val_feature_matrices, val_labels, test_adj_matrices, test_feature_matrices, test_labels = create_random_split(
        path)

    train_dataset = MoleculesDataset(train_adj_matrices, train_feature_matrices, train_labels, path, compact=compact, split='train')
    val_dataset = MoleculesDataset(val_adj_matrices, val_feature_matrices, val_labels, path, compact=compact, split='val')
    test_dataset = MoleculesDataset(test_adj_matrices, test_feature_matrices, test_labels, path, compact=compact, split='test')

    collator = WalkForestCollator(normalize_features=normalize_features) if compact \
        else DefaultCollator(normalize_features=normalize_features, normalize_adj=normalize_adj)

    # BATCH SIZE = 1. EACH BATCH IS AN ENTIRE MOLECULE.
    train_dataloader = data.DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=collator, pin_memory=True)
    val_dataloader = data.DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=collator, pin_memory=True)
    test_dataloader = data.DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collator, pin_memory=True)

    return train_dataloader, val_dataloader, test_dataloader
