# Site-Fraction Mixture Representation

This notebook defines a **site-fraction mixture** representation for crystalline materials with one or more **sublattices**.  This formalism is standard in **ordered alloys, interstitial solutions, defect chemistry, and CALPHAD-style models**, where bulk mole fractions are insufficient to describe configuration.

The representation treats **site occupancy**, not total moles, as the fundamental state variable.

---

## Motivation

Many crystalline phases cannot be described by a single set of mole fractions $X_i$:

- Ordered compounds (e.g. B2, L1$_2$)
- Interstitial solutions (C in Fe)
- Vacancy-mediated diffusion
- Defect thermodynamics
- Sublattice models in CALPHAD

In these systems, *which lattice sites* species occupy matters.

---

## Definitions

Consider a phase with sublattices $s = 1,\dots,S$.

Let:
- $N_s$ be the number of sites on sublattice $s$
- $n_{i}^{(s)}$ be the number of sites on sublattice $s$ occupied by species $i$

### Site Fractions

The **site fraction** of species $i$ on sublattice $s$ is

$$
y_i^{(s)} \equiv \frac{n_i^{(s)}}{N_s}.
$$

For each sublattice $s$:

$$
y_i^{(s)} \ge 0,
\qquad
\sum_{i \in \mathcal{A}_s} y_i^{(s)} = 1,
$$

where $\mathcal{A}_s$ is the set of species al_


In [None]:
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple, Iterable
import numpy as np


@dataclass(frozen=True)
class SublatticeSpecification:
    """
    Specification of a sublattice for a site-fraction mixture.

    Parameters
    ----------
    name : str
        Sublattice name (e.g., "sub", "int", "alpha", "beta").
    species : list of str
        Allowed species on this sublattice (e.g., ["Fe", "Cr", "Va"]).
    multiplicity : float, optional
        Number of sites of this sublattice per formula unit. Defaults to 1.0.

    Notes
    -----
    The multiplicity is used to map site fractions to overall composition:
        N_i = sum_s multiplicity_s * y_i^(s)
    """
    name: str
    species: List[str]
    multiplicity: float = 1.0

    def __post_init__(self):
        if not self.name:
            raise ValueError("Sublattice name must be non-empty")

        if len(self.species) == 0:
            raise ValueError("Sublattice species list must be non-empty")

        if len(set(self.species)) != len(self.species):
            raise ValueError(f"Duplicate species in sublattice '{self.name}'")

        if self.multiplicity <= 0:
            raise ValueError("Sublattice multiplicity must be positive")


