This repository has been archived by the owner on Apr 9, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
85 lines (74 loc) · 2.81 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""
Note: utils file code is taken from: https://github.com/deepfindr/gnn-project/blob/main/dataset.py
"""
import torch
import numpy as np
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score, precision_score, recall_score
def get_node_features(mol):
"""
This will return a matrix / 2d array of the shape
[Number of Nodes, Node Feature size]
"""
all_node_feats = []
for atom in mol.GetAtoms():
node_feats = []
# Feature 1: Atomic number
node_feats.append(atom.GetAtomicNum())
# Feature 2: Atom degree
node_feats.append(atom.GetDegree())
# Feature 3: Formal charge
node_feats.append(atom.GetFormalCharge())
# Feature 4: Hybridization
node_feats.append(atom.GetHybridization())
# Feature 5: Aromaticity
node_feats.append(atom.GetIsAromatic())
# Feature 6: Total Num Hs
node_feats.append(atom.GetTotalNumHs())
# Feature 7: Radical Electrons
node_feats.append(atom.GetNumRadicalElectrons())
# Feature 8: In Ring
node_feats.append(atom.IsInRing())
# Feature 9: Chirality
node_feats.append(atom.GetChiralTag())
# Append node features to matrix
all_node_feats.append(node_feats)
all_node_feats = np.asarray(all_node_feats)
return torch.tensor(all_node_feats, dtype=torch.float)
def get_edge_features(mol):
"""
This will return a matrix / 2d array of the shape
[Number of edges, Edge Feature size]
"""
all_edge_feats = []
for bond in mol.GetBonds():
edge_feats = []
# Feature 1: Bond type (as double)
edge_feats.append(bond.GetBondTypeAsDouble())
# Feature 2: Rings
edge_feats.append(bond.IsInRing())
# Append node features to matrix (twice, per direction)
all_edge_feats += [edge_feats, edge_feats]
all_edge_feats = np.asarray(all_edge_feats)
return torch.tensor(all_edge_feats, dtype=torch.float)
def get_adjacency_info(mol):
"""
We could also use rdmolops.GetAdjacencyMatrix(mol)
but we want to be sure that the order of the indices
matches the order of the edge features
"""
edge_indices = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
edge_indices += [[i, j], [j, i]]
edge_indices = torch.tensor(edge_indices)
edge_indices = edge_indices.t().to(torch.long).view(2, -1)
return edge_indices
def calculate_metrics(y_pred, y_true):
print(f"\n Confusion matrix: \n {confusion_matrix(y_pred, y_true)}")
print(f"F1 Score: {f1_score(y_true, y_pred)}")
print(f"Accuracy: {accuracy_score(y_true, y_pred)}")
prec = precision_score(y_true, y_pred)
rec = recall_score(y_true, y_pred)
print(f"Precision: {prec}")
print(f"Recall: {rec}")