In [5]:
from moltda.construct_pd import construct_pds
from moltda.read_file import make_supercell
from moltda.vectorize_pds import diagrams_to_arrays, get_images, pd_vectorization

from pymatgen.core import Structure

from mofdscribe.topology._tda_helpers import get_persistent_images_for_structure


from collections import defaultdict
from pathlib import Path
from typing import List, Union


from functools import lru_cache

import numpy as np
from loguru import logger
from moltda.construct_pd import construct_pds
from moltda.read_file import make_supercell
from moltda.vectorize_pds import diagrams_to_arrays, get_images, pd_vectorization
from pymatgen.core import Structure

from mofdscribe.utils.substructures import elements_in_structure, filter_element


In [2]:
s = Structure.from_file("/Users/kevinmaikjablonka/git/kjappelbaum/mofdscribe/tests/test_files/HKUST-1.cif")

In [6]:
def get_persistent_images_for_structure(
    structure: Structure,
    elements: List[List[str]],
    compute_for_all_elements: bool = True,
    min_size: int = 20,
    spread: float = 0.2,
    weighting: str = "identity",
    pixels: List[int] = [50, 50],
    maxB: int = 18,
    maxP: int = 18,
    minB: int = 0,
) -> dict:
    """
    Get the persistent images for a structure.
    Args:
        structure (Structure): input structure
        elements (List[List[str]]): list of elements to compute for
        compute_for_all_elements (bool): compute for all elements
        min_size (int): minimum size of the cell for construction of persistent images
        spread (float): spread of kernel for construction of persistent images
        weighting (str): weighting scheme for construction of persistent images
        pixels (List[int]): size of the image in pixels
        maxB (int): maximum birth time for construction of persistent images
        maxP (int): maximum persistence time for construction of persistent images
        minB (int): minimum birth time for construction of persistent images
    Returns:
        persistent_images (dict): dictionary of persistent images and their barcode representations
    """

    element_images = defaultdict(dict)
    for element in elements:
        try:
            filtered_structure = filter_element(structure, element)
            coords = make_supercell(
                filtered_structure.cart_coords, filtered_structure.lattice.matrix, min_size
            )
            pds = construct_pds(coords)
            pd = diagrams_to_arrays(pds)

            images = get_images(
                pd,
                spread=spread,
                weighting=weighting,
                pixels=pixels,
                specs={"minBD": minB, "maxB": maxB, "maxP": maxP},
            )
        except ValueError as e:
            images = np.zeros((0, pixels[0], pixels[1]))
            pd = np.zeros((0, maxP + 1))

        element_images["image"][element] = images
        element_images["array"][element] = pd

    if compute_for_all_elements:
        coords = make_supercell(structure.cart_coords, structure.lattice.matrix, min_size)
        pd = diagrams_to_arrays(construct_pds(coords))

        images = get_images(pd, spread=spread, weighting=weighting, pixels=pixels)
        element_images["image"]["all"] = images
        element_images["array"]["all"] = pd

    return element_images


In [36]:
coords = make_supercell(s.cart_coords, s.lattice.matrix, 20)
pds = diagrams_to_arrays(construct_pds(coords))

In [80]:
pds['dim0']

array([(0.,        inf,   0), (0., 0.9761139 ,   1),
       (0., 0.7434609 ,   2), (0., 0.9761139 ,   3),
       (0., 0.68832904,   4), (0., 0.4679128 ,   5),
       (0., 0.9761139 ,   6), (0., 0.9761139 ,   7),
       (0., 0.9761139 ,   8), (0., 2.4401782 ,   9),
       (0., 0.9761139 ,  10), (0., 0.9761139 ,  11),
       (0., 0.7434609 ,  12), (0., 0.9761139 ,  13),
       (0., 0.9761139 ,  14), (0., 0.7434609 ,  15),
       (0., 0.9761139 ,  16), (0., 0.7434609 ,  17),
       (0., 0.9761139 ,  18), (0., 0.7434609 ,  19),
       (0., 0.7434609 ,  20), (0., 0.7434609 ,  21),
       (0., 0.9761139 ,  22), (0., 0.69920695,  23),
       (0., 1.369836  ,  24), (0., 0.6937268 ,  25),
       (0., 0.7434609 ,  26), (0., 0.9761139 ,  27),
       (0., 0.7434609 ,  28), (0., 0.9761139 ,  29),
       (0., 0.7434609 ,  30), (0., 0.69920695,  31),
       (0., 0.7434609 ,  32), (0., 0.9761139 ,  33),
       (0., 0.9761139 ,  34), (0., 0.68832904,  35),
       (0., 0.9761139 ,  36), (0., 0.9761139 ,

In [90]:
def get_min_max_from_dia(dia):
    if len(dia) == 0:
        return [0,0,0,0]
    d = np.array([[x["birth"], x["death"]] for x in dia])
    
    # convert to birth - persistence
    d[:, 1] -= d[:, 0]
    d = np.ma.masked_invalid(d)
    return [d[:, 0].min(), d[:, 0].max(), d[:,1].min(), d[:, 1].max()]

In [91]:
def get_persistence_image_limits_for_structure(
    structure: Structure,
    elements: List[List[str]],
    compute_for_all_elements: bool = True,
    min_size: int = 20,
) -> dict:
    limits = defaultdict(list)
    for element in elements:
        try:
            filtered_structure = filter_element(structure, element)
            coords = make_supercell(
                filtered_structure.cart_coords, filtered_structure.lattice.matrix, min_size
            )
            pds = construct_pds(coords)
            pd = diagrams_to_arrays(pds)
            for k, v in pd.items():
                limits[k].append(get_min_max_from_dia(v))
        except ValueError as e:
            pass
    
    if compute_for_all_elements:
        coords = make_supercell(structure.cart_coords, structure.lattice.matrix, min_size)
        pd = diagrams_to_arrays(construct_pds(coords))
        for k, v in pd.items():
            limits[k].append(get_min_max_from_dia(v))
    return limits

In [92]:
get_persistence_image_limits_for_structure(s, ["C", "H"])

defaultdict(list,
            {'dim0': [[0.0, 0.0, 0.68832904, 1.817667],
              [0.0, 0.0, 2.01175, 3.16116],
              [0.0, 0.0, 0.46250537, 2.4401782]],
             'dim1': [[0.6937268, 1.817667, 0.67927796, 3.1930804],
              [2.01175, 3.16116, 0.31121874, 1.3093953],
              [0.6937268, 2.4401782, 5.9604645e-06, 3.4944415]],
             'dim2': [[1.398083, 5.0107474, 8.273125e-05, 2.2315536],
              [3.4732823, 4.4705553, 0.45268965, 3.0167909],
              [1.398083, 4.4705553, 8.273125e-05, 2.7717457]],
             'dim3': [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]})