https://www.repository.cam.ac.uk/handle/1810/307452

The first time you run this notebook, uncomment and execute the following cell

In [1]:
#!wget https://www.repository.cam.ac.uk/bitstream/handle/1810/307452/Carbon_GAP_20.tgz
#!tar -xzvf Carbon_GAP_20.tgz

In [13]:
from ase.io import iread
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from scipy.spatial import distance_matrix
import numpy as np
import dgl
from ase.calculators.mopac import MOPAC
from ase import Atoms

In [3]:
# check with dummy example
# potential energy
# see what software caarbon gap paper uses to compute

In [4]:
# create torch dataset like in class
# choose neural net (one round of mpnn and mlp predictor hw8 network) and train (use cs machine and wandb)
# read paper more throughly, what they do and what they look at with this dataset

In [5]:
# dataset[4782][0] is a single point

In [65]:
class GraphDataset(Dataset):
    def __init__(self):
        self.xyz = []
        self.E = []
        self.graph = []
        for mol in iread('Carbon_GAP_20/Carbon_GAP_20_Training_Set.xyz'):
            self.xyz.append(mol.get_positions())
            self.E.append(mol.get_potential_energy())
    
    def nearest_neighbors(self, g, m, k):
        '''
            g --> (3) one coordinate used as reference point
            m --> (x,3) whole molecule geometry
            k --> (1) number of nearest neighbors to be found
            - assumes g is in m so the first closest neighbor is excluded
            - takes the max amount of neighbors if k is greater than total atoms in a molecule
            - if a molecule is also a single atom, it will be its own neighbor
            - calculates the actual neighbors, the distances, and their indices in the list of atoms
        '''
        if k >= len(m):
            k = len(m)-1
        
        dist = distance_matrix([g], m)
        
        if len(m)==1:
        # if single atom, itself will be its neighbor
            k = 1
            indices = np.argpartition(dist[0], range(k))[:k] 
        else:
            indices = np.argpartition(dist[0], range(k+1))[1:k+1] # excludes first closest neighbor (itself)
        
        k_nearest = []
        k_dist = []
        for idx in indices:
            k_nearest.append(m[idx])
            k_dist.append(dist[0][idx])
        return k_nearest, k_dist, indices

    def featurize_atoms(self, molecule):
    # featurize atoms
        c = Atoms('C', positions=[[0, 0, 0]])
        c.calc = MOPAC(label='C', task='PM7 1SCF UHF')
        energy = c.get_potential_energy()
        
        c_e = []
        for atom in range(len(molecule)):
            c_e.append(energy)
        return {'energy': torch.tensor(c_e)}
    
    def xyz_to_graph(self, molecule, k, node_featurizer, edge_featurizer):
        '''
            molecule --> (x,3) whole molecule geometry
            k --> (1) number of nearest neighbors to be found
            - creates a graph of the molecule where each atom is connected to its k nearest neighbors
            - featurizes the nodes with the energy and the edges with distance
        '''    
        src = []
        dest = []
        ndist = []
        for atom in range(len(molecule)):
            nbhd, dist, idx = self.nearest_neighbors(molecule[atom], molecule, k)
            for i in range(len(nbhd)):
                src.append(atom)
                dest.append(idx[i])
                ndist.append(dist[i])
        g = dgl.graph((torch.tensor(src), torch.tensor(dest)))

        if node_featurizer is not None:
            g.ndata.update(node_featurizer(molecule))

        if edge_featurizer is True:
            g.edata.update({'length': torch.tensor(ndist)})

        return g
    
    def process(self, k):
    # make graph for each molecule
        tmp = []
        counter=0
        for xyz in self.xyz:
            if counter%1000 == 0:
                print(counter)
            counter+=1
            tmp.append(self.xyz_to_graph(xyz, k, node_featurizer=self.featurize_atoms, edge_featurizer=True))
        self.graph = tmp
         
    def __len__(self):
        return len(self.E)

    def __getitem__(self, idx):
        return self.graph[idx], self.E[idx]

