# attentive.py

In [1]:
!pip install torch_geometric
from typing import Optional

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import GRUCell, Linear, Parameter

from torch_geometric.nn import GATConv, MessagePassing, global_add_pool
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils import softmax


class GATEConv(MessagePassing):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        edge_dim: int,
        dropout: float = 0.0,
    ):
        super().__init__(aggr='add', node_dim=0)

        self.dropout = dropout

        self.att_l = Parameter(torch.empty(1, out_channels))
        self.att_r = Parameter(torch.empty(1, in_channels))

        self.lin1 = Linear(in_channels + edge_dim, out_channels, False)
        self.lin2 = Linear(out_channels, out_channels, False)

        self.bias = Parameter(torch.empty(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.att_l)
        glorot(self.att_r)
        glorot(self.lin1.weight)
        glorot(self.lin2.weight)
        zeros(self.bias)

    def forward(self, x: Tensor, edge_index: Adj, edge_attr: Tensor) -> Tensor:
        # edge_updater_type: (x: Tensor, edge_attr: Tensor)
        alpha = self.edge_updater(edge_index, x=x, edge_attr=edge_attr)

        # propagate_type: (x: Tensor, alpha: Tensor)
        out = self.propagate(edge_index, x=x, alpha=alpha)
        out = out + self.bias
        return out

    def edge_update(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor,
                    index: Tensor, ptr: OptTensor,
                    size_i: Optional[int]) -> Tensor:
        x_j = F.leaky_relu_(self.lin1(torch.cat([x_j, edge_attr], dim=-1)))
        alpha_j = (x_j @ self.att_l.t()).squeeze(-1)
        alpha_i = (x_i @ self.att_r.t()).squeeze(-1)
        alpha = alpha_j + alpha_i
        alpha = F.leaky_relu_(alpha)
        alpha = softmax(alpha, index, ptr, size_i)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        return alpha

    def message(self, x_j: Tensor, alpha: Tensor) -> Tensor:
        return self.lin2(x_j) * alpha.unsqueeze(-1)


class AttentiveFP(torch.nn.Module):
    r"""The Attentive FP model for molecular representation learning from the
    `"Pushing the Boundaries of Molecular Representation for Drug Discovery
    with the Graph Attention Mechanism"
    <https://pubs.acs.org/doi/10.1021/acs.jmedchem.9b00959>`_ paper, based on
    graph attention mechanisms.

    Args:
        in_channels (int): Size of each input sample.
        hidden_channels (int): Hidden node feature dimensionality.
        out_channels (int): Size of each output sample.
        edge_dim (int): Edge feature dimensionality.
        num_layers (int): Number of GNN layers.
        num_timesteps (int): Number of iterative refinement steps for global
            readout.
        dropout (float, optional): Dropout probability. (default: :obj:`0.0`)

    """
    def __init__(
        self,
        in_channels: int,
        hidden_channels: int,
        out_channels: int,
        edge_dim: int,
        num_layers: int,
        num_timesteps: int,
        dropout: float = 0.0,
    ):
        super().__init__()

        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.edge_dim = edge_dim
        self.num_layers = num_layers
        self.num_timesteps = num_timesteps
        self.dropout = dropout

        self.lin1 = Linear(in_channels, hidden_channels)

        self.gate_conv = GATEConv(hidden_channels, hidden_channels, edge_dim,
                                  dropout)
        self.gru = GRUCell(hidden_channels, hidden_channels)

        self.atom_convs = torch.nn.ModuleList()
        self.atom_grus = torch.nn.ModuleList()
        for _ in range(num_layers - 1):
            conv = GATConv(hidden_channels, hidden_channels, dropout=dropout,
                           add_self_loops=False, negative_slope=0.01)
            self.atom_convs.append(conv)
            self.atom_grus.append(GRUCell(hidden_channels, hidden_channels))

        self.mol_conv = GATConv(hidden_channels, hidden_channels,
                                dropout=dropout, add_self_loops=False,
                                negative_slope=0.01)
        self.mol_conv.explain = False  # Cannot explain global pooling.
        self.mol_gru = GRUCell(hidden_channels, hidden_channels)

        self.lin2 = Linear(hidden_channels, out_channels)

        self.reset_parameters()

    def reset_parameters(self):
        r"""Resets all learnable parameters of the module."""
        self.lin1.reset_parameters()
        self.gate_conv.reset_parameters()
        self.gru.reset_parameters()
        for conv, gru in zip(self.atom_convs, self.atom_grus):
            conv.reset_parameters()
            gru.reset_parameters()
        self.mol_conv.reset_parameters()
        self.mol_gru.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor,
                batch: Tensor) -> Tensor:
        """"""  # noqa: D419
        # Atom Embedding:
        x = F.leaky_relu_(self.lin1(x))

        h = F.elu_(self.gate_conv(x, edge_index, edge_attr))
        h = F.dropout(h, p=self.dropout, training=self.training)
        x = self.gru(h, x).relu_()

        for conv, gru in zip(self.atom_convs, self.atom_grus):
            h = conv(x, edge_index)
            h = F.elu(h)
            h = F.dropout(h, p=self.dropout, training=self.training)
            x = gru(h, x).relu()

        # Molecule Embedding:
        row = torch.arange(batch.size(0), device=batch.device)
        edge_index = torch.stack([row, batch], dim=0)

        out = global_add_pool(x, batch).relu_()
        for t in range(self.num_timesteps):
            h = F.elu_(self.mol_conv((x, out), edge_index))
            h = F.dropout(h, p=self.dropout, training=self.training)
            out = self.mol_gru(h, out).relu_()

        # Predictor:
        out = F.dropout(out, p=self.dropout, training=self.training)
        return self.lin2(out)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}('
                f'in_channels={self.in_channels}, '
                f'hidden_channels={self.hidden_channels}, '
                f'out_channels={self.out_channels}, '
                f'edge_dim={self.edge_dim}, '
                f'num_layers={self.num_layers}, '
                f'num_timesteps={self.num_timesteps}'
                f')')


