In [3]:
import os
import shutil
import math 

import schnetpack as spk
from schnetpack.datasets import QM9
import schnetpack.transform as trn

import torch
import torchmetrics
import pytorch_lightning as pl
import numpy as np
from ase import Atoms

from torchmetrics.regression import MeanSquaredError, MeanAbsoluteError
from exclusion_transformer import ExclusionTransformer

from collections import defaultdict


In [4]:
atoms = {
    "carbon" : 6,
    "hydrogen" : 1,
    "oxygen" : 8,
    "nitrogen" : 7,
    "fluorine" : 9,
}

In [5]:
print(f"Getting True Values..")

qm9data = QM9(
    "qm9.db",
    batch_size=20,
    property_units={QM9.U0: 'eV'},
    num_workers=4,
    split_file="./splits/split.npz",
    pin_memory=False, # set to false, when not using a GPU
    load_properties=[QM9.U0], #only load U0 property
)
qm9data.prepare_data()
qm9data.setup()

INFO:root:Downloading GDB-9 atom references...


Getting True Values..


INFO:root:Done.
INFO:root:Downloading GDB-9 data...
INFO:root:Done.
INFO:root:Extracting files...
INFO:root:Done.
INFO:root:Parse xyz files...
INFO:root:Write atoms to db...
INFO:root:Done.


In [6]:
def get_desired_atom_indices(
        qm9dataset, 
        split_path='./splits/split.npz', 
        desired_atom="nitrogen"
    ):
     
        
    if desired_atom not in atoms:
        raise ValueError(f"Please select one of the following atoms: {', '.join(atoms.keys())}.")
    
    if not os.path.isfile(split_path):
        raise ValueError(f'{split_path} does not exist. ')

    desired_indices = {"train" : [], "val" : [], "test" : []}
    remaining_indices = {"train" : [], "val" : [], "test" : []}
    atomic_number = atoms[desired_atom]
    
    dataset_names = ["train", "val", "test"]
    datasets = [qm9dataset.train_dataset, qm9dataset.val_dataset, qm9dataset.test_dataset]
    for dataset_name, dataset in zip(dataset_names, datasets):
        try:
            for data in dataset:
                if atomic_number in data["_atomic_numbers"]:
                    desired_indices[dataset_name].append(data["_idx"].item())
                else:
                    remaining_indices[dataset_name].append(data["_idx"].item())
        except:
            pass

        print(f'\nTotal number of molecules that include {desired_atom} atom in {dataset_name} set is {len(desired_indices[dataset_name]):,}.')
        print(f'Total number of molecules that include {desired_atom} atom in {dataset_name} set is {len(remaining_indices[dataset_name]):,}.')

#     print('Recreating split file')
    
    split_file = np.load('./splits/split.npz')
    train_idx = split_file['train_idx']
    val_idx = split_file['val_idx']
    test_idx = split_file['test_idx']
    
    dataset_names = ["train", "val", "test"]
    indicies = [train_idx, val_idx, test_idx]

    for dataset_name, indx in zip(dataset_names, indicies):    
        remaining_indices[dataset_name] = np.array(list(set(remaining_indices[dataset_name]).intersection(set(indx))))
        desired_indices[dataset_name]   = np.array(list(set(desired_indices[dataset_name]).intersection(set(indx))))
        np.random.shuffle(remaining_indices[dataset_name])
        np.random.shuffle(desired_indices[dataset_name])
    
    return desired_indices, remaining_indices



