## Reconstruction of structures from reconstructed moments

This notebook takes a bunch of structure from the QM9 dataset, creates the corresponding invariants and tries to reconstruct the corresponding molecule by first finding peaks in reconstruction from moments

In [1]:
import random
import json

import numpy as np
import pandas

import milad
from milad.play import asetools
from milad import invariants
from milad import reconstruct
from milad import zernike

import qm9_utils

In [2]:
random.seed(1234)
np.random.seed(1234)

In [3]:
from schnetpack import datasets

qm9data = datasets.QM9('data/qm9.db', download=True)
len(qm9data)

133885

The settings for our reconstruction

In [4]:
# Number of attempts to reconstruct per structure
num_attempts = 3

cutoff = 5

# Invariants to use
invs = invariants.read(invariants.COMPLEX_INVARIANTS) 

Let's load the subset of indices from the QM9 database we want to use

In [6]:
with open('data/qm9_subset.json', 'r') as subset:
    test_set = json.load(subset)
test_set = {int(key): value for key, value in test_set.items()}

In [7]:
columns=('QM9 ID', 'Size', 'Attempt #', 'RMSD', 'Result')

In [8]:
def do_reconstruction(
    qm9data, 
    test_set,
    num_attempts: int,
    with_species: bool,
    species=None
):
    results = []
    
    # Let's create the descriptor we're going to use
    descriptor = qm9_utils.create_descriptor(invs, cutoff, apply_cutoff=False)
    # Create and re-use the grid query used when finding peaks in the reconstructed moments
    grid_query = zernike.ZernikeMoments.create_reconstruction_query(
        zernike.ZernikeMoments.get_grid(31), invs.max_order)
    
    for size, indices in sorted(test_set.items(), reverse=True):
        size = int(size)
        print(f"Size {size}:")
        for idx in indices:
            print(f"Idx={idx}:", end='')

            # Get the system
            system = qm9data.get_atoms(idx=idx)
            max_radius = asetools.prepare_molecule(system) # Centre the molecule and find the max radius.

            
            # Prepare the milad data type
            milad_molecule = asetools.ase2milad(system)
            num_atoms = milad_molecule.num_atoms
            
            if not with_species:
                # Species all fixed
                milad_molecule.numbers = 1.
            
            # Calculate the ground truth
            fingerprint = descriptor(milad_molecule)

            for attempt in range(num_attempts):
                print(f"\t", end='')
                
                initial = qm9_utils.create_initial_atoms(num_atoms, max_radius, include_species=with_species)
                

                # Finally, optimise the structure wrt to the fingerprint
                result = reconstruct.find_iteratively(
                    descriptor,
                    fingerprint, # Target the fingerprint
                    initial,
                    find_species=with_species,
                    grid_query=grid_query,
                    minsep=0.55,
                    verbose=False,
                )
                
                print(f'{result.rmsd:5.5}', end='')
                
                results.append((idx, size, attempt, result.rmsd, result))

            print()
        
    return results

In [11]:
results = do_reconstruction(
    qm9data, 
    test_set,
    num_attempts, 
    with_species=True,
    species=qm9_utils.species,
)
with_species_frame = pandas.DataFrame(results, columns=columns)

Size 29:
Idx=57517:	0.032109	0.031845	0.025984
Idx=58098:	0.031155	0.044432	0.04432
Idx=58182:	0.057318	0.099554	0.038621
Size 27:
Idx=42138:	0.049128	0.042086	0.01399
Idx=57349:	3.1079e-10	0.046525	1.209e-05
Idx=57419:	0.05818	0.049564	0.065609
Size 26:
Idx=5805:	0.041691	0.042301	0.079548
Idx=5810:	0.03275	0.026102	0.03303
Idx=5850:	0.060578	0.020377	0.074838
Size 25:
Idx=36927:	0.023619	0.048505	0.045732
Idx=36945:	0.038764	0.061525	0.023479
Idx=36959:	0.039705	0.039692	0.031708
Size 24:
Idx=5806:	0.029976	0.030091	0.044223
Idx=5807:	4.3357e-08	0.037024	1.2116e-13
Idx=5808:	0.031051	0.051669	0.023312
Size 23:
Idx=1093:	0.1642	0.045891	0.046713
Idx=1103:	8.6153e-13	0.021567	7.4633e-08
Idx=1129:	6.8823e-09	2.7889e-09	3.755e-09
Size 22:
Idx=5796:	0.0317	0.50978	0.16201
Idx=5809:	0.047727	0.033303	0.016026
Idx=5812:	1.7314e-11	2.6648e-11	4.5969e-11
Size 21:
Idx=1091:	9.3609e-15	0.0082299	0.016177
Idx=1094:	0.041673	0.075947	0.016619
Idx=1095:	0.033138	0.0084655	0.033669
Size 20:
Idx=227

In [12]:
with_species_frame.to_pickle('structure_recovery_iterative_with_species.pickle')

In [9]:
results = do_reconstruction(
    qm9data, 
    test_set,
    num_attempts,
    with_species=False,
)
no_species_frame = pandas.DataFrame(results, columns=columns)

Size 29:
Idx=57517:	4.0123e-11	3.4408e-11	2.4334e-11
Idx=58098:	1.0915e-08	9.1729e-09	1.4363e-08
Idx=58182:	0.0090292	0.0099889	5.112e-13
Size 27:
Idx=42138:	4.9004e-09	7.3896e-09	6.2858e-08
Idx=57349:	1.1819e-11	2.8756e-11	1.2472e-11
Idx=57419:	0.041954	1.8774e-14	0.020114
Size 26:
Idx=5805:	0.00049622	0.0065494	1.1596e-14
Idx=5810:	3.1883e-08	2.8429e-08	3.226e-10
Idx=5850:	0.024731	0.035814	0.07438
Size 25:
Idx=36927:	0.013411	0.007343	2.4609e-14
Idx=36945:	4.8214e-08	6.4875e-10	1.8972e-10
Idx=36959:	1.4562e-05	0.00012678	1.4385e-10
Size 24:
Idx=5806:	9.5958e-15	7.8089e-15	6.0935e-15
Idx=5807:	5.515e-08	4.069e-08	5.515e-08
Idx=5808:	4.8672e-15	1.3867e-14	1.4423e-14
Size 23:
Idx=1093:	0.18946	2.2135e-09	2.106e-09
Idx=1103:	7.826e-08	7.826e-08	7.826e-08
Idx=1129:	8.5294e-09	6.15e-09	6.0581e-09
Size 22:
Idx=5796:	0.464	0.56905	1.1704e-13
Idx=5809:	2.0931e-14	9.8656e-15	0.024615
Idx=5812:	3.719e-11	9.0807e-12	2.1762e-11
Size 21:
Idx=1091:	7.5987e-15	4.1477e-08	1.5638e-14
Idx=1094:	3.517e

In [10]:
no_species_frame.to_pickle('structure_recovery_iterative_no_species.pickle')