In [87]:
import numpy as np
import numpy as np
from scipy.spatial.transform import Rotation as R
from scipy.spatial import cKDTree


def read_xyz(filename):
    """
    Read an XYZ file and return:
        labels: list of atom symbols
        coords: numpy array of shape (N,3)
    """
    labels, coords = [], []
    with open(filename) as f:
        lines = f.readlines()
    for line in lines[2:]:
        parts = line.split()
        if len(parts) >= 4:
            labels.append(parts[0])
            coords.append([float(parts[1]), float(parts[2]), float(parts[3])])
    return labels, np.array(coords)


def generate_cluster_points(self,
                            target_number,
                            box_size,
                            min_distance=2.0,
                            cluster_center=None,
                            cluster_radius=None,
                            seed=None):
    """
    Generate points with minimum distance constraint, optionally clustered, fast and reproducible.

    Parameters
    ----------
    target_number : int
        Number of points to generate.
    box_size : array-like
        Box size in x, y, z.
    min_distance : float
        Minimum distance between points.
    cluster_center : array-like, optional
        Center of cluster. If None, points fill the box.
    cluster_radius : float, optional
        Radius of cluster sphere.
    seed : int, optional
        Random seed.

    Returns
    -------
    points : np.ndarray (target_number, 3)
    """
    rng = np.random.default_rng(seed)
    if self.buffer is not None:
        min_distance = max(min_distance, self.buffer)

    if cluster_center is not None and cluster_radius is not None:
        # Generate points on a cubic grid inside the sphere
        n_side = int(np.ceil((2 * cluster_radius) / min_distance))
        lin = np.linspace(-cluster_radius, cluster_radius, n_side)
        X, Y, Z = np.meshgrid(lin, lin, lin, indexing='ij')
        pts = np.vstack([X.ravel(), Y.ravel(), Z.ravel()]).T
        # Keep points inside the sphere
        mask = np.linalg.norm(pts, axis=1) <= cluster_radius
        pts = pts[mask] + cluster_center
    else:
        # Generate points on a cubic grid inside the box
        n_side = [int(np.ceil(bs / min_distance)) for bs in box_size]
        linx = np.linspace(0, box_size[0], n_side[0], endpoint=False)
        liny = np.linspace(0, box_size[1], n_side[1], endpoint=False)
        linz = np.linspace(0, box_size[2], n_side[2], endpoint=False)
        X, Y, Z = np.meshgrid(linx, liny, linz, indexing='ij')
        pts = np.vstack([X.ravel(), Y.ravel(), Z.ravel()]).T

    # Shuffle and pick required number
    rng.shuffle(pts)
    if pts.shape[0] < target_number:
        raise RuntimeError(
            f"Not enough points generated: {pts.shape[0]} < {target_number}")
        return pts
    return pts[:target_number]


def generate_candidates_each_solvent(solvent_coords,
                                     solvent_labels,
                                     box_size,
                                     target_number,
                                     residue_idx_start=0,
                                     points_template=None):
    """
    Vectorized tiling and rotation of solvent molecules.
    Returns:
        candidates: (N_atoms_total,3)
        labels: (N_atoms_total,)
        residue_idx: (N_atoms_total,)
    """
    n_atoms = solvent_coords.shape[0]

    # Random translations and rotations
    if points_template is None:
        #random_points = np.random.rand(target_number, 3) * box_size
        random_points = generate_cluster_points(target_number,
                                                box_size,
                                                min_distance=2.0,
                                                seed=42)  # (target_number,3)
    else:
        random_points = points_template
    rots = R.random(target_number).as_matrix()  # (target_number,3,3)

    n_atoms = solvent_coords.shape[0]
    coords_exp = solvent_coords[np.newaxis, :, :]  # (1, n_atoms, 3)
    # rots: (n_tiles, 3, 3)
    rot_coords = np.matmul(coords_exp,
                           rots.transpose(0, 2, 1))  # (n_tiles, n_atoms, 3)

    candidates = rot_coords.reshape(-1, 3)
    candidates += np.repeat(random_points, n_atoms, axis=0)
    labels = np.array(list(solvent_labels) * target_number).reshape(-1, 1)
    residue_idx = np.repeat(
        np.arange(residue_idx_start, residue_idx_start + target_number),
        n_atoms).reshape(-1, 1)

    return candidates, labels, residue_idx


