# Boost up AI 2025: 신약 개발 경진대회 (GNN method)
- 참가팀: dalcw
- 참가자: 문성수
- 소속: 전남대학교(휴학)

## Install Library
- torch
- torch_geometric
- pandas
- numpy
- tqdm
- rdkit

In [1]:
# library import
from rdkit import Chem
from rdkit.Chem import Descriptors, Lipinski, rdMolDescriptors, MACCSkeys, AllChem, Crippen, QED
from rdkit import RDLogger
from rdkit.ML.Descriptors import MoleculeDescriptors

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear, Sequential, ReLU, LayerNorm, Dropout
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
from torch_geometric.nn import GATConv, global_mean_pool, GINConv, AttentionalAggregation
from torch_geometric.loader import DataLoader

import numpy as np
import pandas as pd
import tqdm

# 모든 경고 끄기
RDLogger.DisableLog('rdApp.*')
device = "cuda:0" if torch.cuda.is_available() else "cpu"

## Dataset load
- 대회에서 제공하는 데이터만을 이용함

In [2]:
# train
train = pd.read_csv("./main_data/train.csv")

## Graph neural network method
- using GIN model

## Feature extraction

In [3]:
# 특징 추출기
# 그래프 구조의 분자를 특징 단위로 추출하기
def atom_to_feature(atom):
    g_charge = float(atom.GetProp('_GasteigerCharge')) if atom.HasProp('_GasteigerCharge') else 0.0
    return [
        atom.GetAtomicNum(),
        int(atom.GetIsAromatic()),
        int(atom.GetHybridization() == Chem.HybridizationType.SP),
        int(atom.GetHybridization() == Chem.HybridizationType.SP2),
        int(atom.GetHybridization() == Chem.HybridizationType.SP3),
        int(atom.GetHybridization() == Chem.HybridizationType.SP3D),
        int(atom.GetHybridization() == Chem.HybridizationType.SP3D2),
        atom.GetFormalCharge(),
        int(atom.IsInRing()),
        atom.GetTotalNumHs(),
        atom.GetDegree(),
        atom.GetImplicitValence(),
        atom.GetNumExplicitHs(),
        atom.GetNumImplicitHs(),
        atom.GetMass(),
        int(atom.GetIsotope()),
        int(atom.GetChiralTag()),
        int(atom.GetNoImplicit()),
        int(atom.HasProp("_CIPCode")),
        g_charge,  # Gasteiger 전하
        atom.GetNumRadicalElectrons(),  # 라디칼 전자 수
        atom.GetTotalValence()
    ]


def global_feature_extractor(mol):
    features = [
        Descriptors.MolWt(mol),
        Crippen.MolLogP(mol),
        Descriptors.TPSA(mol),
        Lipinski.NumRotatableBonds(mol),
        Lipinski.NumHDonors(mol),
        Lipinski.NumHAcceptors(mol),
        rdMolDescriptors.CalcNumAromaticRings(mol),
        rdMolDescriptors.CalcNumRings(mol),
        rdMolDescriptors.CalcFractionCSP3(mol),
        Descriptors.HeavyAtomCount(mol),
        rdMolDescriptors.CalcLabuteASA(mol),                       # 접근 가능 표면적
        Descriptors.MolMR(mol),                                    # 몰 굴절률
        rdMolDescriptors.CalcExactMolWt(mol),                      # 정밀 분자량
        Descriptors.NumValenceElectrons(mol),                      # 원자가 전자 수
        len([a for a in mol.GetAtoms() if a.GetSymbol() == 'P'])   # 인(P) 원자 수
    ]

    return features

