In [1]:
import os
import sys
sys.path.insert(0, os.path.dirname(os.getcwd()))

import torch
from datasets import QH9Stable, QH9Dynamic

  from .autonotebook import tqdm as notebook_tqdm


### Here is the statistics of the dataset

In [2]:
def get_hamiltonian_size(molecule_atoms):
    atom_mask_periodic_row1 = molecule_atoms <= 2
    atom_mask_periodic_row2 = molecule_atoms > 2
    num_orbitals = atom_mask_periodic_row2.sum() * 14 + (atom_mask_periodic_row1.sum()) * 2
    return num_orbitals


def get_dataset_statistic(ori_dataset):
    statistic_info = {}
    dataset_split_name = ['train', 'val', 'test']
    for split_name in dataset_split_name:
        statistic_info[split_name] = {}
        dataset = ori_dataset[getattr(ori_dataset, f'{split_name}_mask')]

        all_num_nodes = [data.num_nodes for data in dataset]
        all_num_nodes = torch.tensor(all_num_nodes).float()
        num_node_mean, num_node_min, num_node_max, num_node_median = \
            all_num_nodes.mean(), all_num_nodes.min(), all_num_nodes.max(), all_num_nodes.median()

        all_electronics = torch.tensor([data.atoms.sum() for data in dataset]).float()
        num_electronics_mean, num_electronics_min, num_electronics_max, num_electronics_median = \
            all_electronics.mean(), all_electronics.min(), all_electronics.max(), all_electronics.median()

        all_hamiltonian_matrix_size = [get_hamiltonian_size(data.atoms) for data in dataset]
        all_hamiltonian_matrix_size = torch.tensor(all_hamiltonian_matrix_size).float()
        hamiltonian_size_mean, hamiltonian_size_min, hamiltonian_size_max, hamiltonian_size_median = \
            all_hamiltonian_matrix_size.mean(), all_hamiltonian_matrix_size.min(), \
            all_hamiltonian_matrix_size.max(), all_hamiltonian_matrix_size.median()

        statistic_info[split_name]['num_node_mean'], statistic_info[split_name]['num_node_min'], \
            statistic_info[split_name]['num_node_max'], statistic_info[split_name]['num_node_median'] = \
            num_node_mean.item(), num_node_min.item(), num_node_max.item(), num_node_median.item()

        statistic_info[split_name]['num_electronics_mean'], statistic_info[split_name]['num_electronics_min'], \
            statistic_info[split_name]['num_electronics_max'], statistic_info[split_name]['num_electronics_median'] = \
            num_electronics_mean.item(), num_electronics_min.item(), num_electronics_max.item(), num_electronics_median.item()

        statistic_info[split_name]['hamiltonian_size_mean'], statistic_info[split_name]['hamiltonian_size_min'], \
            statistic_info[split_name]['hamiltonian_size_max'], statistic_info[split_name]['hamiltonian_size_median'], \
            = hamiltonian_size_mean.item(), hamiltonian_size_min.item(), hamiltonian_size_max.item(), hamiltonian_size_median.item()

    return statistic_info

In [4]:
dataset_stable_random = QH9Stable(root=os.path.join(os.sep.join(os.getcwd().split(os.sep)[:-1]), 'datasets'), split='random')
dataset_stable_random_statistic= get_dataset_statistic(dataset_stable_random)

In [3]:
dataset_stable_ood = QH9Stable(root=os.path.join(os.sep.join(os.getcwd().split(os.sep)[:-1]), 'datasets'), split='size_ood')
dataset_stable_ood_statistic = get_dataset_statistic(dataset_stable_ood)

Processing...
  0%|          | 113M/30.5G [00:19<22:44, 22.3MB/s]

In [None]:
dataset_dynamic_geo_100k = QH9Dynamic(root=os.path.join(os.sep.join(os.getcwd().split(os.sep)[:-1]), 'datasets'), version='100k', split='geometry')
dataset_dynamic_geo_100k_statistic = get_dataset_statistic(dataset_dynamic_geo_100k)

In [None]:
dataset_dynamic_mol_100k = QH9Dynamic(root=os.path.join(os.sep.join(os.getcwd().split(os.sep)[:-1]), 'datasets'), version='100k', split='mol')
dataset_dynamic_mol_100k_statistic = get_dataset_statistic(dataset_dynamic_mol_100k)

In [3]:
dataset_dynamic_geo_300k = QH9Dynamic(root=os.path.join(os.sep.join(os.getcwd().split(os.sep)[:-1]), 'datasets'), version='300k', split='geometry')
dataset_dynamic_geo_300k_statistic = get_dataset_statistic(dataset_dynamic_geo_300k)
dataset_dynamic_geo_300k_statistic

