# Debug the Training Scheme

In [1]:
%load_ext autoreload
%autoreload 2

In [5]:
import sys
import os
import yaml

import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

sys.path.append("../")
device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
from architectures.EquivariantGNN.Models.legnn import LEGNN

## Debug Model

In [None]:
class EGNNBase(LightningModule):
    
    def __init__(self, hparams):
        super().__init__()
        
        """
        Initialise the Lightning Module that can scan over different Equivariant GNN training regimes
        """
        # Assign hyperparameters
        self.save_hyperparameters(hparams)

    def setup(self, stage):
        # Handle any subset of [train, val, test] data split, assuming that ordering
        self.trainset, self.valset = load_datasets(self.hparams["input_dir"], self.hparams["data_split"])

    def train_dataloader(self):
        if self.trainset is not None:
            return DataLoader(self.trainset, batch_size=1, num_workers=1)
        else:
            return None

    def val_dataloader(self):
        if self.valset is not None:
            return DataLoader(self.valset, batch_size=1, num_workers=1)
        else:
            return None

    def test_dataloader(self):
        if self.testset is not None:
            return DataLoader(self.testset, batch_size=1, num_workers=1)
        else:
            return None

    def configure_optimizers(self):
        optimizer = [
            torch.optim.AdamW(
                self.parameters(),
                lr=(self.hparams["lr"]),
                betas=(0.9, 0.999),
                eps=1e-08,
                amsgrad=True,
            )
        ]
        scheduler = [
            {
                "scheduler": torch.optim.lr_scheduler.StepLR(
                    optimizer[0],
                    step_size=self.hparams["patience"],
                    gamma=self.hparams["factor"],
                ),
                "interval": "epoch",
                "frequency": 1,
            }
        ]
        return optimizer, scheduler

    def training_step(self, batch, batch_idx):

        p, y = torch.squeeze(batch["p"]), batch["y"]

        n_nodes = p.size()[0]

        edges = get_edges(n_nodes)
        row, column = edges

        h, _ = compute_radials(edges, p)  # torch.zeros(n_nodes, 1)

        output, x = self(h, p, edges)

        # output, _ = L_GCL.compute_radials(edges, x)
        # output = torch.sigmoid(torch.mean(output).unsqueeze(0))

        output = torch.mean(output)
        output = torch.sigmoid(output)
        output = output.unsqueeze(0)

        loss = F.binary_cross_entropy(output, y.float())

        prediction = output.round()
        
        tp = (prediction == y).sum().item()
        t = y.sum().item()
        acc = tp / max(t, 1)
        
        self.log_dict({"train_loss": loss, "train_acc": acc})

        return loss

    def validation_step(self, batch, batch_idx):

        p, y = torch.squeeze(batch["p"]), batch["y"]
        n_nodes = p.size()[0]

        edges = get_edges(n_nodes)
        row, column = edges

        h, _ = compute_radials(edges, p)  # torch.zeros(n_nodes, 1)

        output, x = self(h, p, edges)

        # output, _ = L_GCL.compute_radials(edges, x)
        # output = torch.sigmoid(torch.mean(output).unsqueeze(0))

        output = torch.mean(output)
        output = torch.sigmoid(output)
        output = output.unsqueeze(0)

        prediction = output.round()

        loss = F.binary_cross_entropy(output, y.float())

        tp = (prediction == y).sum().item()
        t = y.sum().item()
        acc = tp / max(t, 1)

        current_lr = self.optimizers().param_groups[0]["lr"]
        
        self.log_dict({"val_loss": loss, "acc": acc, "current_lr": current_lr})


        return {
            "loss": loss
        }
    
    def optimizer_step(
        self,
        epoch,
        batch_idx,
        optimizer,
        optimizer_idx,
        optimizer_closure=None,
        on_tpu=False,
        using_native_amp=False,
        using_lbfgs=False,
    ):
        # warm up lr
        if (self.hparams["warmup"] is not None) and (
            self.trainer.global_step < self.hparams["warmup"]
        ):
            lr_scale = min(
                1.0, float(self.trainer.global_step + 1) / self.hparams["warmup"]
            )
            for pg in optimizer.param_groups:
                pg["lr"] = lr_scale * self.hparams["lr"]

        # update params
        optimizer.step(closure=optimizer_closure)
        optimizer.zero_grad()

