In [None]:
# Import required libraries
import tensorflow as tf
import numpy as np
import pandas as pd
import json
import itertools
import os
from typing import Tuple, List, Dict, Optional
from dataclasses import dataclass
from scipy.spatial import ConvexHull

import pymatgen as mg
from pymatgen.entries.computed_entries import ComputedEntry
from pymatgen.analysis import phase_diagram
from pymatgen.core import Composition, Element
from pymatgen.analysis.phase_diagram import PhaseDiagram, PDEntry

# Configure TensorFlow for mixed precision
tf.keras.mixed_precision.set_global_policy('mixed_float16')

@dataclass
class DecompositionResult:
    """Data class to store decomposition calculation results"""
    energy: float
    decomposition: Dict
    hull_distance: float
    is_stable: bool
    competing_phases: List[str]

class GNoMEDataProcessor:
    """Enhanced processor for GNoME dataset with TensorFlow acceleration"""
    
    def __init__(self, cache_dir: str = './cache'):
        self.cache_dir = cache_dir
        self._setup_cache_directory()
        self.tf_energy_cache = {}
        
    def _setup_cache_directory(self):
        """Create cache directory if it doesn't exist"""
        os.makedirs(self.cache_dir, exist_ok=True)
    
    @staticmethod
    def download_dataset(bucket_name: str = "gdm_materials_discovery") -> Tuple[pd.DataFrame, pd.DataFrame]:
        """Download and prepare GNoME dataset with enhanced error handling"""
        PUBLIC_LINK = "https://storage.googleapis.com/"
        FOLDER_NAME = "gnome_data"
        FILES = ["stable_materials_summary.csv"]
        EXTERNAL_FOLDER_NAME = "external_data"
        EXTERNAL_FILES = ["external_materials_summary.csv"]
        
        def download_with_retry(link: str, output_dir: str, max_retries: int = 3):
            for attempt in range(max_retries):
                try:
                    if os.system(f"wget {link} -P {output_dir}") == 0:
                        return True
                except Exception as e:
                    if attempt == max_retries - 1:
                        raise RuntimeError(f"Failed to download {link} after {max_retries} attempts: {e}")
                    continue
            return False

        parent_directory = os.path.join(PUBLIC_LINK, bucket_name)
        
        # Download main dataset
        for filename in FILES:
            public_link = os.path.join(parent_directory, FOLDER_NAME, filename)
            download_with_retry(public_link, '.')
            
        # Download external dataset
        for filename in EXTERNAL_FILES:
            public_link = os.path.join(parent_directory, EXTERNAL_FOLDER_NAME, filename)
            download_with_retry(public_link, '.')
            
        # Load and validate datasets
        gnome_crystals = pd.read_csv('stable_materials_summary.csv', index_col=0)
        reference_crystals = pd.read_csv('external_materials_summary.csv')
        
        return gnome_crystals, reference_crystals

    @staticmethod
    def annotate_chemical_system(crystals: pd.DataFrame) -> pd.DataFrame:
        """Annotate DataFrame with chemical system information using TensorFlow acceleration"""
        def process_elements(elements_str: str) -> tuple:
            elements = json.loads(elements_str.replace("'", '"'))
            return tuple(sorted(elements))
        
        # Vectorize the processing using TensorFlow
        elements_series = crystals['Elements']
        chemical_systems = tf.numpy_function(
            lambda x: np.array([process_elements(e) for e in x]),
            [elements_series.values],
            tf.string
        )
        
        crystals['Chemical System'] = chemical_systems.numpy()
        return crystals

