In [2]:
%load_ext autoreload

In [3]:
%autoreload
from rdkit import Chem
from chemprop.data.datapoints import MoleculeDatapoint
from chemprop.data.datasets import MoleculeDataset
from chemprop.data import build_dataloader
from chemprop.featurizers.molgraph.molecule import SimpleMoleculeMolGraphFeaturizer
from chemprop.data.collate import BatchMolGraph
import numpy as np
import torch

In [1]:
smi = ["CCCCCCCO", "CCCCN", "CCCCCCOCCC"]
ys = [1.4, 2.5, 1.6]

In [6]:
mol_datapoints = [MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smi, ys)]

In [7]:
featurizer = SimpleMoleculeMolGraphFeaturizer()
mol_dataset = MoleculeDataset(mol_datapoints, featurizer=featurizer)

In [8]:
mol_dataset

MoleculeDataset(data=[MoleculeDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x7ddac17b2e30>, y=1.4, weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='CCCCCCCO', V_f=None, E_f=None, V_d=None), MoleculeDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x7ddac17b3060>, y=2.5, weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='CCCCN', V_f=None, E_f=None, V_d=None), MoleculeDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x7ddac17b2650>, y=1.6, weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='CCCCCCOCCC', V_f=None, E_f=None, V_d=None)], featurizer=SimpleMoleculeMolGraphFeaturizer(atom_featurizer=<chemprop.featurizers.atom.MultiHotAtomFeaturizer object at 0x7dd967c7c350>, bond_featurizer=<chemprop.featurizers.bond.MultiHotBondFeaturizer object at 0x7dd967c0aad0>))

In [9]:
bmg = BatchMolGraph([mol.mg for mol in mol_dataset])

In [11]:
print(sum([len(x) for x in smi]))
print(bmg.V.shape)
print(bmg.V_w.shape)
print(bmg.E.shape)
print(bmg.E_w.shape)

23
torch.Size([23, 72])
torch.Size([23])
torch.Size([40, 14])
torch.Size([40])


index_torch produces a tensor of shape n_atoms x descriptors where each descriptor is the batch number relating to the molecule

dim_size = the number of batches in the batch mol graph

In [31]:
index_torch = bmg.batch.unsqueeze(1).repeat(1, bmg.V.shape[1])
dim_size = bmg.batch.max().int() + 1
print(index_torch)
print(index_torch.shape)
print(dim_size)

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [2, 2, 2,  ..., 2, 2, 2],
        [2, 2, 2,  ..., 2, 2, 2],
        [2, 2, 2,  ..., 2, 2, 2]])
torch.Size([23, 72])
tensor(3, dtype=torch.int32)


The following expression (Aggregation function) combines all of the atom features for each Mol into a single n_descriptor represenatation (here 72)

In [46]:
H = torch.zeros(dim_size, bmg.V.shape[1], dtype=bmg.V.dtype, device=bmg.V.device).scatter_reduce_(
    0, index_torch, bmg.V, reduce="mean", include_self=False
)

In [47]:
H.shape

torch.Size([3, 72])

In [48]:
H.shape[0]

3

In [35]:
bmg.degree_of_poly.shape

torch.Size([3])

In [50]:
degree_of_polys = bmg.degree_of_poly[: H.shape[0]]

In [52]:
degree_of_polys

tensor([1., 1., 1.])

In [51]:
degree_of_polys.shape

torch.Size([3])

In [53]:
torch.mul(degree_of_polys.unsqueeze(1), H)

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [24]:
t1 = torch.ones((477, 300))
t2 = torch.ones((477))

In [27]:
t3 = torch.mul(t2.unsqueeze(1), t1)

In [None]:
M = torch.cat([bmg.V[bmg.edge_index[0]], bmg.E], dim=1)
print(M)
print(M.shape)

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
torch.Size([40, 86])


In [13]:
torch.set_printoptions(threshold=10)
# For a small molecule the weighted tensor should be equal to the initial Tensor
Wi = torch.cat([bmg.V[bmg.edge_index[0]], bmg.E], dim=1)
w_M = torch.mul(bmg.E_w[bmg.edge_index[0]].unsqueeze(1), Wi)
equal = torch.equal(Wi, w_M)
print(f"Tensors are equal?: {equal}")
print(w_M.shape)

Tensors are equal?: True
torch.Size([40, 86])


In [14]:
mol_to_feat = Chem.MolFromSmiles(smi[0])
featurizer = SimpleMoleculeMolGraphFeaturizer()
feat_mol = featurizer(mol_to_feat)