# -----------------------------
# Step 3: Remove overlapping residues (fast vectorized)
# -----------------------------


def remove_overlaps_kdtree(existing_coords, candidate_coords,
                           candidate_residues, buffer):
    """
    Remove solvent residues that overlap with:
        1. existing atoms (solute + accepted solvents)
        2. other candidate residues themselves

    Parameters
    ----------
    existing_coords : np.ndarray (Ne,3)
        Coordinates of solute + previously accepted solvents.
    candidate_coords : np.ndarray (Nc,3)
        Coordinates of candidate solvent atoms.
    candidate_residues : np.ndarray (Nc,)
        Residue index for each candidate atom.
    buffer : float
        Minimum allowed distance.

    Returns
    -------
    keep_mask : np.ndarray (Nc,), bool
        Mask indicating kept atoms.
    drop_mask : np.ndarray (Nc,), bool
        Mask indicating dropped atoms.
    """
    # Ensure candidate_residues is 1D
    candidate_residues = candidate_residues.reshape(-1)

    # --- Round 1: Overlap with existing atoms ---
    tree_existing = cKDTree(existing_coords)
    neighbors_existing = tree_existing.query_ball_point(candidate_coords,
                                                        r=buffer)
    mask_overlap_existing = np.array(
        [len(neigh) > 0 for neigh in neighbors_existing])
    bad_residues_existing = np.unique(
        candidate_residues[mask_overlap_existing])

    # --- Round 2: Overlap among candidates themselves ---
    tree_candidates = cKDTree(candidate_coords)
    neighbors_candidates = tree_candidates.query_ball_point(candidate_coords,
                                                            r=buffer)

    # A candidate overlaps if it has neighbors (excluding itself)
    #bad residues should be if find neighbors but not the residue itself, so based on residue index
    mask_overlap_candidates = np.array([
        any(candidate_residues[neigh] != candidate_residues[i]
            for neigh in neighbors_candidates[i] if neigh != i)
        for i in range(len(candidate_coords))
    ])
    bad_residues_candidates = np.unique(
        candidate_residues[mask_overlap_candidates])

    # --- Combine bad residues from both rounds ---
    bad_residues = np.union1d(bad_residues_existing, bad_residues_candidates)

    keep_mask = ~np.isin(candidate_residues, bad_residues)
    drop_mask = ~keep_mask  # same length as candidate_residues

    return keep_mask, drop_mask


def _remove_overlaps_kdtree(existing_coords, candidate_coords,
                            candidate_residues, buffer):
    print(
        f"shape check: existing_coords {existing_coords.shape}, candidate_coords {candidate_coords.shape}, candidate_residues {candidate_residues.shape}"
    )
    """
    Remove solvent residues that overlap with existing atoms using KDTree.
    
    Parameters:
        existing_coords: np.ndarray (Ne,3) - coordinates of solute + previously accepted solvents
        candidate_coords: np.ndarray (Nc,3) - coordinates of candidate solvent atoms
        candidate_residues: np.ndarray (Nc,) - residue index for each candidate atom
        buffer: float - minimum allowed distance
        
    Returns:
        filtered_candidate_coords: np.ndarray
        filtered_candidate_residues: np.ndarray
        keep_mask: boolean array of length Nc indicating kept atoms
    """
    #reshape candidate_residues to ensure it's a 1D array
    candidate_residues = candidate_residues.reshape(-1)
    # Build KDTree for existing atoms
    tree = cKDTree(existing_coords)

    # Query all candidate atoms for neighbors within buffer
    # Returns list of neighbor indices for each candidate atom
    neighbors = tree.query_ball_point(candidate_coords, r=buffer)

    # Candidate atoms that have at least one neighbor are overlapping
    mask_overlap = np.array([len(neigh) > 0 for neigh in neighbors])

    # Identify residues to remove
    bad_residues = np.unique(candidate_residues[mask_overlap])

    # Keep atoms whose residue is not in bad_residues
    keep_mask = ~np.isin(candidate_residues, bad_residues)
    drop_mask = np.isin(candidate_residues, bad_residues)

    filtered_candidate_coords = candidate_coords[keep_mask]
    filtered_candidate_residues = candidate_residues[keep_mask]

    return keep_mask, drop_mask