In [4]:
# 그래프 생성 + 전역 정보 추출
def smiles_to_graph_data(row, train=True):
    mol = Chem.MolFromSmiles(row["Canonical_Smiles"])
    if mol is None:
        raise ValueError("Invalid SMILES")

    # 노드 피처
    x = torch.tensor([atom_to_feature(atom) for atom in mol.GetAtoms()], dtype=torch.float)
        
    # 엣지 정보
    edge_index = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_index += [[i, j], [j, i]]
    edge_index = torch.tensor(edge_index, dtype=torch.long).T

    # 정답값
    if train:
        y = torch.tensor(row["Inhibition"], dtype=torch.float)
    else:
        y = 0

    # 전역 피처
    global_x = global_feature_extractor(mol)
    global_x = torch.tensor([global_feature_extractor(mol)], dtype=torch.float)

    return Data(x=x, glo_x=global_x, edge_index=edge_index, y=y)
    
# train dataset
train_graphs = []
for _, row in tqdm.tqdm(train.iterrows(), total=len(train)):
    try: train_graphs.append(smiles_to_graph_data(row))
    except: pass

train_loader = DataLoader(train_graphs, batch_size=128, shuffle=True)

100%|██████████| 1681/1681 [00:03<00:00, 442.90it/s]


In [5]:
# 노드 특성 수집
all_node_features = []
for data in train_graphs:
    all_node_features.append(data.x.cpu().numpy())
all_node_features = np.concatenate(all_node_features, axis=0)  # shape (총 노드수, feat_dim)

node_mean = all_node_features.mean(axis=0)
node_std = all_node_features.std(axis=0)

# 전역 특성 수집
all_global_features = []
for data in train_graphs:
    all_global_features.append(data.glo_x.cpu().numpy())
all_global_features = np.stack(all_global_features, axis=0)  # shape (num_graphs, global_feat_dim)

global_mean = all_global_features.mean(axis=0)
global_std = all_global_features.std(axis=0)


def normalize_node_features(x, mean, std):
    return (x - mean) / (std + 1e-8)

def normalize_global_features(x, mean, std):
    return (x - mean) / (std + 1e-8)

for data in train_graphs:
    data.x = torch.tensor(normalize_node_features(data.x.cpu().numpy(), node_mean, node_std), dtype=torch.float32).to(data.x.device)
    data.glo_x = torch.tensor(normalize_global_features(data.glo_x.cpu().numpy(), global_mean, global_std), dtype=torch.float32).to(data.glo_x.device)

## Modeling

In [6]:
class SmileGIN(nn.Module):
    def __init__(self, atom_dim, global_dim, hidden_dim, dropout=0):
        super(SmileGIN, self).__init__()

        # embedding block: graph + global
        self.backbone = nn.ModuleDict({
            "atom_proj": Linear(atom_dim, hidden_dim),
            "gin_layers": nn.ModuleList([
                GINConv(Sequential(
                    Linear(hidden_dim, hidden_dim),
                    ReLU(),
                )) for _ in range(2)
            ]),
            "norm": LayerNorm(hidden_dim),
            "pool": AttentionalAggregation(
                gate_nn=Linear(hidden_dim, 1),
                nn=Sequential(Linear(hidden_dim, hidden_dim), ReLU())
            ),
            "global_embed": nn.Sequential(
                Linear(global_dim, hidden_dim),
                LayerNorm(hidden_dim),
                ReLU(),
            )
        })

        self.embedding = Linear(hidden_dim * 3, 64)

        # head block: fusion + regression
        self.head = nn.Sequential(
            Linear(64, 64),
            nn.ReLU(),
            Linear(64, 1)
        )

    def embedding_layer(self, graph_x, global_x, edge_index, batch):
        # embedding: atom features → GIN → pooling
        x = F.relu(self.backbone["atom_proj"](graph_x))
        for gin in self.backbone["gin_layers"]:
            x = x + gin(x, edge_index)
        x = self.backbone["norm"](x)
        graph_feat = self.backbone["pool"](x, batch)

        # embedding: global features
        global_feat = self.backbone["global_embed"](global_x)

        # fusion
        fusion = graph_feat * global_feat
        merged = torch.cat([graph_feat, global_feat, fusion], dim=1)
        embedding = self.embedding(merged)
        return embedding

    def regression_layer(self, embedding):
        return self.head(embedding)

    def forward(self, graph_x, global_x, edge_index, batch):
        embedding = self.embedding_layer(graph_x, global_x, edge_index, batch)
        output = self.regression_layer(embedding)
        return output

