In [1]:
import seaborn as sns
from pathlib import Path
from baysic.utils import quick_view, json_to_df
import pandas as pd
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

df = pd.read_pickle('merged_test_data3.pkl')

In [2]:
from baysic.utils import get_group
from pyxtal import Group
from collections import defaultdict
import numpy as np
import itertools
from pymatgen.core import Composition

class Wyckoffs:
    def __init__(self, group: int | str | Group, n_max: int = 32) -> None:
        self.group = get_group(group)        
        self.general = set()
        special_counts = defaultdict(int)
        for wp in self.group.Wyckoff_positions:
            if wp.get_dof() == 0:
                special_counts[wp.multiplicity] += 1
            else:
                self.general.add(wp.multiplicity)

        self.special_counts = []
        self.special_mults = []
        for k in sorted(special_counts.keys()):
            self.special_mults.append(k)
            self.special_counts.append(special_counts[k])

        self.n_max = n_max
        self._build(self.n_max)

    def __repr__(self):
        return f'''    
Group {self.group.number}
General: {" ".join(map(str, self.general))}
Special:
Mult. | {" | ".join(map(str, self.special_mults))}
# WPs | {" | ".join(map(str, self.special_counts))}
'''    

    @staticmethod
    def pareto_front(pts: np.array) -> np.array:
        front = []
        for pt in pts:
            if not any(all(pt >= front_pt) for front_pt in front):
                front.append(pt)

        return np.array(front, dtype=int)
    
    def _build(self, n_max: int):   
        ndim = len(self.special_mults)        
        def zero() -> np.array:
            return np.zeros((1, ndim), dtype=int)

        all_counts = [zero()]

        self.limit = np.array(self.special_counts)

        costs = []
        mults = []
        for j_vec, j_mult in zip(np.eye(ndim), self.special_mults):
            if j_mult not in self.general:
                costs.append(j_vec)
                mults.append(j_mult)

        for mult in self.general:
            costs.append(zero())
            mults.append(mult)

        O = tuple(np.zeros(ndim))
        for k in range(1, n_max + 1):
            all_vecs = set()
            for cost, mult in zip(costs, mults):        
                if mult <= k and all_counts[k - mult] is not None:        
                    vecs = all_counts[k - mult] + cost
                    all_vecs.update(tuple(vec) for vec in vecs[np.all(vecs <= self.limit, axis=1)])
            
            if O in all_vecs:
                all_counts.append(zero())
            elif all_vecs:
                all_counts.append(self.pareto_front(np.array(list(all_vecs))))
            else:
                all_counts.append(None)

        
        self.frontiers = {}        

        for i, counts in enumerate(all_counts):
            if i == 0 or counts is None:
                continue
            else:
                self.frontiers[i] = counts                
        
        self.n_dim = ndim        

    def _can_make_single(self, card: int) -> bool:
        if card > self.n_max:
            raise ValueError(f'Only computed values up to {self.n_max}, not {card}')
        
        return card in self.frontiers
        
    def can_make(self, num_atoms: list[int] | Composition) -> bool:
        if isinstance(num_atoms, Composition):
            num_atoms = list(num_atoms.values())

        int_num_atoms = np.array(num_atoms).astype(int)
        if not np.allclose(int_num_atoms, num_atoms):
            raise ValueError(f'Does not support fractional composition {num_atoms}')

        # also checks for too many atoms
        if not all(self._can_make_single(card) for card in int_num_atoms):
            return False
                
        # can we choose vectors v from self.frontiers[c], for each
        # c in num_atoms, such that the sum of all v <= self.limit?
        choices = [range(len(self.frontiers[c])) for c in int_num_atoms]
        frontiers = [self.frontiers[c] for c in num_atoms]
        for choice in itertools.product(*choices):
            if np.all(sum(vec[i] for vec, i in zip(frontiers, choice)) <= self.limit):
                return True
            
        return False

In [202]:
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
symms = []
sgas = []
for struct in df['struct']:
    sga = SpacegroupAnalyzer(struct)
    symm = sga.get_conventional_standard_structure()    
    sgas.append(sga)
    symms.append(symm)    


In [206]:
all_wycks = []
for sga, symm in zip(sgas, symms):
    all_wycks.append(Wyckoffs(sga.get_space_group_number(), 256))

In [207]:
frontiers = []
for wycks, symm in zip(all_wycks, symms):
    comp = symm.composition

    frontiers.append(np.prod([len(x) for x in wycks.frontiers.values()]))
    for num_atoms in comp.values():
        assert wycks._can_make_single(int(num_atoms))

    assert wycks.can_make(comp)

