In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, DataLoader
from rdkit import Chem
from rdkit.Chem import AllChem

In [2]:
class GCNMoleculeModel(nn.Module):
    def __init__(self, num_node_features):
        super(GCNMoleculeModel, self).__init__()
        self.conv1 = GCNConv(num_node_features, 64)
        self.conv2 = GCNConv(64, 64)
        self.fc = nn.Linear(64, 1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = torch.mean(x, dim=0)  # Global mean pooling
        x = self.fc(x)
        return x

In [3]:
def smiles_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    mol = Chem.AddHs(mol)
    
    node_features = []
    for atom in mol.GetAtoms():
        feature = [
            atom.GetAtomicNum(),
            atom.GetTotalDegree(),
            atom.GetFormalCharge(),
            atom.GetTotalNumHs(),
            int(atom.GetIsAromatic())
        ]
        node_features.append(feature)

    edges = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edges.append([i, j])
        edges.append([j, i])  # 無向グラフなので両方向を追加

    x = torch.tensor(node_features, dtype=torch.float)
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    
    return Data(x=x, edge_index=edge_index)

In [4]:
graph = smiles_to_graph("CCC")

In [5]:
graph.x

tensor([[6., 4., 0., 0., 0.],
        [6., 4., 0., 0., 0.],
        [6., 4., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 0., 0., 0.]])

In [6]:
graph.edge_index

tensor([[ 0,  1,  1,  2,  0,  3,  0,  4,  0,  5,  1,  6,  1,  7,  2,  8,  2,  9,
          2, 10],
        [ 1,  0,  2,  1,  3,  0,  4,  0,  5,  0,  6,  1,  7,  1,  8,  2,  9,  2,
         10,  2]])

In [7]:
smiles_list = ["CC(=O)OC1=CC=CC=C1C(=O)O", "CCO", "C1=CC=NC=C1"]  # アスピリン、エタノール、ピリジン
values = [0, 1, 2] 

dataset = [
    (smiles_to_graph(smiles), torch.tensor([label], dtype=torch.float))
    for smiles, label in zip(smiles_list, values)
]

In [8]:
dataset

[(Data(x=[21, 5], edge_index=[2, 42]), tensor([0.])),
 (Data(x=[9, 5], edge_index=[2, 16]), tensor([1.])),
 (Data(x=[11, 5], edge_index=[2, 22]), tensor([2.]))]

In [9]:
# データローダーの設定
loader = DataLoader(dataset, batch_size=1, shuffle=True)



In [10]:
# モデルの初期化
model = GCNMoleculeModel(num_node_features=5)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss() 

In [11]:
# 訓練ループ
def train():
    model.train()
    for data, target in loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        print(f"Loss: {loss.item()}")

# 訓練の実行
for epoch in range(10):
    print(f"Epoch {epoch + 1}")
    train()

Epoch 1
Loss: 0.426400363445282
Loss: 3.763733386993408
Loss: 0.390403687953949
Epoch 2
Loss: 1.350990653038025
Loss: 1.750903606414795
Loss: 0.0012758751399815083
Epoch 3
Loss: 0.00592241482809186
Loss: 2.2299811840057373
Loss: 0.5152594447135925
Epoch 4
Loss: 1.701269268989563
Loss: 0.021333137527108192
Loss: 0.9832474589347839
Epoch 5
Loss: 1.0436395406723022
Loss: 0.05441497266292572
Loss: 1.1648766994476318
Epoch 6
Loss: 0.054052043706178665
Loss: 1.0604530572891235
Loss: 1.1433470249176025
Epoch 7
Loss: 1.162408471107483
Loss: 0.07574790716171265
Loss: 1.0541080236434937
Epoch 8
Loss: 1.0236179828643799
Loss: 0.05168420076370239
Loss: 1.1822112798690796
Epoch 9
Loss: 0.04413101449608803
Loss: 0.9190360903739929
Loss: 1.202519178390503
Epoch 10
Loss: 0.041047483682632446
Loss: 0.8995616436004639
Loss: 1.2020825147628784


  return F.mse_loss(input, target, reduction=self.reduction)


In [13]:
def predict(smiles):
    model.eval()
    data = smiles_to_graph(smiles)
    with torch.no_grad():
        output = model(data)
    return output.argmax().item()

# 予測の例
test_smiles = "CCN"  # エチルアミン
predicted_class = predict(test_smiles)
print(f"Predicted value {test_smiles}: {predicted_class}")

Predicted value CCN: 0