# -----------------------------
# Step 4: Cluster isolated cavity points
# -----------------------------
def cluster_candidates(candidate_coords, candidate_residues, cluster_radius):
    """
    Cluster candidate residues using nearest neighbor distance threshold.
    Returns cluster centers and representative residue indices.
    """
    #check shape of candidate_residues
    candidate_residues = candidate_residues.reshape(-1)
    # Compute centers of each unique residue
    unique_residues = np.unique(candidate_residues)
    centers = np.array([
        candidate_coords[candidate_residues == res].astype(
            np.float32).mean(axis=0) for res in unique_residues
    ])

    # Flatten pairwise distances
    diff = centers[:, None, :] - centers[None, :, :]
    dist2 = np.sum(diff**2, axis=2)

    # Build mask for neighbors
    neighbors_mask = dist2 < cluster_radius**2
    np.fill_diagonal(neighbors_mask, False)

    # Accept centers that have no neighbor within cluster_radius
    isolated_mask = ~np.any(neighbors_mask, axis=1)
    accepted_centers = centers[isolated_mask]
    accepted_residues = unique_residues[isolated_mask]

    return accepted_centers, accepted_residues


_solute_labels, _solute_coords = read_xyz(
    "output/UiO-66_mofbuilder_output.xyz")
_water_labels, _water_coords = read_xyz("water.xyz")
_dmso_labels, _dmso_coords = read_xyz("dmso.xyz")
box_size = [30.0, 30.0, 30.0]  # Define your box size
target_water_number = 1000  # Target number of water molecules
target_dmso_number = 500  # Target number of DMSO molecules
buffer = 1.0  # Minimum distance to existing atoms
water_candidates, water_labels, water_residue_idx = generate_candidates_each_solvent(
    _water_coords,
    _water_labels,
    box_size,
    target_water_number,
    residue_idx_start=0)
dmso_candidates, dmso_labels, dmso_residue_idx = generate_candidates_each_solvent(
    _dmso_coords,
    _dmso_labels,
    box_size,
    target_dmso_number,
    residue_idx_start=target_water_number)

water_candidates_info = np.hstack(
    (water_residue_idx, water_labels, water_candidates))
dmso_candidates_info = np.hstack(
    (dmso_residue_idx, dmso_labels, dmso_candidates))
all_candidates_info = np.vstack((water_candidates_info, dmso_candidates_info))
all_candidate_coords = all_candidates_info[:, 2:5]
all_candidate_labels = all_candidates_info[:, 1]
all_candidate_residues = all_candidates_info[:, 0].astype(int)

# Minimum distance to existing atoms
print("Total candidate atoms:", all_candidate_coords.shape[0])
keep_mask, drop_mask = remove_overlaps_kdtree(_solute_coords,
                                              all_candidate_coords,
                                              all_candidate_residues, buffer)
print("Number of accepted candidate atoms:", np.sum(keep_mask))
keep_candidate_coords = all_candidate_coords[keep_mask]
keep_candidate_labels = all_candidate_labels[keep_mask]

cavity_coords = all_candidate_coords[drop_mask]
cavity_residues = all_candidate_residues[drop_mask]
cavity_labels = all_candidate_labels[drop_mask]

possible_centers, possible_residues = cluster_candidates(cavity_coords,
                                                         cavity_residues,
                                                         cluster_radius=3.0)
print("Number of possible centers:", possible_centers.shape)

#regenerate solvent molecules at possible centers
round2water_candidates, round2water_labels, round2water_residue_idx = generate_candidates_each_solvent(
    _water_coords,
    _water_labels,
    box_size,
    target_number=possible_centers.shape[0],
    residue_idx_start=target_water_number + target_dmso_number,
    points_template=possible_centers)

round2water_info = np.hstack(
    (round2water_residue_idx, round2water_labels, round2water_candidates))

print("Number of newly added water atoms:", keep_candidate_coords.shape)

round2_keep_mask, round2_drop_mask = remove_overlaps_kdtree(
    keep_candidate_coords, round2water_candidates, round2water_residue_idx,
    buffer)

