In [16]:
import torch
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.nn import NNConv, global_add_pool

from pytorch_lightning.loggers import WandbLogger

import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import random_split
import pytorch_lightning as pl
 
from tqdm.notebook import tqdm

ip = get_ipython()
ip.sphinxify_docstring = True
ip.enable_html_pager = True

In [3]:
dataset = QM9(root='/tmp/QM9')

PyTorch Geometric requires certain attributes to be present in the dataset. For example
- `x` should contain node features
- 'edge_attr` edge freatures
- `edge_index` edge list
- `y` labels

The QM9 dataset also has
- `pos` the 3D position of each atom
- `z` the atomic number of each atom

We're going to train on QM9 to predict isotropic polarization.

In [12]:
class GCN(pl.LightningModule):
    def __init__(self,num_node_features,num_edge_features):
        super().__init__()
        
        convC1 = nn.Sequential(
            nn.Linear(num_edge_features,32),
            nn.SiLU(),
            nn.Linear(32,num_node_features*num_node_features)
        ) # shape must be num_edge_features -> in_channels*out_channels
        
        convC2 = nn.Sequential(
            nn.Linear(num_edge_features,32),
            nn.SiLU(),
            nn.Linear(32,num_node_features*num_node_features)
        )
        
        self.convC1 = NNConv(num_node_features,num_node_features,convC1) # in_channels, out_channels, nn.Module
        self.convC2 = NNConv(num_node_features,num_node_features,convC2)
        
        self.fc1 = nn.Linear(num_node_features,32)
        self.out = nn.Linear(32,1)
        

    def forward(self, data):
        batch, x, edge_index, edge_attr = \
            data.batch, data.x, data.edge_index, data.edge_attr

        # convolution layers with skip connection
        x = F.relu(x - self.convC1(x, edge_index, edge_attr))
        x = F.relu(x - self.convC2(x, edge_index, edge_attr))
        
        # x = F.dropout(x, training=self.training)
        x = global_add_pool(x,batch)
        x = F.relu(self.fc1(x))

        return self.out(x)
    
    def training_step(self,batch,batch_idx):
        y_hat = self(batch)
        loss = F.mse_loss(y_hat,batch.y[:, target_idx].unsqueeze(1)) 
        
        # self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=len(batch))
        
        return loss
    
    def validation_step(self,batch,batch_idx):
        y_hat = self(batch)
        loss = F.mse_loss(y_hat,batch.y[:, target_idx].unsqueeze(1)) 
        
        # self.log("val_loss",loss,batch_size=len(batch))
        
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.01) 

In [13]:
train_set, valid_set, test_set = random_split(dataset,[110000, 10831, 10000])

train_loader = DataLoader(train_set, batch_size=32, shuffle=True, num_workers=12)
valid_loader = DataLoader(valid_set, batch_size=32, shuffle=False,num_workers=12)
test_loader  = DataLoader(test_set, batch_size=32, shuffle=False,num_workers=12)

In [14]:
qm9_node_feats, qm9_edge_feats = 11, 4
model = GCN(qm9_node_feats,qm9_edge_feats)

target_idx = 1 # index position of the polarizability label

In [20]:
# wandblogger = WandbLogger()

In [19]:
epochs = 5

trainer = pl.Trainer(accelerator="gpu",
                     max_epochs=epochs,
                    )
                    # logger=wandblogger)

torch.set_float32_matmul_precision('medium')

trainer.fit(model, train_dataloaders=train_loader,val_dataloaders=valid_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type   | Params
----------------------------------
0 | convC1 | NNConv | 4.3 K 
1 | convC2 | NNConv | 4.3 K 
2 | fc1    | Linear | 384   
3 | out    | Linear | 33    
----------------------------------
9.0 K     Trainable params
0         Non-trainable params
9.0 K     Total params
0.036     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.
