In [36]:
import torch
import torch_geometric
from torch_geometric import edge_index
from torch_geometric.transforms import BaseTransform, Compose
from torch_geometric.datasets import QM9
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.data import Dataset
from torch_geometric.nn.aggr import SumAggregation

import matplotlib.pyplot as plt
import numpy as np

import os
from pathlib import Path

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_device(device)
HERE = Path("/Users/USER/PycharmProjects/molecule")
DATA = HERE / "data"

  _C._set_default_tensor_type(t)


TypeError: invalid type object: only floating-point types are supported as the default type

In [5]:
def num_heavy_atoms(qm9_data: Data) -> int:
    """
    Count the number of heavy atoms in a torch geometric
    Data object.
    """
    # every atom with atomic number other than 1 is heavy
    return (qm9_data.z != 1).sum()

qm9_dataset = QM9(
    DATA,
    # Filter out molecules with more than 8 heavy atoms
    pre_filter=lambda data: num_heavy_atoms(data) < 9,
    force_reload=True
)

mx = 0
for d in qm9_dataset:
    mx = max(mx, d.z.shape[0])

data_sample = qm9_dataset[0]
data_sample.pos.round(decimals=3)
data_sample.y
len(qm9_dataset)

Processing...
Using a pre-processed version of the dataset. Please install 'rdkit' to alternatively process the raw data.
Done!


21800

In [6]:
import periodictable as pt