Defaulting to user installation because normal site-packages is not writeable


# attentive_fp.py

In [18]:
!pip install rdkit
!pip install optuna
import optuna
from optuna.trial import TrialState
import os.path as osp
from math import sqrt

import torch
import torch.nn.functional as F
from rdkit import Chem

from torch_geometric.datasets import MoleculeNet
from torch_geometric.loader import DataLoader
#from attentive_fp import AttentiveFP


class GenFeatures:
    def __init__(self):
        self.symbols = [
            'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br',
            'Te', 'I', 'At', 'other'
        ]

        self.hybridizations = [
            Chem.rdchem.HybridizationType.SP,
            Chem.rdchem.HybridizationType.SP2,
            Chem.rdchem.HybridizationType.SP3,
            Chem.rdchem.HybridizationType.SP3D,
            Chem.rdchem.HybridizationType.SP3D2,
            'other',
        ]

        self.stereos = [
            Chem.rdchem.BondStereo.STEREONONE,
            Chem.rdchem.BondStereo.STEREOANY,
            Chem.rdchem.BondStereo.STEREOZ,
            Chem.rdchem.BondStereo.STEREOE,
        ]

    def __call__(self, data):
        # Generate AttentiveFP features according to Table 1.
        mol = Chem.MolFromSmiles(data.smiles)

        xs = []
        from rdkit.Chem import Descriptors

    def __call__(self, data):
        # Generate AttentiveFP features according to Table 1.
        mol = Chem.MolFromSmiles(data.smiles)

        xs = []
        for atom in mol.GetAtoms():
            symbol = [0.] * len(self.symbols)
            symbol[self.symbols.index(atom.GetSymbol())] = 1.
            degree = [0.] * 6
            degree[atom.GetDegree()] = 1.
            formal_charge = atom.GetFormalCharge()
            radical_electrons = atom.GetNumRadicalElectrons()
            hybridization = [0.] * len(self.hybridizations)
            hybridization[self.hybridizations.index(
                atom.GetHybridization())] = 1.
            aromaticity = 1. if atom.GetIsAromatic() else 0.
            hydrogens = [0.] * 5
            hydrogens[atom.GetTotalNumHs()] = 1.
            chirality = 1. if atom.HasProp('_ChiralityPossible') else 0.
            chirality_type = [0.] * 2
            is_in_ring = 1. if atom.IsInRing() else 0.
            mass = atom.GetMass()
            is_donor = 1. if rdMolDescriptors.CalcNumHBD(mol) > 0 else 0.
            is_acceptor = 1. if rdMolDescriptors.CalcNumHBA(mol) > 0 else 0.
            if atom.HasProp('_CIPCode'):
                chirality_type[['R', 'S'].index(atom.GetProp('_CIPCode'))] = 1.

            # Append extra atom features
            x = torch.cat((x, torch.tensor([is_in_ring, mass, is_donor, is_acceptor])))

            x = torch.tensor(symbol + degree + [formal_charge] +
                             [radical_electrons] + hybridization +
                             [aromaticity] + hydrogens + [chirality] +
                             chirality_type)
            xs.append(x)

        data.x = torch.stack(xs, dim=0)

        edge_indices = []
        edge_attrs = []
        for bond in mol.GetBonds():
            bond_dir = 0.
            if bond.GetBondDir() in [BondDir.ENDUPRIGHT, BondDir.ENDDOWNRIGHT]:
                bond_dir = 1.

            if conf is not None:
                idx1, idx2 = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                pos1 = conf.GetAtomPosition(idx1)
                pos2 = conf.GetAtomPosition(idx2)
                bond_length = float((pos1 - pos2).Length())
            else:
                bond_length = 0.
            edge_indices += [[bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]]
            edge_indices += [[bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]]

            bond_type = bond.GetBondType()
            single = 1. if bond_type == Chem.rdchem.BondType.SINGLE else 0.
            double = 1. if bond_type == Chem.rdchem.BondType.DOUBLE else 0.
            triple = 1. if bond_type == Chem.rdchem.BondType.TRIPLE else 0.
            aromatic = 1. if bond_type == Chem.rdchem.BondType.AROMATIC else 0.
            conjugation = 1. if bond.GetIsConjugated() else 0.
            ring = 1. if bond.IsInRing() else 0.
            stereo = [0.] * 4
            stereo[self.stereos.index(bond.GetStereo())] = 1.

            edge_attr = torch.tensor([bond_length, bond_dir] +
                [single, double, triple, aromatic, conjugation] + stereo)

            edge_attrs += [edge_attr, edge_attr]

        if len(edge_attrs) == 0:
            data.edge_index = torch.zeros((2, 0), dtype=torch.long)
            data.edge_attr = torch.zeros((0, 10), dtype=torch.float)
        else:
            data.edge_index = torch.tensor(edge_indices).t().contiguous()
            data.edge_attr = torch.stack(edge_attrs, dim=0)

        return data

