In [1]:
%load_ext autoreload
%autoreload 2

In [14]:
import torch_geometric as tg

from models.se3_transformer import Se3EquivariantTransformer
from utils.load_md17 import load_md17
import e3nn

import torch
import torch_geometric as tg
import pytorch_lightning as pl
import torchmetrics


In [None]:
class MD17Transformer(pl.LightningModule):
    def __init__(self, model, lr=1e-3, loss_weight=0.5):
        super().__init__()
        self.model = model
        self.lr = lr

        assert 0.0 <= loss_weight <= 1.0
        self.loss_weight = loss_weight

        self.energy_train_metric = torchmetrics.MeanAbsoluteError()
        self.energy_valid_metric = torchmetrics.MeanAbsoluteError()
        self.energy_test_metric = torchmetrics.MeanAbsoluteError()
        self.force_train_metric = torchmetrics.MeanAbsoluteError()
        self.force_valid_metric = torchmetrics.MeanAbsoluteError()
        self.force_test_metric = torchmetrics.MeanAbsoluteError()

    @staticmethod
    def compute_normalisers(self, dataset):
        return

    def forward(self, graph):
        graph.pos = torch.autograd.Variable(graph.pos, requires_grad=True)
        predicted_energy = self.model(graph)
        predicted_force = -1 * torch.autograd.grad(predicted_energy,
                                                   graph.pos,
                                                   grad_outputs=torch.ones_like(predicted_energy),
                                                   create_graph=True,
                                                   retain_graph=True,
                                                   )

        predicted_force = predicted_force[0]
        predicted_energy = predicted_energy.squeeze(-1)

        return predicted_energy, predicted_force

    def energy_and_force_loss(self, graph, energy, force):
        loss = torch.nn.functional.mse_loss(energy, graph.energy)
        loss = loss * (1 - self.los_weight) + self.loss_weight * torch.nn.functional.mse_loss(force, graph.force)
        return loss

    def training_step(self, graph):
        energy, force = self(graph)
        loss = self.energy_and_force_loss(graph, energy, force)
        self.energy_train_metric(energy, graph.energy)
        self.force_train_metric(force, graph.force)

        cur_lr = self.trainer.optimizers[0].param_groups[0]["lr"]
        self.log("lr", cur_lr, prog_bar=True, on_step=True)
        return loss

    def on_train_epoch_end(self):
        self.log("Energy train MAE", self.energy_train_metric, prog_bar=True)
        self.log("Force train MAE", self.force_train_metric, prog_bar=True)

    @torch.inference_mode(False)
    def validation_step(self, graph, batch_idx):
        energy, force = self(graph)
        self.energy_valid_metric(energy * self.scale + self.shift, graph.energy)
        self.force_valid_metric(force * self.scale, graph.force)

    def on_validation_epoch_end(self):
        self.log("Energy valid MAE", self.energy_valid_metric, prog_bar=True)
        self.log("Force valid MAE", self.force_valid_metric, prog_bar=True)

    def test_step(self, graph, batch_idx):
        energy, force = self(graph)
        self.energy_test_metric(energy, graph.energy)
        self.force_test_metric(force, graph.force)

    def on_test_epoch_end(self):
        self.log("Energy test MAE", self.energy_test_metric, prog_bar=True)
        self.log("Force test MAE", self.force_test_metric, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)

        num_steps = self.trainer.estimated_stepping_batches
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_steps)

        lr_scheduler_config = {
            "scheduler": scheduler,
            "interval": "step",
            "frequency": 1,
        }
        return [optimizer], [lr_scheduler_config]