In [82]:
"""Full stochastic structure generator using Pyro."""

from copy import deepcopy
from ctypes.wintypes import WPARAM
from doctest import debug
import logging
from math import floor
from signal import struct_siginfo
import numpy as np
from sympy import N
import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroParam, PyroSample
from pymatgen.core import Composition, Lattice, Structure, Element
from pymatgen.analysis.molecule_structure_comparator import CovalentRadius
from tqdm import trange
from baysic.structure_evaluation import MIN_DIST_RATIO, e_form, point_energy
from baysic.pyro_wp import WyckoffSet
from baysic.lattice import CubicLattice, atomic_volume, LatticeModel
from baysic.interpolator import LinearSpline
from baysic.feature_space import FeatureSpace
from pyxtal import Group, Wyckoff_position
from baysic.utils import get_group, get_wp
from cctbx.sgtbx.direct_space_asu.reference_table import get_asu
from scipy.spatial import ConvexHull
import networkx as nx


ngrid = 12
nbeam = 1000

def debug_shapes(*names):
    import inspect
    frame = inspect.currentframe().f_back.f_locals
    try:                
        shapes = [frame[name].shape for name in names]        
        max_len = int(max(map(len, shapes)))
        max_digits = len(str(max(map(max, shapes))))
        max_name_len = max(len(name) for name in names)             
        for name, shape in zip(names, shapes):
            logging.debug(f'{name:>{max_name_len}} = ' + ' '.join([' ' * max_digits] * (max_len - len(shape)) + [f'{dim:>{max_digits}}' for dim in shape]))
    finally:
        del frame

def pairwise_dist_ratio(c1, c2, rads1, rads2, lattice):
    set_diffs = c1.unsqueeze(0).unsqueeze(0) - c2.unsqueeze(-2).unsqueeze(-2)
    set_diffs = set_diffs % 1
    set_diffs = torch.minimum(set_diffs, 1 - set_diffs)
    set_cart_diffs = torch.matmul(set_diffs, lattice.T)                
    diffs = torch.sqrt(torch.sum(torch.square(set_cart_diffs), axis=-1))
    rads = rads1.unsqueeze(0).unsqueeze(1) + rads2.unsqueeze(-1).unsqueeze(-1)
    return diffs / rads