In [66]:
graph_dataset = GraphDataset()

In [67]:
graph_dataset.process(3)

0
1000
2000
3000
4000
5000
6000


In [68]:
graph_dataset[3]

(Graph(num_nodes=20, num_edges=60,
       ndata_schemes={'energy': Scheme(shape=(), dtype=torch.float32)}
       edata_schemes={'length': Scheme(shape=(), dtype=torch.float64)}),
 -130.06626593)

In [69]:
from tqdm.notebook import tqdm
from dgl.dataloading import GraphDataLoader
from dgllife.model.gnn.mpnn import MPNNGNN
from dgllife.model.readout.mlp_readout import MLPNodeReadout

In [70]:
dataloader = GraphDataLoader(graph_dataset,batch_size=32)

In [71]:
class Model(nn.Module):
    def __init__(self, 
                 node_in_feats,
                 edge_in_feats,
                 node_out_feats=1291,
                 edge_hidden_feats=128,
                 n_tasks=1,
                 num_step_message_passing=6):
        
        super(Model, self).__init__()
        self.gnn = MPNNGNN(node_in_feats=node_in_feats,
                           node_out_feats=node_out_feats,
                           edge_in_feats=edge_in_feats,
                           edge_hidden_feats=edge_hidden_feats,
                           num_step_message_passing=num_step_message_passing)
        
        self.readout = MLPNodeReadout(node_feats=node_in_feats, hidden_feats=edge_hidden_feats, graph_feats=node_out_feats)
        
        self.predict = nn.Sequential(
            nn.Linear(node_out_feats, node_out_feats),
            nn.Dropout(p=0.5),
            nn.ReLU(),
            nn.Linear(node_out_feats, n_tasks)
        )
        
    def forward(self, g, nodes, edges):
        node_feats = self.gnn(g, nodes, edges)
        graph_feats = self.readout(g, nodes)
        return self.predict(graph_feats)

In [72]:
model = Model(1291,3865)

In [73]:
def train(epochs):
    optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
    for epoch in tqdm(range(epochs)):
        model.train()
        running_loss = 0.
        for batch_x, batch_y in dataloader:
            optimizer.zero_grad()
            atoms = batch_x.ndata['energy']
            edges = batch_x.edata['length']
            print(batch_x)
            print(batch_y)
            print(edges)
            y_pred = model(batch_x, atoms, edges)
            mse = ((y_pred.reshape(-1) - batch_y)**2).sum()
            running_loss += mse.item()
            mse.backward()
            optimizer.step()
            
        running_loss /= len(dataloader)
        print("Train loss: ", running_loss)

In [74]:
train(5) # model is expecting double tensor but got float tensor

  0%|          | 0/5 [00:00<?, ?it/s]

Graph(num_nodes=1291, num_edges=3865,
      ndata_schemes={'energy': Scheme(shape=(), dtype=torch.float32)}
      edata_schemes={'length': Scheme(shape=(), dtype=torch.float64)})
tensor([ -767.8094,    -5.6269,  -429.0779,  -130.0663,  -215.3576,  -436.4863,
          -78.7024,  -113.8423,  -311.4249,  -183.0554,   -53.8655,  -385.6450,
         -442.7742,   -31.1861,   -79.6305,    -4.3007, -1057.8439,  -272.3579,
         -323.8642,  -573.0015,  -177.7945,  -225.1224,  -162.3694,  -131.2183,
         -798.4573,  -380.8149,  -178.7317,  -121.5088,  -186.4133,   -77.1495,
         -181.6423,  -115.6125], dtype=torch.float64)
tensor([1.4237, 1.4828, 1.6471,  ..., 1.5054, 1.5547, 1.5803],
       dtype=torch.float64)


RuntimeError: expected scalar type Double but found Float