model = SmileGIN(atom_dim=22, global_dim=15, hidden_dim=64).to(device)

model(
    train_graphs[0].x.to(device),
    train_graphs[0].glo_x.to(device),
    train_graphs[0].edge_index.to(device),
    torch.zeros(train_graphs[0].num_nodes, dtype=torch.long).to(device)
)

tensor([[0.1481]], device='cuda:0', grad_fn=<AddmmBackward0>)

## Training

In [7]:
def mixup_embeddings(h, y, alpha=0.3):
    """ h: (B, D), y: (B,) """
    lam = np.random.beta(alpha, alpha)
    index = torch.randperm(h.size(0))
    h2 = h[index]
    y2 = y[index]
    h_mix = lam * h + (1 - lam) * h2
    y_mix = lam * y + (1 - lam) * y2
    return h_mix, y_mix

In [8]:
# optimizer
optimizer = optim.AdamW(model.parameters(), lr=0.0001)

# epoch
epochs = 100

for epoch in range(epochs):
    total_loss = 0

    for data in tqdm.tqdm(train_loader):
        # data
        graph_x = data.x.to(device)
        global_x = data.glo_x.to(device)
        edge_index = data.edge_index.to(device)
        batch = data.batch.to(device)    
        y = data.y.to(device)

        embeddings = model.embedding_layer(graph_x, global_x, edge_index, batch)

        # mixup
        embeddings_mix, y_mix = mixup_embeddings(embeddings, y)

        # prediction layer
        predict = model.regression_layer(embeddings_mix)

        # parameter update
        optimizer.zero_grad()

        loss = F.mse_loss(predict.squeeze(-1), y_mix.to(torch.float))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    # scheduler.step()
    print(f"[Train Loss - Epoch: {epoch+1}]\t{total_loss / len(train_loader):.5f}")

# model save
torch.save(model.state_dict(), "./parameter/gnn.pt")

100%|██████████| 14/14 [00:00<00:00, 36.54it/s]


[Train Loss - Epoch: 1]	1695.50245


100%|██████████| 14/14 [00:00<00:00, 60.24it/s]


[Train Loss - Epoch: 2]	1621.55353


100%|██████████| 14/14 [00:00<00:00, 61.68it/s]


[Train Loss - Epoch: 3]	1612.43089


100%|██████████| 14/14 [00:00<00:00, 60.13it/s]


[Train Loss - Epoch: 4]	1672.84671


100%|██████████| 14/14 [00:00<00:00, 59.57it/s]


[Train Loss - Epoch: 5]	1564.35744


100%|██████████| 14/14 [00:00<00:00, 58.28it/s]


[Train Loss - Epoch: 6]	1484.45120


100%|██████████| 14/14 [00:00<00:00, 61.73it/s]


[Train Loss - Epoch: 7]	1420.81787


100%|██████████| 14/14 [00:00<00:00, 58.85it/s]


[Train Loss - Epoch: 8]	1334.91029


100%|██████████| 14/14 [00:00<00:00, 59.43it/s]


[Train Loss - Epoch: 9]	1074.10828


100%|██████████| 14/14 [00:00<00:00, 61.32it/s]


[Train Loss - Epoch: 10]	979.87333


100%|██████████| 14/14 [00:00<00:00, 64.01it/s]


[Train Loss - Epoch: 11]	777.06402


100%|██████████| 14/14 [00:00<00:00, 62.64it/s]


[Train Loss - Epoch: 12]	649.78654


100%|██████████| 14/14 [00:00<00:00, 65.87it/s]


[Train Loss - Epoch: 13]	568.40991


100%|██████████| 14/14 [00:00<00:00, 64.46it/s]


[Train Loss - Epoch: 14]	508.96485


100%|██████████| 14/14 [00:00<00:00, 60.51it/s]


[Train Loss - Epoch: 15]	552.86294


100%|██████████| 14/14 [00:00<00:00, 64.36it/s]


