In [1]:
import torch
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split
import torch.nn as nn
import torch.nn.functional as F



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Original dataset
dataset = QM9(root='data/QM9')

total_size = len(dataset)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size
print(f"Original dataset - Train size: {train_size}, Val size: {val_size}, Test size: {test_size}")

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])


batch_size = 128

# DataLoaders for original dataset
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)


for batch in train_loader:
    break


Original dataset - Train size: 104664, Val size: 13083, Test size: 13084


In [3]:
def create_adjacency_matrix(batch, molecule_index):
    # Create adjacency matrix filled with zeros
    num_nodes = batch[molecule_index].num_nodes
    adj_matrix = torch.zeros((num_nodes, num_nodes))

    # Fill adjacency matrix with bond numbers from edge attributes
    edge_index = batch[molecule_index].edge_index
    edge_attr = batch[molecule_index].edge_attr

    for i in range(edge_index.shape[1]):
        start, end = edge_index[0,i], edge_index[1,i]
        # Get bond type (assuming first dimension of edge_attr encodes bond type)
        bond_type = torch.argmax(edge_attr[i]) + 1
        # Make adjacency matrix symmetric
        adj_matrix[start,end] = bond_type
        adj_matrix[end,start] = bond_type

    return adj_matrix

In [4]:
from egnn.egnn import EGNN

class HeatCapacityNet(nn.Module):
    def __init__(self, embedding_dim, num_layers, node_dim=None):
        super().__init__()
        self.atom_embedding = nn.Embedding(num_embeddings=10, embedding_dim=embedding_dim)
        if node_dim is None:   
            node_dim = embedding_dim
        self.egnn = EGNN(num_layers=num_layers, node_dim=node_dim, message_dim=embedding_dim, hidden_dim=embedding_dim)
        self.sol = nn.Sequential(
            nn.LazyLinear(128),
            nn.ReLU(),
            nn.LazyLinear(1)
        )

    def forward(self, z, r, adj, x=None):
        if x is None:
            embs = self.atom_embedding(z)
        else:
            embs = x
        h, r = self.egnn(embs, r, adj)
        hc = self.sol(h.mean(dim=0))
        return hc


In [5]:
net = HeatCapacityNet(embedding_dim=256, num_layers=10)
net(batch[0].z, batch[0].pos, create_adjacency_matrix(batch, 0)).shape

torch.Size([1])

In [6]:
from tqdm import tqdm
import torch
import torch.nn.functional as F

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print(f"GPU available: {torch.cuda.get_device_name(0)}")
else:
    print("No GPU available, using CPU")

net = HeatCapacityNet(embedding_dim=128, num_layers=3).to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)

best_val_mae = float('inf')
best_model_state = None