In [15]:
print(feat_mol.V.shape)
print(feat_mol.V_w.shape)

(8, 72)
(8,)


In [21]:
from chemprop.data.dataloader import build_dataloader


dataloader = build_dataloader(mol_dataset)

In [23]:
batch = next(iter(dataloader))
bmg, V_d, X_d, *_ = batch

In [24]:
basic_model(bmg, V_d, X_d)

tensor([[0.0224],
        [0.0226],
        [0.0225]], grad_fn=<AddmmBackward0>)

# Polymers

In [99]:
%autoreload
import pandas as pd
from chemprop.data.datapoints import PolymerDatapoint
from chemprop.data.datasets import PolymerDataset
from chemprop.featurizers.molgraph import PolymerMolGraphFeaturizer

In [100]:
smi = "[*:1]c1cc(F)c([*:2])cc1F.[*:3]c1c(O)cc(O)c([*:4])c1O|0.5|0.5|<1-3:0.5:0.5<1-4:0.5:0.5<2-3:0.5:0.5<2-4:0.5:0.5~10"

In [101]:
datapoint = PolymerDatapoint.from_smi(smi, y=1.5)

In [102]:
datapoint

PolymerDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x7719e314ea40>, fragment_weights=[0.5, 0.5], edges=['1-3:0.5:0.5', '1-4:0.5:0.5', '2-3:0.5:0.5', '2-4:0.5:0.5~10'], y=1.5, weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='[*:1]c1cc(F)c([*:2])cc1F.[*:3]c1c(O)cc(O)c([*:4])c1O|0.5|0.5|<1-3:0.5:0.5<1-4:0.5:0.5<2-3:0.5:0.5<2-4:0.5:0.5~10', V_f=None, E_f=None, V_d=None)

In [103]:
featurizer = PolymerMolGraphFeaturizer()

In [104]:
dp = featurizer(datapoint)

In [105]:
dp.V.shape

(17, 72)

In [106]:
dp.E.shape

(42, 14)

In [107]:
dp.edge_index.shape

(2, 42)

In [108]:
dp.V[dp.edge_index[0]].shape

(42, 72)

In [109]:
V = torch.from_numpy(dp.V).float()
E = torch.from_numpy(dp.E).float()
edge_index = torch.from_numpy(dp.edge_index).long()

In [110]:
torch.cat((V[edge_index[0]], E), dim=1)

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.1201,
         0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.

In [111]:
polymer_df = pd.read_csv("../tests/data/test_polymer.csv")
poly_strings = polymer_df["smiles"]
targets = polymer_df["EA vs SHE (eV)"]

In [112]:
data = [
    PolymerDatapoint.from_smi(polymer, y=target) for polymer, target in zip(poly_strings, targets)
]

In [113]:
data[:5]

[PolymerDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x7719e314e260>, fragment_weights=[0.5, 0.5], edges=['1-3:0.5:0.5', '1-4:0.5:0.5', '2-3:0.5:0.5', '2-4:0.5:0.5'], y=-3.40621031325285, weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='[*:1]c1cc(F)c([*:2])cc1F.[*:3]c1c(O)cc(O)c([*:4])c1O|0.5|0.5|<1-3:0.5:0.5<1-4:0.5:0.5<2-3:0.5:0.5<2-4:0.5:0.5', V_f=None, E_f=None, V_d=None),
 PolymerDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x7719e314d930>, fragment_weights=[0.5, 0.5], edges=['1-2:0.375:0.375', '1-1:0.375:0.375', '2-2:0.375:0.375', '3-4:0.375:0.375', '3-3:0.375:0.375', '4-4:0.125:0.125', '1-3:0.125:0.125', '1-4:0.125:0.125', '2-3:0.125:0.125', '2-4:0.125:0.125'], y=-2.99190911810533, weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='[*:1]c1cc(F)c([*:2])cc1F.[*:3]c1c(O)cc(O)c([*:4])c1O|0.5|0.5|<1-2:0.375:0.375<1-1:0.375:0.375<2-2:0.375:0.375<3-4:0.375:0.375<3-3:0.375:0.375<4-4:0.125:0.125<1-3:0.125:0.125<1-4:0.125:0.125<2-3:0.125:0.125<2

In [114]:
print(PolymerDataset(data))

