In [1]:
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Tuple, Optional, List
import os
from rdkit import Chem
from rdkit.Chem import AllChem
import copy


def _fa_key(fa: Optional[FattyAcid]) -> tuple:
    """
    Return a fully comparable, canonical sorting key for a FattyAcid (or None).

    This function is used to impose a total ordering on fatty acids so that
    sn-1 and sn-3 positions can be treated symmetrically in SymmetricGlyceride.

    The returned key is a tuple of immutable, orderable components that
    completely describe the canonical fatty acid structure.

    Rules:
    • None (EMPTY) sorts before any actual FattyAcid.
    • Otherwise, ordering is determined lexicographically by:
        (length, db_positions, db_stereo, branches)
        where each field comes from the canonicalized FattyAcid.
    """
    if fa is None:
        # Empty slot (e.g., diacylglyceride)
        return (0,)

    # Canonicalize first (sorts db positions, normalizes Z/E, etc.)
    fac = fa.canonical()

    # The key must contain only immutable, natively comparable types.
    # Each attribute is already sorted and normalized in canonical().
    return (
        1,                                 # Tag so all real FAs > EMPTY
        fac.length,                        # Chain length
        tuple(fac.db_positions),           # Sorted double bond positions
        tuple(fac.db_stereo),              # Aligned stereochem (Z/E)
        tuple(fac.branches),               # Sorted branches (pos, label)
    )

@dataclass(frozen=True)
class FattyAcid:
    """
    Immutable description of a fatty acid chain.

    Attributes:
        length (int): number of carbons in the main chain
        db_positions (Tuple[int, ...]) : tuple of integers k representing a double bond between C_k and C_{k+1} (dk)
        db_stereo (Tuple[str, ...]): tuple of "Z"/"E" (or "cis"/"trans") algined with db_positions
            Geometric isomers of a double bond correspond to either cis ("kinked chains") or trans ("linear chains")
        branches (Tuple[int, ...]): tuple of integers representing positions of methyl branches (e.g., (15, "Me"))

    Functions:
        canonical(): returns a canonical (normalized) version of the fatty acid.
        to_rdkit_mol(): converts the fatty acid to an RDKit molecule.
    """

    # Use of fields to ensure mutable objects are not shared between instances
    length: int
    db_positions: Tuple[int, ...] = field(default_factory=tuple)
    db_stereo: Tuple[str, ...] = field(default_factory=tuple)
    branches: Tuple[Tuple[int, str], ...] = field(default_factory=tuple)

    def __post_init__(self):
        if self.length < 0:
            raise ValueError("length must be >= 0")
        if len(self.db_positions) != len(self.db_stereo):
            raise ValueError("db_positions and db_stereo must be same length")
        for k in self.db_positions:
            if not (1 <= k <= max(self.length - 1, 0)):
                raise ValueError(
                    f"double-bond position Δ{k} out of range for C{self.length}"
                )
        for pos, _label in self.branches:
            if not (1 <= pos <= self.length):
                raise ValueError(
                    f"branch position C{pos} out of range for C{self.length}"
                )

    def canonical(self) -> "FattyAcid":
        """
        Returns a canonical (normalized) version of the fatty acid. Ensures that
        object signature is identical for equivalent fatty acids.
        """

        def norm_st(s: str) -> str:
            s = s.strip().lower()
            return {"cis": "z", "trans": "e"}.get(s, s.upper())  # supports "Z"/"E"

        positions = tuple(sorted(self.db_positions))
        # Keep stereo aligned with sorted positions
        if positions:
            # Build mapping from original pos to stereo
            mp = {p: norm_st(st) for p, st in zip(self.db_positions, self.db_stereo)}
            stereo = tuple(mp[p] for p in positions)
        else:
            stereo = tuple()
        branches = tuple(sorted((int(p), str(lbl)) for p, lbl in self.branches))
        return FattyAcid(self.length, positions, stereo, branches)
    
    @property
    def name(self) -> str:
        """
        Generate a standardized name for the fatty acid using the format:

        N{CC}D{DD}[P{pp}{S}{pp}{S}...][M{pp}...][OH{pp}...]

        N{CC} - Number of Carbons like 06 or 12
        D{DD} - Number of double bonds
        [P{pp}(S){pp{S}...] - double bond position and stereo where S is trans and Z is cis (e.g. 06Z )
        M{pp}... -  Methyl branches at position pp
        OH{pp}... - Hydroxyl branches at position pp.

        Example: N18D1P09Z means 18 carbons, 1 double bond at position 9 with Z (cis) stereo and is oleic acid.
        """
        parts = [f"N{self.length:02d}", f"D{len(self.db_positions):02d}"]
        if self.db_positions:
            pos_stereo = []
            for p, s in zip(self.db_positions, self.db_stereo):
                pos_stereo.append(f"P{p:02d}{s.upper()}")
            parts.extend(pos_stereo)
        for bpos, blabel in self.branches:
            parts.append(f"M{bpos:02d}")  # Only 'Me' supported now
        return "".join(parts)


