<a href="https://colab.research.google.com/github/engelberger/InfGCN-pytorch/blob/main/inference_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%pip install -q lz4
%pip install -q e3nn==0.5.1
%pip install -q torch==1.13.1
%pip install -q torchvision==0.14.1

In [3]:
%pip install torch-scatter -f https://data.pyg.org/whl/torch-2.3.0+121.html

%pip install torch-cluster==1.6.3
%pip install torch-geometric==2.2.0

Looking in links: https://data.pyg.org/whl/torch-2.3.0+121.html
Collecting torch-scatter
  Using cached torch_scatter-2.1.2.tar.gz (108 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: torch-scatter
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py bdist_wheel[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Building wheel for torch-scatter (setup.py) ... [?25lerror
[31m  ERROR: Failed building wheel for torch-scatter[0m[31m
[0m[?25h  Running setup.py clean for torch-scatter
Failed to build torch-scatter
[31mERROR: Could not build wheels for torch-scatter, which is required to install pyproject.toml-based projects[0m[31m
[0mCollecting torch-cluster==1.6.3
  Using cached torch_cluster-1.6.3.tar.gz (54 kB)
  Preparing metadata (setu

In [4]:
!git clone https://github.com/engelberger/InfGCN-pytorch.git
%cd InfGCN-pytorch
%pip install -q -r requirements.txt

fatal: destination path 'InfGCN-pytorch' already exists and is not an empty directory.
/content/InfGCN-pytorch
  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.1/17.1 MB[0m [31m32.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m33.7/33.7 MB[0m [31m20.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m35.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.8/11.8 MB[0m [31m59.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.5/78.5 kB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [

# Inference with InfGCN

By Chaoran Cheng, Oct 1, 2023


In [5]:
import os

import torch
import lz4.frame
from tqdm import tqdm
import plotly.graph_objects as go
from IPython.display import Image, display

from datasets import DensityDataset
from models import get_model
from utils import load_config

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device('cuda')

ModuleNotFoundError: No module named 'torch_sparse'

In [None]:
# Set to True to generate static figures
# Set to False to generate interactive figures
# WARNING: Interactive figures significantly increase the notebook size
static_fig = True

In [None]:
def get_pretrained_model(cfg_path, model_path):
    print(f'Loading config from {cfg_path}')
    cfg = load_config(cfg_path)
    model = get_model(cfg.model).to(device)
    print(f'Loading model from {model_path}')
    ckpt = torch.load(model_path, map_location=device)
    model.load_state_dict(ckpt['model'])
    return model

In [None]:
def inference_model(model, g, density, grid_coord, infos, grid_batch_size=8196):
    with torch.no_grad():
        model.eval()
        if grid_batch_size is None:
            preds = model(g.x, g.pos, grid_coord, g.batch, infos).squeeze(0)
        else:
            preds = []
            for grid in tqdm(grid_coord.split(grid_batch_size, dim=1)):
                preds.append(model(g.x, g.pos, grid.contiguous(), g.batch, infos).squeeze(0))
            preds = torch.cat(preds, dim=0)
        diff = torch.abs(preds - density)
        loss = diff.pow(2).sum()
        mae = diff.sum() / density.sum()
    return preds, loss, mae

In [None]:
def draw_volume(grid, density, atom_type, atom_coord, isomin=0.05, isomax=None, surface_count=5, title=None):
    atom_colorscale = ['grey', 'white', 'red', 'blue', 'green']
    fig = go.Figure()
    fig.add_trace(go.Volume(
        x=grid[..., 0], y=grid[..., 1], z=grid[..., 2],
        value=density,
        isomin=isomin,
        isomax=isomax,
        opacity=0.1, # needs to be small to see through all surfaces
        surface_count=surface_count, # needs to be a large number for good volume rendering
        caps=dict(x_show=False, y_show=False, z_show=False),
    ))
    axis_dict = dict(
        showgrid=False,
        showbackground=False,
        zeroline=False,
        visible=False,
    )
    fig.add_trace(go.Scatter3d(
        x=atom_coord[:, 0],
        y=atom_coord[:, 1],
        z=atom_coord[:, 2],
        mode='markers',
        marker=dict(
            size=10,
            color=atom_type,
            cmin=0, cmax=4,
            colorscale=atom_colorscale,
            opacity=0.6
        )
    ))
    if title is not None:
        title = dict(
            text=title,
            x=0.5, y=0.3,
            xanchor='center',
            yanchor='bottom',
        )
    fig.update_layout(
        autosize=False,
        width=800,
        height=800,
        showlegend=False,
        scene=dict(
            xaxis=axis_dict,
            yaxis=axis_dict,
            zaxis=axis_dict
        ),
        title=title,
        title_font_family='Times New Roman',
    )
    return fig

In [None]:
# Load the dataset
# If you want to run the pretrained model, you can make a dummy data split file like
# {"train": [], "val": [], "test": []}
dataset = DensityDataset('data/QM9', 'test', 'data_split.json', './atom_info/qm9.json', 'CHGCAR', 'lz4')

In [None]:
file_id = 24492  # indole
# file_id = 114514  # nonane
# file_id = 214  # benzene
# file_id = 2  # ammonia
with lz4.frame.open(f'data/QM9/{file_id:06d}.CHGCAR.lz4') as f:
    g, density, grid_coord, info = dataset.read_chgcar(f)

g.batch = torch.zeros_like(g.x)
g = g.to(device)
density = density.to(device)
grid_coord = grid_coord.to(device)

In [None]:
# Display the ground truth electron density
fig = draw_volume(
    grid_coord.detach().cpu().numpy(), density.detach().cpu().numpy(),
    g.x.detach().cpu().numpy(), g.pos.detach().cpu().numpy(),
    isomin=0.05, isomax=3.5, surface_count=5
)
if static_fig:
    img_bytes = fig.to_image(format="png", scale=2)
    display(Image(img_bytes))
else:
    fig.show()

In [None]:
# Load the pretrained model
# There might be a warning from jitting, which arises inside the `e3nn` package.
# You can safely ignore it.
model = get_pretrained_model('configs/qm9.yml', 'logs/train/40000.pt')
print('Complete!')

In [None]:
# Make prediction
# The very first step might be slow (due to jitting in `e3nn`), but the rest should be fast.
grid_batch_size = 4096
preds, loss, mae = inference_model(
    model, g, density, grid_coord[None], [info], grid_batch_size=grid_batch_size
)
print(f'Loss: {loss.item():.6f}, MAE: {mae.item():.6f}')

In [None]:
# Display the electron density difference
fig = draw_volume(
    grid_coord.detach().cpu().numpy(), (density - preds).detach().cpu().numpy(),
    g.x.detach().cpu().numpy(), g.pos.detach().cpu().numpy(), isomin=-0.06, isomax=0.06, surface_count=4
)
if static_fig:
    img_bytes = fig.to_image(format="png", scale=2)
    display(Image(img_bytes))
else:
    fig.show()

In [None]:
# Display the predicted electron density
fig = draw_volume(
    grid_coord.detach().cpu().numpy(), preds.detach().cpu().numpy(),
    g.x.detach().cpu().numpy(), g.pos.detach().cpu().numpy(), isomin=0.05, isomax=3.5, surface_count=5
)
if static_fig:
    img_bytes = fig.to_image(format="png", scale=2)
    display(Image(img_bytes))
else:
    fig.show()