class SystemStructureModel(PyroModule):
    """A stochastic structure generator working within a particular lattice type."""    
    def __init__(self, comp: Composition, lattice: LatticeModel):
        super().__init__()
        self.comp = comp
        self.lattice_model = lattice
        
        # mode 4.5/5, mean 5.5/5
        # around 1, matches empirical distribution well
        # self.volume_ratio = PyroSample(dist.Gamma(5.5, 5))            
        self.volume_ratio = PyroSample(dist.Gamma(12, 10))            
        self.atom_volume = atomic_volume(comp)

        groups = self.lattice_model.get_groups()
        self.group_options = []
        self.wyckoff_options = []
        self.group_cards = []
        self.opt_cards = []
        self.inds = []

        n_els = np.array(list(comp.values()))        
        for sg in groups:
            combs, _has_freedom, _inds = sg.list_wyckoff_combinations(n_els)
            if combs:
                self.group_options.extend([sg.number] * len(combs))
                self.wyckoff_options.extend(combs)
                self.group_cards.extend([1 / len(combs)] * len(combs))
                self.opt_cards.extend([1] * len(combs))                                

        self.group_cards = torch.tensor(self.group_cards).float()
        self.group_cards /= self.group_cards.sum().float()
        self.opt_cards = torch.tensor(self.opt_cards).float()
        self.opt_cards /= self.opt_cards.sum().float()

        self.wyck_opt = PyroSample(dist.Categorical(probs=self.group_cards))      
                
        
    def forward(self):
        self.volume = self.volume_ratio * self.atom_volume
        self.lattice = self.lattice_model(self.volume)()
        
        opt = self.wyck_opt
        sg = self.group_options[opt]
        comb = self.wyckoff_options[opt]

        # todo add pairwise distance
        self.coords = torch.tensor([])
        self.elems = []
        self.wsets = []        
        for el, spots in zip(self.comp.elements, comb):
            radius = torch.tensor([CovalentRadius.radius[el.symbol]])
            for spot in spots:
                wset = WyckoffSet(sg, spot)
                if wset.dof == 0:
                    posns = torch.zeros(3)
                    set_coords = wset.to_all_positions(posns)
                else:
                    base = torch.cartesian_prod(*[torch.linspace(0, 1, ngrid + 2)[1:-1] for _ in range(wset.dof)])
                    debug_shapes('base')
                    base = base.reshape(ngrid ** wset.dof, wset.dof)
                    max_move = 0.49 / (ngrid + 1)
                    low = base - max_move
                    high = base + max_move
                    posns = pyro.sample(f'coords_{len(self.elems)}', dist.Uniform(low, high))
                
                    set_coords = wset.to_all_positions(wset.to_asu(posns))
                    
                debug_shapes('set_coords', 'posns')
                if set_coords.shape[-2] > 1:
                    # check pairwise distances
                    set_diffs = pairwise_dist_ratio(set_coords[..., 1:, :], set_coords[..., [0], :], radius, radius, self.lattice)
                    debug_shapes('set_diffs')
                    # [ngrid, 1, ngrid, dof - 1] if used a grid search
                    # [1, 1, 1, dof - 1] if no degrees of freedom
                    # here, we only care about comparing a single WP to its own copies, not the full pairwise
                    n_new_coords = set_diffs.shape[0]
                    set_diffs = set_diffs[torch.arange(n_new_coords), 0, torch.arange(n_new_coords), :].reshape(-1, set_diffs.shape[-1])
                    # [ngrid, dof - 1]
                    debug_shapes('set_diffs')
                    set_valid = (set_diffs >= MIN_DIST_RATIO).all(dim=-1)
                    debug_shapes('set_valid')
                else:
                    # 1 coordinate is always valid
                    set_valid = torch.Tensor([1])
            
                if not set_valid.any():
                    raise ValueError('Could not find assignment')
                
                debug_shapes('set_coords', 'set_valid')
                good_all_coords = set_coords[torch.where(set_valid)[0], :, :]
                # only need to check base coord
                good_coords = good_all_coords[:, :1, :]
                
                
                if self.coords.numel():                    
                    radii = torch.tensor([CovalentRadius.radius[el.symbol] for el in self.elems])
                    coords = self.coords                          
                    debug_shapes('good_coords', 'coords', 'radius', 'radii') 
                    # print(self.elems, self.wsets, wset.multiplicity)
                    cdists = pairwise_dist_ratio(good_coords, coords, radius, radii, self.lattice)
                    # shape [coords_batch, coords_num, good_batch, good_num]
                    
                    min_cdists = cdists.permute((0, 2, 1, 3)).min(dim=-1)[0].min(dim=-1)[0]
                    # shape [coords_batch, good_batch]                    
                
                    if not (min_cdists >= MIN_DIST_RATIO).any():
                        raise ValueError('Could not find assignment')
                    
                    # take the best nbeam pairs of (old_coords, new_coords) that work
                    all_old, all_new = torch.where(min_cdists >= MIN_DIST_RATIO)
                    adds = torch.argsort(min_cdists[all_old, all_new], descending=True)[:nbeam]

                    old = self.coords[all_old[adds]]
                    new = good_all_coords[all_new[adds]]
                    debug_shapes('old', 'new')
                    self.coords = torch.cat([old, new], dim=1)
                    # self.coords.append(set_coords[torch.where(set_valid)[0][0]].unsqueeze(0))

                else:
                    # no other coordinates to worry about, just add all found coordinates
                    self.coords = good_all_coords
                    
                self.elems.extend([el] * wset.multiplicity)
                self.wsets.append(wset)
                        
        return (self.coords, self.lattice, self.elems, self.wsets)
    
    def to_structure(self) -> Structure:
        np_coords = self.coords.detach().cpu().numpy().squeeze(0)
        return Structure(self.lattice, self.elems, np_coords)

    

if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO, force=True)
    torch.manual_seed(34761)
    mod = SystemStructureModel(
        # Composition.from_dict({'K': 8, 'Li': 4, 'Cr': 4, 'F': 24}),
        Composition.from_dict({'Sr': 3, 'Ti': 1, 'O': 1}),
        CubicLattice
    )

    structs = []
    success = []
    actual_success = []
    for _ in trange(100):
        try:
            coords, lat, elems, wsets = mod.forward()
            struct = mod.to_structure()
            structs.append(struct)
            assert np.allclose(struct.lattice.matrix, lat.numpy())
            actual_success.append(point_energy(deepcopy(struct)) < 80)
            success.append(1)
        except ValueError:
            success.append(0)

    print(np.mean(success), np.mean(actual_success))

  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [00:03<00:00, 30.45it/s]

0.52 1.0





