In [1]:
import numpy as np
import pandas as pd
from pymatgen.core.structure import Structure
from pymatgen.core.composition import Composition
from matminer.featurizers.composition import ElementProperty, Stoichiometry 
from matminer.featurizers.composition import ValenceOrbital, IonProperty, AtomicOrbitals
from matminer.featurizers.base import MultipleFeaturizer

from typing import Dict, Tuple

import torch
from torch_geometric.data import Data
import json
import warnings



In [2]:
def load_atom_features(atom_init_path: str) -> Dict:
    """Load atomic embedding file (traditionally keys are atomic numbers)"""
    with open(atom_init_path, 'r') as f:
        data = json.load(f)
    return data

In [4]:
def build_pyg_cgcnn_graph_from_structure(structure: Structure, 
                                         atom_features_dict: Dict, 
                                         radius: float=10.0, 
                                         max_neighbors: int=12) -> Data:
    """Converts a pymatgen Structure to a PyTorch Geometric Data object with atomic features and edge distances."""
    num_atoms = len(structure)
    atomic_features = []
    
    # Node features
    for site in structure:
        number = site.specie.number
        feature = atom_features_dict.get(str(number))
        if feature is None:
            raise ValueError(f"Atomic feature not found for element: {number}")
        atomic_features.append(feature)

    x = torch.tensor(atomic_features, dtype=torch.float32)
    
    # Edge features: collect neighbors
    edge_index = []
    edge_attr = []
    
    all_neighbors = structure.get_all_neighbors(radius, include_index=True)
    disconnected_atoms=[]
    for i, neighbors in enumerate(all_neighbors):
        neighbors = sorted(neighbors, key=lambda x: x[1])[:max_neighbors]  # take closest max_neighbors
        if len(neighbors) == 0:
            disconnected_atoms.append(i)
        for neighbor in neighbors:
            j = neighbor[2]  # neighbor atom index
            dist = neighbor[1]
            edge_index.append([i, j])
            edge_attr.append([dist])
    
    if disconnected_atoms:
        warnings.warn(
            f"{len(disconnected_atoms)} atoms had no neighbors within radius {radius}. "
            f"Disconnected atom indices: {disconnected_atoms}"
        )

    # Convert to tensors
    if edge_index:
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr, dtype=torch.float32)
    else:
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, 1), dtype=torch.float32)
    
    # Create PyG Data object
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    return data

In [34]:
struct=Structure.from_file('/Users/elena.patyukova/Documents/cif_files_ICSD/cifs_binary_Mg/YourCustomFileName_CollCode258540.cif')

In [35]:
feat_dict=load_atom_features('/Users/elena.patyukova/Documents/github/k-points/src/CGCNN+ALIGNN_PyG/embeddings/atom_init_original.json')

In [36]:
data=build_pyg_cgcnn_graph_from_structure(struct,feat_dict)

In [37]:
data

Data(x=[12, 92], edge_index=[2, 144], edge_attr=[144, 1])

In [38]:
from pymatgen.analysis.local_env import VoronoiNN

In [39]:
from pymatgen.analysis.local_env import CrystalNN
cnn = CrystalNN()
# neighbors = cnn.get_nn_info(structure, index=0)  # Gives bonded atoms only


In [40]:
len(struct)

12

In [42]:
struct

Structure Summary
Lattice
    abc : 6.76 6.76 6.76
 angles : 90.0 90.0 90.0
 volume : 308.91577599999994
      A : np.float64(6.76) np.float64(0.0) np.float64(4.1393061811180535e-16)
      B : np.float64(-4.1393061811180535e-16) np.float64(6.76) np.float64(4.1393061811180535e-16)
      C : np.float64(0.0) np.float64(0.0) np.float64(6.76)
    pbc : True True True