[Train Loss - Epoch: 16]	512.42367


100%|██████████| 14/14 [00:00<00:00, 60.68it/s]


[Train Loss - Epoch: 17]	499.01110


100%|██████████| 14/14 [00:00<00:00, 60.87it/s]


[Train Loss - Epoch: 18]	502.10459


100%|██████████| 14/14 [00:00<00:00, 59.16it/s]


[Train Loss - Epoch: 19]	518.16650


100%|██████████| 14/14 [00:00<00:00, 59.76it/s]


[Train Loss - Epoch: 20]	515.92695


100%|██████████| 14/14 [00:00<00:00, 69.99it/s]


[Train Loss - Epoch: 21]	517.74540


100%|██████████| 14/14 [00:00<00:00, 66.45it/s]


[Train Loss - Epoch: 22]	499.19716


100%|██████████| 14/14 [00:00<00:00, 60.41it/s]


[Train Loss - Epoch: 23]	522.26166


100%|██████████| 14/14 [00:00<00:00, 64.57it/s]


[Train Loss - Epoch: 24]	554.78002


100%|██████████| 14/14 [00:00<00:00, 63.12it/s]


[Train Loss - Epoch: 25]	491.99649


100%|██████████| 14/14 [00:00<00:00, 53.28it/s]


[Train Loss - Epoch: 26]	468.84478


100%|██████████| 14/14 [00:00<00:00, 61.74it/s]


[Train Loss - Epoch: 27]	495.28319


100%|██████████| 14/14 [00:00<00:00, 61.80it/s]


[Train Loss - Epoch: 28]	478.78910


100%|██████████| 14/14 [00:00<00:00, 61.97it/s]


[Train Loss - Epoch: 29]	469.11237


100%|██████████| 14/14 [00:00<00:00, 60.92it/s]


[Train Loss - Epoch: 30]	478.00648


100%|██████████| 14/14 [00:00<00:00, 61.48it/s]


[Train Loss - Epoch: 31]	487.11269


100%|██████████| 14/14 [00:00<00:00, 63.73it/s]


[Train Loss - Epoch: 32]	500.85899


100%|██████████| 14/14 [00:00<00:00, 63.11it/s]


[Train Loss - Epoch: 33]	487.26521


100%|██████████| 14/14 [00:00<00:00, 61.72it/s]


[Train Loss - Epoch: 34]	476.51468


100%|██████████| 14/14 [00:00<00:00, 64.03it/s]


[Train Loss - Epoch: 35]	492.71873


100%|██████████| 14/14 [00:00<00:00, 68.70it/s]


[Train Loss - Epoch: 36]	462.92436


100%|██████████| 14/14 [00:00<00:00, 63.15it/s]


[Train Loss - Epoch: 37]	480.69007


100%|██████████| 14/14 [00:00<00:00, 62.52it/s]


[Train Loss - Epoch: 38]	461.29638


100%|██████████| 14/14 [00:00<00:00, 62.43it/s]


[Train Loss - Epoch: 39]	491.11789


100%|██████████| 14/14 [00:00<00:00, 61.28it/s]


[Train Loss - Epoch: 40]	469.83847


100%|██████████| 14/14 [00:00<00:00, 61.65it/s]


[Train Loss - Epoch: 41]	477.71583


100%|██████████| 14/14 [00:00<00:00, 60.70it/s]


[Train Loss - Epoch: 42]	447.92171


100%|██████████| 14/14 [00:00<00:00, 64.27it/s]


[Train Loss - Epoch: 43]	472.41763


100%|██████████| 14/14 [00:00<00:00, 63.35it/s]


[Train Loss - Epoch: 44]	513.33366


100%|██████████| 14/14 [00:00<00:00, 59.45it/s]


[Train Loss - Epoch: 45]	464.89739


100%|██████████| 14/14 [00:00<00:00, 64.64it/s]


[Train Loss - Epoch: 46]	458.56215


100%|██████████| 14/14 [00:00<00:00, 60.87it/s]


[Train Loss - Epoch: 47]	500.52063