class Glyceride:
    """
    Description of a glyceride (diacylglyceride if a chain is None)

    sn: tuple of three Optional[FattyAcid] in sn-1, sn-2, sn-3 order.
        Use None for an emtpy chain (e.g. diacylglyceride embedding).
    """

    def __init__(self, sn: Tuple[Optional[FattyAcid], Optional[FattyAcid], Optional[FattyAcid]]):
        self.sn = sn
        if len(self.sn) != 3:
            raise ValueError("sn must have length 3 (sn-1, sn-2, sn-3)")

    # TODO: Validate from_name and make sure to canonicalize stereochemical names in name
    @classmethod
    def from_name(cls, name: str) -> "Glyceride":
        """
        Create a Glyceride using the naming scheme:

        Fatty acid format:

        N{CC}D{DD}[P{pp}{S}{pp}{S}...][M{pp}...][OH{pp}...]

        N{CC} - Number of Carbons like 06 or 12
        D{DD} - Number of double bonds
        [P{pp}(S){pp{S}...] - double bond position and stereo where S is trans and Z is cis (e.g. 06Z )
        M{pp}... -  Methyl branches at position pp
        OH{pp}... - Hydroxyl branches at position pp.

        Example: N18D1P09Z means 18 carbons, 1 double bond at position 9 with Z (cis) stereo and is oleic acid.

        Glyceride format:

        G_{FA1}_{FA2}_{FA3}

        Example: G_N18D1P09Z_N18D1P09Z_N18D1P09Z is triolein.

        Args:
            name (str): Name of the glyceride in the specified format.

        Returns:
            Glyceride: The corresponding Glyceride object.
        """

        def parse_fatty_acid(fa_str: str) -> Optional[FattyAcid]:
            if fa_str == "EMPTY":
                return None
            if not fa_str.startswith("N"):
                raise ValueError(f"Invalid fatty acid format: {fa_str}")
            length = int(fa_str[1:3])
            if "D" not in fa_str:
                raise ValueError(f"Invalid fatty acid format: {fa_str}")
            d_index = fa_str.index("D")
            num_db = int(fa_str[d_index + 1 : d_index + 3])
            db_positions = []
            db_stereo = []
            branches = []
            i = d_index + 3
            while i < len(fa_str):
                if fa_str[i] == "P":
                    i += 1
                    for _ in range(num_db):
                        pos = int(fa_str[i : i + 2])
                        db_positions.append(pos)
                        i += 2
                        if i < len(fa_str) and fa_str[i] in (
                            "Z",
                            "E",
                            "z",
                            "e",
                            "C",
                            "c",
                            "T",
                            "t",
                        ):
                            db_stereo.append(fa_str[i].upper())
                            i += 1
                        else:
                            db_stereo.append("Z")  # Default to Z if not specified
                elif fa_str[i] == "M":
                    i += 1
                    while i + 1 < len(fa_str) and fa_str[i : i + 2].isdigit():
                        pos = int(fa_str[i : i + 2])
                        branches.append((pos, "Me"))
                        i += 2
                elif fa_str[i] == "O":
                    if fa_str[i : i + 2] == "OH":
                        i += 2
                        while i + 1 < len(fa_str) and fa_str[i : i + 2].isdigit():
                            pos = int(fa_str[i : i + 2])
                            branches.append((pos, "OH"))
                            i += 2
                    else:
                        raise ValueError(f"Invalid fatty acid format: {fa_str}")
                else:
                    raise ValueError(f"Invalid fatty acid format: {fa_str}")
            if len(db_positions) != num_db:
                raise ValueError(
                    f"Number of double bonds does not match positions in: {fa_str}"
                )
            return FattyAcid(
                length, tuple(db_positions), tuple(db_stereo), tuple(branches)
            )

        if not name.startswith("G_"):
            raise ValueError(f"Invalid glyceride format: {name}")
        parts = name[2:].split("_")
        if len(parts) != 3:
            raise ValueError(f"Glyceride must have three fatty acids: {name}")
        fa1 = parse_fatty_acid(parts[0])
        fa2 = parse_fatty_acid(parts[1])
        fa3 = parse_fatty_acid(parts[2])
        return cls((fa1, fa2, fa3))

    def signature_tuple(self) -> Tuple:
        """Canonical, hashable structure signature (topology only)."""
        parts = []
        for fa in self.sn:
            if fa is None:
                parts.append(("EMPTY",))
            else:
                fac = fa.canonical()
                parts.append(
                    ("FA", fac.length, fac.db_positions, fac.db_stereo, fac.branches)
                )
        return tuple(parts)
    
    def add_fatty_acid(self, index: int, fatty_acid: FattyAcid, deep_copy: bool = True):
        """
        Add a fatty acid to the glyceride and return a deepcopy of the new glyceride.

        Paramters:
            index (int): Index (0, 1, or 2) to add the fatty acid to.
            fatty_acid (FattyAcid): The fatty acid to add.
            deep_copy (bool): Whether to perform a deep copy of the glyceride.
        
        Returns:
            Glyceride: A new Glyceride instance with the added fatty acid. 
        """
        if index not in (0, 1, 2):
            raise ValueError("Index must be 0, 1, or 2.")
        if self.sn[index] is not None:
            raise ValueError(f"Position sn-{index + 1} is already occupied.")

        if deep_copy:
            new_sn = list(self.sn)
            new_sn[index] = copy.deepcopy(fatty_acid)
            return self.__class__(tuple(new_sn))   
        else:
            new_sn = list(self.sn)
            new_sn[index] = fatty_acid
            self.sn = tuple(new_sn)               
            return self

    def swap_fatty_acids(self, index1: int, index2: int, deep_copy: bool = True):
        """
        Swap two fatty acids in the glyceride and return a deepcopy of the new glyceride.

        Paramters:
            index1 (int): Index (0, 1, or 2) of the first fatty acid to swap.
            index2 (int): Index (0, 1, or 2) of the second fatty acid to swap.
            deep_copy (bool): Whether to perform a deep copy of the glyceride.
        
        Returns:
            Glyceride: A new Glyceride instance with the swapped fatty acids.
        """
        if index1 not in (0, 1, 2) or index2 not in (0, 1, 2):
            raise ValueError("Indices must be 0, 1, or 2.")
        if index1 == index2:
            raise ValueError("Indices must be different to perform a swap.")

        if deep_copy:
            new_sn = list(copy.deepcopy(self.sn))
            new_sn[index1], new_sn[index2] = new_sn[index2], new_sn[index1]
            return self.__class__(tuple(new_sn))  
        else:
            new_sn = list(self.sn)
            new_sn[index1], new_sn[index2] = new_sn[index2], new_sn[index1]
            self.sn = tuple(new_sn)
            return self

    def glyceride_to_rdkit(self, optimize: bool = True) -> Chem.Mol:
        """
        Build an RDKit molecule for the given Glyceride, embed in 3D, and relax.
        Uses kwargs-only ETKDG (ETversion=2) for compatibility with your RDKit.

        Args: 
            optimize (bool): Whether to optimize the 3D structure with force fields. 
                If False, only embedding is done. 

        Returns:
            Chem.Mol: The RDKit molecule with 3D coordinates. 
        """
        rw, sn_os, _ = self._build_glycerol_backbone()
        for idx, fa in enumerate(self.sn):
            if fa is None:
                continue
            fac = fa.canonical()
            carbonyl_c, _ = self._build_acyl_chain(fac, rw)
            rw.AddBond(sn_os[idx], carbonyl_c, Chem.BondType.SINGLE)  # ester bond

        mol = rw.GetMol()
        Chem.SanitizeMol(mol)
        mol = Chem.AddHs(mol)

        if optimize:
            # Helper: one embed attempt with ETKDG v2 and a toggle for random coords
            def _try_embed(use_random: bool) -> int:
                return AllChem.EmbedMolecule(
                    mol,
                    maxAttempts=8000,
                    randomSeed=0xBEEF,
                    useRandomCoords=use_random,
                    boxSizeMult=2.0,
                    randNegEig=True,
                    numZeroFail=1,
                    forceTol=0.001,
                    ignoreSmoothingFailures=False,
                    enforceChirality=True,
                    useExpTorsionAnglePrefs=True,
                    useBasicKnowledge=True,
                    useSmallRingTorsions=False,
                    useMacrocycleTorsions=True,
                    ETversion=2,
                )

            # Single-conformer attempts: deterministic -> random
            confId = _try_embed(use_random=False)
            if confId == -1:
                confId = _try_embed(use_random=True)

            if confId == -1:
                conf_ids = list(
                    AllChem.EmbedMultipleConfs(
                        mol,
                        numConfs=24,
                        maxAttempts=8000,
                        randomSeed=0xBEEF,
                        useRandomCoords=True,
                        boxSizeMult=2.0,
                        randNegEig=True,
                        numZeroFail=1,
                        forceTol=0.001,
                        ignoreSmoothingFailures=False,
                        enforceChirality=True,
                        useExpTorsionAnglePrefs=True,
                        useBasicKnowledge=True,
                        useSmallRingTorsions=False,
                        useMacrocycleTorsions=True,
                        ETversion=2,
                    )
                )
                if not conf_ids:
                    raise RuntimeError(
                        "Conformer embedding failed (no conformers generated)."
                    )

                # Optimize all with UFF and select best
                res = AllChem.UFFOptimizeMoleculeConfs(
                    mol, confIds=conf_ids, maxIters=1000
                )
                energies = [r[1] for r in res]
                best_idx = min(range(len(conf_ids)), key=lambda i: energies[i])
                confId = conf_ids[best_idx]

            #  Force-field relaxation of the chosen conformer
            if AllChem.MMFFHasAllMoleculeParams(mol):
                try:
                    AllChem.MMFFOptimizeMolecule(mol, confId=confId, maxIters=2000)
                except Exception:
                    AllChem.UFFOptimizeMolecule(mol, confId=confId, maxIters=2000)
            else:
                AllChem.UFFOptimizeMolecule(mol, confId=confId, maxIters=2000)

            # Assign stereochemistry after coords exist
            Chem.AssignStereochemistry(mol, cleanIt=True, force=True)
        return mol

    def rdkit_mol_to_gaussian_gjf(
        self,
        mol: Chem.Mol,
        gjf_path: str,
        jobname: str = "glyceride_opt",
        mem="8GB",
        nproc=8,
        chg=0,
        mult=1,
    ) -> None:
        """
        Write a Gaussian .gjf file from an RDkit molecule.

        Args:
            mol (Chem.Mol): The RDKit molecule with 3D coordinates.
            gjf_path (str): Output path for the Gaussian .gjf file.
            jobname (str): Title for the Gaussian job.
            mem (str): Memory allocation for Gaussian (e.g., "8GB").
            nproc (int): Number of processors for Gaussian
            chg (int): Total charge of the molecule.
            mult (int): Sping multiplicity of the molecule.

        Raises:
            RuntimeError: If the molecule has no 3D conformer.

        Returns:
            None: Write the .gjf file to gjf_path.
        """
        print(type(mol))
        if mol.GetNumConformers() == 0:
            raise RuntimeError(
                "No 3D conformer found; create with glyceride_to_rdkit(optimize=True)."
            )

        conf = mol.GetConformer()
        lines = []
        lines.append(f"%mem={mem}")
        lines.append(f"%nprocshared={nproc}")
        lines.append(f"%chk={os.path.splitext(gjf_path)[0]}.chk")
        lines.append("#p B3LYP/6-311G(d,p) EmpiricalDispersion=GD3BJ Opt SCF=Tight")
        lines.append("")
        lines.append(jobname)
        lines.append("")
        lines.append(f"{chg} {mult}")
        for i, atom in enumerate(mol.GetAtoms()):
            pos = conf.GetAtomPosition(i)
            lines.append(
                f"{atom.GetSymbol():<2}  {pos.x: .6f}  {pos.y: .6f}  {pos.z: .6f}"
            )
        lines.append("")
        with open(gjf_path, "w") as f:
            f.write("\n".join(lines))

    def _add_branch_methyl(self, rw: Chem.RWMol, carbon_idx: int) -> None:
        """Attach a methyl (-CH3) to the given carbon atom index."""
        c = rw.AddAtom(Chem.Atom(6))
        rw.AddBond(carbon_idx, c, Chem.BondType.SINGLE)

    def _build_acyl_chain(self, fa: FattyAcid, rw: Chem.RWMol) -> Tuple[int, List[int]]:
        """
        Build an acyl group for the fatty acid into rw:
           O=C(-) — C2 — C3 — ... — Cn
        Returns:
            (carbonyl_C_idx, [C1=carbonyl, C2, ..., Cn] indices)
        """
        # Carbonyl carbon + caarbonyl oxygen (double-bond O)
        c1 = rw.AddAtom(Chem.Atom(6))  # carbonyl carbon (C1)
        o_dbl = rw.AddAtom(Chem.Atom(8))
        rw.AddBond(c1, o_dbl, Chem.BondType.DOUBLE)

        # Build the rest of the chain (C2...Cn)
        chain_idx = [c1]
        last = c1
        for i in range(2, fa.length + 1):
            ci = rw.AddAtom(Chem.Atom(6))
            rw.AddBond(last, ci, Chem.BondType.SINGLE)
            chain_idx.append(ci)
            last = ci

        # Branches
        for pos, lbl in fa.branches:
            # Ensure pos maps to chain_idx[pos - 1]
            if lbl.lower() in ("me", "methyl"):
                if 1 <= pos <= fa.length:
                    self._add_branch_methyl(rw, chain_idx[pos - 1])

            else:
                raise NotImplementedError(
                    f"Branch label '{lbl}' not implemented yet (only 'Me')."
                )

        # Double bonds along chain
        # Map positions k (C_k -- C_{k + 1}) to indices (chain_idx[k - 1], chain_idx[k])
        for k, st in zip(fa.db_positions, fa.db_stereo):
            if k < 2:
                # Avoid making the carbonyl single bond a C=C;
                continue
            a = chain_idx[k - 1]
            b = chain_idx[k]
            bond = rw.GetBondBetweenAtoms(a, b)
            if bond is None:
                raise RuntimeError("Internal: expected a bond to set C=C.")
            bond.SetBondType(Chem.BondType.DOUBLE)

            # Assign E/Z stereo if possible
            # We need to pick one neighbor on each side that is not the other double-bond atom.
            # Left neighbors of 'a':
            a_neighbors = [
                nbr.GetIdx()
                for nbr in rw.GetAtomWithIdx(a).GetNeighbors()
                if nbr.GetIdx() != b
            ]
            # Right neighbors of 'b':
            b_neighbors = [
                nbr.GetIdx()
                for nbr in rw.GetAtomWithIdx(b).GetNeighbors()
                if nbr.GetIdx() != a
            ]
            if a_neighbors and b_neighbors:
                # Choose first neighbor on each side for stereo refs
                bond.SetStereoAtoms(a_neighbors[0], b_neighbors[0])
                norm = st.strip().lower()
                if norm in ("z", "cis"):
                    bond.SetStereo(Chem.BondStereo.STEREOZ)
                elif norm in ("e", "trans"):
                    bond.SetStereo(Chem.BondStereo.STEREOE)
                else:
                    pass

        return c1, chain_idx

    def _build_glycerol_backbone(
        self,
    ) -> Tuple[Chem.RWMol, Tuple[int, int, int], List[int]]:
        """
        Build glycerol (as triol) and return:
        rw_mol, (o_sn1, o_sn2, o_sn3), carbon_indices
        Skeleton (numbering of O for clarity):
        HO-CH2-(O2)CH-(O3)CH2-O1H
        We'll keep three hydroxyl oxygens to esterify later.
        """
        # Initialize empty RWMol
        rw = Chem.RWMol()
        # Carbons
        c1 = rw.AddAtom(Chem.Atom(6))  # CH2 (sn-1 carbon)
        c2 = rw.AddAtom(Chem.Atom(6))  # CH (sn-2 carbon)
        c3 = rw.AddAtom(Chem.Atom(6))  # Ch2 (sn-3 carbon)

        # Connect backbone
        rw.AddBond(c1, c2, Chem.BondType.SINGLE)
        rw.AddBond(c2, c3, Chem.BondType.SINGLE)

        # Add Hydroxyls
        o1 = rw.AddAtom(Chem.Atom(8))
        rw.AddBond(c3, o1, Chem.BondType.SINGLE)
        o2 = rw.AddAtom(Chem.Atom(8))
        rw.AddBond(c1, o2, Chem.BondType.SINGLE)
        o3 = rw.AddAtom(Chem.Atom(8))
        rw.AddBond(c2, o3, Chem.BondType.SINGLE)

        return rw, (o2, o3, o1), [c1, c2, c3]

    @property
    def molar_mass(self) -> float:
        """Calculate the molar mass of a glyceride in g/mol"""
        mol = self.glyceride_to_rdkit()
        mass = 0
        for atom in mol.GetAtoms():
            mass += atom.GetMass()
        return mass

    @property
    def num_fatty_acids(self) -> int:
        """Number of fatty acid chains (1, 2, or 3)."""
        return sum(1 for fa in self.sn if fa is not None)

    @property
    def chain_lengths(self) -> Tuple[int, int, int]:
        """Tuple of chain lengths (0 if empty) in sn-1, sn-2, sn-3 order."""

        def L(fa: Optional[FattyAcid]) -> int:
            return fa.length if fa is not None else 0

        return (L(self.sn[0]), L(self.sn[1]), L(self.sn[2]))

    @property
    def name(self) -> str:
        """Generate a standardized name for the glyceride."""

        def fa_name(fa: Optional[FattyAcid]) -> str:
            if fa is None:
                return "EMPTY"
            parts = [f"N{fa.length:02d}", f"D{len(fa.db_positions):02d}"]
            if fa.db_positions:
                pos_stereo = []
                for p, s in zip(fa.db_positions, fa.db_stereo):
                    pos_stereo.append(f"P{p:02d}{s.upper()}")
                parts.extend(pos_stereo)
            for bpos, blabel in fa.branches:
                parts.append(f"M{bpos:02d}")  # Only 'Me' supported now
            return "".join(parts)

        return "G_" + "_".join(fa_name(fa) for fa in self.sn)

    def __eq__(self, other):
        """Equality based on the signature tuple."""
        if not isinstance(other, Glyceride):
            return NotImplemented
        return self.signature_tuple() == other.signature_tuple()

    def __hash__(self):
        """Hash based on the signature tuple."""
        return hash((self.signature_tuple(),))  
    
    def __str__(self):
        lines = []
        for idx, fa in enumerate(self.sn):
            fa_str = "EMPTY" if fa is None else fa.name
            lines.append(f"sn-{idx + 1}: {fa_str}")
        return "\n".join(lines)
    
    def __gt__(self, other: Glyceride):
        """Greater than based on chain length"""
        if isinstance(other, Glyceride):
            return self.chain_lengths > other.chain_lengths
        else:
            assert TypeError("Must compare two Glyceride objects")

    def __lt__(self, other: Glyceride):
        """Lesser than based on chain length"""
        if isinstance(other, Glyceride):
            return self.chain_lengths < other.chain_lengths
        else:
            assert TypeError("Must compare two Glyceride objects")
    