PolymerDataset(data=[PolymerDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x7719e314e260>, fragment_weights=[0.5, 0.5], edges=['1-3:0.5:0.5', '1-4:0.5:0.5', '2-3:0.5:0.5', '2-4:0.5:0.5'], y=-3.40621031325285, weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='[*:1]c1cc(F)c([*:2])cc1F.[*:3]c1c(O)cc(O)c([*:4])c1O|0.5|0.5|<1-3:0.5:0.5<1-4:0.5:0.5<2-3:0.5:0.5<2-4:0.5:0.5', V_f=None, E_f=None, V_d=None), PolymerDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x7719e314d930>, fragment_weights=[0.5, 0.5], edges=['1-2:0.375:0.375', '1-1:0.375:0.375', '2-2:0.375:0.375', '3-4:0.375:0.375', '3-3:0.375:0.375', '4-4:0.125:0.125', '1-3:0.125:0.125', '1-4:0.125:0.125', '2-3:0.125:0.125', '2-4:0.125:0.125'], y=-2.99190911810533, weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='[*:1]c1cc(F)c([*:2])cc1F.[*:3]c1c(O)cc(O)c([*:4])c1O|0.5|0.5|<1-2:0.375:0.375<1-1:0.375:0.375<2-2:0.375:0.375<3-4:0.375:0.375<3-3:0.375:0.375<4-4:0.125:0.125<1-3:0.125:0.125<1-4:0.125:0.12

In [115]:
featurizer = PolymerMolGraphFeaturizer()

In [116]:
features = featurizer(data[0])

In [117]:
df = pd.read_csv("/home/shortm/Documents/ML/dataset-poly_chemprop.csv")
smiles_column = "poly_chemprop_input"
target_columns = ["EA vs SHE (eV)"]
smis = df.loc[:, smiles_column].values
ys = df.loc[:, target_columns].values

In [118]:
all_data = [PolymerDatapoint.from_smi(smi, y=y) for smi, y in zip(smis, ys)]

In [119]:
torch.set_printoptions(threshold=10_000)
featurizer = PolymerMolGraphFeaturizer()
for i, poly in enumerate(all_data):
    print(poly.name)
    print(f"Number of rules: {len(poly.edges)}")
    dp = featurizer(poly)
    V = torch.from_numpy(dp.V).float()
    E = torch.from_numpy(dp.E).float()
    edge_index = torch.from_numpy(dp.edge_index).long()
    print(
        f"Shapes; V:{V.shape}, E:{E.shape}, edge_index: {edge_index.shape}, V[edge_index[0]]: {V[edge_index[0]].shape}]"
    )
    print(f"{edge_index[0]}")
    # print(V[edge_index[0]])
    try:
        cat = torch.cat((V[edge_index[0]], E), dim=1)
    except RuntimeError:
        # print(V)
        # print(E)
        # for i in range(len(edge_index[0])):
        # print(edge_index[0][i], edge_index[1][i])
        break
    print(cat.shape)

[*:1]c1cc(F)c([*:2])cc1F.[*:3]c1c(O)cc(O)c([*:4])c1O|0.5|0.5|<1-3:0.5:0.5<1-4:0.5:0.5<2-3:0.5:0.5<2-4:0.5:0.5
Number of rules: 4
Shapes; V:torch.Size([17, 72]), E:torch.Size([42, 14]), edge_index: torch.Size([2, 42]), V[edge_index[0]]: torch.Size([42, 72])]
tensor([ 0,  1,  1,  2,  2,  3,  2,  4,  4,  5,  5,  6,  6,  7,  6,  0,  8,  9,
         9, 10,  9, 11, 11, 12, 12, 13, 12, 14, 14, 15, 15, 16, 15,  8,  0,  8,
         0, 14,  4,  8,  4, 14])
torch.Size([42, 86])
[*:1]c1cc(F)c([*:2])cc1F.[*:3]c1c(O)cc(O)c([*:4])c1O|0.5|0.5|<1-2:0.375:0.375<1-1:0.375:0.375<2-2:0.375:0.375<3-4:0.375:0.375<3-3:0.375:0.375<4-4:0.125:0.125<1-3:0.125:0.125<1-4:0.125:0.125<2-3:0.125:0.125<2-4:0.125:0.125
Number of rules: 10
Shapes; V:torch.Size([17, 72]), E:torch.Size([54, 14]), edge_index: torch.Size([2, 54]), V[edge_index[0]]: torch.Size([54, 72])]
tensor([ 0,  1,  1,  2,  2,  3,  2,  4,  4,  5,  5,  6,  6,  7,  6,  0,  8,  9,
         9, 10,  9, 11, 11, 12, 12, 13, 12, 14, 14, 15, 15, 16, 15,  8,  0,  

KeyboardInterrupt: 