100%|██████████| 14/14 [00:00<00:00, 61.86it/s]


[Train Loss - Epoch: 48]	483.97919


100%|██████████| 14/14 [00:00<00:00, 62.41it/s]


[Train Loss - Epoch: 49]	514.15308


100%|██████████| 14/14 [00:00<00:00, 61.37it/s]


[Train Loss - Epoch: 50]	439.78538


100%|██████████| 14/14 [00:00<00:00, 61.30it/s]


[Train Loss - Epoch: 51]	482.11544


100%|██████████| 14/14 [00:00<00:00, 61.01it/s]


[Train Loss - Epoch: 52]	464.33586


100%|██████████| 14/14 [00:00<00:00, 62.28it/s]


[Train Loss - Epoch: 53]	448.25654


100%|██████████| 14/14 [00:00<00:00, 60.75it/s]


[Train Loss - Epoch: 54]	458.27455


100%|██████████| 14/14 [00:00<00:00, 64.14it/s]


[Train Loss - Epoch: 55]	453.18012


100%|██████████| 14/14 [00:00<00:00, 63.64it/s]


[Train Loss - Epoch: 56]	472.16274


100%|██████████| 14/14 [00:00<00:00, 66.11it/s]


[Train Loss - Epoch: 57]	382.72807


100%|██████████| 14/14 [00:00<00:00, 64.38it/s]


[Train Loss - Epoch: 58]	465.68879


100%|██████████| 14/14 [00:00<00:00, 62.30it/s]


[Train Loss - Epoch: 59]	487.44297


100%|██████████| 14/14 [00:00<00:00, 62.80it/s]


[Train Loss - Epoch: 60]	504.66940


100%|██████████| 14/14 [00:00<00:00, 60.68it/s]


[Train Loss - Epoch: 61]	468.39469


100%|██████████| 14/14 [00:00<00:00, 62.15it/s]


[Train Loss - Epoch: 62]	417.96581


100%|██████████| 14/14 [00:00<00:00, 62.46it/s]


[Train Loss - Epoch: 63]	424.82277


100%|██████████| 14/14 [00:00<00:00, 60.49it/s]


[Train Loss - Epoch: 64]	441.79819


100%|██████████| 14/14 [00:00<00:00, 64.35it/s]


[Train Loss - Epoch: 65]	471.29223


100%|██████████| 14/14 [00:00<00:00, 64.84it/s]


[Train Loss - Epoch: 66]	464.29193


100%|██████████| 14/14 [00:00<00:00, 63.64it/s]


[Train Loss - Epoch: 67]	395.57597


100%|██████████| 14/14 [00:00<00:00, 65.05it/s]


[Train Loss - Epoch: 68]	404.22004


100%|██████████| 14/14 [00:00<00:00, 58.74it/s]


[Train Loss - Epoch: 69]	401.50018


100%|██████████| 14/14 [00:00<00:00, 62.36it/s]


[Train Loss - Epoch: 70]	460.45662


100%|██████████| 14/14 [00:00<00:00, 61.72it/s]


[Train Loss - Epoch: 71]	436.44714


100%|██████████| 14/14 [00:00<00:00, 61.95it/s]


[Train Loss - Epoch: 72]	483.95712


100%|██████████| 14/14 [00:00<00:00, 61.25it/s]


[Train Loss - Epoch: 73]	464.57282


100%|██████████| 14/14 [00:00<00:00, 60.69it/s]


[Train Loss - Epoch: 74]	461.69831


100%|██████████| 14/14 [00:00<00:00, 62.06it/s]


[Train Loss - Epoch: 75]	459.20649


100%|██████████| 14/14 [00:00<00:00, 63.07it/s]


[Train Loss - Epoch: 76]	425.43730


100%|██████████| 14/14 [00:00<00:00, 64.43it/s]


[Train Loss - Epoch: 77]	447.51850


100%|██████████| 14/14 [00:00<00:00, 71.97it/s]


[Train Loss - Epoch: 78]	404.87502


