# Introduction

In this tutorial we will guide you through implementing a $SE(3)$ steerable graph convolutional network. We will start by going over the theory behind these models and then look at a way to implement these ideas in a neural architecture. 

Note that this notebook is *not* self-contained. It references several python packages and papers that further explain some of the concepts. To get the most out of this notebook, be sure to check out these references as they show up. 

Good luck!

In [1]:
%%capture 
# Let's install some packages.
!pip install pytorch_lightning
!pip install e3nn
!pip install vapeplot
!pip install plotly

import torch

def format_pytorch_version(version):
  return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
  return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-geometric 

In [2]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import torchmetrics
import torch_geometric as tg
import plotly.graph_objects as go

import e3nn.o3 as o3
from e3nn.o3 import Irreps
from e3nn.nn import Gate

import vapeplot
cmap = vapeplot.cmap('crystal_pepsi')

## Spherical Harmonics
As you have seen in this course, what we need to build equivariant networks is to find functions whose outputs transform predictably under the group action. It turns out that for the group $SE(3)$, we can achieve this by turning the neurons in our network into [spherical harmonics](https://en.wikipedia.org/wiki/Spherical_harmonics). They have two properties which make them suited for building networks with:

1. They transform predictably under the [Wigner-D matrices](https://en.wikipedia.org/wiki/Wigner_D-matrix), which are the representations of the group $SO(3)$
2. They form a basis on the sphere. This means we can approximate any function $f: \mathbb{S}^2 \rightarrow \mathbb{R}$

Think of decomposing a function in a Fourier basis. A weighted combination of frequencies can be used to fit a signals—for example audio. When working with functions on the sphere, we can decompose a function $f$ as:

\begin{equation}
    f(\mathbf{x}) = \sum_{l=0}^\infty \sum_{m=-l}^l a_m^l Y_m^l(\mathbf{x}).
\end{equation}

The spherical harmonics are indexed by two variables $l$ and $m$, where $-l \leq m \leq l$, as are their coefficients $a_m^l \in \mathbb{R}$. Below I have written some code to visualise the spherical harmonics with random coefficients. 

In [3]:

axis = dict(
    showbackground=False,
    showticklabels=False,
    showgrid=False,
    zeroline=False,
    title='',
)

layout = dict(
    showlegend=False,
    scene=dict(
        aspectmode="data",
        xaxis=dict(
            **axis,
        ),
        yaxis=dict(
            **axis,
        ),
        zaxis=dict(
            **axis,
        ),
    ),
    paper_bgcolor="rgba(0,0,0,0)",
    plot_bgcolor="rgba(0,0,0,0)",
    margin=dict(l=0, r=0, t=0, b=0)
)

def s2_grid(N=100):
    """ Create grid on which we can sample spherical signals """
    betas = torch.linspace(0, math.pi, int(N/2))
    alphas = torch.linspace(0, 2 * math.pi, N)
    beta, alpha = torch.meshgrid(betas, alphas)
    return o3.angles_to_xyz(alpha, beta)

def sh_trace(sh, cmin, cmax, grid, warp=True, pos=None):

    if warp:
        grid = grid*sh.abs().unsqueeze(-1)

    if pos is not None:
        grid = grid + pos

    # Plot nodes
    x = grid[..., 0]
    y = grid[..., 1]
    z = grid[..., 2]
    trace = go.Surface(x=x, y=y, z=z, surfacecolor=sh, colorscale=vapeplot.palette('vaporwave'), cmin=cmin, cmax=cmax)
    return trace

def plot_all_shs(lmax, grid=s2_grid(), warp=True):
    fig = go.Figure(layout=layout)

    irreps = Irreps.spherical_harmonics(lmax)
    shs = o3.spherical_harmonics(irreps, grid, True)

    cmin = shs.min().item()
    cmax = shs.max().item()

    for l in range(irreps.lmax+1):
        for m in range(-l, l+1):
            i = 2*l + m
            pos = torch.tensor([0, 2*m, -2*l])
            trace = sh_trace(shs[..., i], cmin, cmax, grid=grid, warp=warp, pos=pos)
            fig.add_trace(trace)
    
    fig.show()

def plot_sphere(coefficients, l_max, grid=s2_grid(), warp=True):
    fig = go.Figure(layout=layout)

    irreps = Irreps.spherical_harmonics(l_max)
    shs = o3.spherical_harmonics(irreps, grid, True)

    shs *= coefficients.view(1, 1, -1)
    shs = shs.sum(-1)

    cmin = shs.min().item()
    cmax = shs.max().item()

    trace = sh_trace(shs, cmin, cmax, grid=grid, warp=warp)
    fig.add_trace(trace)
    fig.show()


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


While the spherical harmonics are functions on the sphere, we can visualise by stretching the sphere depending on the absolute value of the function on the sphere. We use the argument ```warp``` to toggle this. 

In [4]:
# Shows all harmonics up to some order
l_max = 3
plot_all_shs(l_max, warp=False)

In [15]:
# Shows a random spherical function by sampling random coefficients.
l_max = 11
grid = s2_grid(200)
N = (2*torch.arange(l_max+1) + 1).sum()
c = torch.randn(N)
plot_sphere(c, l_max, grid=grid, warp=False)

## Building a Steerable Convolutional Network

In order to build a steerable network, we need to define
1. A linear operation
2. A non-linearity

The first we get from representation theory and it is the [Clebsch-Gordan tensor product](https://docs.e3nn.org/en/stable/api/o3/o3_tp.html). This linear operation allows us to mix two spherical harmonics, possible of different orders $l$. For example, two first order spherical harmonics can combine to form a zeroth order harmonic, just like the dot product between two vectors yields a scalar. 

We will parameterise the weights of this linear layer using a function conditioned on the distance between two points. Since distance is an invariant of $SE(3)$, we are free to do so. 

The non-linearity for $l=0$ is easy: since spherical harmonics of $l=0$ are scalars, we are free to use standard non-linearities. Whenever $l>0$, we cannot do this, since standard non-linearities do not commute with rotation. We can scale all components $-l \leq m \leq l$ by a scalar. We will take inspiration from the SiLU non-linearity and scale these terms by a sigmoidal term, which acts as a gate. The scalar $s$ is learned by the network.

\begin{equation}
    Gate(\mathbf{x}, s) = \sigma(s)
    \begin{pmatrix}
        x_1 \\ x_2 \\ x_3 \\ ... \\ x_n
    \end{pmatrix}
\end{equation}

## E3NN
E3NN is a library that implements all the functionality required to build steerable networks. Our feature vectors will consist of multiple copies of spherical harmonics. E3NN denotes these in the format ``` multiplicity x order  parity ```. For example, a feature vector ``` 2x0e + 4x1o + 13x2e ``` consists of 2 zeroth-order harmonics, 4 first order harmonics and 13 second order harmonics. The parity ```e``` and ```o``` stand for "even" and "odd" and denote how the functions transform under reflections.  

## Graph Convolution
Our convolutional kernels will have the form 

\begin{equation}
    K(\mathbf{r}_{ij}) = R(|\mathbf{r}_{ij}|)Y_m^l(\hat{\mathbf{r}}_{ij}).
\end{equation}

This kernel interacts with feature vectors through the Clebsch-Gordan tensor product.

In [6]:
class Convolution(nn.Module):
    """ SE(3) equivariant convolution, parameterised by a radial network """
    def __init__(self, irreps_in1, irreps_in2, irreps_out):
        super().__init__()
        self.irreps_in1 = irreps_in1
        self.irreps_in2 = irreps_in2
        self.irreps_out = irreps_out
        self.tp =  o3.FullyConnectedTensorProduct(
            irreps_in1,
            irreps_in2,
            irreps_out,
            irrep_normalization="component",
            path_normalization="element",
            internal_weights=False,
            shared_weights=False
        )

        self.radial_net = RadialNet(self.tp.weight_numel)

    def forward(self, x, rel_pos_sh, distance):
        """
        Features of shape [E, irreps_in1.dim]
        rel_pos_sh of shape [E, irreps_in2.dim]
        distance of shape [E, 1]
        """
        weights = self.radial_net(distance)
        return self.tp(x, rel_pos_sh, weights)

def compute_gate_irreps(irreps_out):
    """Compute irreps_scalars, irreps"""
    irreps_scalars = Irreps([(mul, ir) for mul, ir in irreps_out if ir.l == 0])
    irreps_gated = Irreps([(mul, ir) for mul, ir in irreps_out if ir.l > 0])
    irreps_gates = Irreps([(mul, "0e") for mul, _ in irreps_gated]).simplify()

    return irreps_scalars, irreps_gated, irreps_gates 

class RadialNet(nn.Module):
    def __init__(self, num_weights):
        super().__init__()

        num_basis = 10
        basis = tg.nn.models.dimenet.BesselBasisLayer(num_basis, cutoff=4)

        self.net = nn.Sequential(basis,
                                nn.Linear(num_basis, 16),
                                nn.SiLU(),
                                nn.Linear(16, num_weights))
    def forward(self, dist):
        return self.net(dist.squeeze(-1))
    

class ConvLayerSE3(tg.nn.MessagePassing):
    def __init__(self, irreps_in1, irreps_in2, irreps_out, activation=True):
        super().__init__(aggr="add")

        self.irreps_in1 = irreps_in1
        self.irreps_in2 = irreps_in2
        self.irreps_out = irreps_out 

        irreps_scalars, irreps_gated, irreps_gates = compute_gate_irreps(irreps_out)
        self.conv = Convolution(irreps_in1, irreps_in2, irreps_gates + irreps_out)

        if activation:
            self.gate = Gate(irreps_scalars, [nn.SiLU()], irreps_gates, [nn.Sigmoid()], irreps_gated)
        else:
            self.gate = nn.Identity()

    def forward(self, edge_index, x, rel_pos_sh, dist):
        x = self.propagate(edge_index, x=x, rel_pos_sh=rel_pos_sh, dist=dist)
        x = self.gate(x)
        return x
    
    def message(self, x_j, rel_pos_sh, dist):
        return self.conv(x_j, rel_pos_sh, dist)

class ConvModel(nn.Module):
    def __init__(self, irreps_in, irreps_hidden, irreps_edge, irreps_out, depth, max_z=10):
        super().__init__()

        self.irreps_in = irreps_in
        self.irreps_hidden = irreps_hidden
        self.irreps_edge = irreps_edge
        self.irreps_out = irreps_out

        self.embedder = nn.Embedding(max_z, irreps_in.dim)

        self.layers = nn.ModuleList()
        self.layers.append(ConvLayerSE3(irreps_in, irreps_edge, irreps_hidden))
        for i in range(depth-2):
            self.layers.append(ConvLayerSE3(irreps_hidden, irreps_edge, irreps_hidden))
        self.layers.append(ConvLayerSE3(irreps_hidden, irreps_edge, irreps_out, activation=False))


    def forward(self, graph):
        edge_index = graph.edge_index
        z = graph.z
        pos = graph.pos
        batch = graph.batch

        # Prepare quantities for convolutional layers
        src, tgt = edge_index[0], edge_index[1]
        rel_pos = pos[tgt] - pos[src]
        rel_pos_sh = o3.spherical_harmonics(self.irreps_edge, rel_pos, normalize=True)
        dist = torch.linalg.vector_norm(rel_pos, dim=-1, keepdims=True)

        x = self.embedder(z)
        # Let's go!
        for layer in self.layers:
            x = layer(edge_index, x, rel_pos_sh, dist)

        # Global pooling
        x = tg.nn.global_add_pool(x, batch)
        return x

## Force and Energy prediction on MD17

The MD17 dataset consists of relaxation trajectories of small molecules. At every time step, we have access to the force vector at each atom and the total energy of the molecule. Since forces form a conservative vector field, they can be written as the negative gradient of the potential:

\begin{equation}
    \mathbf{F} = - \nabla V(\mathbf{x}).
\end{equation}

We will predict the potential at each atom and take the gradient to get the force components. Then the sum of the potentials will be equal to the total energy, up to a constant term. The loss we optimise is

\begin{equation}
    L = |\hat{E} - E|^2 + \lambda_F \frac{1}{3N} \sum_{i=1}^N \sum_{\alpha=1}^3 \left| \hat{F}_{i, \alpha} + \frac{\partial E}{\partial r_{i, \alpha}} \right|^2
\end{equation}

We weigh the force and energy losses using the factor $\lambda_F$. Furthermore, we shift the energy to have zero mean and scale energy and force by the root mean square of the force magnitude. 

The following code implements all relevant training details. 

In [7]:
class MD17(pl.LightningModule):
    def __init__(
        self,
        model,
        lr,
        weight=1,
        shift=0,
        scale=1,
    ):
        super().__init__()
        self.model = model
        self.lr = lr

        self.weight = weight
        self.shift = shift
        self.scale = scale

        self.energy_train_metric = torchmetrics.MeanAbsoluteError()
        self.energy_valid_metric = torchmetrics.MeanAbsoluteError()
        self.energy_test_metric = torchmetrics.MeanAbsoluteError()
        self.force_train_metric = torchmetrics.MeanAbsoluteError()
        self.force_valid_metric = torchmetrics.MeanAbsoluteError()
        self.force_test_metric = torchmetrics.MeanAbsoluteError()

    def forward(self, graph):
        energy, force = self.pred_energy_and_force(graph)
        return energy, force

    def pred_energy_and_force(self, graph):
        graph.pos = torch.autograd.Variable(graph.pos, requires_grad=True)
        pred_energy = self.model(graph)

        sign = -1.0
        pred_force = (
            sign
            * torch.autograd.grad(
                pred_energy,
                graph.pos,
                grad_outputs=torch.ones_like(pred_energy),
                create_graph=True,
                retain_graph=True,
            )[0]
        )
        return pred_energy.squeeze(-1), pred_force

    def energy_and_force_loss(self, graph, energy, force):
        loss = F.mse_loss(energy, (graph.energy - self.shift) / self.scale)
        loss = loss + self.weight * F.mse_loss(force, graph.force / self.scale)
        return loss

    def training_step(self, graph):
        energy, force = self(graph)
        loss = self.energy_and_force_loss(graph, energy, force)
        self.energy_train_metric(energy * self.scale + self.shift, graph.energy)
        self.force_train_metric(force * self.scale, graph.force)

        cur_lr = self.trainer.optimizers[0].param_groups[0]["lr"]
        self.log("lr", cur_lr, prog_bar=True, on_step=True)
        return loss

    def on_train_epoch_end(self):
        self.log("Energy train MAE", self.energy_train_metric, prog_bar=True)
        self.log("Force train MAE", self.force_train_metric, prog_bar=True)

    @torch.inference_mode(False)
    def validation_step(self, graph, batch_idx):
        energy, force = self(graph)
        self.energy_valid_metric(energy * self.scale + self.shift, graph.energy)
        self.force_valid_metric(force * self.scale, graph.force)

    def on_validation_epoch_end(self):
        self.log("Energy valid MAE", self.energy_valid_metric, prog_bar=True)
        self.log("Force valid MAE", self.force_valid_metric, prog_bar=True)

    def test_step(self, graph, batch_idx):
        energy, force = self(graph)
        self.energy_test_metric(energy * self.scale + self.shift, graph.energy)
        self.force_test_metric(force * self.scale, graph.force)

    def on_test_epoch_end(self):
        self.log("Energy test MAE", self.energy_test_metric, prog_bar=True)
        self.log("Force test MAE", self.force_test_metric, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        
        num_steps = self.trainer.estimated_stepping_batches
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_steps)

        lr_scheduler_config = {
            "scheduler": scheduler,
            "interval": "step",
            "frequency": 1,
        }
        return [optimizer], [lr_scheduler_config]

In [8]:
num_features = 32
l_max = 3
depth = 3

irreps_in = (Irreps("1x0e")*num_features).simplify()
irreps_hidden = (Irreps.spherical_harmonics(l_max)*num_features).sort()[0].simplify()
irreps_edge = Irreps.spherical_harmonics(l_max)
irreps_out = Irreps("1x0e")

print("Input irreps", irreps_in)
print("Hidden irreps", irreps_hidden)
print("Edge irreps", irreps_edge)
print("Output irreps", irreps_out)
print("Dim hidden irreps:", irreps_hidden.dim)

model = ConvModel(irreps_in, irreps_hidden, irreps_edge, irreps_out, depth)
print()
print(model)
print(model.layers[0].conv)

Input irreps 32x0e
Hidden irreps 32x0e+32x1o+32x2e+32x3o
Edge irreps 1x0e+1x1o+1x2e+1x3o
Output irreps 1x0e
Dim hidden irreps: 512



The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.




ConvModel(
  (embedder): Embedding(10, 32)
  (layers): ModuleList(
    (0-2): 3 x ConvLayerSE3()
  )
)
Convolution(
  (tp): FullyConnectedTensorProduct(32x0e x 1x0e+1x1o+1x2e+1x3o -> 128x0e+32x1o+32x2e+32x3o | 7168 paths | 7168 weights)
  (radial_net): RadialNet(
    (net): Sequential(
      (0): BesselBasisLayer(
        (envelope): Envelope()
      )
      (1): Linear(in_features=10, out_features=16, bias=True)
      (2): SiLU()
      (3): Linear(in_features=16, out_features=7168, bias=True)
    )
  )
)


In [9]:
epochs = 100
lr = 1e-3
F_weight = 500
radius = 2
batch_size = 5

In [10]:
# Let's load the datasets. 
class Kcal2meV:
    def __init__(self):
        # Kcal/mol to meV
        self.conversion = 43.3634

    def __call__(self, graph):
        graph.energy = graph.energy * self.conversion
        graph.force = graph.force * self.conversion
        return graph

transform = tg.transforms.Compose([
    tg.transforms.RadiusGraph(radius),
    Kcal2meV(),
])m

train_dataset = tg.datasets.MD17("data", name="aspirin CCSD", train=True, transform=transform)
train_dataset, valid_dataset = torch.utils.data.random_split(train_dataset, [950, 50], generator=torch.Generator().manual_seed(42))
test_dataset = tg.datasets.MD17("data", name="aspirin CCSD", train=False, transform=transform)

dataloaders = {
    "train": tg.loader.DataLoader(train_dataset, batch_size=batch_size, shuffle=True),
    "valid": tg.loader.DataLoader(valid_dataset, batch_size=batch_size),
    "test": tg.loader.DataLoader(test_dataset, batch_size=batch_size),
}

Downloading http://quantum-machine.org/gdml/data/npz/aspirin_ccsd.zip
Extracting data/aspirin CCSD/raw/aspirin_ccsd.zip
Processing...
Done!


In [11]:
N_nodes = 0
shift = 0
scale = 0
for graph in train_dataset:
    shift += graph.energy
    scale += torch.linalg.vector_norm(graph.force, dim=1).sum()
    N_nodes += graph.num_nodes

shift = (shift/len(train_dataset)).item()
scale = (scale/N_nodes).item()

print("Shift:", shift)
print("Scale:", scale)

Shift: -17591926.0
Scale: 2018.620361328125


In [12]:
task = MD17(model, lr, F_weight, shift, scale)
trainer = pl.Trainer(max_epochs=epochs)
trainer.fit(task, dataloaders["train"], dataloaders["valid"])

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.utilities.rank_zero:Loading `train_dataloader` to estimate number of stepping batches.
INFO:pytorch_lightning.callbacks.model_summary:
  | Name                | Type              | Params
----------------------------------------------------------
0 | model               | ConvModel         | 734 K 
1 | energy_train_metric | MeanAbsoluteError | 0     
2 | energy_valid_metric | MeanAbsoluteError | 0     
3 | energy_test_metric  | MeanAbsoluteError | 0     
4 | force_train_metric  | MeanAbsoluteError | 0     
5 | force_valid_metric  | MeanAbsoluteError | 0     
6 | f

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=100` reached.
