In [1]:
import os, sys
import numpy as np
import qcportal as ptl
import pandas as pd
from collections import defaultdict
import torch
from loguru import logger


In [2]:
class QM9Dataset(torch.utils.data.Dataset):
    
    def __init__(self, dataset_file: str = None, name: str = ""):
        self._name = name
        self.dataset_file = dataset_file
        self.load(dataset_file)
        self.molecules = self.qm9.get_molecules()
        self.records = self.qm9.get_records(method="b3lyp")


    def load(self, dataset_file):
        """
        Loads the raw dataset from qcarchive.

        If a valid qcarchive generated hdf5 file is not pass to the
        set_raw_dataset_file function, the code will download the
        raw dataset from qcarchive.
        """
        qcp_client = ptl.FractalClient()
        qcportal_data = {"collection": "Dataset", "dataset": "QM9"}

        try:
            self.qm9 = qcp_client.get_collection(
                qcportal_data["collection"], qcportal_data["dataset"]
            )
        except Exception:
            print(
                f"Dataset {qcportal_data['dataset']} is not available in collection {qcportal_data['collection']}."
            )

        if dataset_file and os.path.isfile(dataset_file):
            if not dataset_file.endswith(".hdf5"):
                raise ValueError("Input file must be an .hdf5 file.")
            logger.debug(f'Loading from {dataset_file}')
            self.qm9.set_view(dataset_file)
        else:
            logger.debug(f'Downloading from qcportal')

            # to get QM9 from qcportal, we need to define which collection and QM9

            self.qm9.download(dataset_file)
            self.qm9.to_file(path=dataset_file, encoding="hdf5")
            self.qm9.set_view(self.dataset_file)

    def __len__(self):
        molecules = self.qm9.get_molecules()
        return molecules.shape[0]

    def __getitem__(self, idx):
        with h5py.File(self.hdf5_file, 'r') as f:
            geometry = torch.tensor(f[self.keys[idx]]['geometry'][:])
            energy = torch.tensor(f[self.keys[idx]]['energy'][()])
        
        if self.transform:
            geometry = self.transform(geometry)
        if self.target_transform:
            energy = self.target_transform(energy)
        
        return geometry, energy


In [3]:
data = QM9Dataset('test.hdf5')

[32m2023-08-01 14:37:25.667[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mload[0m:[36m34[0m - [34m[1mDownloading from qcportal[0m
160MB [01:26, 1.85MB/s]                            


AttributeError: module 'distutils' has no attribute 'version'

In [None]:
data.__len__()