In [26]:
import collections
import random
import json

import numpy as np
import pandas
import pymatgen.io.ase
from pymatgen import symmetry

import milad
from milad.play import asetools
from milad import invariants
from milad import reconstruct
from milad import zernike

import qm9_utils

In [2]:
from schnetpack import datasets

qm9data = datasets.QM9('data/qm9.db', download=True)
len(qm9data)

133885

In [3]:
with open('data/qm9_subset.json', 'r') as subset:
    test_set = json.load(subset)
test_set = {int(key): value for key, value in test_set.items()}

In [58]:
def get_point_groups(
    qm9data, 
    test_set,
    ignore_species=False
):
    results = collections.defaultdict(dict)
    adapter = pymatgen.io.ase.AseAtomsAdaptor()
    
    for size, ids in test_set.items():
        for qm9id in ids:
            molecule = qm9data.get_atoms(idx=qm9id)
            if ignore_species:
                molecule.numbers[:] = 1.

            asetools.prepare_molecule(molecule)

            mol = adapter.get_molecule(molecule)
            mol.get_centered_molecule = lambda: mol
            
            analyser = symmetry.analyzer.PointGroupAnalyzer(mol)
            pg = analyser.get_pointgroup()

            results[size][qm9id] = pg
    
    return results

In [59]:
get_point_groups(qm9data, test_set, ignore_species=False)

defaultdict(dict,
            {5: {0: Td, 23: C*v, 26: Cs},
             4: {1: C3v, 3: D*h, 5: C2v},
             3: {2: C2v, 4: C*v},
             8: {6: D3d, 19: C2, 31: C1},
             6: {7: Cs, 9: C3v, 11: Cs},
             7: {8: C3v, 10: Cs, 16: C2v},
             11: {12: C2v, 44: Cs, 64: C1},
             9: {13: Cs, 14: C2v, 15: D3h},
             10: {17: C2v, 28: D3h, 29: Cs},
             14: {20: C3v, 38: C2h, 101: C1},
             12: {21: C1, 39: Cs, 40: Cs},
             17: {53: Td, 82: C1, 132: C2v},
             15: {54: Cs, 80: C1, 83: C1},
             13: {62: Cs, 67: Cs, 71: Cs},
             16: {218: C3v, 223: Cs, 226: Cs},
             18: {225: Cs, 228: C1, 229: Cs},
             20: {227: Cs, 273: C1, 290: C2h},
             19: {1081: C3v, 1083: Cs, 1087: C1},
             21: {1091: Cs, 1094: C1, 1095: Cs},
             23: {1093: Cs, 1103: Cs, 1129: C2},
             22: {5796: Cs, 5809: C1, 5812: C2h},
             26: {5805: C1, 5810: D3, 5850: C1}

In [42]:
def order_by_pointgroup(results: dict):
    # Expecting results to be a dict of [size, dict] where the second dict
    # contains [id, pointgroup] pairs
    ordered = collections.defaultdict(list)
    for size, entry in results.items():
        for qm9id, pg in entry.items():
            ordered[str(pg)].append(qm9id)
    return ordered

In [45]:
order_by_pointgroup(res)

defaultdict(list,
            {'Td': [0, 53],
             'C*v': [23, 4],
             'Cs': [26,
              7,
              11,
              10,
              44,
              13,
              29,
              39,
              40,
              54,
              62,
              67,
              71,
              223,
              226,
              225,
              229,
              227,
              1083,
              1091,
              1095,
              1093,
              1103,
              5796,
              5807,
              57349],
             'C3v': [1, 9, 8, 20, 218, 1081],
             'D*h': [3],
             'C2v': [5, 2, 16, 12, 14, 17, 132],
             'D3d': [6],
             'C2': [19, 1129, 36945, 36959, 57517, 58098],
             'C1': [31,
              64,
              101,
              21,
              82,
              80,
              83,
              228,
              273,
              1087,
              1094,
              

In [60]:
get_point_groups(qm9data, test_set, ignore_species=False)

defaultdict(dict,
            {5: {0: Td, 23: C*v, 26: Cs},
             4: {1: C3v, 3: D*h, 5: C2v},
             3: {2: C2v, 4: C*v},
             8: {6: D3d, 19: C2, 31: C1},
             6: {7: Cs, 9: C3v, 11: Cs},
             7: {8: C3v, 10: Cs, 16: C2v},
             11: {12: C2v, 44: Cs, 64: C1},
             9: {13: Cs, 14: C2v, 15: D3h},
             10: {17: C2v, 28: D3h, 29: Cs},
             14: {20: C3v, 38: C2h, 101: C1},
             12: {21: C1, 39: Cs, 40: Cs},
             17: {53: Td, 82: C1, 132: C2v},
             15: {54: Cs, 80: C1, 83: C1},
             13: {62: Cs, 67: Cs, 71: Cs},
             16: {218: C3v, 223: Cs, 226: Cs},
             18: {225: Cs, 228: C1, 229: Cs},
             20: {227: Cs, 273: C1, 290: C2h},
             19: {1081: C3v, 1083: Cs, 1087: C1},
             21: {1091: Cs, 1094: C1, 1095: Cs},
             23: {1093: Cs, 1103: Cs, 1129: C2},
             22: {5796: Cs, 5809: C1, 5812: C2h},
             26: {5805: C1, 5810: D3, 5850: C1}

In [61]:
get_point_groups(qm9data, test_set, ignore_species=True)

defaultdict(dict,
            {5: {0: Td, 23: D*h, 26: Cs},
             4: {1: C3v, 3: D*h, 5: C2v},
             3: {2: C2v, 4: D*h},
             8: {6: D3d, 19: C2, 31: C1},
             6: {7: Cs, 9: C3v, 11: C2h},
             7: {8: C3v, 10: Cs, 16: C2v},
             11: {12: C2v, 44: Cs, 64: C1},
             9: {13: Cs, 14: C2v, 15: D3h},
             10: {17: C2v, 28: D3h, 29: Cs},
             14: {20: C3v, 38: C2h, 101: C1},
             12: {21: C1, 39: Cs, 40: Cs},
             17: {53: Td, 82: C1, 132: C2v},
             15: {54: Cs, 80: C1, 83: C1},
             13: {62: Cs, 67: Cs, 71: Cs},
             16: {218: C3v, 223: Cs, 226: Cs},
             18: {225: Cs, 228: C1, 229: Cs},
             20: {227: Cs, 273: C1, 290: C2h},
             19: {1081: C3v, 1083: Cs, 1087: C1},
             21: {1091: Cs, 1094: C1, 1095: Cs},
             23: {1093: Cs, 1103: Cs, 1129: C2},
             22: {5796: Cs, 5809: C1, 5812: C2h},
             26: {5805: C1, 5810: D3, 5850: C1