def train():
    total_loss = total_examples = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        loss = F.mse_loss(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs
        total_examples += data.num_graphs
    return sqrt(total_loss / total_examples)


@torch.no_grad()
def test(loader):
    mse = []
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        mse.append(F.mse_loss(out, data.y, reduction='none').cpu())
    return float(torch.cat(mse, dim=0).mean().sqrt())

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [20]:
import random
import numpy as np
import copy
import os.path as osp
from rdkit.Chem import rdMolDescriptors
from rdkit.Chem.rdchem import BondDir
from rdkit.Chem import AllChem

# Compute molecule-level info once
#mol = Chem.AddHs(mol)
#AllChem.EmbedMolecule(mol, AllChem.ETKDG())
#conf = mol.GetConformer() if mol.GetNumConformers() > 0 else None

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(21)

# Use current working directory in Jupyter
notebook_dir = osp.dirname(osp.abspath(''))
path = osp.abspath('../data/AFP_Mol')

#path = '/content/drive/MyDrive/my_project/data/AFP_Mol'
dataset = MoleculeNet(path, name='ESOL', pre_transform=GenFeatures()).shuffle()

N = len(dataset) // 10
val_dataset = dataset[:N]
test_dataset = dataset[N:2 * N]
train_dataset = dataset[2 * N:]

train_loader = DataLoader(train_dataset, batch_size=200, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=200)
test_loader = DataLoader(test_dataset, batch_size=200)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = AttentiveFP(in_channels=39, hidden_channels=250, out_channels=1,
                    edge_dim=10, num_layers=3, num_timesteps=2,
                    dropout=0.2).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.003,
                             weight_decay=0.00001)

