https://www.repository.cam.ac.uk/handle/1810/307452

The first time you run this notebook, uncomment and execute the following cell

In [None]:
#!wget https://www.repository.cam.ac.uk/bitstream/handle/1810/307452/Carbon_GAP_20.tgz
#!tar -xzvf Carbon_GAP_20.tgz

In [18]:
from ase.io import iread
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from scipy.spatial import distance_matrix
import numpy as np
import dgl
from ase.calculators.mopac import MOPAC
from ase import Atoms

In [18]:
# check with dummy example
# potential energy
# see what software caarbon gap paper uses to compute

In [None]:
# create torch dataset like in class
# choose neural net (one round of mpnn and mlp predictor hw8 network) and train (use cs machine and wandb)
# read paper more throughly, what they do and what they look at with this dataset

In [None]:
# dataset[4782][0] is a single point

In [128]:
class GraphDataset(Dataset):
    def __init__(self):
        self.xyz = []
        self.E = []
        self.graph = []
        for mol in iread('Carbon_GAP_20/Carbon_GAP_20_Training_Set.xyz'):
            self.xyz.append(mol.get_positions())
            self.E.append(mol.get_potential_energy())
    
    def nearest_neighbors(self, g, m, k):
        '''
            g --> (3) one coordinate used as reference point
            m --> (x,3) whole molecule geometry
            k --> (1) number of nearest neighbors to be found
            - assumes g is in m so the first closest neighbor is excluded
            - takes the max amount of neighbors if k is greater than total atoms in a molecule
            - if a molecule is also a single atom, it will be its own neighbor
            - calculates the actual neighbors, the distances, and their indices in the list of atoms
        '''
        if k >= len(m):
            k = len(m)-1
        
        dist = distance_matrix([g], m)
        
        if len(m)==1:
        # if single atom, itself will be its neighbor
            k = 1
            indices = np.argpartition(dist[0], range(k))[:k] 
        else:
            indices = np.argpartition(dist[0], range(k+1))[1:k+1] # excludes first closest neighbor (itself)
        
        k_nearest = []
        k_dist = []
        for idx in indices:
            k_nearest.append(m[idx])
            k_dist.append(dist[0][idx])
        return k_nearest, k_dist, indices

    def featurize_atoms(self, molecule):
    # featurize atoms
        c = Atoms('C', positions=[[0, 0, 0]])
        c.calc = MOPAC(label='C', task='PM7 1SCF UHF')
        energy = c.get_potential_energy()
        
        c_e = []
        for atom in range(len(molecule)):
            c_e.append(energy)
        return {'energy': torch.tensor(c_e)}
    
    def xyz_to_graph(self, molecule, k, node_featurizer, edge_featurizer):
        '''
            molecule --> (x,3) whole molecule geometry
            k --> (1) number of nearest neighbors to be found
            - creates a graph of the molecule where each atom is connected to its k nearest neighbors
            - featurizes the nodes with the energy and the edges with distance
        '''    
        src = []
        dest = []
        ndist = []
        for atom in range(len(molecule)):
            nbhd, dist, idx = self.nearest_neighbors(molecule[atom], molecule, k)
            for i in range(len(nbhd)):
                src.append(atom)
                dest.append(idx[i])
                ndist.append(dist[i])
        g = dgl.graph((torch.tensor(src), torch.tensor(dest)))

        if node_featurizer is not None:
            g.ndata.update(node_featurizer(molecule))

        if edge_featurizer is True:
            g.edata.update({'length': torch.tensor(ndist)})

        return g
    
    def process(self, k):
    # make graph for each molecule
        tmp = []
        counter=0
        for xyz in self.xyz:
            if counter%1000 == 0:
                print(counter)
            counter+=1
            tmp.append(self.xyz_to_graph(xyz, k, node_featurizer=self.featurize_atoms, edge_featurizer=True))
        self.graph = tmp
         
    def __len__(self):
        return len(self.E)

    def __getitem__(self, idx):
        return self.graph[idx], self.E[idx]

In [129]:
graph_dataset = GraphDataset()

In [130]:
graph_dataset.process(3)

0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
6000


In [139]:
graph_dataset[3]

(Graph(num_nodes=20, num_edges=60,
       ndata_schemes={'energy': Scheme(shape=(), dtype=torch.float32)}
       edata_schemes={'length': Scheme(shape=(), dtype=torch.float64)}),
 -130.06626593)