In [1]:
import pandas as pd
from rdkit import Chem
import torch

from chemprop import data, nn, featurizers, models

In [2]:
df = pd.read_csv("../../data/pcqm4mv2/data.csv")
smis = df["smiles"].tolist()
ys = df["homolumogap"].tolist()

In [3]:
%%time
mols = [Chem.MolFromSmiles(smi) for smi in smis]

[14:43:58] Conflicting single bond directions around double bond at index 13.
[14:43:58]   BondStereo set to STEREONONE and single bond directions set to NONE.
[14:46:41] Conflicting single bond directions around double bond at index 11.
[14:46:41]   BondStereo set to STEREONONE and single bond directions set to NONE.


CPU times: user 3min 33s, sys: 12 s, total: 3min 45s
Wall time: 3min 45s


In [4]:
%%time
datapoints = [data.MoleculeDatapoint(mol, [y]) for mol, y in zip(mols, ys)]


CPU times: user 6.89 s, sys: 540 ms, total: 7.43 s
Wall time: 7.43 s


In [5]:
%%time
dataset = data.MoleculeDataset(datapoints)

CPU times: user 1.36 s, sys: 90 ms, total: 1.45 s
Wall time: 1.45 s


In [6]:
%%time
dataset.cache = True

CPU times: user 17min 29s, sys: 14.7 s, total: 17min 44s
Wall time: 17min 45s


In [7]:
%%time
dataloader = data.build_dataloader(dataset, seed=0, num_workers=0, batch_size=64)

CPU times: user 4.24 ms, sys: 0 ns, total: 4.24 ms
Wall time: 3.77 ms


In [8]:
%%timeit
for batch in dataloader:
    pass

50.2 s ± 174 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
%%time
mp = nn.BondMessagePassing()
agg = nn.NormAggregation()
ffn = nn.RegressionFFN()
model = models.MPNN(mp, agg, ffn)

CPU times: user 4.64 ms, sys: 0 ns, total: 4.64 ms
Wall time: 4.29 ms


In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

MPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_o): Linear(in_features=372, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): NormAggregation()
  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
      )
    )
    (criterion): MSELoss(task_weights=[[1.0]])
    (output_transform): Identity()
  )
  (X_d_transform): Identity()
)

In [11]:
opt = torch.optim.Adam(model.parameters(), 1e-4)

In [12]:
%%time
model.train()
for batch in dataloader:
    opt.zero_grad()

    bmg, V_d, X_d, targets, weights, lt_mask, gt_mask = batch
    bmg.V = bmg.V.to(device)
    bmg.E = bmg.E.to(device)
    bmg.edge_index = bmg.edge_index.to(device)
    bmg.rev_edge_index = bmg.rev_edge_index.to(device)
    bmg.batch = bmg.batch.to(device)
    V_d = V_d.to(device) if V_d is not None else None
    X_d = X_d.to(device) if X_d is not None else None
    targets = targets.to(device)
    weights = weights.to(device)
    lt_mask = lt_mask.to(device) if lt_mask is not None else None
    gt_mask = gt_mask.to(device) if gt_mask is not None else None
    
    mask = targets.isfinite()
    targets = targets.nan_to_num(nan=0.0)

    Z = model.fingerprint(bmg, V_d, X_d)
    preds = model.predictor.train_step(Z)
    loss = model.criterion(preds, targets, mask, weights, lt_mask, gt_mask)

    loss.backward()
    opt.step()

CPU times: user 3min 41s, sys: 12.4 s, total: 3min 53s
Wall time: 3min 53s