In [197]:
len(list(wycks.frontiers.values())[0])

1

In [1]:
from baysic.pyro_generator import SystemStructureModel
from baysic.config import MainConfig
import pyrallis

with open('configs/perf_test.toml', 'r') as infile:
    conf = pyrallis.load(MainConfig, infile)

conf



MainConfig(log=LogConfig(use_wandb=False, wandb_api_key=None, wandb_project='baysic', use_directory=True, log_directory=PosixPath('logs'), log_dir_mode=<FileLoggingMode.overwrite: 'overwrite'>), search=SearchConfig(rng_seed=16274, n_parallel_structs=500, n_grid=12, wyckoff_strategy=<WyckoffSelectionStrategy.sample_distinct: 'sample_distinct'>, smoke_test=False, num_generations=200, max_gens_at_once=10, allowed_attempts_per_gen=10.0, lattice_scale_factor_mu=1.1, lattice_scale_factor_sigma=0.47, order_positions_by_radius=False, groups_to_search=[47, 77, 16, 194, 131, 2, 10, 223, 71, 195, 89]), target=TargetStructureConfig(mp_id='mp-11251', api_key=None), cli=CliConfig(verbosity=<LoggingLevel.info: 20>, show_progress=True))

In [3]:
from baysic.lattice import CubicLattice
model = SystemStructureModel(conf.log, conf.search, conf.target.composition, CubicLattice, force_group=195)
model.log_info

{'num_assignments': 0}

In [8]:
model.forward()

WyckoffAssignmentFailed: (Composition('Mg6 Au2'), <class 'baysic.lattice.CubicLattice'>, [-- Spacegroup --# 195 (P23)--
12j	site symm: 1
6i	site symm: 2 . .
6h	site symm: 2 . .
6g	site symm: 2 . .
6f	site symm: 2 . .
4e	site symm: . 3 .
3d	site symm: 222 . .
3c	site symm: 22 . .
1b	site symm: 2 33 .
1a	site symm: 2 3 .], 'Could not find valid Wyckoff assignment')

In [31]:
import torch
import pyro.distributions as dist
from baysic.errors import WyckoffAssignmentFailed

self = model

group = self.group_options[self.group_opt]            
num_atoms = list(self.comp.values())            
all_wps = group.Wyckoff_positions                            
mults = torch.tensor([wp.multiplicity for wp in all_wps]).float()
has_freedom = torch.tensor([wp.get_dof() != 0 for wp in all_wps])
def try_assignment():
    complete_assignment = []
    for count in num_atoms:    
        removed = torch.zeros_like(has_freedom)
        assignment = []
        curr_count = count
        while curr_count != 0:
            is_possible = torch.where(~removed & (mults <= curr_count))[0]
            if len(is_possible) == 0:
                return None
            weights = mults[is_possible] * 5            
            _uniq, inv, counts = torch.unique(mults[is_possible], return_inverse=True, return_counts=True)
            weights /= counts[inv]
            weights /= weights.sum()            
            selection = is_possible[dist.Categorical(probs=weights).sample()].item()        
            print(weights.numpy().round(2), ' '.join([all_wps[i].letter for i in is_possible]))
            assignment.append(all_wps[selection].letter)
            curr_count -= mults[selection].item()
            if not has_freedom[selection]:
                removed[selection] = True                            

        complete_assignment.append(assignment)
    return complete_assignment

tries = 0
comb = None
while tries < 100 and comb is None:
    tries += 1
    comb = try_assignment()

self.log_info['num_assignments'] = -tries
if comb is None:
    # note the difference from WyckoffAssignmentImpossible: this should 
    # (hopefully) never happen and is a problem with the model
    raise WyckoffAssignmentFailed(
        self.comp,
        self.lattice_model,
        [group],
        'Could not find valid Wyckoff assignment')

[0.11 0.11 0.11 0.11 0.29 0.11 0.11 0.04 0.04] i h g f e d c b a
[0.75 0.12 0.12] d b a
[1.] b
[0.11 0.11 0.11 0.11 0.29 0.11 0.11 0.04 0.04] i h g f e d c b a
[0.75 0.12 0.12] c b a
[0.5 0.5] b a
[1.] a


  has_freedom = torch.tensor([wp.get_dof() != 0 for wp in all_wps])


In [3]:
import torch

v = torch.randn(3, 4).flatten()

v.topk(4, largest=False)

torch.return_types.topk(
values=tensor([-1.6375, -1.4772, -1.0596, -0.5137]),
indices=tensor([ 8,  7, 11,  0]))