# Example starting from building blocks

In [27]:
import os
import gzip
import pickle
import torch
import numpy as np
from ase import Atoms
from ase.io import read, write
from ase.data import atomic_numbers
from nglview import show_ase

In [28]:
os.chdir('../')
print(f"Current working directory: {os.getcwd()}")

Current working directory: /home/nayoung/MOFFlow


## Load building block files
- Load building block files as ase.Atoms object

In [29]:
# Import files
import glob

bb_files = glob.glob('example/*.xyz')

# Load building block files as ase.Atoms object
bb_atoms = []
for file in bb_files:
    bb_atoms.append(read(file))

In [30]:
# Visualize
show_ase(bb_atoms[1])

NGLWidget()

## Prepare `data`
- Goal: Align building blocks with pca axis

In [31]:
def get_equiv_vec(cart_coords, atom_types):
    centroid = np.mean(cart_coords, axis=0)
    weight = atom_types / atom_types.sum()
    weighted_centroid = np.sum(cart_coords * weight[:, None], axis=0)
    equiv_vec = weighted_centroid - centroid
    if np.allclose(equiv_vec, 0):
        dist = np.linalg.norm(cart_coords, axis=1)
        sorted_indices = np.argsort(dist)
        i = 0
        while i < len(sorted_indices) and np.allclose(equiv_vec, 0):
            equiv_vec = cart_coords[sorted_indices[i]]
            i += 1
    assert not np.allclose(equiv_vec, 0), "Equivariant vector is zero"
    return equiv_vec

def get_pca_axes(data):
    data_mean = np.mean(data, axis=0)
    centered_data = data - data_mean
    covariance_matrix = np.cov(centered_data, rowvar=False)
    if covariance_matrix.ndim == 0:
        return np.zeros(3), np.eye(3)
    eigenvalues, eigenvectors = np.linalg.eigh(covariance_matrix)
    sorted_indices = np.argsort(eigenvalues)[::-1]
    eigenvalues = eigenvalues[sorted_indices]
    eigenvectors = eigenvectors[:, sorted_indices]
    return eigenvalues, eigenvectors

def get_equivariant_axes(cart_coords, atom_types):
    if cart_coords.shape[0] == 1:
        return np.eye(3)
    equiv_vec = get_equiv_vec(cart_coords, atom_types)
    _, axes = get_pca_axes(cart_coords)
    ve = equiv_vec @ axes
    flips = ve < 0
    axes = np.where(flips[None], -axes, axes)
    right_hand = np.stack([
        axes[:, 0], axes[:, 1], np.cross(axes[:, 0], axes[:, 1])
    ], axis=1)
    return right_hand

def get_local_coords(bb_atoms):
    local_coords_list = []
    for bb in bb_atoms:
        coords = bb.get_positions()
        atom_types = np.array([atomic_numbers[s] for s in bb.get_chemical_symbols()])
        center = coords.mean(axis=0)
        centered_coords = coords - center
        rotmat = get_equivariant_axes(centered_coords, atom_types)
        local_coords = centered_coords @ rotmat
        local_coords_list.append(local_coords)
    return local_coords_list

In [32]:
# Compute local_coords for each building block
local_coords_list = get_local_coords(bb_atoms)

In [33]:
# Visualize aligned building blocks
local_atoms = []

for i, local_coords in enumerate(local_coords_list):
    local_atoms.append(Atoms(
        symbols=bb_atoms[i].get_chemical_symbols(),
        positions=local_coords,
        pbc=False
    ))

show_ase(local_atoms[1])

NGLWidget()

In [34]:
# Prepare data 
from torch_geometric.data import Data, Batch

# Concatenate number of building blocks
bb_num_vec = torch.tensor([len(bb_atom) for bb_atom in bb_atoms]).int()

# Concatenate atom types
atom_types = torch.cat([torch.tensor(bb_atom.get_atomic_numbers()) for bb_atom in bb_atoms]).int()

# Concatenate local_coords_list
local_coords = torch.cat([torch.tensor(local_coords) for local_coords in local_coords_list]).float()

# Create data, batch object
data = Data(
    num_nodes=bb_num_vec.shape[0],
    num_atoms=local_coords.shape[0],
    num_bbs=bb_num_vec.shape[0],
    local_coords=local_coords,
    atom_types=atom_types,
    bb_num_vec=bb_num_vec,
)
batch = Batch.from_data_list([data])

## Load model

In [33]:
from models.flow_module import FlowModule
from hydra import initialize, compose
from omegaconf import OmegaConf

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [35]:
# Load config
with initialize(version_base=None, config_path="../configs"):
    cfg = compose(config_name='inference.yaml')

# Load checkpoint
ckpt_path = "logs/mof-csp/_published_batch_ver/ckpt/last.ckpt" # TODO: change to your own path
device = "cuda" if torch.cuda.is_available() else "cpu"

flow_module = FlowModule.load_from_checkpoint(
    checkpoint_path=ckpt_path,
    cfg=cfg,
)

flow_module.eval()
flow_module.to(device)

FlowModule(
  (model): FlowModel(
    (node_feature_net): NodeFeatureNet(
      (bb_embedder): Linear(in_features=64, out_features=128, bias=False)
      (linear): Linear(in_features=384, out_features=256, bias=True)
    )
    (edge_feature_net): EdgeFeatureNet(
      (linear_s_p): Linear(in_features=256, out_features=64, bias=True)
      (edge_embedder): Sequential(
        (0): Linear(in_features=172, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=128, bias=True)
        (3): ReLU()
        (4): Linear(in_features=128, out_features=128, bias=True)
        (5): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      )
    )
    (bb_embedder): BuildingBlockEmbedder(
      (atom_type_embedder): Embedding(100, 64)
      (edge_dist_embedder): GaussianSmearing()
      (egnn_layers): ModuleList(
        (0-3): 4 x E_GCL(
          (edge_mlp): Sequential(
            (0): Linear(in_features=193, out_features=64, bias=True)
            (

## Sample

In [36]:
from data.interpolant import Interpolant

interpolant = Interpolant(cfg.interpolant)
interpolant.set_device(device)

In [37]:
# Sample
batch = Batch.from_data_list([data])
batch = batch.to(device)

mof_traj, _ = interpolant.sample(
    num_batch=batch.num_graphs,
    num_bbs=batch.num_bbs,
    model=flow_module.model,
    atom_types=batch.atom_types,
    local_coords=batch.local_coords,
    batch_vec=batch.batch,
    bb_num_vec=batch.bb_num_vec,
)

In [38]:
from pymatgen.core.lattice import Lattice
from pymatgen.core.structure import Structure
from pymatgen.io.ase import AseAtomsAdaptor
from ase.build import make_supercell

coords, lattice = mof_traj[-1]
lattice = lattice.squeeze().detach().cpu().numpy()

# Make ase atoms
mof = Structure(
    lattice=Lattice.from_parameters(*lattice),
    species=atom_types.detach().cpu().numpy(),
    coords=coords.detach().cpu().numpy(),
    coords_are_cartesian=True,
)
mof = AseAtomsAdaptor.get_atoms(mof)

# Supercell
multiplier = np.eye(3) * 2
mof_supercell = make_supercell(mof, multiplier)

In [39]:
# Visualize prediction
viewer = show_ase(mof)
viewer.add_unitcell()
viewer

NGLWidget()

In [40]:
# Visualize supercell
viewer = show_ase(mof_supercell)
viewer.add_unitcell()
viewer

NGLWidget()