## Recovery of species from decoded positions

In [1]:
import random
import json

from ase import visualize
import numpy as np
import pandas

import milad
from milad import atomic
from milad import generate
from milad.play import asetools
from milad import invariants

import qm9_utils

In [2]:
# Set the seeds for some kind of reproducibility
random.seed(1234)
np.random.seed(1234)

In [3]:
from schnetpack import datasets

qm9data = datasets.QM9('data/qm9.db', download=True)
len(qm9data)

133885

The settings for our reconstruction

In [5]:
# Number of attempts to reconstruct per structure
num_attempts = 3
# The multiple of the maximum radius to make the cutoff
cutoff = 5
# Optimiser to use
optimiser = milad.optimisers.StructureOptimiser() 
# Keep atoms from coming too close
optimiser.separation_force = atomic.SeparationForce(epsilon=1e-1, cutoff=0.55, power=2)
# Invariants to use
invs = invariants.read(invariants.COMPLEX_INVARIANTS) 

Let's load the subset of indices from the QM9 database we want to use

In [6]:
with open('data/qm9_subset.json', 'r') as subset:
    test_set = json.load(subset)
test_set = {int(key): value for key, value in test_set.items()}

In [7]:
def random_atom_collection(num_atoms: int, radius: float, numbers=None):
    atoms = milad.atomic.AtomsCollection(
        num_atoms,
        positions=generate.random_points_in_sphere(num_atoms, radius, centre=True))
    if numbers is not None:
        atoms.numbers = numbers
    
    return atoms

In [8]:
columns=('QM9 ID', 'Size', 'Attempt #', 'RMSD', 'Result')

# Load results from reconstruction of positions only
no_species = pandas.read_pickle('structure_recovery_iterative_no_species.pickle')

In [9]:
subframe = no_species.loc[(no_species['QM9 ID'] == 2)]
subframe.loc[subframe['RMSD'] == subframe['RMSD'].min()].iloc[0]['Result']

StructureOptimisationResult(success=True, message='`gtol` termination condition is satisfied.', value=<milad.atomic.AtomsCollection object at 0x7f5fcea29140>, rmsd=7.474060375117366e-12, n_func_eval=95, n_jac_eval=82, traj=None)

In [10]:
def do_reconstruction(qm9data, test_set, result_set, num_attempts, cutoff: float):
    results = []
    
    # Let's create the descriptor we're going to use
    descriptor = qm9_utils.create_descriptor(invs, cutoff=cutoff, apply_cutoff=False)


    for size, indices in sorted(test_set.items()):
        size = int(size)
        print(f"Size {size}:")
        for idx in indices:
            print(f"Idx={idx}:", end='')

            # Get the system
            system = qm9data.get_atoms(idx=idx)
            max_radius = asetools.prepare_molecule(system) # Centre the molecule and find the max radius.
            assert max_radius < cutoff

            # Prepare the milad data type
            milad_molecule = asetools.ase2milad(system)
            num_atoms = milad_molecule.num_atoms
            
            fingerprint = descriptor(milad_molecule)
            
            # Get the result with the lowest RMSD
            subframe = result_set.loc[(result_set['QM9 ID'] == idx)]
#             result = subframe.loc[subframe['RMSD'] == subframe['RMSD'].min()].iloc[0]['Result']
#             reconstructed = result.value
            
            num_attempts = len(subframe)

            for attempt in range(num_attempts):
                print(f"...", end='')
                
                result = subframe.loc[subframe['Attempt #'] == attempt].iloc[0]['Result']
                reconstructed = result.value
                initial = reconstructed

                # Fix the positions, we only want to get the species
                mask = reconstructed.get_mask()
                mask.positions = reconstructed.positions

                result = optimiser.optimise(
                    descriptor=descriptor,
                    # Target the original fingerprint
                    target=fingerprint,
                    initial=initial,
                    mask=mask,
                )
                
                mask = result.value.get_mask()
                mask.numbers = result.value.numbers
                
                result = optimiser.optimise(
                    descriptor=descriptor,
                    # Target the original fingerprint
                    target=fingerprint,
                    initial=result.value,
                    mask=mask,
                )
                
                print(f'{result.rmsd:5.5}', end='')
                results.append((idx, size, attempt, result.rmsd, result))

            print()
        
    return results

In [11]:
results = do_reconstruction(
    qm9data, 
    test_set,
    no_species,
    num_attempts, 
    cutoff=cutoff,
)
no_species_frame = pandas.DataFrame(results, columns=columns)

Size 3:
Idx=2:...6.9365e-12...2.6566e-11...2.2313e-11
Idx=4:...1.1914e-11...8.1818e-12...4.7391e-12
Size 4:
Idx=1:...6.0647e-12...6.2785e-11...1.7021e-11
Idx=3:...2.1772e-11...2.7587e-11...1.6058e-11
Idx=5:...0.029376...0.029376...0.029376
Size 5:
Idx=0:...7.1161e-12...7.4478e-12...1.8298e-11
Idx=23:...1.8428e-12...2.6847e-12...1.3268e-12
Idx=26:...6.222e-13...2.2192e-13...7.3494e-13
Size 6:
Idx=7:...3.4428e-13...1.2636e-12...4.6011e-13
Idx=9:...1.0562e-12...9.3905e-13...1.941e-10
Idx=11:...8.6076e-11...0.020546...0.020546
Size 7:
Idx=8:...2.8766e-10...2.8719e-10...8.2393e-13
Idx=10:...2.1087e-08...7.7909e-16...3.9487e-15
Idx=16:...4.2019e-12...7.92e-12...1.1279e-11
Size 8:
Idx=6:...1.7187e-11...4.1334e-11...1.6743e-11
Idx=19:...3.677e-11...4.1313e-11...3.3473e-11
Idx=31:...4.8836e-15...1.8565e-13...1.3666e-15
Size 9:
Idx=13:...1.1527e-10...5.981e-11...1.1548e-10
Idx=14:...7.9692e-12...2.2354e-12...2.7546e-12
Idx=15:...4.2583e-11...5.2244e-11...7.8627e-11
Size 10:
Idx=17:...1.478e-11..

In [12]:
no_species_frame.to_pickle('data/species_recovery_from_decoded_positions.pickle')