In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, DataLoader
from rdkit import Chem
from rdkit.Chem import AllChem

In [15]:
class GCNMoleculeModel(nn.Module):
    def __init__(self, num_node_features):
        super(GCNMoleculeModel, self).__init__()
        self.conv1 = GCNConv(num_node_features, 64)
        self.conv2 = GCNConv(64, 64)
        self.fc = nn.Linear(64, 1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = torch.mean(x, dim=0)  # Global mean pooling
        x = self.fc(x)
        return x

In [16]:
def smiles_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    mol = Chem.AddHs(mol)

    node_features = []
    for atom in mol.GetAtoms():
        feature = [
            atom.GetAtomicNum(),
            atom.GetTotalDegree(),
            atom.GetFormalCharge(),
            atom.GetTotalNumHs(),
            int(atom.GetIsAromatic()),
        ]
        node_features.append(feature)

    edges = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edges.append([i, j])
        edges.append([j, i])

    x = torch.tensor(node_features, dtype=torch.float)
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()

    return Data(x=x, edge_index=edge_index)

In [17]:
graph = smiles_to_graph("CCC")

In [18]:
graph.x

tensor([[6., 4., 0., 0., 0.],
        [6., 4., 0., 0., 0.],
        [6., 4., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 0., 0., 0.]])

In [19]:
graph.edge_index

tensor([[ 0,  1,  1,  2,  0,  3,  0,  4,  0,  5,  1,  6,  1,  7,  2,  8,  2,  9,
          2, 10],
        [ 1,  0,  2,  1,  3,  0,  4,  0,  5,  0,  6,  1,  7,  1,  8,  2,  9,  2,
         10,  2]])

In [20]:
smiles_list = [
    "CC(=O)OC1=CC=CC=C1C(=O)O",
    "CCO",
    "C1=CC=NC=C1",
]  # アスピリン、エタノール、ピリジン
values = [0, 1, 2]

dataset = [
    (smiles_to_graph(smiles), torch.tensor([label], dtype=torch.float))
    for smiles, label in zip(smiles_list, values)
]

In [21]:
dataset

[(Data(edge_index=[2, 42], x=[21, 5]), tensor([0.])),
 (Data(edge_index=[2, 16], x=[9, 5]), tensor([1.])),
 (Data(edge_index=[2, 22], x=[11, 5]), tensor([2.]))]

In [22]:
# データローダーの設定
loader = DataLoader(dataset, batch_size=1, shuffle=True)

In [23]:
# モデルの初期化
model = GCNMoleculeModel(num_node_features=5)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

In [24]:
# 訓練ループ
def train():
    model.train()
    for data, target in loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        print(f"Loss: {loss.item()}")


# 訓練の実行
for epoch in range(10):
    print(f"Epoch {epoch + 1}")
    train()

Epoch 1
Loss: 5.515108585357666
Loss: 0.07478471100330353
Loss: 0.31003931164741516
Epoch 2
Loss: 0.09409603476524353
Loss: 1.3412928581237793
Loss: 0.8520198464393616
Epoch 3
Loss: 0.679561972618103
Loss: 2.0366554260253906
Loss: 0.013095987029373646
Epoch 4
Loss: 1.8889793157577515
Loss: 0.00034540321212261915
Loss: 0.8559615015983582
Epoch 5
Loss: 0.018706589937210083
Loss: 1.1341453790664673
Loss: 1.1077682971954346
Epoch 6
Loss: 1.1366193294525146
Loss: 0.04141313582658768
Loss: 1.0854483842849731
Epoch 7
Loss: 1.007396936416626
Loss: 1.114715576171875
Loss: 0.03131241351366043
Epoch 8
Loss: 1.0377110242843628
Loss: 0.05386393889784813
Loss: 1.1990697383880615
Epoch 9
Loss: 0.8502244353294373
Loss: 0.08148965239524841
Loss: 1.2866061925888062
Epoch 10
Loss: 0.7985292673110962
Loss: 1.250282883644104
Loss: 0.0704004317522049


  return F.mse_loss(input, target, reduction=self.reduction)


In [27]:
def predict(smiles):
    model.eval()
    data = smiles_to_graph(smiles)
    with torch.no_grad():
        output = model(data)
    return output


# 予測の例
test_smiles = "CCN"  # エチルアミン
predicted_class = predict(test_smiles)
print(f"Predicted value {test_smiles}: {predicted_class}")

Predicted value CCN: tensor([0.7531])