for epoch in range(1, 671):
    train_rmse = train()
    val_rmse = test(val_loader)
    test_rmse = test(test_loader)
    if(epoch%10==0):
      print(f'Epoch: {epoch:03d}, Loss: {train_rmse:.4f} Val: {val_rmse:.4f} '
          f'Test: {test_rmse:.4f}')


def objective(trial):
    set_seed(20)  # Ensures reproducibility

    # Suggest hyperparameters
    num_layers = trial.suggest_int("num_layers", 2, 5)
    weight_decay = trial.suggest_loguniform("weight_decay", 1e-6, 1e-3)
    learning_rate = trial.suggest_loguniform("lr", 1e-4, 1e-2)

    # Define model and optimizer with trial parameters
    model = AttentiveFP(
        in_channels=39,
        hidden_channels=200,
        out_channels=1,
        edge_dim=10,
        num_layers=num_layers,
        num_timesteps=2,
        dropout=0.2
    ).to(device)

    optimizer = torch.optim.Adam(
        model.parameters(), lr=learning_rate, weight_decay=weight_decay
    )

    # Early stopping
    best_val = float("inf")
    best_model_state = None
    patience = 20
    counter = 0

    for epoch in range(1, 201):
        train_rmse = train()
        val_rmse = test(test_loader)

        if val_rmse < best_val:
            best_val = val_rmse
            best_model_state = copy.deepcopy(model.state_dict())
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                break

    # Save best model for this trial
    trial.set_user_attr("best_model_state_dict", best_model_state)

    return best_val

# Run study
#study = optuna.create_study(direction="minimize")
#study.optimize(objective, n_trials=25)

# Report best trial
#best_trial = study.best_trial
#print("Best Trial:")
#print(f"  Value: {best_trial.value:.4f}")
#for k, v in best_trial.params.items():
#    print(f"  {k}: {v}")



Epoch: 010, Loss: 1.1568 Val: 1.1498 Test: 1.1486
Epoch: 020, Loss: 0.8902 Val: 0.8492 Test: 0.8891
Epoch: 030, Loss: 0.8121 Val: 0.7784 Test: 0.8777
Epoch: 040, Loss: 0.7622 Val: 0.7534 Test: 0.7831
Epoch: 050, Loss: 0.7305 Val: 0.7528 Test: 0.7221
Epoch: 060, Loss: 0.6625 Val: 0.6728 Test: 0.7889
Epoch: 070, Loss: 0.6547 Val: 0.7016 Test: 0.7414
Epoch: 080, Loss: 0.6312 Val: 0.6972 Test: 0.6891
Epoch: 090, Loss: 0.6103 Val: 0.6567 Test: 0.6414
Epoch: 100, Loss: 0.5933 Val: 0.6645 Test: 0.6653
Epoch: 110, Loss: 0.5614 Val: 0.6585 Test: 0.7025
Epoch: 120, Loss: 0.5637 Val: 0.6277 Test: 0.6091
Epoch: 130, Loss: 0.5553 Val: 0.6233 Test: 0.6273
Epoch: 140, Loss: 0.5498 Val: 0.6414 Test: 0.6698
Epoch: 150, Loss: 0.5177 Val: 0.6127 Test: 0.6372
Epoch: 160, Loss: 0.5011 Val: 0.5924 Test: 0.6420
Epoch: 170, Loss: 0.4980 Val: 0.5505 Test: 0.6035
Epoch: 180, Loss: 0.4833 Val: 0.5764 Test: 0.5724
Epoch: 190, Loss: 0.4614 Val: 0.5517 Test: 0.5809
Epoch: 200, Loss: 0.5016 Val: 0.5645 Test: 0.5713


In [30]:
import random
import numpy as np
import copy
import os.path as osp

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(21)

# Use current working directory in Jupyter
notebook_dir = osp.dirname(osp.abspath(''))
path = osp.abspath('../data/AFP_Mol')

#path = '/content/drive/MyDrive/my_project/data/AFP_Mol'
dataset = MoleculeNet(path, name='FreeSolv', pre_transform=GenFeatures()).shuffle()

N = len(dataset) // 10
val_dataset = dataset[:N]
test_dataset = dataset[N:2 * N]
train_dataset = dataset[2 * N:]