def adjacency_matrix(z: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
    am = torch.zeros(z.shape[0], z.shape[0])
    for idx in range(edge_index.shape[1]):
        am[idx, idx] = 1
        bond = 0
        for i in range(3):
            bond += edge_attr[idx, i] * (i+1)
        am[edge_index[0, idx], edge_index[1, idx]] = bond
    return am
    
def distance_matrix(z: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
    dm = torch.zeros(z.shape[0], z.shape[0])
    for idx1 in range(z.shape[0]):
        for idx2 in range(z.shape[0]):
            coord1 = pos[idx1]
            coord2 = pos[idx2]
            dist = (coord1 - coord2).pow(2).sum().sqrt().item()
            dm[idx1, idx2] = dist
            dm[idx2, idx1] = dist
    return dm

def display_molecule(pos: torch.Tensor, z: torch.Tensor):
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')
    coord = pos.numpy()
    ax.scatter(coord[:, 0], coord[:, 1], coord[:, 2])
    for i in range(z.size(0)):
        ax.text(coord[i, 0], coord[i, 1], coord[i, 2], "{}".format(pt.elements[z[i].item()].symbol))

# display_molecule(data_sample.pos, data_sample.z)

In [68]:
import torch.nn as nn

data_list = torch.load('data_pair.pt', map_location=lambda storage, loc: storage.cuda(0))

# data_list = []
# cnt = 0
# 
# for dt in qm9_dataset:
#     pos = dt.pos
#     dist = torch.tensor([(pos[i]-pos[j]).pow(2).sum().sqrt().item() for i in range(dt.z.shape[0]) for j in range(i)])
#     pairs = torch.tensor([(i, j) for i in range(dt.z.shape[0]) for j in range(i)])
#     
#     # node level
#     pad1 = nn.ZeroPad1d((0, mx*(mx-1)//2-dist.shape[0]))
#     pad2 = nn.ZeroPad2d((0, 0, 0, mx*(mx-1)//2-dist.shape[0]))
#     data_list.append(Data(x=dt.z.reshape([-1, 1]).float(), edge_index=dt.edge_index, y=pad1(dist), pairs=pad2(pairs)))
#     
#     cnt += 1
#     if cnt % 1000 == 0:
#         print(cnt)
#     # graph level
#     # dist_l = pad(dist_m).flatten()
#     # data_list.append(Data(x=dt.z.reshape([-1, 1]).float(), edge_index=dt.edge_index, y=dist_l))

mx = 0
for d in data_list:
    mx = max(mx, d.x.shape[0])

l = len(data_list)
train_l = data_list[:l*4//5]
val_l = data_list[l*4//5+1:l*9//10]
test_l = data_list[l*9//10+1:l]

batch = 1024
train_loader = DataLoader(train_l, batch_size=batch, shuffle=True, generator=torch.Generator(device='cuda'))
val_loader = DataLoader(val_l, batch_size=batch, shuffle=True, generator=torch.Generator(device='cuda'))
test_loader = DataLoader(test_l, batch_size=batch, shuffle=True, generator=torch.Generator(device='cuda'))

In [55]:
torch.save(data_list, "data_pair.pt")

In [69]:
from torch.nn import Sequential, Linear, BatchNorm1d, ReLU
from torch_geometric.nn import GraphConv, GINConv
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool
import torch.nn.functional as F

class GINPair(torch.nn.Module):
    def __init__(self, dim_in: int, dim_hidden: int, combine_method='concat'):
        super().__init__()
        self.conv1 = GINConv(
            Sequential(Linear(dim_in, dim_hidden), BatchNorm1d(dim_hidden), ReLU(), Linear(dim_hidden, dim_hidden), ReLU())
        )
        self.conv2 = GINConv(
            Sequential(
                Linear(dim_hidden, dim_hidden), BatchNorm1d(dim_hidden), ReLU(), Linear(dim_hidden, dim_hidden), ReLU()
            )
        )
        self.conv3 = GINConv(
            Sequential(
                Linear(dim_hidden, dim_hidden), BatchNorm1d(dim_hidden), ReLU(), Linear(dim_hidden, dim_hidden), ReLU()
            )
        )
        self.combine_method = combine_method
        if combine_method == 'concat':
            self.fc = nn.Linear(2 * dim_hidden, 1)
        else:
            self.fc = nn.Linear(dim_hidden, 1)
    
    def forward(self, x: torch.Tensor, edge_idx: torch.Tensor) -> torch.Tensor:
        output = self.conv1(x, edge_idx)
        output = output.relu()
        # output = self.conv2(output, edge_idx)
        # output = output.relu()
        # output = F.dropout(output, p=0.5, training=self.training)
        output = self.conv3(output, edge_idx)
        output = output.relu()
        # output = global_mean_pool(output, None)
        # output = self.lin(output)
        return output
    
    def combine(self, node1_emb, node2_emb):
        if self.combine_method == 'concat':
            return torch.cat([node1_emb, node2_emb], dim=1)
        elif self.combine_method == 'sum':
            return node1_emb + node2_emb
        elif self.combine_method == 'product':
            return node1_emb * node2_emb
        elif self.combine_method == 'diff':
            return torch.abs(node1_emb - node2_emb)
        elif self.combine_method == 'average':
            return (node1_emb + node2_emb) / 2
        elif self.combine_method == 'max':
            combined_emb, _ = torch.max(torch.stack([node1_emb, node2_emb]), dim=0)
            return combined_emb
        
    def predict(self, node1_emb, node2_emb):
        combined_emb = self.combine(node1_emb, node2_emb)
        return self.fc(combined_emb)

In [72]:
def mask(data_y):
    return (data_y != 0).float()

def train(model, lr):
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                      lr=lr,
                                      weight_decay=5e-4)
    
    model.train()
    loss = 0
    cnt = 0
    for d in train_loader:
        optimizer.zero_grad()
        node_embed = model(d.x, d.edge_index)
        # node1 = torch.tensor([i for i in range(d.x.shape[0]) for _ in range(i)])
        # node2 = torch.tensor([j for i in range(d.x.shape[0]) for j in range(i)])
        node1 = torch.tensor([pair[0] for pair in d.pairs])
        node2 = torch.tensor([pair[1] for pair in d.pairs])
        node1_emb = node_embed[node1]
        node2_emb = node_embed[node2]
        
        output = model.predict(node1_emb, node2_emb).squeeze()
        l = criterion(output, d.y).mean()
        loss += l / len(train_loader)
        l.backward()
        optimizer.step()
        
        cnt += 1
        print(f"{cnt}/{len(train_loader)}")
    return loss, model

def validate(model):
    criterion = torch.nn.MSELoss()
    
    model.eval()
    v_loss = 0
    for d in val_loader:
        node_embed = model(d.x, d.edge_index)
        node1 = torch.tensor([pair[0] for pair in d.pairs])
        node2 = torch.tensor([pair[1] for pair in d.pairs])
        node1_emb = node_embed[node1]
        node2_emb = node_embed[node2]
        output = model.predict(node1_emb, node2_emb).squeeze()
        l = criterion(output, d.y).mean()
        v_loss += l / len(val_loader)
    return v_loss
        
@torch.no_grad()
def test(model):
    criterion = torch.nn.MSELoss()
    t_loss = 0
    for d in test_loader:
        output = model(d.x, d.edge_index)
        l = (criterion(output, d.y) * mask(d.y)).mean()
        t_loss += l / len(test_loader)
    return t_loss

def train_epoch(epoch, model, lr):
    for e in range(epoch):
        e_loss, model = train(model, lr)
        v_loss = validate(model)
        
        if e % 2 == 0:
            print(f"Epoch: {e}, train_loss: {e_loss.item()}, val_loss: {v_loss.item()}")

In [74]:
epochs = 20
model = GINPair(dim_in=1, dim_hidden=64, combine_method='concat')
train_epoch(epochs, model, 0.001)

1/18
2/18
3/18
4/18
5/18
6/18
7/18
8/18
9/18
10/18
11/18
12/18
13/18
14/18
15/18
16/18
17/18
18/18
Epoch: 0, train_loss: 2.758960723876953, val_loss: 3.3880295753479004
1/18


KeyboardInterrupt: 

In [None]:
   # TODO: exclude 0 paddings from loss calculation
"""
Epoch: 0, Train loss: 0.03952614217996597, Val loss: 0.04244199022650719
Epoch: 2, Train loss: 0.03944431617856026, Val loss: 0.04346185922622681
Epoch: 4, Train loss: 0.0394737534224987, Val loss: 0.04332466423511505
Epoch: 6, Train loss: 0.03944838419556618, Val loss: 0.04295289143919945
Epoch: 8, Train loss: 0.03944433107972145, Val loss: 0.042868148535490036
Time: 10m 21s 67ms
"""

"""
Epoch: 0, train_loss: 2.819444417953491, val_loss: 1.3333452939987183
Epoch: 2, train_loss: 1.0538444519042969, val_loss: 0.8541600108146667
Epoch: 4, train_loss: 1.0235944986343384, val_loss: 0.8158113956451416
Epoch: 6, train_loss: 1.0137405395507812, val_loss: 0.8021357655525208
Epoch: 8, train_loss: 1.0096156597137451, val_loss: 0.7891163229942322
Epoch: 10, train_loss: 1.0031167268753052, val_loss: 0.7599336504936218
Epoch: 12, train_loss: 1.0010273456573486, val_loss: 0.8162444233894348
Epoch: 14, train_loss: 0.9971020221710205, val_loss: 0.7471166849136353
Epoch: 16, train_loss: 0.9963588714599609, val_loss: 0.7429596185684204
Epoch: 18, train_loss: 0.9939806461334229, val_loss: 0.7843343615531921
"""
"""
Epoch: 0, train_loss: 2.607553482055664, val_loss: 1.5287598371505737
Epoch: 2, train_loss: 0.9871993064880371, val_loss: 0.8298434019088745
Epoch: 4, train_loss: 0.9566546678543091, val_loss: 0.8694849610328674
Epoch: 6, train_loss: 0.9433390498161316, val_loss: 0.7554209232330322
Epoch: 8, train_loss: 0.9362509250640869, val_loss: 0.8846598267555237
Epoch: 10, train_loss: 0.9299172163009644, val_loss: 0.8427674770355225
Epoch: 12, train_loss: 0.926454484462738, val_loss: 0.7863412499427795
Epoch: 14, train_loss: 0.9211176633834839, val_loss: 1.0939804315567017
Epoch: 16, train_loss: 0.917726457118988, val_loss: 0.8830863237380981
Epoch: 18, train_loss: 0.9147141575813293, val_loss: 0.7525851130485535
Epoch: 20, train_loss: 0.9135162234306335, val_loss: 0.7463304400444031
Epoch: 22, train_loss: 0.9099758267402649, val_loss: 0.760979950428009
Epoch: 24, train_loss: 0.907063901424408, val_loss: 0.7801485061645508
Epoch: 26, train_loss: 0.905415415763855, val_loss: 0.7539505362510681
Epoch: 28, train_loss: 0.905021607875824, val_loss: 0.8408609628677368
Epoch: 30, train_loss: 0.9012361764907837, val_loss: 0.9827206134796143
Epoch: 32, train_loss: 0.8998003005981445, val_loss: 0.9063634872436523
Epoch: 34, train_loss: 0.8969922661781311, val_loss: 0.783204197883606
Epoch: 36, train_loss: 0.8964552283287048, val_loss: 0.7484853267669678
Epoch: 38, train_loss: 0.8951149582862854, val_loss: 0.7687615156173706
Epoch: 40, train_loss: 0.8946443200111389, val_loss: 0.8196905851364136
Epoch: 42, train_loss: 0.892460286617279, val_loss: 0.7178381681442261
Epoch: 44, train_loss: 0.8909822106361389, val_loss: 1.2139568328857422
Epoch: 46, train_loss: 0.8888286352157593, val_loss: 0.7084630727767944
Epoch: 48, train_loss: 0.8880488276481628, val_loss: 0.7248218059539795
"""
# Epoch: 46, train_loss: 0.8888286352157593, val_loss: 0.7084630727767944