In [14]:
def select_different_subsets(        
        qm9dataset, 
        split_path='./splits/split.npz', 
        desired_atom="nitrogen"
    ):
    
    desired_indices, remaining_indices = get_desired_atom_indices(qm9dataset, split_path=split_path, desired_atom=desired_atom)
        
    percentages = [1.0, 0.8, 0.6, 0.4, 0.2, 0.0]
    desired_by_percentage = {}
    remaining_by_percentage = {}
    for perc in percentages:
        desired_by_percentage[str(perc)] = {}
        remaining_by_percentage[str(perc)] = {}

    desired_by_percentage['1.0'] = {
            "train" : desired_indices['train'], 
            "val" : desired_indices['val'],
            "test" : desired_indices['test']
        }


    remaining_by_percentage['1.0'] = {
            "train" : remaining_indices['train'], 
            "val" : remaining_indices['val'],
            "test" : remaining_indices['test']
        }

    for i in range(len(percentages)-1):
        superset_perc = percentages[i]
        subset_perc = percentages[i+1]
        
        for ds_name in ("train", "val", "test"):
            total_desired_len = len(desired_indices[ds_name])
            total_remaining_len = len(remaining_indices[ds_name])
            desired_by_percentage[str(subset_perc)][ds_name] = np.random.choice( desired_by_percentage[str(superset_perc)][ds_name], size=int(subset_perc*total_desired_len))
            remaining_by_percentage[str(subset_perc)][ds_name] = np.random.choice( remaining_by_percentage[str(superset_perc)][ds_name], size=int(subset_perc*total_remaining_len))
    
    return desired_by_percentage, remaining_by_percentage

In [9]:
desired_by_percentage, remaining_by_percentage = select_different_subsets(qm9data)


Total number of molecules that include nitrogen atom in train set is 67,341.
Total number of molecules that include nitrogen atom in train set is 42,659.

Total number of molecules that include nitrogen atom in val set is 6,088.
Total number of molecules that include nitrogen atom in val set is 3,912.

Total number of molecules that include nitrogen atom in test set is 6,688.
Total number of molecules that include nitrogen atom in test set is 4,143.


In [15]:
def create_new_splits(
        train_desired_atom_perc = 0.8, 
        test_desired_atom_perc  = 0.8, 
        train_remaining_atom_perc = 1.0, 
        test_remaining_atom_perc  = 1.0, 
        verbose = True       
    ):
    
    selected_train = desired_by_percentage[str(train_desired_atom_perc)]["train"]
    selected_val   = desired_by_percentage[str(train_desired_atom_perc)]["val"]
    selected_test  = desired_by_percentage[str(test_desired_atom_perc)]["test"]
    
    remaining_train = remaining_by_percentage[str(train_remaining_atom_perc)]["train"]
    remaining_val   = remaining_by_percentage[str(train_remaining_atom_perc)]["val"]
    remaining_test  = remaining_by_percentage[str(test_remaining_atom_perc)]["test"]
    
    train_idx = np.concatenate((selected_train, remaining_train), axis=0)
    val_idx = np.concatenate((selected_val, remaining_val), axis=0)
    test_idx = np.concatenate((selected_test, remaining_test), axis=0)
    
    np.random.shuffle(train_idx)
    np.random.shuffle(val_idx)
    np.random.shuffle(test_idx)
    
    desired_atom = "nitrogen" # for now harcoded
    name_train_part = f'{int(100*train_desired_atom_perc)}_perc_of_{desired_atom}_{int(100*train_remaining_atom_perc)}_perc_of_remaining'
    name_test_part = f'{int(100*test_desired_atom_perc)}_perc_of_{desired_atom}_{int(100*test_remaining_atom_perc)}_perc_of_remaining'
    
    output_name =  f'splits/split_train_{name_train_part}_test_{name_test_part}.npz'
    np.savez(output_name, train_idx=train_idx, val_idx=val_idx, test_idx=test_idx)
    print(f'Split file is saved at {output_name}.')
    
    if verbose:
        print(f'Train and val sets include {int(100*train_desired_atom_perc)}% of the molecules with {desired_atom} and {int(100*train_remaining_atom_perc)}% of the remaining molecules in the original train and val sets, respectively.')
        print(f'Test set includes {int(100*test_desired_atom_perc)}% of the molecules with {desired_atom} and {int(100*test_remaining_atom_perc)}% of the remaining molecules in the original test set.')

        print(f'\nLen train_idx: {len(train_idx):,}')
        print(f'Len molecules with {desired_atom}: {len(selected_train):,} in train set')
        print(f'Len molecules without {desired_atom}: {len(remaining_train):,} in train set')

        print(f'\nLen val_idx: {len(val_idx):,}')
        print(f'Len molecules with {desired_atom}: {len(selected_val):,} in val set')
        print(f'Len molecules without {desired_atom}: {len(remaining_val):,} in val set')

        print(f'\nLen train_idx: {len(test_idx):,}')
        print(f'Len molecules with {desired_atom}: {len(selected_test):,} in train set')
        print(f'Len molecules without {desired_atom}: {len(remaining_test):,} in train set')
    return output_name

