In [None]:
"""
decomposition_energy.py

A comprehensive module for computing decomposition energies of crystal structures
against the GNoME database and known materials. Provides utilities for assessing
thermodynamic stability and analyzing decomposition pathways.



Author: Michael R. Lafave 
Date: November 2024
"""

import os
from typing import Dict, List, Optional, Tuple, Union
import pandas as pd
import numpy as np
import itertools

import pymatgen as mg
from pymatgen.entries.computed_entries import ComputedEntry
from pymatgen.analysis.phase_diagram import PhaseDiagram

class DecompositionEnergyAnalyzer:
    """Analyzer for computing decomposition energies against GNoME database."""
    
    def __init__(
        self,
        gnome_data_path: str = "stable_materials_summary.csv",
        reference_data_path: str = "external_materials_summary.csv"
    ):
        """Initialize the analyzer with GNoME and reference data.
        
        Args:
            gnome_data_path: Path to GNoME structures CSV
            reference_data_path: Path to reference structures CSV
        """
        self.gnome_data = self._load_data(gnome_data_path)
        self.reference_data = self._load_data(reference_data_path)
        self.all_data = pd.concat(
            [self.gnome_data, self.reference_data],
            ignore_index=True
        )
        self._process_data()

    def _load_data(self, data_path: str) -> pd.DataFrame:
        """Load crystal structure data from CSV.
        
        Args:
            data_path: Path to CSV file
            
        Returns:
            Processed DataFrame
        """
        if not os.path.exists(data_path):
            self._download_data(data_path)
            
        df = pd.read_csv(data_path, index_col=0)
        return self._annotate_chemical_system(df)

    def _download_data(self, filename: str) -> None:
        """Download required data files from GCP bucket.
        
        Args:
            filename: Name of file to download
        """
        base_url = "https://storage.googleapis.com/gdm_materials_discovery"
        if "external" in filename:
            url = f"{base_url}/external_data/{filename}"
        else:
            url = f"{base_url}/gnome_data/{filename}"
        os.system(f"wget {url}")

    def _annotate_chemical_system(self, df: pd.DataFrame) -> pd.DataFrame:
        """Add chemical system information to DataFrame.
        
        Args:
            df: Input DataFrame
            
        Returns:
            DataFrame with added chemical system column
        """
        def process_elements(elements_str: str) -> tuple:
            elements = eval(elements_str)
            return tuple(sorted(elements))
            
        df["Chemical System"] = df["Elements"].apply(process_elements)
        return df

    def _process_data(self) -> None:
        """Process loaded data for decomposition calculations."""
        required_cols = [
            "Composition",
            "NSites", 
            "Corrected Energy",
            "Formation Energy Per Atom",
            "Chemical System"
        ]
        self.minimal_entries = self.all_data[required_cols]
        self.grouped_entries = self.minimal_entries.groupby("Chemical System")

    def compute_decomposition_energy(
        self,
        composition: Union[str, mg.Composition],
        energy: float
    ) -> Tuple[float, Dict]:
        """Compute decomposition energy for a structure.
        
        Args:
            composition: Chemical composition as string or Composition object
            energy: Total energy in eV
            
        Returns:
            Tuple containing:
            - Decomposition energy in eV/atom
            - Dictionary of decomposition products
        """
        if isinstance(composition, str):
            composition = mg.Composition(composition)
            
        entry = ComputedEntry(composition, energy)
        chemsys = list(composition.chemical_system.split("-"))
        
        # Collect entries for phase diagram
        entries = self._collect_phase_diagram_entries(chemsys)
        pd_obj = PhaseDiagram(entries)
        
        # Compute decomposition
        decomp, e_above_hull = pd_obj.get_decomp_and_e_above_hull(
            entry,
            allow_negative=True
        )
        
        decomp_products = {
            str(e.composition.reduced_formula): amt 
            for e, amt in decomp.items()
        }
        
        return e_above_hull, decomp_products

    def _collect_phase_diagram_entries(
        self,
        chemsys: List[str]
    ) -> List[ComputedEntry]:
        """Collect all relevant entries for phase diagram construction.
        
        Args:
            chemsys: List of elements in chemical system
            
        Returns:
            List of ComputedEntry objects
        """
        entries = []
        
        # Get entries for all possible compositions
        for length in range(len(chemsys) + 1):
            for subsystem in itertools.combinations(chemsys, length):
                key = tuple(sorted(subsystem))
                subsystem_entries = self.grouped_entries.groups.get(key, [])
                
                if len(subsystem_entries):
                    df_entries = self.minimal_entries.iloc[subsystem_entries]
                    entries.extend([
                        ComputedEntry(
                            composition=row['Composition'],
                            energy=row['Corrected Energy']
                        )
                        for _, row in df_entries.iterrows()
                    ])
        
        # Add elemental reference states
        for element in chemsys:
            entries.append(ComputedEntry(element, 0.0))
            
        return entries

    def analyze_competing_phases(
        self,
        composition: Union[str, mg.Composition],
        energy: float,
        energy_window: float = 0.1
    ) -> pd.DataFrame:
        """Analyze competing phases within an energy window.
        
        Args:
            composition: Chemical composition 
            energy: Total energy in eV
            energy_window: Energy window above hull in eV/atom
            
        Returns:
            DataFrame of competing phases sorted by stability
        """
        if isinstance(composition, str):
            composition = mg.Composition(composition)
            
        chemsys = list(composition.chemical_system.split("-"))
        entries = self._collect_phase_diagram_entries(chemsys)
        pd_obj = PhaseDiagram(entries)
        
        competing = []
        for e in entries:
            e_hull = pd_obj.get_e_above_hull(e)
            if e_hull <= energy_window:
                competing.append({
                    'formula': e.composition.reduced_formula,
                    'energy_above_hull': e_hull,
                    'formation_energy': pd_obj.get_form_energy_per_atom(e)
                })
                
        return pd.DataFrame(competing).sort_values('energy_above_hull')

    def get_stable_compositions(
        self,
        threshold: float = 0.001
    ) -> pd.DataFrame:
        """Get all compositions stable within threshold.
        
        Args:
            threshold: Energy threshold in eV/atom
            
        Returns:
            DataFrame of stable compositions
        """
        results = []
        
        for _, row in self.all_data.iterrows():
            decomp_energy = row['Decomposition Energy Per Atom']
            if abs(decomp_energy) < threshold:
                results.append({
                    'composition': row['Composition'],
                    'space_group': row['Space Group'],
                    'decomposition_energy': decomp_energy,
                    'formation_energy': row['Formation Energy Per Atom']
                })
                
        return pd.DataFrame(results)

if __name__ == "__main__":
    # Example usage
    analyzer = DecompositionEnergyAnalyzer()
    
    # Analyze a test composition
    test_comp = "Fe2O3"
    test_energy = -100.0  # Example energy
    
    # Compute decomposition energy
    e_decomp, products = analyzer.compute_decomposition_energy(
        test_comp,
        test_energy
    )
    
    print(f"Decomposition Analysis for {test_comp}:")
    print(f"Decomposition Energy: {e_decomp:.3f} eV/atom")
    print("Decomposition Products:")
    for product, amount in products.items():
        print(f"  {product}: {amount:.3f}")
        
    # Analyze competing phases
    competing = analyzer.analyze_competing_phases(
        test_comp,
        test_energy,
        energy_window=0.1
    )
    print("\nCompeting Phases:")
    print(competing)
    
    # Get stable compositions
    stable = analyzer.get_stable_compositions()
    print(f"\nFound {len(stable)} stable compositions")