In [1]:
from s1t1_fingerprint import *

dataset_path = "/home/simo/dl/comp2021/samsung_s1t1/data/split_0/data_train.txt"

atom_dict = defaultdict(lambda: len(atom_dict))
bond_dict = defaultdict(lambda: len(bond_dict))
fingerprint_dict = {}
edge_dict = defaultdict(lambda: len(edge_dict))


RAD = 1
bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
data_list = []

with open(dataset_path, "r") as f:
    data = f.read().strip().split("\n")

    for i, line in enumerate(data):
        line = line.split(" ")
        smiles = line[0]
        target = torch.tensor([float(line[1])])

        mol = Chem.AddHs(Chem.MolFromSmiles(smiles))
        N = mol.GetNumAtoms()

        # Create node data
        atoms = create_atoms(mol, atom_dict)

        i_jbond_dict = create_ijbonddict(mol, bond_dict)
        fingerprints = count_fingerprints(
            RAD, atoms, i_jbond_dict, fingerprint_dict, edge_dict
        )

        x1 = torch.tensor(fingerprints).long()

        # Create Edge data (with edge index)

        row, col, edge_type = [], [], []

        for bond in mol.GetBonds():
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            row += [start, end]
            col += [end, start]
            edge_type += 2 * [bonds[bond.GetBondType()]]

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_type = torch.tensor(edge_type, dtype=torch.long)
        edge_attr = F.one_hot(edge_type, num_classes=len(bonds)).to(torch.float)

        perm = (edge_index[0] * N + edge_index[1]).argsort()
        edge_index = edge_index[:, perm]
        edge_type = edge_type[perm]
        edge_attr = edge_attr[perm]

        y = target.unsqueeze(0)

        data = Data(edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i, x=x1)
        data_list.append(data)

In [3]:
# use topk Elements for actual fingerprint : if not in dict, use the mol itself.
fingerprint_dict_sorted = sorted(fingerprint_dict.items(), key=lambda x: x[1])
fingerprint_dict_topk = {k: v for k, v in fingerprint_dict_sorted[-200:]}
fingerprint_dict_topk 

{(6, ((7, 0),)): 116,
 (0, ((0, 1), (1, 0), (3, 0))): 116,
 (0, ((0, 0), (6, 0), (6, 0), (8, 0))): 120,
 (0, ((0, 1), (2, 0), (3, 0))): 129,
 (7, ((3, 0), (3, 0))): 133,
 (0, ((0, 0), (0, 0), (0, 0), (7, 0))): 135,
 (0, ((1, 0), (1, 1), (3, 0))): 137,
 (3, ((3, 2), (3, 2), (12, 0))): 139,
 (0, ((2, 1), (3, 0), (6, 0))): 141,
 (0, ((0, 3), (6, 0))): 142,
 (5, ((3, 2), (9, 2))): 144,
 (2, ((14, 1),)): 155,
 (0, ((0, 0), (0, 0), (6, 0), (12, 0))): 156,
 (3, ((3, 2), (5, 2), (9, 2))): 160,
 (3, ((3, 2), (3, 2), (14, 0))): 165,
 (1, ((0, 1), (3, 0))): 169,
 (0, ((0, 0), (0, 0), (1, 0), (3, 0))): 172,
 (0, ((0, 1), (1, 0), (6, 0))): 172,
 (0, ((0, 0), (0, 1), (2, 0))): 176,
 (1, ((1, 0), (6, 0), (6, 0))): 177,
 (3, ((1, 0), (3, 2), (9, 2))): 178,
 (0, ((0, 0), (3, 0), (3, 0), (6, 0))): 179,
 (0, ((0, 0), (0, 1), (7, 0))): 180,
 (0, ((1, 0), (1, 0), (7, 1))): 181,
 (0, ((1, 0), (3, 0), (3, 0), (6, 0))): 186,
 (1, ((0, 0), (0, 1), (15, 0))): 189,
 (2, ((0, 0), (1, 0))): 190,
 (1, ((0, 0), (0, 

In [8]:
atom_dict
atom_ndict = {k : v for k, v in atom_dict.items()}

In [12]:
# pickle dump.

import pickle

with open('atom_dict.pickle', 'wb') as handle:
    pickle.dump(atom_ndict, handle, protocol=pickle.HIGHEST_PROTOCOL)

with open('fingerprint_dict_topk.pickle', 'wb') as handle:
    pickle.dump(fingerprint_dict_topk, handle, protocol=pickle.HIGHEST_PROTOCOL)



In [14]:
# try import and load.
with open("atom_dict.pickle", 'rb') as handle:
    atom_dict = pickle.load(handle)

with open("fingerprint_dict_topk.pickle", 'rb') as handle:
    fingerprint_dict = pickle.load(handle)

atom_dict, fingerprint_dict



({'C': 0,
  'N': 1,
  'O': 2,
  ('C', 'aromatic'): 3,
  ('O', 'aromatic'): 4,
  ('N', 'aromatic'): 5,
  'H': 6,
  'S': 7,
  'Cl': 8,
  ('S', 'aromatic'): 9,
  'F': 10,
  'Br': 11,
  'Si': 12,
  'I': 13,
  'P': 14,
  'B': 15,
  ('P', 'aromatic'): 16},
 {(6, ((7, 0),)): 116,
  (0, ((0, 1), (1, 0), (3, 0))): 116,
  (0, ((0, 0), (6, 0), (6, 0), (8, 0))): 120,
  (0, ((0, 1), (2, 0), (3, 0))): 129,
  (7, ((3, 0), (3, 0))): 133,
  (0, ((0, 0), (0, 0), (0, 0), (7, 0))): 135,
  (0, ((1, 0), (1, 1), (3, 0))): 137,
  (3, ((3, 2), (3, 2), (12, 0))): 139,
  (0, ((2, 1), (3, 0), (6, 0))): 141,
  (0, ((0, 3), (6, 0))): 142,
  (5, ((3, 2), (9, 2))): 144,
  (2, ((14, 1),)): 155,
  (0, ((0, 0), (0, 0), (6, 0), (12, 0))): 156,
  (3, ((3, 2), (5, 2), (9, 2))): 160,
  (3, ((3, 2), (3, 2), (14, 0))): 165,
  (1, ((0, 1), (3, 0))): 169,
  (0, ((0, 0), (0, 0), (1, 0), (3, 0))): 172,
  (0, ((0, 1), (1, 0), (6, 0))): 172,
  (0, ((0, 0), (0, 1), (2, 0))): 176,
  (1, ((1, 0), (6, 0), (6, 0))): 177,
  (3, ((1, 0), 