class DecompositionCalculator:
    """Enhanced calculator for decomposition energies with TensorFlow acceleration"""
    
    def __init__(self, gnome_crystals: pd.DataFrame, reference_crystals: pd.DataFrame):
        self.gnome_crystals = gnome_crystals
        self.reference_crystals = reference_crystals
        self.all_crystals = None
        self.minimal_entries = None
        self.grouped_entries = None
        self._prepare_data()
        
    def _prepare_data(self):
        """Prepare and preprocess the crystal data"""
        processor = GNoMEDataProcessor()
        self.gnome_crystals = processor.annotate_chemical_system(self.gnome_crystals)
        self.reference_crystals = processor.annotate_chemical_system(self.reference_crystals)
        self.all_crystals = pd.concat([self.gnome_crystals, self.reference_crystals], ignore_index=True)
        
        required_columns = [
            'Composition', 'NSites', 'Corrected Energy',
            'Formation Energy Per Atom', 'Chemical System'
        ]
        self.minimal_entries = self.all_crystals[required_columns]
        self.grouped_entries = self.minimal_entries.groupby('Chemical System')
        
    def gather_convex_hull(self, chemsys: List[str]) -> List[ComputedEntry]:
        """Gather convex hull entries with TensorFlow acceleration"""
        phase_diagram_entries = []
        
        # Use TensorFlow to parallelize subsystem generation
        def generate_subsystems(chemsys, length):
            combinations = list(itertools.combinations(chemsys, length))
            return tf.constant(combinations)
        
        for length in range(len(chemsys) + 1):
            subsystems = generate_subsystems(chemsys, length)
            
            for subsystem in subsystems.numpy():
                subsystem_key = tuple(sorted(subsystem))
                subsystem_entries = self.grouped_entries.groups.get(subsystem_key, [])
                
                if len(subsystem_entries):
                    phase_diagram_entries.append(self.minimal_entries.iloc[subsystem_entries])
        
        if phase_diagram_entries:
            phase_diagram_entries = pd.concat(phase_diagram_entries)
        
        # Convert to ComputedEntries with TensorFlow acceleration
        @tf.function
        def create_computed_entries(compositions, energies):
            return tf.map_fn(
                lambda x: ComputedEntry(x[0], x[1]),
                (compositions, energies),
                dtype=tf.string
            )
        
        mg_entries = []
        compositions = phase_diagram_entries['Composition'].values
        energies = phase_diagram_entries['Corrected Energy'].values
        entries = create_computed_entries(compositions, energies)
        mg_entries.extend(entries.numpy())
        
        # Add elemental entries
        for element in chemsys:
            elemental_entry = ComputedEntry(element, 0.0)
            mg_entries.append(elemental_entry)
        
        return mg_entries

    def calculate_decomposition(self, composition: str, energy: float) -> DecompositionResult:
        """Calculate decomposition energy and related metrics"""
        if not composition:
            sample = self.gnome_crystals.sample()
            sample_entry = ComputedEntry(
                composition=sample['Composition'].item(),
                energy=sample['Corrected Energy'].item(),
            )
            chemsys = sample['Chemical System'].item()
        else:
            comp = mg.Composition(composition)
            sample_entry = ComputedEntry(
                composition=comp,
                energy=energy,
            )
            chemsys = [str(el) for el in comp.elements]
        
        # Calculate phase diagram and decomposition
        mg_entries = self.gather_convex_hull(chemsys)
        diagram = PhaseDiagram(mg_entries)
        decomposition, decomp_energy = diagram.get_decomp_and_e_above_hull(
            sample_entry, allow_negative=True
        )
        
        # Calculate additional metrics
        hull_distance = diagram.get_e_above_hull(sample_entry)
        is_stable = hull_distance < 1e-3
        competing_phases = [
            entry.composition.reduced_formula
            for entry in diagram.get_all_equilibrium_entries()
        ]
        
        return DecompositionResult(
            energy=decomp_energy,
            decomposition=decomposition,
            hull_distance=hull_distance,
            is_stable=is_stable,
            competing_phases=competing_phases
        )

# Utility functions for visualization and analysis
def visualize_hull_analysis(result: DecompositionResult) -> None:
    """Create detailed visualization of decomposition analysis"""
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    # Set up plotting style
    plt.style.use('seaborn')
    sns.set_palette("husl")
    
    # Create decomposition visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Plot hull distance
    ax1.bar(['Hull Distance'], [result.hull_distance])
    ax1.axhline(y=1e-3, color='r', linestyle='--', label='Stability Threshold')
    ax1.set_ylabel('Energy (eV/atom)')
    ax1.set_title('Distance to Hull')
    ax1.legend()
    
    # Plot decomposition products
    products = list(result.decomposition.items())
    labels = [str(p[0].composition.reduced_formula) for p in products]
    fractions = [p[1] for p in products]
    
    ax2.pie(fractions, labels=labels, autopct='%1.1f%%')
    ax2.set_title('Decomposition Products')
    
    plt.tight_layout()
    plt.show()

def export_results(result: DecompositionResult, filename: str) -> None:
    """Export calculation results to JSON"""
    output = {
        "decomposition_energy": float(result.energy),
        "hull_distance": float(result.hull_distance),
        "is_stable": bool(result.is_stable),
        "competing_phases": list(result.competing_phases),
        "decomposition_products": {
            str(k.composition.reduced_formula): float(v)
            for k, v in result.decomposition.items()
        }
    }
    
    with open(filename, 'w') as f:
        json.dump(output, f, indent=2)

# Example usage
if __name__ == "__main__":
    # Initialize processor and calculator
    gnome_crystals, reference_crystals = GNoMEDataProcessor.download_dataset()
    calculator = DecompositionCalculator(gnome_crystals, reference_crystals)
    
    # Calculate decomposition for a sample structure
    test_composition = "Fe2O3"
    test_energy = -15.2
    
    result = calculator.calculate_decomposition(test_composition, test_energy)
    
    # Print results
    print(f"Decomposition Energy: {result.energy:.6f} eV/atom")
    print(f"Hull Distance: {result.hull_distance:.6f} eV/atom")
    print(f"Is Stable: {result.is_stable}")
    print("\nDecomposition Products:")
    for product, fraction in result.decomposition.items():
        print(f"{product.composition.reduced_formula}: {fraction:.3f}")
    
    # Visualize results
    visualize_hull_analysis(result)
    
    # Export results
    export_results(result, "decomposition_results.json")