for e in range(100):
    for batch in tqdm(train_loader, desc=f"Epoch {e+1}/20", leave=False):
        loss = 0
        for i in range(len(batch)):
            initial_r = batch[i].pos.to(device)
            adj_matrix = create_adjacency_matrix(batch, i).to(device)
            pred_sol = net(batch[i].z.to(device), initial_r, adj_matrix)
            target_sol = batch[i].y[0][11].to(device)
            loss += F.mse_loss(pred_sol, target_sol.unsqueeze(0))
        loss = loss/len(batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Compute validation MAE
    val_mae = 0
    total_items = 0
    with torch.no_grad():
        for val_batch in tqdm(val_loader, desc="Validation", leave=False):
            for i, x in enumerate(val_batch):
                initial_r = val_batch[i].pos.to(device)
                adj_matrix = create_adjacency_matrix(val_batch, i).to(device)
                pred_sol = net(val_batch[i].z.to(device), initial_r, adj_matrix)
                target_sol = val_batch[i].y[0][11].to(device)
                val_mae += F.l1_loss(pred_sol, target_sol.unsqueeze(0), reduction='sum')
            total_items += len(val_batch)
    val_mae = val_mae / total_items
    
    print(f"Epoch {e+1}/20, Validation MAE: {val_mae.item():.4f}")
    
    # Save the best model
    if val_mae < best_val_mae:
        best_val_mae = val_mae
        best_model_state = net.state_dict()

# Save the best model
torch.save(best_model_state, 'best_model.pth')
print(f"Best model saved with validation MAE: {best_val_mae:.4f}")


GPU available: NVIDIA GeForce RTX 4090


                                                             

Epoch 1/20, Validation MAE: 0.1624


                                                             

Epoch 2/20, Validation MAE: 0.1453


                                                             

Epoch 3/20, Validation MAE: 0.1166


                                                             

Epoch 4/20, Validation MAE: 0.1042


                                                             

Epoch 5/20, Validation MAE: 0.1022


                                                             

Epoch 6/20, Validation MAE: 0.1219


                                                             

Epoch 7/20, Validation MAE: 0.0888


                                                             

Epoch 8/20, Validation MAE: 0.1296


                                                             

Epoch 9/20, Validation MAE: 0.0972


                                                              

Epoch 10/20, Validation MAE: 0.1047


                                                              

Epoch 11/20, Validation MAE: 0.0862


                                                              

Epoch 12/20, Validation MAE: 0.0824


                                                              

Epoch 13/20, Validation MAE: 0.0796


                                                              

Epoch 14/20, Validation MAE: 0.0768


                                                              

Epoch 15/20, Validation MAE: 0.0704


                                                              

Epoch 16/20, Validation MAE: 0.0625


                                                              

Epoch 17/20, Validation MAE: 0.0592


                                                              

Epoch 18/20, Validation MAE: 0.0768


                                                              

Epoch 19/20, Validation MAE: 0.0585


                                                              

Epoch 20/20, Validation MAE: 0.0566


                                                              

Epoch 21/20, Validation MAE: 0.0687


                                                              

Epoch 22/20, Validation MAE: 0.0780


                                                              

Epoch 23/20, Validation MAE: 0.0544


                                                              

Epoch 24/20, Validation MAE: 0.0700


                                                              

Epoch 25/20, Validation MAE: 0.0987


                                                              

Epoch 26/20, Validation MAE: 0.0785


                                                              

Epoch 27/20, Validation MAE: 0.0845


                                                              

Epoch 28/20, Validation MAE: 0.0743


                                                              

Epoch 29/20, Validation MAE: 0.0796


                                                              

Epoch 30/20, Validation MAE: 0.0836


                                                              

Epoch 31/20, Validation MAE: 0.0768


                                                              

Epoch 32/20, Validation MAE: 0.0795


                                                              

Epoch 33/20, Validation MAE: 0.0702


                                                              

Epoch 34/20, Validation MAE: 0.1041


                                                              

Epoch 35/20, Validation MAE: 0.0625


                                                              

Epoch 36/20, Validation MAE: 0.0680


                                                              

Epoch 37/20, Validation MAE: 0.0630


                                                              

Epoch 38/20, Validation MAE: 0.0957


                                                              

Epoch 39/20, Validation MAE: 0.0832


                                                              

Epoch 40/20, Validation MAE: 0.0916


                                                              

Epoch 41/20, Validation MAE: 0.0744


                                                              

Epoch 42/20, Validation MAE: 0.0673


                                                              

Epoch 43/20, Validation MAE: 0.0660


                                                              

Epoch 44/20, Validation MAE: 0.0722


                                                              

Epoch 45/20, Validation MAE: 0.0672


                                                              

Epoch 46/20, Validation MAE: 0.0636


                                                              

Epoch 47/20, Validation MAE: 0.0590


                                                              

Epoch 48/20, Validation MAE: 0.0616


                                                              

Epoch 49/20, Validation MAE: 0.0604


                                                              

Epoch 50/20, Validation MAE: 0.0776


                                                              

Epoch 51/20, Validation MAE: 0.0591


                                                              

Epoch 52/20, Validation MAE: 0.0799


                                                              

Epoch 53/20, Validation MAE: 0.0627


                                                              

Epoch 54/20, Validation MAE: 0.0562


                                                              

Epoch 55/20, Validation MAE: 0.0573


                                                              

Epoch 56/20, Validation MAE: 0.0562


                                                              

Epoch 57/20, Validation MAE: 0.0501


                                                              

Epoch 58/20, Validation MAE: 0.0514


                                                              

Epoch 59/20, Validation MAE: 0.0504


                                                              

Epoch 60/20, Validation MAE: 0.0573


                                                              

Epoch 61/20, Validation MAE: 0.0592


                                                              

Epoch 62/20, Validation MAE: 0.0530


                                                              

Epoch 63/20, Validation MAE: 0.0480


                                                              

Epoch 64/20, Validation MAE: 0.0482


                                                              

Epoch 65/20, Validation MAE: 0.0483


                                                              

Epoch 66/20, Validation MAE: 0.0533


                                                              

Epoch 67/20, Validation MAE: 0.0496


                                                              

Epoch 68/20, Validation MAE: 0.0516


                                                              

Epoch 69/20, Validation MAE: 0.0447


                                                              

Epoch 70/20, Validation MAE: 0.0469


                                                              

Epoch 71/20, Validation MAE: 0.0616


                                                              

Epoch 72/20, Validation MAE: 0.0475


                                                              

Epoch 73/20, Validation MAE: 0.0436


                                                              

Epoch 74/20, Validation MAE: 0.0463


                                                              

Epoch 75/20, Validation MAE: 0.0480


                                                              

Epoch 76/20, Validation MAE: 0.1032


                                                              

Epoch 77/20, Validation MAE: 0.0696


                                                              

Epoch 78/20, Validation MAE: 0.0726


                                                              

Epoch 79/20, Validation MAE: 0.0606


                                                              

Epoch 80/20, Validation MAE: 0.0596


                                                              

Epoch 81/20, Validation MAE: 0.0712


                                                              

Epoch 82/20, Validation MAE: 0.0528


                                                              

Epoch 83/20, Validation MAE: 0.0508


                                                              

Epoch 84/20, Validation MAE: 0.0497


                                                              

Epoch 85/20, Validation MAE: 0.0501


                                                              

Epoch 86/20, Validation MAE: 0.0464


                                                              

Epoch 87/20, Validation MAE: 0.0427


                                                              

Epoch 88/20, Validation MAE: 0.0431


                                                              

Epoch 89/20, Validation MAE: 0.0398


                                                              

Epoch 90/20, Validation MAE: 0.0402


                                                              

Epoch 91/20, Validation MAE: 0.0668


                                                              

Epoch 92/20, Validation MAE: 0.0411


                                                              

Epoch 93/20, Validation MAE: 0.0366


                                                              

Epoch 94/20, Validation MAE: 0.0521


                                                              

Epoch 95/20, Validation MAE: 0.0527


                                                              

Epoch 96/20, Validation MAE: 0.0688


                                                              

Epoch 97/20, Validation MAE: 0.0568


                                                              

Epoch 98/20, Validation MAE: 0.0474


                                                              

Epoch 99/20, Validation MAE: 0.0454


                                                               

Epoch 100/20, Validation MAE: 0.0437
Best model saved with validation MAE: 0.0366