#use possible centers to add solvent molecules repeatedly until no more can be added, if no possible centersm then get a random dots in the box
round2_keep_candidate_coords = round2water_candidates[round2_keep_mask]
round2_keep_candidate_labels = round2water_labels[round2_keep_mask]
round2_cavity_coords = round2water_candidates[round2_drop_mask]
round2_cavity_residues = round2water_residue_idx[round2_drop_mask]
round2_cavity_labels = round2water_labels[round2_drop_mask]

round2_possible_centers, round2_possible_residues = cluster_candidates(
    round2_cavity_coords, round2_cavity_residues, cluster_radius=3.0)
print("Number of possible centers:", round2_possible_centers.shape)

#merge the two sets of kept candidates
final_candidate_coords = np.vstack(
    (_solute_coords, keep_candidate_coords, round2_keep_candidate_coords))
#merge the 1d arrays of labels
final_candidate_labels = np.hstack(
    (np.array(_solute_labels), keep_candidate_labels,
     round2_keep_candidate_labels.flatten()))

data = np.hstack((final_candidate_labels.reshape(-1,
                                                 1), final_candidate_coords))

#write to xyz file
with open("solvated_structure.xyz", "w") as fp:
    fp.write(f"{data.shape[0]}\n")
    fp.write("Solvated structure\n")
    for row in data:
        fp.write(
            f"{row[0]} {float(row[1]):.4f} {float(row[2]):.4f} {float(row[3]):.4f}\n"
        )
print(len(data), "atoms written to solvated_structure.xyz")

TypeError: generate_cluster_points() missing 1 required positional argument: 'box_size'

In [2]:
import numpy as np
from scipy.spatial.transform import Rotation as R
from scipy.spatial import cKDTree


