In [1]:
# env: zeo_generative_equivariant, somehow zeo_diffusion doesn't work (stuck at parsing MWF cif file)
import os
import numpy as np
import requests
import tqdm as tqdm

from pymatgen.io.cif import CifParser
from pymatgen.core.lattice import Lattice

import ase
import ase.neighborlist

import torch
import torch_scatter
import torch_geometric
import torch_geometric.data
from torch_geometric.utils import to_dense_adj

import matplotlib.pyplot as plt
from typing import Dict, Union

default_dtype = torch.float32
torch.set_default_dtype(default_dtype)

import pickle

from iza_codes import codes # 3-lettter zeolite IZA codes
from get_dummy_graph import get_dummy_graph_for_dense_amor

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Download cifs from IZA database
for code in tqdm.tqdm(codes):
    if not os.path.exists('cifs/{}.cif'.format(code)):
        url = 'https://america.iza-structure.org/IZA-SC/cif/{}.cif'.format(code)
        r = requests.get(url, allow_redirects=True)
        open('cifs/{}.cif'.format(code), 'wb').write(r.content)

100%|██████████| 258/258 [00:00<00:00, 348175.81it/s]


In [3]:
# Convert to each cif to ase Atoms object
ase_list = []

for code in tqdm.tqdm(codes):
    # print('zeo:', code)
    zeo = CifParser(f'cifs/{code}.cif')
    # print('cif parsed')
    zeo = zeo.get_structures()[0] # slow step
    # print('structure parsed')
    lattice = zeo.lattice.matrix
    # print('lattice done')
    coords  = zeo.cart_coords
    # print('coords done')
    atom_types = [site.specie.name for site in zeo.sites]
    # print('atom_types done')
    zeo_ase = ase.Atoms(symbols=atom_types, positions=coords, cell=lattice, pbc=True)
    # print('ase done')
    ase_list.append(zeo_ase)
    
len(ase_list)

100%|██████████| 258/258 [01:19<00:00,  3.25it/s]


258

In [4]:
# Create dataset with each code having x, edge_index, edge_vec and code
n_datapoints = len(ase_list)

radial_cutoff = 3.5  # Only include edges for neighboring atoms within a radius of 3.5 Angstroms.
type_encoding = {'Si': 0, 'O': 1}
type_onehot = torch.eye(len(type_encoding))

dataset = []

for crystal, code in zip(ase_list, codes):
    # edge_src and edge_dst are the indices of the central and neighboring atom, respectively
    # edge_shift indicates whether the neighbors are in different images / copies of the unit cell
    edge_src, edge_dst, edge_shift = ase.neighborlist.neighbor_list("ijS", a=crystal, cutoff=radial_cutoff, self_interaction=True)
    
    pos = torch.tensor(crystal.get_positions()).float()
    lattice = torch.tensor(crystal.cell.array).unsqueeze(0).float()
    x          = type_onehot[[type_encoding[atom] for atom in crystal.symbols]]
    edge_index = torch.stack([torch.LongTensor(edge_src), torch.LongTensor(edge_dst)], dim=0)
    edge_shift = torch.tensor(edge_shift, dtype=default_dtype).float()
    edge_vec = (pos[edge_dst] - pos[edge_src] # Relative distance vectors of edges with periodic boundaries
            + torch.einsum('ni,nij->nj', edge_shift, lattice))
    
    data = torch_geometric.data.Data(
        x          = x, 
        edge_index = edge_index,
        edge_vec   = edge_vec,
        code       = code, # IZA code
    )

    dataset.append(data)

In [5]:
# Create dict mappping IZA code to graph
zeo2graph = {}
for graph in dataset:
    zeo2graph[graph.code] = graph

In [6]:
# Add Dense/Amorphous to dict
code = 'Dense/Amorphous'
data = get_dummy_graph_for_dense_amor(code=code)
zeo2graph[code] = data
data

Data(x=[2, 2], edge_index=[2, 2], edge_vec=[2, 3], code='Dense/Amorphous')

In [7]:
# pickle save zeo2graph
with open('zeo2graph.pkl', 'wb') as f:
    pickle.dump(zeo2graph, f)
zeo2graph

{'ABW': Data(x=[12, 2], edge_index=[2, 108], edge_vec=[108, 3], code='ABW'),
 'ACO': Data(x=[24, 2], edge_index=[2, 216], edge_vec=[216, 3], code='ACO'),
 'AEI': Data(x=[72, 2], edge_index=[2, 660], edge_vec=[660, 3], code='AEI'),
 'AEL': Data(x=[60, 2], edge_index=[2, 588], edge_vec=[588, 3], code='AEL'),
 'AEN': Data(x=[72, 2], edge_index=[2, 776], edge_vec=[776, 3], code='AEN'),
 'AET': Data(x=[108, 2], edge_index=[2, 1100], edge_vec=[1100, 3], code='AET'),
 'AFG': Data(x=[144, 2], edge_index=[2, 1320], edge_vec=[1320, 3], code='AFG'),
 'AFI': Data(x=[72, 2], edge_index=[2, 648], edge_vec=[648, 3], code='AFI'),
 'AFN': Data(x=[48, 2], edge_index=[2, 476], edge_vec=[476, 3], code='AFN'),
 'AFO': Data(x=[60, 2], edge_index=[2, 552], edge_vec=[552, 3], code='AFO'),
 'AFR': Data(x=[96, 2], edge_index=[2, 892], edge_vec=[892, 3], code='AFR'),
 'AFS': Data(x=[168, 2], edge_index=[2, 1584], edge_vec=[1584, 3], code='AFS'),
 'AFT': Data(x=[216, 2], edge_index=[2, 1980], edge_vec=[1980, 3], 