In [83]:
import typing
struct = structs[3]
M = typing.TypeVar('M', np.ndarray, torch.Tensor)
def upper_tri(mat: M) -> M:
    """Get the upper triangle of the matrix."""
    inds0, inds1 = np.triu_indices(mat.shape[-1], 1)
    return mat[..., inds0, inds1]

radii = np.array([CovalentRadius.radius[site.specie.symbol] for site in structs[0].sites])
dists = upper_tri(struct.distance_matrix)
rads = upper_tri(np.add.outer(radii, radii))

print((dists / rads).min())
np.where(np.isclose(struct.distance_matrix / np.add.outer(radii, radii), (dists / rads).min()))

0.9272865172173878


(array([0, 0, 1, 1, 2, 2]), array([1, 2, 0, 2, 0, 1]))

In [3]:
structs[1].num_sites

40

In [85]:
mod.wyckoff_options

[[['3c'], ['1a'], ['1b']],
 [['3c'], ['1b'], ['1a']],
 [['3d'], ['1a'], ['1b']],
 [['3d'], ['1b'], ['1a']],
 [['3c'], ['1a'], ['1b']],
 [['3c'], ['1b'], ['1a']],
 [['3d'], ['1a'], ['1b']],
 [['3d'], ['1b'], ['1a']],
 [['3c'], ['1a'], ['1b']],
 [['3c'], ['1b'], ['1a']],
 [['3d'], ['1a'], ['1b']],
 [['3d'], ['1b'], ['1a']],
 [['3c'], ['1a'], ['1b']],
 [['3c'], ['1b'], ['1a']],
 [['3d'], ['1a'], ['1b']],
 [['3d'], ['1b'], ['1a']],
 [['3c'], ['1a'], ['1b']],
 [['3c'], ['1b'], ['1a']],
 [['3d'], ['1a'], ['1b']],
 [['3d'], ['1b'], ['1a']]]

In [4]:
from baysic.utils import quick_view


quick_view(structs[1])

In [5]:
from baysic.structure_evaluation import VACUUM_SIZE

struct = structs[0]


def vacuum_cond(struct):
    def get_foot(p, a, b):
        p = np.array(p)
        a = np.array(a)
        b = np.array(b)
        ap = p - a
        ab = b - a
        result = a + np.dot(ap, ab) / np.dot(ab, ab) * ab
        return result

    def get_distance(a, b):
        return np.sqrt(np.sum(np.square(b - a)))


    line_a_points = [[0, 0, 0], ]
    line_b_points = [[0, 0, 1], [0, 1, 0], [1, 0, 0],
                        [0, 1, 1], [1, 0, 1], [1, 1, 0], [0, 1, -1], [1, 0, -1], [1, -1, 0],
                        [1, 1, 1], [1, 1, -1], [1, -1, 1], [-1, 1, 1]]
    for a in line_a_points:
        for b in line_b_points:
            foot_points = []
            for p in struct.frac_coords:
                f_p = get_foot(p, a, b)
                foot_points.append(f_p)
            foot_points = sorted(foot_points, key=lambda x: [x[0], x[1], x[2]])

            # 转为笛卡尔坐标
            foot_points = np.asarray(np.mat(foot_points) * np.mat(struct.lattice.matrix))
            for fp_i in range(0, len(foot_points) - 1):
                fp_distance = get_distance(foot_points[fp_i + 1], foot_points[fp_i])
                if fp_distance > VACUUM_SIZE:
                    return False
                
    return True

sum(map(vacuum_cond, structs))

6

In [6]:
list(map(point_energy, structs))

[tensor(98.9114),
 tensor(99.1017),
 tensor(99.3614),
 tensor(99.1762),
 tensor(99.5001),
 tensor(99.2650)]

In [7]:
def pairwise_dist_ratio(c1, c2, rads1, rads2, lattice):
    set_diffs = c1.unsqueeze(0).unsqueeze(0) - c2.unsqueeze(-2).unsqueeze(-2)
    set_diffs = set_diffs % 1
    set_diffs = torch.minimum(set_diffs, 1 - set_diffs)
    set_cart_diffs = torch.matmul(set_diffs, lattice.T)                
    diffs = torch.sqrt(torch.sum(torch.square(set_cart_diffs), axis=-1))
    rads = rads1.unsqueeze(0).unsqueeze(1) + rads2.unsqueeze(-1).unsqueeze(-1)
    return diffs / rads


pairwise_dist_ratio(torch.randn((4, 5, 3)), torch.randn((4, 6, 3)), torch.rand(1), torch.rand(1), torch.eye(3) * 3.2).shape

torch.Size([4, 6, 4, 5])

