In [None]:
from typing import List, Literal

import deepchem as dc
import torch
import torch.nn as nn
import torch.nn.functional as F
from rdkit import Chem
from rdkit.Chem import AllChem
from torch.utils.data import DataLoader
from torch_geometric.data import Batch, Data, Dataset
from torch_geometric.nn import GCNConv, global_mean_pool

No normalization for SPS. Feature removed!
No normalization for AvgIpc. Feature removed!
Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'
Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'dgl'
Skipped loading modules with transformers dependency. No module named 'transformers'
cannot import name 'HuggingFaceModel' from 'deepchem.models.torch_models' (/home/mori/miniforge3/envs/torch/lib/python3.10/site-packages/deepchem/models/torch_models/__init__.py)
Skipped loading modules with pytorch-lightning dependency, missing a dependency. No module named 'lightning'
Skipped loading some Jax models, missing a dependency. No module named 'jax'
Skipped loading some PyTorch models, missing a dependency. No module named 'tensorflow'


In [4]:
class MolGraphDataset(Dataset):
    def __init__(self, smiles_list, transform=None):
        super().__init__()
        self.smiles_list = smiles_list
        self.featurizer = dc.feat.MolGraphConvFeaturizer(use_edges=True)
        self.transform = transform

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, idx):
        smi = self.smiles_list[idx]
        mol = Chem.MolFromSmiles(smi)
        mol = AllChem.AddHs(mol)
        dgraph = self.featurizer.featurize([mol])[0]
        pyg_data = self._dc_to_pyg(dgraph)
        return pyg_data

    def _dc_to_pyg(self, dgraph):
        node_feats = torch.tensor(dgraph.node_features, dtype=torch.float)
        edge_index = torch.tensor(dgraph.edge_index, dtype=torch.long)
        edge_feats = torch.tensor(dgraph.edge_features, dtype=torch.float)

        data = Data(x=node_feats, edge_index=edge_index, edge_attr=edge_feats)
        return data

In [5]:
def random_drop_edge(pyg_data: Data, drop_prob: float = 0.1):
    edge_index = pyg_data.edge_index
    edge_attr = pyg_data.edge_attr

    num_edges = edge_index.size(1)
    mask = torch.rand(num_edges) > drop_prob
    edge_index_new = edge_index[:, mask]
    edge_attr_new = edge_attr[mask]

    aug_data = Data(
        x=pyg_data.x.clone(), edge_index=edge_index_new, edge_attr=edge_attr_new
    )
    return aug_data

In [6]:
def random_mask_node(pyg_data: Data, mask_prob: float = 0.1):
    data = Data(
        x=pyg_data.x.clone(),
        edge_index=pyg_data.edge_index.clone(),
        edge_attr=pyg_data.edge_attr.clone(),
    )
    node_mask = torch.rand(data.x.size(0)) < mask_prob
    data.x[node_mask] = 0.0
    return data

In [7]:
def augmentation(
    pyg_data: Data,
    mode: List[Literal["drop_edge", "mask_node"]],
    edge_drop_prob: float = 0.1,
    node_mask_prob: float = 0.1,
):
    data = pyg_data
    if mode in ["drop_edge"]:
        data = random_drop_edge(data, drop_prob=edge_drop_prob)
    if mode in ["mask_node"]:
        data = random_mask_node(data, mask_prob=node_mask_prob)
    return data

In [8]:
def get_collate_fn(
    augment_mode: List[Literal["drop_edge", "mask_node"]] = ["drop_edge", "mask_node"],
    edge_drop_prob=0.1,
    node_mask_prob=0.1,
):
    def collate_fn(batch):
        aug1_list = []
        aug2_list = []
        for pyg_data in batch:
            aug1 = augmentation(
                pyg_data,
                augment_mode,
                edge_drop_prob=edge_drop_prob,
                node_mask_prob=node_mask_prob,
            )
            aug2 = augmentation(
                pyg_data,
                augment_mode,
                edge_drop_prob=edge_drop_prob,
                node_mask_prob=node_mask_prob,
            )
            aug1_list.append(aug1)
            aug2_list.append(aug2)
        return aug1_list, aug2_list

    return collate_fn