@dataclass(frozen=True)
class SiteFractionMixture:
    """
    Mixture defined by site fractions on one or more sublattices.

    Parameters
    ----------
    sublattices : list of SublatticeSpec
        Sublattice specifications defining allowed species and multiplicities.
    y_by_sublattice : dict[str, np.ndarray]
        Mapping from sublattice name to site-fraction vector ordered as
        `sublattices[k].species`.

    Notes
    -----
    For each sublattice s, site fractions satisfy:
        - y_i^(s) >= 0
        - sum_i y_i^(s) = 1

    This is a product of simplices (one simplex per sublattice), not a single
    global simplex.
    """
    sublattices: List[SublatticeSpecification]
    y_by_sublattice: Dict[str, np.ndarray]

    def __post_init__(self):
        if len(self.sublattices) == 0:
            raise ValueError("At least one sublattice is required")

        names = [sl.name for sl in self.sublattices]
        if len(set(names)) != len(names):
            raise ValueError("Sublattice names must be unique")

        # Validate y vectors for each declared sublattice
        for sl in self.sublattices:
            if sl.name not in self.y_by_sublattice:
                raise ValueError(f"Missing site-fraction vector for sublattice '{sl.name}'")

            y = np.asarray(self.y_by_sublattice[sl.name], dtype=float)

            if y.ndim != 1:
                raise ValueError(f"y for sublattice '{sl.name}' must be a 1D array")

            if y.size != len(sl.species):
                raise ValueError(
                    f"y length mismatch for sublattice '{sl.name}': "
                    f"expected {len(sl.species)}, got {y.size}"
                )

            if np.any(y < 0):
                raise ValueError(f"Negative site fraction in sublattice '{sl.name}'")

            if not np.isclose(y.sum(), 1.0):
                raise ValueError(f"Site fractions must sum to 1 in sublattice '{sl.name}'")

            # Store normalized float array back (ensures canonical dtype)
            self.y_by_sublattice[sl.name] = y  # type: ignore[misc]

        # Also reject any extra keys not declared in sublattices (strictness)
        extra = set(self.y_by_sublattice.keys()) - set(names)
        if extra:
            raise ValueError(f"Unknown sublattice keys in y_by_sublattice: {sorted(extra)}")

    # ------------------------------------------------------------------
    # Accessors
    # ------------------------------------------------------------------
    def site_fraction(self, sublattice: str, species: str) -> float:
        """
        Return the site fraction y^(s) for a given (sublattice, species).

        Parameters
        ----------
        sublattice : str
            Sublattice name.
        species : str
            Species name.

        Returns
        -------
        float
            Site fraction of the given species on the given sublattice.

        Raises
        ------
        KeyError
            If the sublattice or species is not present.
        """
        sl = self._get_sublattice_spec(sublattice)
        try:
            j = sl.species.index(species)
        except ValueError:
            raise KeyError(f"Species '{species}' not allowed on sublattice '{sublattice}'")

        return float(self.y_by_sublattice[sublattice][j])

    def species_on(self, sublattice: str) -> List[str]:
        """
        Return the ordered species list for a sublattice.

        Parameters
        ----------
        sublattice : str
            Sublattice name.

        Returns
        -------
        list of str
            Species allowed on that sublattice, in the order used by y.
        """
        return list(self._get_sublattice_spec(sublattice).species)

    def y_vector(self, sublattice: str) -> np.ndarray:
        """
        Return a copy of the site-fraction vector for a sublattice.

        Parameters
        ----------
        sublattice : str
            Sublattice name.

        Returns
        -------
        np.ndarray
            1D array of site fractions ordered as `species_on(sublattice)`.
        """
        return np.array(self.y_by_sublattice[sublattice], dtype=float, copy=True)

    # ------------------------------------------------------------------
    # Constructors
    # ------------------------------------------------------------------
    @staticmethod
    def from_site_counts(
        sublattices: List[SublatticeSpec],
        n_by_sublattice: Dict[str, Sequence[float] | np.ndarray],
    ) -> "SiteFractionMixture":
        """
        Construct a SiteFractionMixture from site-occupation counts.

        Parameters
        ----------
        sublattices : list of SublatticeSpec
            Sublattice specifications.
        n_by_sublattice : dict[str, array-like]
            Mapping sublattice name -> counts n_i^(s) ordered as that sublattice's
            species list.

        Returns
        -------
        SiteFractionMixture
            Site-fraction mixture where y_i^(s) = n_i^(s) / sum_j n_j^(s).

        Raises
        ------
        ValueError
            If counts are negative or total sites on any sublattice is not positive.
        """
        y_map: Dict[str, np.ndarray] = {}

        for sl in sublattices:
            if sl.name not in n_by_sublattice:
                raise ValueError(f"Missing counts for sublattice '{sl.name}'")

            n = np.asarray(n_by_sublattice[sl.name], dtype=float)

            if n.ndim != 1:
                raise ValueError(f"Counts for sublattice '{sl.name}' must be 1D")

            if n.size != len(sl.species):
                raise ValueError(
                    f"Count length mismatch for sublattice '{sl.name}': "
                    f"expected {len(sl.species)}, got {n.size}"
                )

            if np.any(n < 0):
                raise ValueError(f"Negative counts in sublattice '{sl.name}'")

            n_tot = n.sum()
            if n_tot <= 0:
                raise ValueError(f"Total site count must be positive in sublattice '{sl.name}'")

            y_map[sl.name] = n / n_tot

        return SiteFractionMixture(sublattices=sublattices, y_by_sublattice=y_map)

    # ------------------------------------------------------------------
    # Mapping to overall composition
    # ------------------------------------------------------------------
    def overall_amounts_per_formula_unit(self) -> Dict[str, float]:
        """
        Compute species amounts per formula unit from site fractions.

        Returns
        -------
        dict[str, float]
            Mapping species -> N_i where:
                N_i = sum_s multiplicity_s * y_i^(s)
            Species not present on any sublattice are absent from the dict.
        """
        N: Dict[str, float] = {}

        for sl in self.sublattices:
            y = self.y_by_sublattice[sl.name]
            a = sl.multiplicity

            for sp, y_sp in zip(sl.species, y):
                N[sp] = N.get(sp, 0.0) + a * float(y_sp)

        return N

    def overall_mole_fractions(
        self,
        exclude: Optional[Iterable[str]] = None,
    ) -> Dict[str, float]:
        """
        Compute overall mole fractions from site fractions.

        Parameters
        ----------
        exclude : iterable of str, optional
            Species to exclude from the normalization (commonly {"Va"}).

        Returns
        -------
        dict[str, float]
            Overall mole fractions X_i computed from N_i per formula unit:
                X_i = N_i / sum_j N_j
            where the sum excludes any species in `exclude`.

        Raises
        ------
        ValueError
            If the total included amount is not positive.
        """
        N = self.overall_amounts_per_formula_unit()
        excl = set(exclude) if exclude is not None else set()

        total = 0.0
        for sp, val in N.items():
            if sp not in excl:
                total += val

        if total <= 0:
            raise ValueError("Total included amount must be positive")

        X: Dict[str, float] = {}
        for sp, val in N.items():
            if sp not in excl:
                X[sp] = val / total

        return X

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------
    def _get_sublattice_spec(self, name: str) -> SublatticeSpec:
        for sl in self.sublattices:
            if sl.name == name:
                return sl
        raise KeyError(f"Sublattice '{name}' not present")


# ----------------------------------------------------------------------
# Example usage (B2 ordering: alpha/beta sublattices)
# ----------------------------------------------------------------------
if __name__ == "__main__":
    sublattices = [
        SublatticeSpec(name="alpha", species=["A", "B"], multiplicity=1.0),
        SublatticeSpec(name="beta",  species=["A", "B"], multiplicity=1.0),
    ]

    mix = SiteFractionMixture(
        sublattices=sublattices,
        y_by_sublattice={
            "alpha": np.array([1.0, 0.0]),  # A on alpha
            "beta":  np.array([0.0, 1.0]),  # B on beta
        },
    )

    print(mix.site_fraction("alpha", "A"))              # 1.0
    print(mix.overall_amounts_per_formula_unit())       # {'A': 1.0, 'B': 1.0}
    print(mix.overall_mole_fractions())                 # {'A': 0.5, 'B': 0.5}
    print(mix.overall_mole_fractions(exclude={"Va"}))   # same here (no vacancies)