100%|██████████| 14/14 [00:00<00:00, 58.29it/s]


[Train Loss - Epoch: 79]	425.30629


100%|██████████| 14/14 [00:00<00:00, 61.19it/s]


[Train Loss - Epoch: 80]	423.15430


100%|██████████| 14/14 [00:00<00:00, 64.00it/s]


[Train Loss - Epoch: 81]	470.63121


100%|██████████| 14/14 [00:00<00:00, 62.22it/s]


[Train Loss - Epoch: 82]	492.92881


100%|██████████| 14/14 [00:00<00:00, 64.72it/s]


[Train Loss - Epoch: 83]	403.61013


100%|██████████| 14/14 [00:00<00:00, 62.58it/s]


[Train Loss - Epoch: 84]	422.61800


100%|██████████| 14/14 [00:00<00:00, 60.26it/s]


[Train Loss - Epoch: 85]	457.95210


100%|██████████| 14/14 [00:00<00:00, 61.65it/s]


[Train Loss - Epoch: 86]	431.69734


100%|██████████| 14/14 [00:00<00:00, 63.09it/s]


[Train Loss - Epoch: 87]	408.05556


100%|██████████| 14/14 [00:00<00:00, 64.73it/s]


[Train Loss - Epoch: 88]	424.38277


100%|██████████| 14/14 [00:00<00:00, 63.92it/s]


[Train Loss - Epoch: 89]	431.10311


100%|██████████| 14/14 [00:00<00:00, 60.01it/s]


[Train Loss - Epoch: 90]	433.56163


100%|██████████| 14/14 [00:00<00:00, 69.64it/s]


[Train Loss - Epoch: 91]	468.05026


100%|██████████| 14/14 [00:00<00:00, 62.64it/s]


[Train Loss - Epoch: 92]	433.98140


100%|██████████| 14/14 [00:00<00:00, 61.62it/s]


[Train Loss - Epoch: 93]	417.84295


100%|██████████| 14/14 [00:00<00:00, 62.35it/s]


[Train Loss - Epoch: 94]	451.39547


100%|██████████| 14/14 [00:00<00:00, 62.50it/s]


[Train Loss - Epoch: 95]	408.50362


100%|██████████| 14/14 [00:00<00:00, 60.77it/s]


[Train Loss - Epoch: 96]	475.38767


100%|██████████| 14/14 [00:00<00:00, 62.36it/s]


[Train Loss - Epoch: 97]	438.40037


100%|██████████| 14/14 [00:00<00:00, 64.73it/s]


[Train Loss - Epoch: 98]	433.31975


100%|██████████| 14/14 [00:00<00:00, 70.86it/s]


[Train Loss - Epoch: 99]	442.75444


100%|██████████| 14/14 [00:00<00:00, 84.56it/s]

[Train Loss - Epoch: 100]	373.08113





## Test file gen.

In [9]:
# test file gneneration
test = pd.read_csv("./main_data/test.csv")

test_graphs = []
for _, row in tqdm.tqdm(test.iterrows(), total=len(test)):
    try: test_graphs.append(smiles_to_graph_data(row, False))
    except: pass

for data in test_graphs:
    data.x = torch.tensor(normalize_node_features(data.x.cpu().numpy(), node_mean, node_std), dtype=torch.float32).to(data.x.device)
    data.glo_x = torch.tensor(normalize_global_features(data.glo_x.cpu().numpy(), global_mean, global_std), dtype=torch.float32).to(data.glo_x.device)

predicts = []
model.eval()
with torch.no_grad():
    for data in test_graphs:

        preds = model(
                data.x.to(device), 
                data.glo_x.to(device),
                data.edge_index.to(device),
                torch.zeros(data.num_nodes, dtype=torch.long).to(device)
            )

        predicts.append(preds.item())

# submission
sub = pd.read_csv("./main_data/sample_submission.csv")
sub["Inhibition"] = predicts

sub.to_csv("./submission/submission_gnn.csv", index=False)

100%|██████████| 100/100 [00:00<00:00, 478.23it/s]