In [86]:
import pandas as pd


df = pd.read_pickle('merged_test_data3.pkl')

df

Unnamed: 0,material_id,formula_pretty,nsites,spacegroup,nelements,elements_list,CrystalSystem,category,nontrivial_coordinates,struct
0,mp-557997,CaSeO3,20,14,3,Ca O Se,Monoclinic,polymorph_ternary,14,"[[3.34824742 5.7240056 5.93286188] Ca, [0.887..."
1,mp-13171,YMgCu,9,189,3,Cu Mg Y,Hexagonal,polymorph_ternary,1,"[[-2.18135055 3.7782094 0. ] Y, [ 4...."
2,mp-7550,CeNbO4,12,15,3,Ce Nb O,Monoclinic,polymorph_ternary,7,"[[4.11571443 0.6447703 3.28952249] Ce, [-0.53..."
3,mp-23550,KBrF4,12,140,3,Br F K,Tetragonal,polymorph_ternary,3,"[[1.05567769 1.34280835 2.17287608] K, [ 3.165..."
4,mp-5126,ZnSO4,24,62,3,O S Zn,Orthorhombic,polymorph_ternary,8,"[[2.37637601 3.33392793 4.30100142] Zn, [0. ..."
...,...,...,...,...,...,...,...,...,...,...
175,mp-1106064,Ho4Ga2Ni,17,229,3,Ga Ho Ni,Cubic,template-based_ternary,1,"[[ 2.11805495 2.11805495 -2.11805495] Ho, [ 2..."
176,mp-1105955,Er3Cu3Sb4,20,220,3,Cu Er Sb,Cubic,template-based_ternary,5,"[[-2.3728156 3.55922339 4.74563119] Er, [2...."
177,mp-1105893,La3Cu3Bi4,20,220,3,Bi Cu La,Cubic,template-based_ternary,5,"[[-2.52068287 3.7810243 5.04136574] La, [2...."
178,mp-1105802,CaGe2Pt,16,71,3,Ca Ge Pt,Orthorhombic,template-based_ternary,6,[[5.30951944e-17 2.31557074e+00 6.99969495e-17...


In [101]:
from monty.json import MSONable, MontyDecoder, MontyEncoder

def save_df(df: pd.DataFrame, fn):
    df.to_json(fn, orient='records', default_handler=MontyEncoder().default)

save_df(df, 'test.json')

import json
with open('test.json', 'r') as infile:
    data = json.load(infile, cls=MontyDecoder)

pd.json_normalize(data)

Unnamed: 0,material_id,formula_pretty,nsites,spacegroup,nelements,elements_list,CrystalSystem,category,nontrivial_coordinates,struct
0,mp-557997,CaSeO3,20,14,3,Ca O Se,Monoclinic,polymorph_ternary,14,"[[3.34824742 5.7240056 5.93286188] Ca, [0.887..."
1,mp-13171,YMgCu,9,189,3,Cu Mg Y,Hexagonal,polymorph_ternary,1,"[[-2.18135055 3.7782094 0. ] Y, [ 4...."
2,mp-7550,CeNbO4,12,15,3,Ce Nb O,Monoclinic,polymorph_ternary,7,"[[4.11571443 0.6447703 3.28952249] Ce, [-0.53..."
3,mp-23550,KBrF4,12,140,3,Br F K,Tetragonal,polymorph_ternary,3,"[[1.05567769 1.34280835 2.17287608] K, [ 3.165..."
4,mp-5126,ZnSO4,24,62,3,O S Zn,Orthorhombic,polymorph_ternary,8,"[[2.37637601 3.33392793 4.30100142] Zn, [0. ..."
...,...,...,...,...,...,...,...,...,...,...
175,mp-1106064,Ho4Ga2Ni,17,229,3,Ga Ho Ni,Cubic,template-based_ternary,1,"[[ 2.11805495 2.11805495 -2.11805495] Ho, [ 2..."
176,mp-1105955,Er3Cu3Sb4,20,220,3,Cu Er Sb,Cubic,template-based_ternary,5,"[[-2.3728156 3.55922339 4.74563119] Er, [2...."
177,mp-1105893,La3Cu3Bi4,20,220,3,Bi Cu La,Cubic,template-based_ternary,5,"[[-2.52068287 3.7810243 5.04136574] La, [2...."
178,mp-1105802,CaGe2Pt,16,71,3,Ca Ge Pt,Orthorhombic,template-based_ternary,6,[[5.30951944e-17 2.31557074e+00 6.99969495e-17...