class SymmetricGlyceride(Glyceride):
    """Glyceride where sn-1 and sn-3 are considered equivalent for equality/hash."""

    def signature_tuple(self) -> tuple:
        fa1, fa2, fa3 = self.sn
        left, right = sorted((fa1, fa3), key=_fa_key)
        parts = []
        for fa in (left, fa2, right):
            if fa is None:
                parts.append(("EMPTY",))
            else:
                fac = fa.canonical()
                parts.append(("FA", fac.length, fac.db_positions, fac.db_stereo, fac.branches))
        return tuple(parts)

    @property
    def name(self) -> str:
        """
        Generate a standardized name for the glyceride with sn-1 and sn-3
        ordered canonically so symmetric species share the same string.
        """
        fa1, fa2, fa3 = self.sn
        left, right = sorted((fa1, fa3), key=_fa_key)

        def fa_name(fa: Optional[FattyAcid]) -> str:
            if fa is None:
                return "EMPTY"
            # Keep your original fatty-acid naming style:
            parts = [f"N{fa.length:02d}", f"D{len(fa.db_positions):02d}"]
            if fa.db_positions:
                pos_stereo = []
                for p, s in zip(fa.db_positions, fa.db_stereo):
                    pos_stereo.append(f"P{p:02d}{s.upper()}")
                parts.extend(pos_stereo)
            for bpos, blabel in fa.branches:
                parts.append(f"M{bpos:02d}")  # Only 'Me' supported now
            return "".join(parts)

        return "G_" + "_".join([fa_name(left), fa_name(fa2), fa_name(right)])