In [9]:
class GNNEncoder(nn.Module):
    def __init__(
        self, num_node_features=20, num_edge_features=11, hidden_dim=64, out_dim=64
    ):
        super().__init__()
        self.conv1 = GCNConv(num_node_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.proj_head = nn.Linear(hidden_dim, out_dim)

    def forward(self, x, edge_index, edge_attr, batch_index):
        h = self.conv1(x, edge_index)
        h = F.relu(h)
        h = self.conv2(h, edge_index)
        h = F.relu(h)
        h_pool = global_mean_pool(h, batch_index)
        z = self.proj_head(h_pool)
        return z

In [10]:
def info_nce_loss(z1, z2, temperature=0.1):
    z1_norm = F.normalize(z1, p=2, dim=1)
    z2_norm = F.normalize(z2, p=2, dim=1)
    sim_matrix = torch.matmul(z1_norm, z2_norm.t())
    sim_matrix = sim_matrix / temperature

    # InfoNCE ロスの計算
    logits_max, _ = torch.max(sim_matrix, dim=1, keepdim=True)
    sim_matrix_exp = torch.exp(sim_matrix - logits_max.detach())
    pos = torch.diag(sim_matrix_exp)
    denom = torch.sum(sim_matrix_exp, dim=1)
    loss = -torch.log(pos / denom)
    return loss.mean()

In [None]:
def train_contrastive(
    smiles_list,
    epochs=10,
    batch_size=32,
    lr=1e-3,
    augmentation_mode: List[Literal["drop_edge", "mask_node"]] = [
        "drop_edge",
        "mask_node",
    ],
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # 2) Dataset & Dataloader 準備
    dataset = MolGraphDataset(smiles_list)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=get_collate_fn(augment_mode=augmentation_mode),
    )

    # 3) GNNエンコーダ作成
    encoder = GNNEncoder(num_node_features=30, num_edge_features=11)
    encoder.to(device)  # GPU/CPUに転送

    # 4) Optimizer
    optimizer = torch.optim.Adam(encoder.parameters(), lr=lr)

    # 5) 学習ループ
    for epoch in range(1, epochs + 1):
        encoder.train()
        total_loss = 0.0
        for i, (aug1_list, aug2_list) in enumerate(dataloader):
            batch_aug1 = Batch.from_data_list(aug1_list).to(device)
            batch_aug2 = Batch.from_data_list(aug2_list).to(device)
            # 順伝搬
            z1 = encoder(
                batch_aug1.x,
                batch_aug1.edge_index,
                batch_aug1.edge_attr,
                batch_aug1.batch,
            )
            z2 = encoder(
                batch_aug2.x,
                batch_aug2.edge_index,
                batch_aug2.edge_attr,
                batch_aug2.batch,
            )

            # InfoNCE loss 計算
            loss = info_nce_loss(z1, z2, temperature=0.1)

            # 逆伝搬 & パラメータ更新
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            if i % 100 == 0:
                print(f"Epoch [{epoch}/{epochs}] | Batch [{i}/{len(dataloader)}] | Loss: {loss.item():.4f}")

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch [{epoch}/{epochs}] | Loss: {avg_loss:.4f}")

    torch.save(encoder.state_dict(), "pretrained_gnn.pt")

In [1]:
!wget https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/pubchem_10m.txt.zip

--2024-12-22 21:04:28--  https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/pubchem_10m.txt.zip
Resolving deepchemdata.s3-us-west-1.amazonaws.com (deepchemdata.s3-us-west-1.amazonaws.com)... 3.5.160.186, 3.5.160.162, 52.219.220.178, ...
Connecting to deepchemdata.s3-us-west-1.amazonaws.com (deepchemdata.s3-us-west-1.amazonaws.com)|3.5.160.186|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 130376753 (124M) [application/zip]
Saving to: ‘pubchem_10m.txt.zip’


2024-12-22 21:04:44 (8.07 MB/s) - ‘pubchem_10m.txt.zip’ saved [130376753/130376753]



In [15]:
import io
import zipfile

zip_file_path = "pubchem_10m.txt.zip"
lines_to_read = 1000
smiles_list = []
with zipfile.ZipFile(zip_file_path, "r") as zip_file:
    file_names = zip_file.namelist()
    for file_name in file_names:
        if file_name.endswith(".txt"):
            with zip_file.open(file_name) as file:
                with io.TextIOWrapper(file, encoding="utf-8") as text_file:
                    for i, line in enumerate(text_file):
                        if i >= lines_to_read:
                            break
                        smiles_list.append(line.strip())

In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
trained_encoder = train_contrastive(
    smiles_list=smiles_list,
    epochs=20,
    batch_size=32,
    lr=1e-3,
    augmentation_mode=["drop_edge", "mask_node"],
)

Using device: cpu
Epoch [1/10] | Loss: 0.2442
Epoch [2/10] | Loss: 0.0461
Epoch [3/10] | Loss: 0.0329
Epoch [4/10] | Loss: 0.0327
Epoch [5/10] | Loss: 0.0288
Epoch [6/10] | Loss: 0.0230
Epoch [7/10] | Loss: 0.0269
Epoch [8/10] | Loss: 0.0256
Epoch [9/10] | Loss: 0.0154
Epoch [10/10] | Loss: 0.0147


In [17]:
pretrained_encoder = GNNEncoder(num_node_features=30, num_edge_features=11)
pretrained_encoder.load_state_dict(torch.load("pretrained_gnn.pt"))

  pretrained_encoder.load_state_dict(torch.load("pretrained_gnn.pt"))


<All keys matched successfully>