## Load model

In [7]:
with open("config.yaml") as f:
        hparams = yaml.load(f, Loader=yaml.FullLoader)

In [8]:
model = LEGNN(hparams)

## Debug output

In [19]:
from architectures.EquivariantGNN.data_loader import *
from architectures.EquivariantGNN.egnn_base import compute_radials

In [9]:
model.setup("train")

In [10]:
sample = model.trainset[0]

In [26]:
p, y = torch.squeeze(sample["p"]), sample["y"]

In [28]:
%%time
n_nodes = p.size()[0]

edges = get_edges(n_nodes)
row, column = edges

CPU times: user 8.08 ms, sys: 51 µs, total: 8.14 ms
Wall time: 8.11 ms


In [29]:
h, _ = compute_radials(edges, p)  # torch.zeros(n_nodes, 1)

In [30]:
output, x = model(h, p, edges)

In [33]:
output.mean()

tensor(-0.0480, grad_fn=<MeanBackward0>)

In [32]:
x

tensor([[ 1.3837e+02, -6.7306e+01,  9.3524e+01,  7.6595e+01],
        [ 6.2718e+01, -3.1180e+01,  4.2315e+01,  3.4206e+01],
        [ 3.9231e+01, -2.4910e+01,  2.0649e+01,  2.2178e+01],
        [ 2.9761e+01, -1.7083e+01,  1.8460e+01,  1.5904e+01],
        [ 3.3264e+01, -4.6813e+00,  2.3749e+01,  2.2812e+01],
        [ 2.6629e+01, -9.1773e+00,  2.2136e+01,  1.1612e+01],
        [ 3.2561e+01, -5.7305e+00,  2.0435e+01,  2.4691e+01],
        [ 2.8715e+01, -5.7487e+00,  1.9229e+01,  2.0533e+01],
        [ 1.6868e+01, -1.0846e+01,  9.0213e+00,  9.2439e+00],
        [ 1.4599e+01, -5.9257e+00,  1.2565e+01,  4.4878e+00],
        [ 1.4308e+01, -5.1888e+00,  1.2358e+01,  5.0098e+00],
        [ 1.6326e+01, -7.9385e+00,  1.0776e+01,  9.3466e+00],
        [ 1.3995e+01, -5.1673e+00,  1.1878e+01,  5.2975e+00],
        [ 1.3982e+01, -8.8781e+00,  7.3591e+00,  7.9042e+00],
        [ 1.3370e+01, -9.2855e+00,  6.7325e+00,  6.8690e+00],
        [ 1.5275e+01, -2.9373e+00,  1.0407e+01,  1.0788e+01],
        

In [16]:
y

tensor(1)

## Train

In [None]:
logger = WandbLogger(project="LorentzNet", group="InitialTest")
trainer = Trainer(gpus=1, max_epochs=hparams["max_epochs"], logger=logger)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[34m[1mwandb[0m: Currently logged in as: [33mmurnanedaniel[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.0 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Set SLURM handle signals.

  | Name        | Type   | Params
---------------------------------------
0 | feature_in  | Linear | 128   
1 | feature_out | Linear | 130   
2 | gcl_0       | L_GCL  | 29.2 K
3 | gcl_1       | L_GCL  | 29.2 K
4 | gcl_2       | L_GCL  | 29.2 K
5 | gcl_3       | L_GCL  | 29.2 K
6 | gcl_4       | L_GCL  | 29.2 K
7 | gcl_5       | L_GCL  | 29.2 K
---------------------------------------
175 K     Trainable params
0         Non-trainable params
175 K     Total params
0.701     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  "Relying on `self.log('val_loss', ...)` to set the ModelCheckpoint monitor is deprecated in v1.2"


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

## Validate