PeriodicSite: Mg1 (Mg0+) (1.69, 1.69, 5.07) [0.25, 0.25, 0.75]
PeriodicSite: Mg1 (Mg0+) (1.69, 5.07, 1.69) [0.25, 0.75, 0.25]
PeriodicSite: Mg1 (Mg0+) (5.07, 1.69, 1.69) [0.75, 0.25, 0.25]
PeriodicSite: Mg1 (Mg0+) (5.07, 5.07, 5.07) [0.75, 0.75, 0.75]
PeriodicSite: Mg1 (Mg0+) (5.07, 5.07, 1.69) [0.75, 0.75, 0.25]
PeriodicSite: Mg1 (Mg0+) (5.07, 1.69, 5.07) [0.75, 0.25, 0.75]
PeriodicSite: Mg1 (Mg0+) (1.69, 5.07, 5.07) [0.25, 0.75, 0.75]
PeriodicSite: Mg1 (Mg0+) (1.69, 1.69, 1.69) [0.25, 0.25, 0.25]
PeriodicSite: Sn1 (Sn0+) (0.0, 0.0, 0.0) [0.0, 0.0, 0.0]
PeriodicSite: Sn1 (Sn0+) (-2.07e-16, 3.38, 3.38) [0.0, 0.5, 0.5]
PeriodicS

In [43]:
neighbors = cnn.get_nn_info(struct,2)
len(neighbors)

4

In [44]:
neighbors[0]

{'site': PeriodicNeighbor: Sn1 (Sn0+) (3.38, 0.0, 3.38) [0.5, 0.0, 0.5],
 'image': array([0., 0., 0.]),
 'weight': 1,
 'site_index': np.int64(10)}

In [100]:
from pymatgen.analysis.local_env import CrystalNN, IsayevNN
from pymatgen.core import Structure
import numpy as np

cnn = CrystalNN()
cnn1 = IsayevNN()
structure = Structure.from_file("/Users/elena.patyukova/Documents/cif_files_ICSD/cifs_binary_Mg/YourCustomFileName_CollCode601198.cif")  # or POSCAR, etc.

edge_index = []
edge_attr = []

for i in range(len(structure)):
    neighbors = cnn.get_nn_info(structure, i)
    neighbors1 = cnn1.get_nn_info(structure, i)
    for neighbor in neighbors:
        j = neighbor["site_index"]
        dist = structure.get_distance(i, j)

        # Record edge i → j
        edge_index.append((i, j))
        edge_attr.append([dist])


In [102]:
cnn1 = IsayevNN()

In [128]:
i=10
neighbors = cnn.get_nn_info(structure, i)
neighbors1 = cnn1.get_nn_info(structure, i)
len(neighbors),len(neighbors1)

(8, 4)

In [131]:
neighbors

[{'site': PeriodicNeighbor: Gd1 (Gd0+) (7.324, 3.662, 3.662) [1.0, 0.5, 0.5],
  'image': array([1., 0., 0.]),
  'weight': 1,
  'site_index': np.int64(1)},
 {'site': PeriodicNeighbor: Gd1 (Gd0+) (7.324, 0.0, 4.485e-16) [1.0, 0.0, 0.0],
  'image': array([1., 0., 0.]),
  'weight': 1,
  'site_index': np.int64(0)},
 {'site': PeriodicNeighbor: Gd1 (Gd0+) (3.662, 3.662, 4.485e-16) [0.5, 0.5, 0.0],
  'image': array([0., 0., 0.]),
  'weight': 1,
  'site_index': np.int64(3)},
 {'site': PeriodicNeighbor: Gd1 (Gd0+) (3.662, 0.0, 3.662) [0.5, 0.0, 0.5],
  'image': array([0., 0., 0.]),
  'weight': 1,
  'site_index': np.int64(2)},
 {'site': PeriodicNeighbor: Mg1 (Mg0+) (7.324, 3.662, 6.727e-16) [1.0, 0.5, 0.0],
  'image': array([1., 0., 0.]),
  'weight': 1,
  'site_index': np.int64(6)},
 {'site': PeriodicNeighbor: Mg1 (Mg0+) (3.662, 0.0, 2.242e-16) [0.5, 0.0, 0.0],
  'image': array([0., 0., 0.]),
  'weight': 1,
  'site_index': np.int64(5)},
 {'site': PeriodicNeighbor: Mg1 (Mg0+) (7.324, 0.0, 3.662) [

In [130]:
structure.sites[i]

PeriodicSite: Mg2 (Mg0+) (5.493, 1.831, 1.831) [0.75, 0.25, 0.25]