train_loader = DataLoader(train_dataset, batch_size=200, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=200)
test_loader = DataLoader(test_dataset, batch_size=200)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = AttentiveFP(in_channels=39, hidden_channels=250, out_channels=1,
                    edge_dim=10, num_layers=4, num_timesteps=2,
                    dropout=0.2).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001329477429508522,
                             weight_decay=1.8804404706749282e-05)

for epoch in range(1, 691):
    train_rmse = train()
    val_rmse = test(val_loader)
    test_rmse = test(test_loader)
    if(epoch%10==0):
      print(f'Epoch: {epoch:03d}, Loss: {train_rmse:.4f} Val: {val_rmse:.4f} '
          f'Test: {test_rmse:.4f}')


def objective(trial):
    set_seed(21)  # Ensures reproducibility

    # Suggest hyperparameters
    num_layers = trial.suggest_int("num_layers", 2, 5)
    weight_decay = trial.suggest_loguniform("weight_decay", 1e-6, 1e-3)
    learning_rate = trial.suggest_loguniform("lr", 1e-4, 1e-2)

    # Define model and optimizer with trial parameters
    model = AttentiveFP(
        in_channels=39,
        hidden_channels=200,
        out_channels=1,
        edge_dim=10,
        num_layers=num_layers,
        num_timesteps=2,
        dropout=0.2
    ).to(device)

    optimizer = torch.optim.Adam(
        model.parameters(), lr=learning_rate, weight_decay=weight_decay
    )

    # Early stopping
    best_val = float("inf")
    best_model_state = None
    patience = 20
    counter = 0

    for epoch in range(1, 201):
        train_rmse = train()
        val_rmse = test(test_loader)

        if val_rmse < best_val:
            best_val = val_rmse
            best_model_state = copy.deepcopy(model.state_dict())
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                break

    # Save best model for this trial
    trial.set_user_attr("best_model_state_dict", best_model_state)

    return best_val

# Run study
#study = optuna.create_study(direction="minimize")
#study.optimize(objective, n_trials=25)

# Report best trial
#best_trial = study.best_trial
#print("Best Trial:")
#print(f"  Value: {best_trial.value:.4f}")
#for k, v in best_trial.params.items():
#    print(f"  {k}: {v}")

Epoch: 010, Loss: 2.9087 Val: 2.7666 Test: 2.4191
Epoch: 020, Loss: 1.6465 Val: 1.7319 Test: 1.6867
Epoch: 030, Loss: 1.2699 Val: 1.3590 Test: 1.2912
Epoch: 040, Loss: 1.1079 Val: 1.1130 Test: 1.3638
Epoch: 050, Loss: 1.0673 Val: 1.1454 Test: 1.3659
Epoch: 060, Loss: 0.9814 Val: 1.2775 Test: 1.4030
Epoch: 070, Loss: 0.9375 Val: 1.0887 Test: 1.2712
Epoch: 080, Loss: 0.8385 Val: 1.2452 Test: 1.1866
Epoch: 090, Loss: 0.9781 Val: 1.0845 Test: 1.2516
Epoch: 100, Loss: 0.8011 Val: 1.0479 Test: 1.1853
Epoch: 110, Loss: 0.7509 Val: 1.0340 Test: 1.2274
Epoch: 120, Loss: 0.8075 Val: 1.0576 Test: 1.1060
Epoch: 130, Loss: 0.6985 Val: 1.0365 Test: 1.0541
Epoch: 140, Loss: 0.7196 Val: 0.9471 Test: 1.2061
Epoch: 150, Loss: 0.7139 Val: 1.1354 Test: 1.0715
Epoch: 160, Loss: 0.6619 Val: 0.9191 Test: 1.0560
Epoch: 170, Loss: 0.5961 Val: 1.0185 Test: 1.1495
Epoch: 180, Loss: 0.6220 Val: 1.1894 Test: 1.0884
Epoch: 190, Loss: 0.6946 Val: 1.0160 Test: 1.1435
Epoch: 200, Loss: 0.6005 Val: 0.9741 Test: 1.0351


In [43]:
import random
import numpy as np
import copy
import os.path as osp

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(21)

# Use current working directory in Jupyter
notebook_dir = osp.dirname(osp.abspath(''))
path = osp.abspath('../data/AFP_Mol')

#path = '/content/drive/MyDrive/my_project/data/AFP_Mol'
dataset = MoleculeNet(path, name='BACE', pre_transform=GenFeatures()).shuffle()

