In [4]:
import numpy as np
from pymatgen.core import Structure

def tile_structure(lattice, species, coordinates, scale):
    """
    Tiles a structure by a given scale factor, effectively creating scale^3 more points.

    Parameters:
    - lattice: np.array, the lattice vectors of the structure.
    - species: list, the species (types) of atoms in the structure.
    - coordinates: np.array, the fractional coordinates of the atoms in the structure.
    - scale: int, the scale factor to tile the lattice and coordinates.

    Returns:
    - tiled_lattice: np.array, the scaled lattice vectors.
    - tiled_species: list, the species list repeated scale^3 times.
    - tiled_coordinates: np.array, the tiled coordinates of the atoms.
    """
    # Scale the lattice by the scale factor
    tiled_lattice = lattice * scale * 0.75
    
    # Initialize list to hold the tiled coordinates
    tiled_coordinates_list = []
    
    # Initialize list to hold the tiled species
    tiled_species = []
    
    # Tile the structure by creating additional points in the scaled volume
    for i in range(scale):
        for j in range(scale):
            for k in range(scale):
                # Calculate new coordinates by adding the offset for tiling
                new_coords = (coordinates + np.array([i, j, k])) / scale
                
                # Append the new coordinates to the list
                tiled_coordinates_list.append(new_coords)
                
                # Repeat the species list for each set of new coordinates
                tiled_species.extend(species)
    
    # Concatenate all coordinate arrays into a single NumPy array
    tiled_coordinates = np.concatenate(tiled_coordinates_list, axis=0)
    
    return tiled_lattice, tiled_species, tiled_coordinates

In [11]:
from diffusion.inference.process_generated_crystals import (
    get_one_crystal,
    load_sample_results_from_hdf5,
)
from diffusion.inference.visualize_crystal import plot_crystal
import crystal_toolkit

OUT_DIR = "out"
crystal_file = f"{OUT_DIR}/crystals.h5"
sample_result = load_sample_results_from_hdf5(crystal_file)

for i in range(20):
    lattice, frac_x, atomic_numbers = get_one_crystal(sample_result, i)
    print(lattice)
    #frac_x -= np.min(frac_x, axis=0)
    #tiled_lattice, tiled_species, tiled_coordinates = tile_structure(lattice, atomic_numbers, frac_x, 2)
    #display(Structure(tiled_lattice, tiled_species, tiled_coordinates))

#fig = plot_crystal(atomic_numbers, lattice, frac_x, show_bonds=False)
#display(fig)

[[-12.43242656  -4.31100055   1.09444839]
 [ 10.80482135  -7.97667024   4.44763025]
 [  7.06126406   3.05692248   6.53137197]]
[[ -5.08694638   8.82677214  -3.46121738]
 [ -4.71786844  -1.09531127  -7.63957805]
 [-14.63669024  -3.40746059  -2.37948377]]
[[15.71311333 -2.02542082  3.59050382]
 [ 1.34668618 11.02346565 -2.99007526]
 [ 6.16415976 -0.20893886  6.1703389 ]]
[[  6.66375365  -4.33683933   7.26270778]
 [  2.29144872 -10.39164951   3.23321741]
 [ 15.54195403   2.05130529   1.63324053]]
[[-8.74856616  2.80196106  2.85011984]
 [ 9.70354109 -5.89079301  7.11206427]
 [12.05115048  7.88247869  2.34356772]]
[[-0.70048436 -5.90240462  4.80785975]
 [11.43298481 -5.25268733 -1.39520226]
 [11.58224802  4.46177105  6.24709786]]
[[ 0.72656779 -1.00570488  5.18610287]
 [12.71856772 -5.91601415  4.45329181]
 [ 5.69659504  8.32516638 -1.76139469]]
[[-8.42111769  8.11226736 -3.57192532]
 [ 4.36404739  0.94966994  5.79775668]
 [13.147116    7.32341454  0.04466999]]
[[ 14.44390832  -0.67882492  

In [9]:
# Example usage:
lattice = np.array([[5.29736053, 0., 0.],
                    [2.64868026, 4.58764879, 0.],
                    [2.64868026, 1.52921626, 4.32527676]])
species = ["Ac", "Ac", "Ir", "Ag"]
coordinates = np.array([[0.75, 0.75, 0.75],
                        [0.25, 0.25, 0.25],
                        [0., 0., 0.],
                        [0.5, 0.5, 0.5]])
scale = 2  # Example scale factor

tiled_lattice, tiled_species, tiled_coordinates = tile_structure(lattice, species, coordinates, scale)
Structure(tiled_lattice, tiled_species, tiled_coordinates)