In [51]:
import numpy as np
from typing import Dict, Mapping, Callable, List, Tuple
import MDAnalysis as mda
import mdapackmol
import shutil

AVOGADRO = 6.02214076e23  
CM3_PER_A3 = 1e-24 

class GlycerideMix:
    """
    Represents the composition of glycerides in a mixture.

    Attributes:
        mix (Dict[Glyceride, float]): A dictionary mapping Glyceride objects to their quantities.
        units (str): The units for the quantities (default is "mole").
    """

    def __init__(self, mix: List[Tuple[Glyceride, float]], units: str = "mole"):
        # Turn mix into a dictionary
        self.mix = {i[0]: i[1] for i in mix}
        self.units = units

    def add(self, glyceride: Glyceride, quantity: float):
        """
        Add a glyceride and its quantity to the composition.

        Args:
            glyceride (Glyceride): The glyceride to add.
            quantity (float): The quantity of the glyceride.
        """
        if glyceride in self.mix:
            self.mix[glyceride] += quantity
        else:
            self.mix[glyceride] = quantity

    def total_quantity(self) -> float:
        """
        Calculate the total quantity of all glycerides in the composition.

        Returns:
            float: The total quantity.
        """
        return sum(self.mix.values())
    
    def build_simulation_box(
        self,
        num_molecules: int,
        density_g_per_cm3: float,
        min_dist: float = 2.0,
        seed: int | None = None,
    ) -> mda.Universe:
        """
        Pack a homogeneous mixture of triglycerides into a cubic box at a target density.

        Args:
            num_molecules: total molecules in the box (all species combined).
            density_g_per_cm3: target bulk density (e.g., 0.9 for many triglyceride oils).
            glyceride_to_universe: function that converts a Glyceride -> MDAnalysis.Universe
                                   (single molecule topology+coords).
            min_dist: Packmol 'tolerance' (angstroms)  minimum allowed interatomic distance.
            seed: random seed forwarded to Packmol for reproducibility.

        Returns:
            MDAnalysis.Universe for the packed system, with unit cell set.
        """
        if not self.mix:
            raise ValueError("GlycerideMix is empty.")

        total_qty = self.total_quantity()
        if total_qty <= 0:
            raise ValueError("Total quantity must be positive.")

        #  Determine integer counts by molar fractions
        mol_fractions = {g: qty / total_qty for g, qty in self.mix.items()}
        counts = self._integer_counts_from_fractions(mol_fractions, num_molecules)

        #  Compute box length from target density:
        mass_g = 0.0
        for g, n in counts.items():
            MW = g.molar_mass
            mass_g += n * (MW / AVOGADRO)
        if mass_g <= 0:
            raise ValueError("Computed total mass is non-positive; check molar masses.")
        volume_cm3 = mass_g / float(density_g_per_cm3)
        volume_A3 = volume_cm3 / CM3_PER_A3
        L = float(volume_A3 ** (1.0 / 3.0))  

        #  Build Packmol instructions for a cubic box [0, L]^3
        instructions = [f"inside box 0. 0. 0. {L:.6f} {L:.6f} {L:.6f}"]
        # Global options for Packmol (tolerance/min distance and seed)
        packmol_kwargs = {"tolerance": float(min_dist)}
        if seed is not None:
            packmol_kwargs["seed"] = int(seed)

        # Prepare species blocks
        species_blocks = []
        for g, n in counts.items():
            if n == 0:
                continue
            # Turn an rdkit object into a mda object 
            u = mda.Universe(g.glyceride_to_rdkit(optimize=True))

            nres = len(u.residues)
            if not nres:
                raise ValueError("No residues present—cannot label. (Did the RDKit import create residues?)")

            try:
                _ = u.residues.resnames
            except AttributeError:
                u.add_TopologyAttr("resnames", [""] * nres)

            u.residues.resnames = [g.name] * nres
            species_blocks.append(
                mdapackmol.PackmolStructure(
                    u, number=int(n), instructions=instructions
                )
            )
        # Run Packmol. Returns an MDAnalysis.Universe with merged topology.
        try:
            system = mdapackmol.packmol(species_blocks, tolerance=float(min_dist))
        except ValueError as e:
            msg = str(e)
            if "STOP 173" in msg or "errorcode 173" in msg:
                # Rename Packmol's default output
                if os.path.exists("output.pdb_FORCED"):
                    shutil.move("output.pdb_FORCED", "output.pdb")
                # Load renamed file into MDAnalysis
                system = mda.Universe("output.pdb")
            else:
                raise
        else:
            # Success path (exit code 0): rename output if it exists
            if os.path.exists("output.pdb_FORCED"):
                shutil.move("output.pdb_FORCED", "output.pdb")

        system.dimensions = np.array([L, L, L, 90.0, 90.0, 90.0], dtype=float)

        # Check that the expected number of atoms is in the box
        expected_num_atoms = 0
        for g in self.mix.keys():
            num_atoms_gly = g.glyceride_to_rdkit().GetNumAtoms()
            expected_num_atoms += num_atoms_gly * counts[g]

        
        # Ensure number of atoms in the box was met
        if len(system.atoms) != expected_num_atoms:
            raise ValueError("Expected number of atoms in the box was not met")
        
        for seg in system.segments:
            tag = (seg.segid or "RES")[:3].upper()
        try:
            _ = seg.residues.resnames
        except AttributeError:
            system.add_TopologyAttr("resnames", [""] * len(system.residues))
        seg.residues.resnames = [tag] * len(seg.residues)

        return system
    
    @staticmethod
    def _integer_counts_from_fractions(fracs: Mapping[Glyceride, float], N: int) -> Dict[Glyceride, int]:
        """Round fractional allocations to integers while preserving the total N."""
        raw = {g: fracs[g] * N for g in fracs}
        floors = {g: int(np.floor(raw[g])) for g in fracs}
        deficit = N - sum(floors.values())
        # Distribute remaining molecules to the largest fractional remainders
        remainders = sorted(((raw[g] - floors[g], g) for g in fracs), reverse=True)
        for i in range(deficit):
            _, g = remainders[i]
            floors[g] += 1
        return floors

    @property
    def name(self) -> str:
        """
        Generate a standardized name for the glyceride mix using the format:

        N{CC}D{DD}[P{pp}{S}{pp}{S}...][M{pp}...][OH{pp}...]

        N{CC} - Number of Carbons like 06 or 12
        D{DD} - Number of double bonds
        [P{pp}(S){pp{S}...] - double bond position and stereo where S is trans and Z is cis (e.g. 06Z )
        M{pp}... -  Methyl branches at position pp
        OH{pp}... - Hydroxyl branches at position pp.

        Example: N18D1P09Z means 18 carbons, 1 double bond at position 9 with Z (cis) stereo and is oleic acid.

        To represent a glyceride, the following format is used: 

        G_fa1_fa2_fa3

        where fa1, fa2, and fa3 are the fatty acids attached to the glycerol backbone with the naming convention above. 

        Finally, to represent a mixture of glycerides, the following format is used:

        MIX_G_fa1_fa2_fa3-qty_G_fa1_fa2_fa3-qty...

        where qty is the quantity up to three significant figures of each glyceride in the mixture.
        """
        parts = [
            f"{glyceride.name}-{qty:.3g}"
            for glyceride, qty in self.mix.items()
        ]
        return "MIX_" + "_".join(parts)

    def __repr__(self):
        parts = [
            f"{glyceride.name}: {qty}" for glyceride, qty in self.mix.items()
        ]
        return "Glyceride_Composition({" + ", ".join(parts) + "})"