class SolventPacker:

    def __init__(self, buffer=1.0, box_size=(30.0, 30.0, 30.0)):
        """
        Parameters
        ----------
        buffer : float
            Minimum allowed distance between atoms.
        box_size : tuple of float
            Dimensions of the box for placing solvents.
        """
        self.buffer = buffer
        self.box_size = np.array(box_size)
        self.verbose = True

    def read_xyz(self, filename):
        labels, coords = [], []
        with open(filename) as f:
            lines = f.readlines()
        for line in lines[2:]:
            parts = line.split()
            if len(parts) >= 4:
                labels.append(parts[0])
                coords.append(
                    [float(parts[1]),
                     float(parts[2]),
                     float(parts[3])])
        return labels, np.array(coords)

    def generate_cluster_points(self,
                                target_number,
                                box_size,
                                min_distance=0.5,
                                cluster_center=None,
                                cluster_radius=None,
                                seed=None):
        """
        Generate points with minimum distance constraint, optionally clustered, fast and reproducible.

        Parameters
        ----------
        target_number : int
            Number of points to generate.
        box_size : array-like
            Box size in x, y, z.
        min_distance : float
            Minimum distance between points.
        cluster_center : array-like, optional
            Center of cluster. If None, points fill the box.
        cluster_radius : float, optional
            Radius of cluster sphere.
        seed : int, optional
            Random seed.

        Returns
        -------
        points : np.ndarray (target_number, 3)
        """
        rng = np.random.default_rng(seed)

        if cluster_center is not None and cluster_radius is not None:
            # Generate points on a cubic grid inside the sphere
            n_side = int(np.ceil((2 * cluster_radius) / min_distance))
            lin = np.linspace(-cluster_radius, cluster_radius, n_side)
            X, Y, Z = np.meshgrid(lin, lin, lin, indexing='ij')
            pts = np.vstack([X.ravel(), Y.ravel(), Z.ravel()]).T
            # Keep points inside the sphere
            mask = np.linalg.norm(pts, axis=1) <= cluster_radius
            pts = pts[mask] + cluster_center
        else:
            # Generate points on a cubic grid inside the box
            n_side = [int(np.ceil(bs / min_distance)) for bs in box_size]
            linx = np.linspace(0, box_size[0], n_side[0], endpoint=False)
            liny = np.linspace(0, box_size[1], n_side[1], endpoint=False)
            linz = np.linspace(0, box_size[2], n_side[2], endpoint=False)
            X, Y, Z = np.meshgrid(linx, liny, linz, indexing='ij')
            pts = np.vstack([X.ravel(), Y.ravel(), Z.ravel()]).T

        # Shuffle and pick required number
        rng.shuffle(pts)
        if pts.shape[0] < target_number:
            print(
                f"Not enough points generated: {pts.shape[0]} < {target_number}"
            )
            return pts
        return pts[:target_number]

    def generate_candidates_each_solvent(self,
                                         solvent_coords,
                                         solvent_labels,
                                         target_number,
                                         residue_idx_start=0,
                                         points_template=None):
        n_atoms = solvent_coords.shape[0]
        if points_template is None:
            random_points = np.random.rand(target_number, 3) * self.box_size
            #random_points = self.generate_cluster_points(
            #    target_number, self.box_size, min_distance=2.0,
            #    seed=42)  # (target_number,3)
        else:
            random_points = points_template
        target_number = random_points.shape[0]

        rots = R.random(target_number).as_matrix()
        coords_exp = solvent_coords[np.newaxis, :, :]
        rot_coords = np.matmul(coords_exp, rots.transpose(0, 2, 1))
        candidates = rot_coords.reshape(-1, 3)
        candidates += np.repeat(random_points, n_atoms, axis=0)

        labels = np.array(list(solvent_labels) * target_number).reshape(-1, 1)
        residue_idx = np.repeat(
            np.arange(residue_idx_start, residue_idx_start + target_number),
            n_atoms).reshape(-1, 1)

        return candidates, labels, residue_idx

    def remove_overlaps_kdtree(self, existing_coords, candidate_coords,
                               candidate_residues):
        candidate_residues = candidate_residues.reshape(-1)

        # Round 1: overlap with existing atoms
        tree_existing = cKDTree(existing_coords)
        neighbors_existing = tree_existing.query_ball_point(candidate_coords,
                                                            r=self.buffer)
        mask_overlap_existing = np.array(
            [len(neigh) > 0 for neigh in neighbors_existing])
        bad_residues_existing = np.unique(
            candidate_residues[mask_overlap_existing])

        # Round 2: overlap among candidates
        tree_candidates = cKDTree(candidate_coords)
        neighbors_candidates = tree_candidates.query_ball_point(
            candidate_coords, r=self.buffer)
        mask_overlap_candidates = np.array([
            any(candidate_residues[neigh] != candidate_residues[i]
                for neigh in neighbors_candidates[i] if neigh != i)
            for i in range(len(candidate_coords))
        ])
        bad_residues_candidates = np.unique(
            candidate_residues[mask_overlap_candidates])

        bad_residues = np.union1d(bad_residues_existing,
                                  bad_residues_candidates)
        keep_mask = ~np.isin(candidate_residues, bad_residues)
        drop_mask = ~keep_mask
        return keep_mask, drop_mask

    def _remove_overlaps_kdtree(self, existing_coords, candidate_coords,
                                candidate_residues):
        print(f"shape check: existing_coords {existing_coords.shape}, "
              f"candidate_coords {candidate_coords.shape}, "
              f"candidate_residues {candidate_residues.shape}")
        candidate_residues = candidate_residues.reshape(-1)
        tree = cKDTree(existing_coords)
        neighbors = tree.query_ball_point(candidate_coords, r=self.buffer)
        mask_overlap = np.array([len(neigh) > 0 for neigh in neighbors])
        bad_residues = np.unique(candidate_residues[mask_overlap])
        keep_mask = ~np.isin(candidate_residues, bad_residues)
        drop_mask = np.isin(candidate_residues, bad_residues)

        filtered_candidate_coords = candidate_coords[keep_mask]
        filtered_candidate_residues = candidate_residues[keep_mask]

        return keep_mask, drop_mask

    def cluster_candidates(self, candidate_coords, candidate_residues,
                           cluster_radius):
        candidate_residues = candidate_residues.reshape(-1)
        unique_residues = np.unique(candidate_residues)
        centers = np.array([
            candidate_coords[candidate_residues == res].astype(
                np.float32).mean(axis=0) for res in unique_residues
        ])

        diff = centers[:, None, :] - centers[None, :, :]
        dist2 = np.sum(diff**2, axis=2)
        neighbors_mask = dist2 < cluster_radius**2
        np.fill_diagonal(neighbors_mask, False)

        isolated_mask = ~np.any(neighbors_mask, axis=1)
        accepted_centers = centers[isolated_mask]
        accepted_residues = unique_residues[isolated_mask]

        return accepted_centers, accepted_residues

    def solvate(self,
                solute_file,
                water_file,
                dmso_file,
                target_water_number=1000,
                target_dmso_number=500,
                output_file="solvated_structure.xyz",
                trial_rounds=10):

        # --- read solute and solvents ---
        _solute_labels, _solute_coords = self.read_xyz(solute_file)
        _water_labels, _water_coords = self.read_xyz(water_file)
        _dmso_labels, _dmso_coords = self.read_xyz(dmso_file)

        best_coords = None
        best_labels = None
        max_added = 0

        # --- Trial loop for 10 random seeds ---
        for trial in range(trial_rounds):
            if self.verbose:
                print(f"Trial {trial+1}/{trial_rounds},before{max_added}")

            np.random.seed(trial)  # different random seed for each trial

            # --- Generate initial solvent candidates ---
            water_candidates, water_labels, water_residue_idx = self.generate_candidates_each_solvent(
                _water_coords,
                _water_labels,
                target_water_number,
                residue_idx_start=0)
            dmso_candidates, dmso_labels, dmso_residue_idx = self.generate_candidates_each_solvent(
                _dmso_coords,
                _dmso_labels,
                target_dmso_number,
                residue_idx_start=target_water_number)

            all_candidates_info = np.vstack(
                (np.hstack(
                    (water_residue_idx, water_labels, water_candidates)),
                 np.hstack((dmso_residue_idx, dmso_labels, dmso_candidates))))

            all_candidate_coords = all_candidates_info[:, 2:5].astype(float)
            all_candidate_labels = all_candidates_info[:, 1]
            all_candidate_residues = all_candidates_info[:, 0].astype(int)

            print(
                f"Generated {all_candidate_coords.shape[0]} initial solvent candidates."
            )

            # --- Round 1 overlap removal ---
            keep_mask, drop_mask = self.remove_overlaps_kdtree(
                _solute_coords, all_candidate_coords, all_candidate_residues)
            accepted_coords = all_candidate_coords[keep_mask]
            accepted_labels = all_candidate_labels[keep_mask]
            cavity_coords = all_candidate_coords[drop_mask]
            cavity_residues = all_candidate_residues[drop_mask]

            print(
                f"After Round 1 overlap removal: {accepted_coords.shape[0]} accepted, {cavity_coords.shape[0]} left in cavity."
            )

            # --- Iterative cavity filling (big round) ---
            max_fill_rounds = 1000
            round_idx = 0
            while round_idx < max_fill_rounds and cavity_coords.shape[0] > 0:
                round_idx += 1
                possible_centers, _ = self.cluster_candidates(
                    cavity_coords, cavity_residues, cluster_radius=self.buffer)
                print(
                    f"Round {round_idx}: {possible_centers.shape[0]} possible centers identified."
                )

                if possible_centers.shape[0] == 0:
                    break

                round_candidates, round_labels, round_residues = self.generate_candidates_each_solvent(
                    _water_coords,
                    _water_labels,
                    target_number=possible_centers.shape[0],
                    residue_idx_start=accepted_coords.shape[0],
                    points_template=possible_centers)

                keep_mask, drop_mask = self.remove_overlaps_kdtree(
                    accepted_coords, round_candidates, round_residues)
                round_keep_coords = round_candidates[keep_mask]
                round_keep_labels = round_labels[keep_mask]

                print(
                    f"Round {round_idx}: {round_keep_coords.shape[0]} added, {round_candidates[drop_mask].shape[0]} left in cavity."
                )

                if round_keep_coords.shape[0] == 0:
                    break

                # Update accepted molecules
                accepted_coords = np.vstack(
                    (accepted_coords, round_keep_coords))
                accepted_labels = np.hstack(
                    (accepted_labels, round_keep_labels.flatten()))

                # Update cavity for next iteration
                cavity_coords = round_candidates[drop_mask]
                cavity_residues = round_residues[drop_mask]

            # --- Update best trial ---
            print(
                f"after{max_added},accepted_coords shape {accepted_coords.shape},solute shape {_solute_coords.shape}"
            )
            n_added = accepted_coords.shape[0]
            if n_added > max_added:
                max_added = n_added
                best_coords = accepted_coords.copy()
                best_labels = accepted_labels.copy()
                if self.verbose:
                    print(
                        f"Trial {trial+1} is new best with {n_added} added atoms."
                    )

        # --- Merge solute and best solvent trial ---
        if best_coords is not None:
            final_coords = np.vstack(
                (_solute_coords, best_coords.astype(float)))
        else:
            raise ValueError(
                "No valid solvent molecules were added in any trial.")
        final_labels = np.hstack((_solute_labels, best_labels))

        # --- Write xyz ---
        with open(output_file, "w") as fp:
            fp.write(f"{final_coords.shape[0]}\n")
            fp.write("Solvated structure\n")
            for label, (x, y, z) in zip(final_labels, final_coords):
                fp.write(f"{label} {x:.4f} {y:.4f} {z:.4f}\n")

        if self.verbose:
            print(len(final_coords), "atoms written to", output_file)
            print("Best trial added", max_added, "solvent atoms.")

        return final_coords, final_labels