N = len(dataset) // 10
val_dataset = dataset[:N]
test_dataset = dataset[N:2 * N]
train_dataset = dataset[2 * N:]

train_loader = DataLoader(train_dataset, batch_size=200, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=200)
test_loader = DataLoader(test_dataset, batch_size=200)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = AttentiveFP(in_channels=39, hidden_channels=200, out_channels=1,
                    edge_dim=10, num_layers=3, num_timesteps=2,
                    dropout=0.2).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.008762276062649685,
                             weight_decay=1.1650462882147405e-06)

for epoch in range(1, 561):
    train_rmse = train()
    val_rmse = test(val_loader)
    test_rmse = test(test_loader)
    if(epoch%10==0):
      print(f'Epoch: {epoch:03d}, Loss: {train_rmse:.4f} Val: {val_rmse:.4f} '
          f'Test: {test_rmse:.4f}')


def objective(trial):
    set_seed(21)  # Ensures reproducibility

    # Suggest hyperparameters
    num_layers = trial.suggest_int("num_layers", 2, 5)
    weight_decay = trial.suggest_loguniform("weight_decay", 1e-6, 1e-3)
    learning_rate = trial.suggest_loguniform("lr", 1e-4, 1e-2)

    # Define model and optimizer with trial parameters
    model = AttentiveFP(
        in_channels=39,
        hidden_channels=200,
        out_channels=1,
        edge_dim=10,
        num_layers=num_layers,
        num_timesteps=2,
        dropout=0.2
    ).to(device)

    optimizer = torch.optim.Adam(
        model.parameters(), lr=learning_rate, weight_decay=weight_decay
    )

    # Early stopping
    best_val = float("inf")
    best_model_state = None
    patience = 20
    counter = 0

    for epoch in range(1, 201):
        train_rmse = train()
        val_rmse = test(test_loader)

        if val_rmse < best_val:
            best_val = val_rmse
            best_model_state = copy.deepcopy(model.state_dict())
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                break

    # Save best model for this trial
    trial.set_user_attr("best_model_state_dict", best_model_state)

    return best_val

# Run study
#study = optuna.create_study(direction="minimize")
#study.optimize(objective, n_trials=25)

# Report best trial
#best_trial = study.best_trial
#print("Best Trial:")
#print(f"  Value: {best_trial.value:.4f}")
#for k, v in best_trial.params.items():
#    print(f"  {k}: {v}")

Epoch: 010, Loss: 0.4976 Val: 0.4935 Test: 0.4994
Epoch: 020, Loss: 0.4850 Val: 0.4829 Test: 0.4988
Epoch: 030, Loss: 0.4412 Val: 0.4469 Test: 0.4618
Epoch: 040, Loss: 0.4382 Val: 0.4256 Test: 0.4594
Epoch: 050, Loss: 0.4226 Val: 0.4298 Test: 0.4424
Epoch: 060, Loss: 0.4227 Val: 0.4443 Test: 0.4554
Epoch: 070, Loss: 0.4257 Val: 0.4354 Test: 0.4317
Epoch: 080, Loss: 0.4109 Val: 0.4157 Test: 0.4257
Epoch: 090, Loss: 0.4055 Val: 0.4326 Test: 0.4298
Epoch: 100, Loss: 0.4058 Val: 0.4620 Test: 0.4326
Epoch: 110, Loss: 0.3822 Val: 0.4628 Test: 0.4447
Epoch: 120, Loss: 0.3826 Val: 0.4043 Test: 0.4152
Epoch: 130, Loss: 0.3996 Val: 0.4430 Test: 0.4123
Epoch: 140, Loss: 0.4001 Val: 0.4330 Test: 0.4051
Epoch: 150, Loss: 0.3959 Val: 0.4055 Test: 0.4145
Epoch: 160, Loss: 0.3820 Val: 0.4326 Test: 0.4093
Epoch: 170, Loss: 0.3911 Val: 0.4403 Test: 0.4288
Epoch: 180, Loss: 0.3877 Val: 0.3991 Test: 0.4219
Epoch: 190, Loss: 0.3770 Val: 0.4478 Test: 0.4336
Epoch: 200, Loss: 0.3701 Val: 0.4213 Test: 0.3997