In [58]:
import py3Dmol

olein = FattyAcid(length=18, db_positions=(9,), db_stereo=('Z',))
fa2 = FattyAcid(length=6)

g1 = Glyceride(sn=(olein, olein, olein))
g2 = Glyceride(sn=(fa2, fa2, fa2))

mix = GlycerideMix(mix=[(g1, 0.5), (g2, 0.5)])


print(mix.name)
print(f"Mass of Triolein: {g1.molar_mass}")
box = mix.build_simulation_box(num_molecules=24, density_g_per_cm3=0.7)
box.atoms.write("temp.pdb")

# Read coordinates for py3Dmol
with open("temp.pdb") as f:
    pdb = f.read()

# Create the 3D viewer
view = py3Dmol.view(width=1100, height=1100)
view.addModel(pdb, "pdb")
view.setStyle({'stick': {}})
view.zoomTo()
view.show()

MIX_G_N18D01P09Z_N18D01P09Z_N18D01P09Z-0.5_G_N06D00_N06D00_N06D00-0.5
Mass of Triolein: 885.4530000000037




In [61]:
# import MDAnalysis as md

# Read coordinates for py3Dmol
with open("output.pdb") as f:
    pdb = f.read()

# Create the 3D viewer
# Color by residue name

view = py3Dmol.view(width=1100, height=1100)
view.addModel(pdb, "pdb")           # add once
view.setStyle({'stick': {}})        # base style

# Color by residue *name*
view.setStyle({'resn': 'R0'}, {'stick': {'color': 'orange'}})
view.setStyle({'resn': 'R1'}, {'stick': {'color': 'skyblue'}})

view.zoomTo()
view.show()

u = mda.Universe("output.pdb")
print(len(u.atoms))

# mol1 = g1.glyceride_to_rdkit()
# mol2 = g2.glyceride_to_rdkit()
# u = mda.Universe(mol1)
# print(u.residues)
# print(f"Printing number of atoms in glyceride1: {mol1.GetNumAtoms()}")
# print(f"Printing number of atoms in glyceride2: {mol2.GetNumAtoms()}")
# print(f"Number of expected atoms = {((mol1.GetNumAtoms() * 250) + (mol2.GetNumAtoms() * 250))}")
# print(f"Calculated dimensions: {u.dimensions}")


2784