In [4]:
box_size = [80.0, 50.0, 50.0]  # Å³
water_density = 1.0  # g/cm³
water_molar_mass = 18.015  # g/mol


def solvent_number_from_density(box_size, density, molar_mass):
    """
    Compute the number of solvent molecules to fill a box at given density.

    Parameters
    ----------
    box_size : array-like of length 3
        Box size in Angstroms [x, y, z]
    density : float
        Density in g/cm³
    molar_mass : float
        Molar mass in g/mol

    Returns
    -------
    n_molecules : int
        Number of solvent molecules to generate
    """
    V_A3 = np.prod(box_size)  # Å³
    V_cm3 = V_A3 * 1e-24  # cm³
    N_A = 6.022e23  # Avogadro's number
    n_molecules = int(density * V_cm3 * N_A / molar_mass)
    return n_molecules


n_water = solvent_number_from_density(box_size, water_density,
                                      water_molar_mass)
print("Number of water molecules:", n_water)

dmso_density = 1.1  # g/cm³
dmso_molar_mass = 78.13  # g/mol
n_dmso = solvent_number_from_density(box_size, dmso_density, dmso_molar_mass)
print("Number of DMSO molecules:", n_dmso)

Number of water molecules: 6685
Number of DMSO molecules: 1695


In [5]:
packer = SolventPacker(buffer=1, box_size=(80.0, 50.0, 50.0))
packer.solvate(solute_file="output/UiO-66_mofbuilder_output.xyz",
               water_file="water.xyz",
               dmso_file="dmso.xyz",
               target_water_number=4000,
               target_dmso_number=20,
               output_file="solvated_structure.xyz",
               trial_rounds=50)

