## Reconstruction of structures from reconstructed moments

This notebook takes a bunch of structure from the QM9 dataset, creates the corresponding invariants and tries to reconstruct the corresponding molecule by first finding peaks in reconstruction from moments

In [1]:
import random
import json

import numpy as np
import pandas

import milad
from milad.play import asetools
from milad import invariants
from milad import reconstruct

import qm9_utils

In [2]:
random.seed(1234)
np.random.seed(1234)

In [3]:
from schnetpack import datasets

qm9data = datasets.QM9('data/qm9.db', download=True)
len(qm9data)

133885

The settings for our reconstruction

In [4]:
# Number of attempts to reconstruct per structure
num_attempts = 3

cutoff = 5

# Invariants to use
invs = invariants.read(invariants.COMPLEX_INVARIANTS) 

In [5]:
# Optimisers to use
structure_optimiser = milad.optimisers.StructureOptimiser() 
moments_optimiser = milad.optimisers.MomentsOptimiser() 
lstsq_optimiser = milad.optimisers.LeastSquaresOptimiser() 

Let's load the subset of indices from the QM9 database we want to use

In [6]:
with open('data/qm9_subset.json', 'r') as subset:
    test_set = json.load(subset)
test_set = {int(key): value for key, value in test_set.items()}

In [7]:
columns=('QM9 ID', 'Size', 'Attempt #', 'RMSD', 'Result')

In [15]:
def do_reconstruction(
    qm9data, 
    test_set,
    num_attempts: int,
    with_species: bool,
    species
):
    results = []
    
    # Let's create the descriptor we're going to use
    descriptor = qm9_utils.create_descriptor(invs, cutoff, apply_cutoff=False)
    
    for size, indices in sorted(test_set.items(), reverse=True):
        size = int(size)
        print(f"Size {size}:")
        for idx in indices:
            print(f"Idx={idx}:", end='')

            # Get the system
            system = qm9data.get_atoms(idx=idx)
            max_radius = asetools.prepare_molecule(system) # Centre the molecule and find the max radius.

            
            # Prepare the milad data type
            milad_molecule = asetools.ase2milad(system)
            num_atoms = milad_molecule.num_atoms
            
            if not with_species:
                # Species all fixed
                milad_molecule.numbers = 1.
            
            # Calculate the ground truth
            fingerprint = descriptor(milad_molecule)

            for attempt in range(num_attempts):
                print(f"...", end='')
                
                initial = milad.atomic.random_atom_collection_in_sphere(
                    num_atoms, max_radius, centre=True)
                # Set the species appropriately
                initial.numbers = np.array(random.choices(species, k=num_atoms)) if with_species else 1.

                # Finally, optimise the structure wrt to the fingerprint
                result = reconstruct.find_iteratively(
                    descriptor,
                    fingerprint, # Target the fingerprint
                    num_atoms,
                    initial,
                    find_species=with_species,
                    verbose=False,
                )
                
                print(f'{result.rmsd:5.5}', end='')
                
                results.append((idx, size, attempt, result.rmsd, result))

            print()
        
    return results

In [9]:
results = do_reconstruction(
    qm9data, 
    test_set,
    num_attempts,
    with_species=False,
)
no_species_frame = pandas.DataFrame(results, columns=columns)

Size 29:
Idx=57517:...9.9235e-10...1.823e-09...3.4432e-09
Idx=58098:...1.4395e-08...1.2601e-08...1.127e-08
Idx=58182:...2.4972e-11...3.0174e-12...0.0089994
Size 27:
Idx=42138:...1.9077e-07...1.846e-07...1.8451e-07
Idx=57349:...0.032259...0.032371...3.5915e-08
Idx=57419:...0.043991...0.00059066...1.1894e-12
Size 26:
Idx=5805:...2.2202e-13...0.006529...1.4644e-12
Idx=5810:...5.8025e-08...3.206e-08...2.1161e-08
Idx=5850:...0.07384...0.035701...0.042092
Size 25:
Idx=36927:...0.0073437...0.013412...5.5853e-13
Idx=36945:...3.6285e-08...1.3966e-09...4.6081e-10
Idx=36959:...2.0911e-05...0.00024014...2.1334e-05
Size 24:
Idx=5806:...4.8795e-15...1.858e-14...1.2294e-12
Idx=5807:...5.5151e-08...1.4444e-11...2.2878e-10
Idx=5808:...0.0040641...0.0040353...9.2665e-14
Size 23:
Idx=1093:...0.1855...2.8843e-09...0.18552
Idx=1103:...7.6492e-10...1.2016e-09...7.8261e-08
Idx=1129:...6.0579e-09...7.31e-09...3.3898e-09
Size 22:
Idx=5796:...0.22505...0.22873...0.22506
Idx=5809:...6.5509e-13...0.025523...0.015

In [10]:
no_species_frame.to_pickle('structure_recovery_iterative_no_species.pickle')

In [16]:
results = do_reconstruction(
    qm9data, 
    test_set,
    num_attempts, 
    with_species=True,
    species=qm9_utils.species,
)
with_species_frame = pandas.DataFrame(results, columns=columns)

Size 29:
Idx=57517:...0.006377...0.013607...5.287e-05
Idx=58098:...0.0063496...0.006336...0.0057216
Idx=58182:...0.016476...0.0064672...0.01497
Size 27:
Idx=42138:...0.0075256...0.00014443...0.00018096
Idx=57349:...0.046303...0.012052...0.033991
Idx=57419:...0.0068903...0.030577...0.030936
Size 26:
Idx=5805:...0.00038371...0.00021989...0.0068823
Idx=5810:...0.0067983...5.5864e-05...0.0060944
Idx=5850:...0.045572...0.0070716...0.064043
Size 25:
Idx=36927:...0.0079092...0.00072325...0.017752
Idx=36945:...3.3114e-10...0.0077763...4.8575e-10
Idx=36959:...0.0080262...0.016316...0.0078236
Size 24:
Idx=5806:...0.00043183...0.015271...0.008502
Idx=5807:...0.0075291...0.00015913...0.030838
Idx=5808:...0.007524...0.007463...0.0076204
Size 23:
Idx=1093:...0.0078143...0.00014179...3.6722e-09
Idx=1103:...0.0074871...4.8379e-05...0.015134
Idx=1129:...3.7512e-09...0.0078103...0.0077937
Size 22:
Idx=5796:...0.15956...0.17293...0.00037842
Idx=5809:...0.00080488...0.0080117...0.012038
Idx=5812:...0.0077

In [17]:
with_species_frame.to_pickle('structure_recovery_iterative_with_species.pickle')