In [1]:
import numpy as np
from pymatgen.core import Structure

def tile_structure(lattice, species, coordinates, scale):
    """
    Tiles a structure by a given scale factor, effectively creating scale^3 more points.

    Parameters:
    - lattice: np.array, the lattice vectors of the structure.
    - species: list, the species (types) of atoms in the structure.
    - coordinates: np.array, the fractional coordinates of the atoms in the structure.
    - scale: int, the scale factor to tile the lattice and coordinates.

    Returns:
    - tiled_lattice: np.array, the scaled lattice vectors.
    - tiled_species: list, the species list repeated scale^3 times.
    - tiled_coordinates: np.array, the tiled coordinates of the atoms.
    """
    # Scale the lattice by the scale factor
    tiled_lattice = lattice * scale * 0.75
    
    # Initialize list to hold the tiled coordinates
    tiled_coordinates_list = []
    
    # Initialize list to hold the tiled species
    tiled_species = []
    
    # Tile the structure by creating additional points in the scaled volume
    for i in range(scale):
        for j in range(scale):
            for k in range(scale):
                # Calculate new coordinates by adding the offset for tiling
                new_coords = (coordinates + np.array([i, j, k])) / scale
                
                # Append the new coordinates to the list
                tiled_coordinates_list.append(new_coords)
                
                # Repeat the species list for each set of new coordinates
                tiled_species.extend(species)
    
    # Concatenate all coordinate arrays into a single NumPy array
    tiled_coordinates = np.concatenate(tiled_coordinates_list, axis=0)
    
    return tiled_lattice, tiled_species, tiled_coordinates

In [8]:
from diffusion.inference.process_generated_crystals import (
    get_one_crystal,
    load_sample_results_from_hdf5,
)
from diffusion.inference.visualize_crystal import plot_crystal
import crystal_toolkit

OUT_DIR = "out"
# crystal_file = f"{OUT_DIR}/crystals.h5"
crystal_file = f"{OUT_DIR}/relax/relaxed.h5"
sample_result = load_sample_results_from_hdf5(crystal_file)

for i in range(1):
    lattice, frac_x, atomic_numbers = get_one_crystal(sample_result, i)
    #print(lattice)
    frac_x -= np.min(frac_x, axis=0)
    tiled_lattice, tiled_species, tiled_coordinates = tile_structure(lattice, atomic_numbers, frac_x, 2)
    display(Structure(tiled_lattice, tiled_species, tiled_coordinates))

#fig = plot_crystal(atomic_numbers, lattice, frac_x, show_bonds=False)
#display(fig)

In [9]:
# Example usage:
lattice = np.array([[5.29736053, 0., 0.],
                    [2.64868026, 4.58764879, 0.],
                    [2.64868026, 1.52921626, 4.32527676]])
species = ["Ac", "Ac", "Ir", "Ag"]
coordinates = np.array([[0.75, 0.75, 0.75],
                        [0.25, 0.25, 0.25],
                        [0., 0., 0.],
                        [0.5, 0.5, 0.5]])
scale = 2  # Example scale factor

tiled_lattice, tiled_species, tiled_coordinates = tile_structure(lattice, species, coordinates, scale)
Structure(tiled_lattice, tiled_species, tiled_coordinates)