In [16]:
split_file = np.load('./splits/split.npz')
print( f'Len of train_idx: {len(split_file["train_idx"]) :,}.' ) 
print( f'Len of test_idx : {len( split_file["test_idx"]) :,}.' ) 
print( f'Len of val_idx  : {len( split_file["val_idx"]) :,}.' ) 

Len of train_idx: 110,000.
Len of test_idx : 10,831.
Len of val_idx  : 10,000.


In [13]:
for train_desired_atom_perc in [1.0, 0.8, 0.6, 0.4, 0.2, 0.0]:
    for test_desired_atom_perc in [1.0, 0.8, 0.6, 0.4, 0.2, 0.0]:
    
        print( f'\ntrain_desired_atom_perc: {int(100*train_desired_atom_perc)} and test_desired_atom_perc: {int(100*test_desired_atom_perc) ' )

        new_split_file_path = create_new_splits(
            train_desired_atom_perc=train_desired_atom_perc, 
            test_desired_atom_perc=test_desired_atom_perc, 
            verbose=False )

        split_file = np.load(new_split_file_path)

        print( f'\nLen of train_idx: {len(split_file["train_idx"]) :,}.' ) 
        print( f'Len of test_idx : {len( split_file["test_idx"]) :,}.' ) 
        print( f'Len of val_idx  : {len( split_file["val_idx"]) :,}.' ) 


train_desired_atom_perc: 100 and test_desired_atom_perc: 100. 
Split file is saved at split_files/split_train_100_perc_of_nitrogen_100_perc_of_remaining_test_100_perc_of_nitrogen_100_perc_of_remaining.npz.

Len of train_idx: 110,000.
Len of test_idx : 10,831.
Len of val_idx  : 10,000.

train_desired_atom_perc: 100 and test_desired_atom_perc: 80. 
Split file is saved at split_files/split_train_100_perc_of_nitrogen_100_perc_of_remaining_test_80_perc_of_nitrogen_100_perc_of_remaining.npz.

Len of train_idx: 110,000.
Len of test_idx : 9,493.
Len of val_idx  : 10,000.

train_desired_atom_perc: 100 and test_desired_atom_perc: 60. 
Split file is saved at split_files/split_train_100_perc_of_nitrogen_100_perc_of_remaining_test_60_perc_of_nitrogen_100_perc_of_remaining.npz.

Len of train_idx: 110,000.
Len of test_idx : 8,155.
Len of val_idx  : 10,000.

train_desired_atom_perc: 100 and test_desired_atom_perc: 40. 
Split file is saved at split_files/split_train_100_perc_of_nitrogen_100_perc_of_re

In [17]:
new_split_file_path = create_new_splits(1, 1, 1, 0, verbose=False )

split_file = np.load(new_split_file_path)

print( f'\nLen of train_idx: {len(split_file["train_idx"]) :,}.' ) 
print( f'Len of test_idx : {len( split_file["test_idx"]) :,}.' ) 
print( f'Len of val_idx  : {len( split_file["val_idx"]) :,}.' ) 

KeyError: '1'