In [None]:
# Import required libraries
import tensorflow as tf
import numpy as np
import pandas as pd
import os
import json
import itertools
from typing import List, Tuple, Dict, Optional
import pymatgen as mg
from pymatgen.entries.computed_entries import ComputedEntry
from pymatgen.analysis import phase_diagram, interface_reactions
from scipy.stats import norm
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns

class AirStabilityAnalyzer:
    """
    Advanced analyzer for material stability in air using ML-enhanced predictions
    and comprehensive chemical analysis.
    """
    def __init__(self, data_path: str = None, use_gpu: bool = True):
        self.device_strategy = self._setup_compute_strategy(use_gpu)
        self.data = self._load_data(data_path) if data_path else None
        self.ml_model = self._build_stability_model()
        self.reference_data = None
        self.scaler = StandardScaler()
        
    def _setup_compute_strategy(self, use_gpu: bool) -> tf.distribute.Strategy:
        """Configure TensorFlow computation strategy."""
        if use_gpu and tf.config.list_physical_devices('GPU'):
            return tf.distribute.MirroredStrategy()
        return tf.distribute.get_strategy()

    def _build_stability_model(self) -> tf.keras.Model:
        """
        Build TensorFlow model for stability prediction.
        """
        with self.device_strategy.scope():
            model = tf.keras.Sequential([
                tf.keras.layers.Dense(256, activation='relu'),
                tf.keras.layers.Dropout(0.3),
                tf.keras.layers.Dense(128, activation='relu'),
                tf.keras.layers.Dropout(0.2),
                tf.keras.layers.Dense(64, activation='relu'),
                tf.keras.layers.Dense(1)
            ])
            model.compile(optimizer='adam', loss='mse', metrics=['mae'])
        return model

    def download_dataset(self, bucket_name: str = "gdm_materials_discovery"):
        """
        Download and prepare GNoME dataset with enhanced error handling.
        """
        base_url = f"https://storage.googleapis.com/{bucket_name}"
        files = {
            "gnome_data/stable_materials_summary.csv": "stable_materials_summary.csv",
            "external_data/external_materials_summary.csv": "external_materials_summary.csv"
        }
        
        for remote_path, local_name in files.items():
            url = f"{base_url}/{remote_path}"
            try:
                os.system(f"wget {url} -O {local_name}")
            except Exception as e:
                raise RuntimeError(f"Failed to download {local_name}: {str(e)}")

    def load_and_process_data(self):
        """
        Load and preprocess crystal data with enhanced validation.
        """
        self.gnome_crystals = self._load_and_validate_csv('stable_materials_summary.csv')
        self.reference_crystals = self._load_and_validate_csv('external_materials_summary.csv')
        
        self.gnome_crystals = self._annotate_chemical_system(self.gnome_crystals)
        self.reference_crystals = self._annotate_chemical_system(self.reference_crystals)
        
        self.all_crystals = pd.concat([self.gnome_crystals, self.reference_crystals], 
                                    ignore_index=True)
        self._prepare_minimal_entries()

    def _load_and_validate_csv(self, filename: str) -> pd.DataFrame:
        """
        Load CSV with validation and error handling.
        """
        try:
            df = pd.read_csv(filename, index_col=0)
            required_columns = ['Composition', 'Elements', 'NSites', 'Corrected Energy']
            missing_cols = [col for col in required_columns if col not in df.columns]
            if missing_cols:
                raise ValueError(f"Missing required columns: {missing_cols}")
            return df
        except Exception as e:
            raise RuntimeError(f"Error loading {filename}: {str(e)}")

    def _annotate_chemical_system(self, crystals: pd.DataFrame) -> pd.DataFrame:
        """
        Enhanced chemical system annotation with validation.
        """
        def parse_elements(elements_str: str) -> tuple:
            try:
                elements = json.loads(elements_str.replace("'", '"'))
                return tuple(sorted(elements))
            except json.JSONDecodeError:
                raise ValueError(f"Invalid elements format: {elements_str}")

        crystals = crystals.copy()
        crystals['Chemical System'] = crystals['Elements'].apply(parse_elements)
        return crystals

    def _prepare_minimal_entries(self):
        """
        Prepare minimal entries for analysis with validation.
        """
        required_columns = [
            'Composition', 'NSites', 'Corrected Energy',
            'Formation Energy Per Atom', 'Chemical System'
        ]
        self.minimal_entries = self.all_crystals[required_columns]
        self.grouped_entries = self.minimal_entries.groupby('Chemical System')

    def calculate_air_stability(self, 
                              structure: mg.core.structure.Structure,
                              temperature: float = 300,
                              pressure: float = 21200,
                              include_water: bool = True,
                              include_co2: bool = True) -> Dict[str, float]:
        """
        Comprehensive air stability analysis including multiple environmental factors.
        """
        composition = structure.composition
        energy = self._calculate_formation_energy(structure)
        entry = ComputedEntry(composition, energy)
        
        results = {}
        
        # Oxygen stability
        o2_stability = self._calculate_oxygen_stability(entry, temperature, pressure)
        results['oxygen_stability'] = o2_stability
        
        if include_water:
            h2o_stability = self._calculate_water_stability(entry)
            results['water_stability'] = h2o_stability
            
        if include_co2:
            co2_stability = self._calculate_co2_stability(entry)
            results['co2_stability'] = co2_stability
            
        results['overall_stability_score'] = self._calculate_overall_stability(results)
        
        return results

    def _calculate_oxygen_stability(self, 
                                  entry: ComputedEntry,
                                  temperature: float,
                                  pressure: float) -> float:
        """
        Calculate oxygen stability using grand potential phase diagram.
        """
        element = 'O'
        chempot_correction = interface_reactions.InterfacialReactivity.get_chempot_correction(
            element, temperature, pressure)
        u_o = -4.95 + chempot_correction
        
        oxygen_chemsys = entry.composition.chemical_system.split("-") + ['O']
        chempots = {mg.core.Element('O'): u_o}
        
        entries = self._collect_phase_diagram_entries(oxygen_chemsys)
        grand_diagram = phase_diagram.GrandPotentialPhaseDiagram(entries, chempots=chempots)
        
        decomp, e_above_hull = grand_diagram.get_decomp_and_e_above_hull(
            phase_diagram.GrandPotPDEntry(entry, chempots=chempots))
        
        return e_above_hull

    def _calculate_water_stability(self, entry: ComputedEntry) -> float:
        """
        Calculate stability against water using interfacial reactivity.
        """
        composition_elements = entry.composition.chemical_system.split("-")
        extended_chemsys = composition_elements + ['H', 'O']
        
        entries = self._collect_phase_diagram_entries(extended_chemsys)
        pd_diagram = phase_diagram.PhaseDiagram(entries)
        
        water = mg.core.Composition('H2O')
        reactivity = interface_reactions.InterfacialReactivity(
            entry.composition, water, pd_diagram, use_hull_energy=True)
        
        return reactivity.minimum[1]

    def _calculate_co2_stability(self, entry: ComputedEntry) -> float:
        """
        Calculate stability against CO2 using interfacial reactivity.
        """
        composition_elements = entry.composition.chemical_system.split("-")
        extended_chemsys = composition_elements + ['C', 'O']
        
        entries = self._collect_phase_diagram_entries(extended_chemsys)
        pd_diagram = phase_diagram.PhaseDiagram(entries)
        
        co2 = mg.core.Composition('CO2')
        reactivity = interface_reactions.InterfacialReactivity(
            entry.composition, co2, pd_diagram, use_hull_energy=True)
        
        return reactivity.minimum[1]

    def _calculate_overall_stability(self, 
                                   stability_results: Dict[str, float]) -> float:
        """
        Calculate weighted overall stability score.
        """
        weights = {
            'oxygen_stability': 0.4,
            'water_stability': 0.3,
            'co2_stability': 0.3
        }
        
        score = 0.0
        for key, weight in weights.items():
            if key in stability_results:
                score += weight * (1.0 / (1.0 + np.exp(stability_results[key])))
        
        return score

    def visualize_stability_analysis(self, 
                                   stability_results: Dict[str, float],
                                   save_path: Optional[str] = None):
        """
        Create comprehensive visualization of stability analysis.
        """
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        # Stability measures comparison
        measures = list(stability_results.keys())[:-1]  # Exclude overall score
        values = [stability_results[m] for m in measures]
        
        sns.barplot(x=measures, y=values, ax=ax1)
        ax1.set_title('Stability Measures Comparison')
        ax1.set_ylabel('Energy (eV/atom)')
        ax1.tick_params(axis='x', rotation=45)
        
        # Overall stability gauge
        overall_score = stability_results['overall_stability_score']
        gauge_colors = ['red', 'yellow', 'green']
        norm = plt.Normalize(0, 1)
        
        sm = plt.cm.ScalarMappable(cmap=plt.cm.RdYlGn, norm=norm)
        sm.set_array([])
        
        ax2.set_title('Overall Stability Score')
        ax2.add_patch(plt.Circle((0.5, 0.5), 0.4, color=sm.to_rgba(overall_score)[:-1]))
        ax2.text(0.5, 0.5, f'{overall_score:.2f}', 
                horizontalalignment='center', verticalalignment='center')
        ax2.set_xlim(0, 1)
        ax2.set_ylim(0, 1)
        ax2.axis('off')
        
        plt.colorbar(sm, ax=ax2)
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path)
        plt.show()

    def _collect_phase_diagram_entries(self, 
                                     chemsys: List[str]) -> List[ComputedEntry]:
        """
        Collect phase diagram entries with enhanced validation.
        """
        entries = []
        
        for length in range(len(chemsys) + 1):
            for subsystem in itertools.combinations(chemsys, length):
                subsystem_key = tuple(sorted(subsystem))
                if subsystem_key in self.grouped_entries.groups:
                    subsystem_data = self.minimal_entries.iloc[
                        self.grouped_entries.groups[subsystem_key]]
                    
                    for _, row in subsystem_data.iterrows():
                        entry = ComputedEntry(
                            composition=row['Composition'],
                            energy=row['Corrected Energy']
                        )
                        entries.append(entry)
        
        # Add elemental reference states
        for element in chemsys:
            entries.append(ComputedEntry(element, 0.0))
            
        return entries