Trial 1/50,before0
Generated 12200 initial solvent candidates.
After Round 1 overlap removal: 7267 accepted, 4933 left in cavity.
Round 1: 1316 possible centers identified.
Round 1: 2664 added, 1284 left in cavity.
Round 2: 378 possible centers identified.
Round 2: 732 added, 402 left in cavity.
Round 3: 124 possible centers identified.
Round 3: 225 added, 147 left in cavity.
Round 4: 47 possible centers identified.
Round 4: 96 added, 45 left in cavity.
Round 5: 15 possible centers identified.
Round 5: 30 added, 15 left in cavity.
Round 6: 5 possible centers identified.
Round 6: 9 added, 6 left in cavity.
Round 7: 2 possible centers identified.
Round 7: 3 added, 3 left in cavity.
Round 8: 1 possible centers identified.
Round 8: 0 added, 3 left in cavity.
after0,accepted_coords shape (11026, 3),solute shape (8326, 3)
Trial 1 is new best with 11026 added atoms.
Trial 2/50,before11026
Generated 12200 initial solvent candidates.
After Round 1 overlap removal: 6963 accepted, 5237 left in ca

(array([[11.042     , -2.826     , 10.44      ],
        [10.066     , -1.973     , 11.007     ],
        [ 9.257     , -2.138     ,  9.859     ],
        ...,
        [22.12479607,  7.92882089,  9.89420294],
        [21.2038733 ,  8.09331263,  9.56947691],
        [22.25852514,  6.95454096,  9.77802259]]),
 array(['D', 'D', 'D', ..., 'O', 'H', 'H'], dtype='<U32'))