## Recovery of species from decoded positions

In [1]:
import random
import json

from ase import visualize
import numpy as np
import pandas

import milad
from milad import generate
from milad.play import asetools
from milad import invariants

import qm9_utils

In [2]:
# Set the seeds for some kind of reproducibility
random.seed(1234)
np.random.seed(1234)

In [3]:
from schnetpack import datasets

qm9data = datasets.QM9('data/qm9.db', download=True)
len(qm9data)

133885

The settings for our reconstruction

In [4]:
# Number of attempts to reconstruct per structure
num_attempts = 3
# The multiple of the maximum radius to make the cutoff
cutoff = 5
# Optimiser to use
optimiser = milad.optimisers.StructureOptimiser() 
# Invariants to use
invs = invariants.read(invariants.COMPLEX_INVARIANTS) 

Let's load the subset of indices from the QM9 database we want to use

In [5]:
with open('data/qm9_subset.json', 'r') as subset:
    test_set = json.load(subset)
test_set = {int(key): value for key, value in test_set.items()}

In [6]:
def random_atom_collection(num_atoms: int, radius: float, numbers=None):
    atoms = milad.atomic.AtomsCollection(
        num_atoms,
        positions=generate.random_points_in_sphere(num_atoms, radius, centre=True))
    if numbers is not None:
        atoms.numbers = numbers
    
    return atoms

In [7]:
columns=('QM9 ID', 'Size', 'Attempt #', 'RMSD', 'Result')

# Load results from reconstruction of positions only
no_species = pandas.read_pickle('structure_recovery_iterative_no_species.pickle')

In [8]:
subframe = no_species.loc[(no_species['QM9 ID'] == 2)]
subframe.loc[subframe['RMSD'] == subframe['RMSD'].min()].iloc[0]['Result']

StructureOptimisationResult(success=True, message='`gtol` termination condition is satisfied.', value=<milad.atomic.AtomsCollection object at 0x7fbcf4ab5a40>, rmsd=3.0097241787904286e-10, n_func_eval=54, n_jac_eval=48, traj=None)

In [9]:
def do_reconstruction(qm9data, test_set, result_set, num_attempts, cutoff: float):
    results = []
    
    # Let's create the descriptor we're going to use
    descriptor = qm9_utils.create_descriptor(invs, cutoff=cutoff, apply_cutoff=False)

    for size, indices in sorted(test_set.items()):
        size = int(size)
        print(f"Size {size}:")
        for idx in indices:
            print(f"Idx={idx}:", end='')

            # Get the system
            system = qm9data.get_atoms(idx=idx)
            max_radius = asetools.prepare_molecule(system) # Centre the molecule and find the max radius.
            assert max_radius < cutoff

            # Prepare the milad data type
            milad_molecule = asetools.ase2milad(system)
            num_atoms = milad_molecule.num_atoms
            
            fingerprint = descriptor(milad_molecule)
            
            # Get the result with the lowest RMSD
            subframe = result_set.loc[(result_set['QM9 ID'] == idx)]
#             result = subframe.loc[subframe['RMSD'] == subframe['RMSD'].min()].iloc[0]['Result']
#             reconstructed = result.value
            
            num_attempts = len(subframe)

            for attempt in range(num_attempts):
                print(f"...", end='')
                
                result = subframe.loc[subframe['Attempt #'] == attempt].iloc[0]['Result']
                reconstructed = result.value
                initial = reconstructed

                # Fix the positions, we only want to get the species
                mask = reconstructed.get_mask()
                mask.positions = reconstructed.positions

                result = optimiser.optimise(
                    descriptor=descriptor,
                    # Target the original fingerprint
                    target=fingerprint,
                    initial=initial,
                    mask=mask,
                )
                
                mask = result.value.get_mask()
                mask.numbers = result.value.numbers
                
                result = optimiser.optimise(
                    descriptor=descriptor,
                    # Target the original fingerprint
                    target=fingerprint,
                    initial=result.value,
                    mask=mask,
                )
                
                print(f'{result.rmsd:5.5}', end='')
                results.append((idx, size, attempt, result.rmsd, result))

            print()
        
    return results

In [10]:
results = do_reconstruction(
    qm9data, 
    test_set,
    no_species,
    num_attempts, 
    cutoff=cutoff,
)
no_species_frame = pandas.DataFrame(results, columns=columns)

Size 3:
Idx=2:...3.7185e-10...2.4183e-10...1.1963e-09
Idx=4:...9.4263e-10...6.3768e-10...7.0026e-10
Size 4:
Idx=1:...1.0015e-09...1.5294e-09...1.7431e-09
Idx=3:...1.069e-09...3.2806e-09...1.5106e-09
Idx=5:...0.00053155...0.00053155...0.00053155
Size 5:
Idx=0:...5.1461e-10...1.4424e-09...1.6949e-09
Idx=23:...2.1787e-10...2.5019e-10...6.4246e-10
Idx=26:...3.1493e-11...1.1068e-10...1.2488e-10
Size 6:
Idx=7:...1.4459e-09...3.3543e-09...6.1241e-09
Idx=9:...1.5327e-09...1.074e-09...1.0124e-09
Idx=11:...1.5868e-08...0.00034236...6.8772e-10
Size 7:
Idx=8:...3.6816e-10...4.3432e-10...4.1785e-10
Idx=10:...6.831e-10...1.372e-09...7.6467e-10
Idx=16:...2.6496e-10...7.7901e-10...7.4204e-10
Size 8:
Idx=6:...6.8154e-10...2.3183e-10...4.5636e-10
Idx=19:...1.9786e-09...3.1215e-09...3.9771e-09
Idx=31:...2.6181e-15...5.8936e-11...7.8847e-16
Size 9:
Idx=13:...3.7577e-10...1.7976e-09...4.7625e-10
Idx=14:...1.8923e-10...8.6453e-10...3.1239e-10
Idx=15:...1.3037e-09...4.6281e-10...2.0626e-09
Size 10:
Idx=17:..

In [11]:
no_species_frame.to_pickle('data/species_recovery_from_decoded_positions.pickle')

In [12]:
idx=4
attempt=0
subframe = no_species.loc[(no_species['QM9 ID'] == idx)]
result = subframe.loc[subframe['Attempt #'] == attempt].iloc[0]['Result']
print(result.value.numbers)
visualize.view(asetools.milad2ase(result.value))

[1. 1. 1.]


In [13]:
system = qm9data.get_atoms(idx=idx)
visualize.view(system)

In [14]:
subframe = no_species_frame.loc[(no_species_frame['QM9 ID'] == idx) & (no_species_frame['Attempt #'] == attempt)]
result = subframe.iloc[0]['Result']
visualize.view(asetools.milad2ase(result.value))

In [15]:
result.value.numbers

array([6., 7., 1.])