{'train': {'num_node_mean': 18.03936004638672,
  'num_node_min': 7.0,
  'num_node_max': 27.0,
  'num_node_median': 18.0,
  'num_electronics_mean': 65.87992095947266,
  'num_electronics_min': 24.0,
  'num_electronics_max': 74.0,
  'num_electronics_median': 66.0,
  'hamiltonian_size_mean': 141.5890655517578,
  'hamiltonian_size_min': 54.0,
  'hamiltonian_size_max': 162.0,
  'hamiltonian_size_median': 144.0},
 'val': {'num_node_mean': 18.03936004638672,
  'num_node_min': 7.0,
  'num_node_max': 27.0,
  'num_node_median': 18.0,
  'num_electronics_mean': 65.87992095947266,
  'num_electronics_min': 24.0,
  'num_electronics_max': 74.0,
  'num_electronics_median': 66.0,
  'hamiltonian_size_mean': 141.5890655517578,
  'hamiltonian_size_min': 54.0,
  'hamiltonian_size_max': 162.0,
  'hamiltonian_size_median': 144.0},
 'test': {'num_node_mean': 18.03936004638672,
  'num_node_min': 7.0,
  'num_node_max': 27.0,
  'num_node_median': 18.0,
  'num_electronics_mean': 65.87992095947266,
  'num_electronic

In [3]:
dataset_dynamic_mol_300k = QH9Dynamic(root=os.path.join(os.sep.join(os.getcwd().split(os.sep)[:-1]), 'datasets'), version='300k', split='mol')
dataset_dynamic_mol_300k_statistic = get_dataset_statistic(dataset_dynamic_mol_300k)
dataset_dynamic_mol_300k_statistic

{'train': {'num_node_mean': 18.015846252441406,
  'num_node_min': 7.0,
  'num_node_max': 27.0,
  'num_node_median': 18.0,
  'num_electronics_mean': 65.91242980957031,
  'num_electronics_min': 24.0,
  'num_electronics_max': 74.0,
  'num_electronics_median': 66.0,
  'hamiltonian_size_mean': 141.58465576171875,
  'hamiltonian_size_min': 54.0,
  'hamiltonian_size_max': 162.0,
  'hamiltonian_size_median': 144.0},
 'val': {'num_node_mean': 18.153846740722656,
  'num_node_min': 10.0,
  'num_node_max': 25.0,
  'num_node_median': 18.0,
  'num_electronics_mean': 65.71237182617188,
  'num_electronics_min': 34.0,
  'num_electronics_max': 74.0,
  'num_electronics_median': 66.0,
  'hamiltonian_size_mean': 141.17726135253906,
  'hamiltonian_size_min': 72.0,
  'hamiltonian_size_max': 158.0,
  'hamiltonian_size_median': 144.0},
 'test': {'num_node_mean': 18.112957000732422,
  'num_node_min': 9.0,
  'num_node_max': 25.0,
  'num_node_median': 18.0,
  'num_electronics_mean': 65.7873764038086,
  'num_elect

In [9]:
dataset_stable_random_statistic

{'train': {'num_node_mean': 18.023332595825195,
  'num_node_min': 3.0,
  'num_node_max': 29.0,
  'num_node_median': 18.0,
  'num_electronics_mean': 65.89612579345703,
  'num_electronics_min': 10.0,
  'num_electronics_max': 74.0,
  'num_electronics_median': 66.0,
  'hamiltonian_size_mean': 141.5970001220703,
  'hamiltonian_size_min': 18.0,
  'hamiltonian_size_max': 166.0,
  'hamiltonian_size_median': 144.0},
 'val': {'num_node_mean': 18.026752471923828,
  'num_node_min': 6.0,
  'num_node_max': 29.0,
  'num_node_median': 18.0,
  'num_electronics_mean': 65.90185546875,
  'num_electronics_min': 18.0,
  'num_electronics_max': 74.0,
  'num_electronics_median': 66.0,
  'hamiltonian_size_mean': 141.6219482421875,
  'hamiltonian_size_min': 36.0,
  'hamiltonian_size_max': 166.0,
  'hamiltonian_size_median': 144.0},
 'test': {'num_node_mean': 18.035158157348633,
  'num_node_min': 4.0,
  'num_node_max': 29.0,
  'num_node_median': 18.0,
  'num_electronics_mean': 65.8647232055664,
  'num_electronics

In [3]:
dataset_stable_ood = QH9Stable(root=os.path.join(os.sep.join(os.getcwd().split(os.sep)[:-1]), 'datasets'), split='size_ood')
dataset_stable_ood_statistic = get_dataset_statistic(dataset_stable_ood)
dataset_stable_ood_statistic

In [16]:
dataset_stable_ood.train_mask

array([], dtype=int64)