In [45]:
import torch_geometric as tg

from models.se3_transformer import Se3EquivariantTransformer
from utils.load_md17 import load_md17
import e3nn

import torch
import torch_geometric as tg
import pytorch_lightning as pl
import torchmetrics


In [62]:
class MD17TransformerTask(pl.LightningModule):
    def __init__(self, model: torch.nn.Module, lr=1e-3, force_loss_weight=500):
        super().__init__()
        self.model = model
        self.lr = lr

        self.force_loss_weight = force_loss_weight

        self.energy_train_metric = torchmetrics.MeanAbsoluteError()
        self.energy_valid_metric = torchmetrics.MeanAbsoluteError()
        self.energy_test_metric = torchmetrics.MeanAbsoluteError()
        self.force_train_metric = torchmetrics.MeanAbsoluteError()
        self.force_valid_metric = torchmetrics.MeanAbsoluteError()
        self.force_test_metric = torchmetrics.MeanAbsoluteError()

    @staticmethod
    def compute_energy_normalisers(dataset):
        sum_energies = 0
        total_nodes = 0
        force_scales = 0

        for graph in dataset:
            total_nodes += graph.num_nodes
            sum_energies += graph.energy
            force_scales += torch.linalg.vector_norm(graph.force, dim=1).sum()

        mean = sum_energies / total_nodes
        std = force_scales / total_nodes

        return mean, std

    def forward(self, graph):
        graph.relative_positions = torch.autograd.Variable(graph.relative_positions, requires_grad=True)
        predicted_energy = self.model(graph).sum()
        predicted_force = -1 * torch.autograd.grad(predicted_energy,
                                                   graph.relative_positions,
                                                   grad_outputs=torch.ones_like(predicted_energy),
                                                   create_graph=True,
                                                   retain_graph=True,
                                                   # allow_unused=True
                                                   )[0]

        predicted_energy = predicted_energy.squeeze(-1)

        return predicted_energy, predicted_force

    def energy_and_force_loss(self, graph, energy, force):
        energy_loss = torch.nn.functional.mse_loss(energy, graph.energy)
        force_loss = torch.nn.functional.mse_loss(force, graph.force)
        loss = energy_loss + self.force_loss_weight * force_loss
        return loss

    def training_step(self, graph):
        energy, force = self(graph)
        loss = self.energy_and_force_loss(graph, energy, force)
        self.energy_train_metric(energy, graph.energy)
        self.force_train_metric(force, graph.force)

        cur_lr = self.trainer.optimizers[0].param_groups[0]["lr"]
        self.log("lr", cur_lr, prog_bar=True, on_step=True)
        return loss

    def on_train_epoch_end(self):
        self.log("Energy train MAE", self.energy_train_metric, prog_bar=True)
        self.log("Force train MAE", self.force_train_metric, prog_bar=True)

    @torch.inference_mode(False)
    def validation_step(self, graph, batch_idx):
        energy, force = self(graph)
        self.energy_valid_metric(energy * self.scale + self.shift, graph.energy)
        self.force_valid_metric(force * self.scale, graph.force)

    def on_validation_epoch_end(self):
        self.log("Energy valid MAE", self.energy_valid_metric, prog_bar=True)
        self.log("Force valid MAE", self.force_valid_metric, prog_bar=True)

    def test_step(self, graph, batch_idx):
        energy, force = self.forward(graph)
        self.energy_test_metric(energy, graph.energy)
        self.force_test_metric(force, graph.force)

    def on_test_epoch_end(self):
        self.log("Energy test MAE", self.energy_test_metric, prog_bar=True)
        self.log("Force test MAE", self.force_test_metric, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)

        num_steps = self.trainer.estimated_stepping_batches
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_steps)

        lr_scheduler_config = {
            "scheduler": scheduler,
            "interval": "step",
            "frequency": 1,
        }
        return [optimizer], [lr_scheduler_config]

In [63]:
### Prototype workflow

In [64]:

model = Se3EquivariantTransformer.construct_from_number_of_channels_and_lmax(num_channels=8,
                                                                             l_max=2,
                                                                             num_features=9,
                                                                             num_attention_layers=4,
                                                                             num_attention_heads=4,
                                                                             radial_network_hidden_units=32
                                                                             )

In [65]:
batch_size=2
radius=2

data = load_md17(dataset_name='aspirin CCSD', dataset_dir='../real_datasets', radius=radius)

train_dataloader = tg.loader.DataLoader(data['train'], batch_size=batch_size, shuffle=True)
validation_dataloader = tg.loader.DataLoader(data['validation'], batch_size=batch_size)
test_dataloader = tg.loader.DataLoader(data['test'], batch_size=batch_size)

In [66]:
trainer = pl.Trainer(max_epochs=1)
task = MD17TransformerTask(model)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [67]:
trainer.fit(task, train_dataloader, test_dataloader)

Loading `train_dataloader` to estimate number of stepping batches.

  | Name                | Type                      | Params
------------------------------------------------------------------
0 | model               | Se3EquivariantTransformer | 1.2 K 
1 | energy_train_metric | MeanAbsoluteError         | 0     
2 | energy_valid_metric | MeanAbsoluteError         | 0     
3 | energy_test_metric  | MeanAbsoluteError         | 0     
4 | force_train_metric  | MeanAbsoluteError         | 0     
5 | force_valid_metric  | MeanAbsoluteError         | 0     
6 | force_test_metric   | MeanAbsoluteError         | 0     
------------------------------------------------------------------
1.2 K     Trainable params
0         Non-trainable params
1.2 K     Total params
0.005     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

AttributeError: 'MD17TransformerTask' object has no attribute 'scale'