In [None]:
# Import required libraries
import tensorflow as tf
import numpy as np
import pandas as pd
import json
import os
import itertools
from typing import List, Tuple, Dict, Optional, Union
from dataclasses import dataclass
from collections import defaultdict

import pymatgen as mg
from pymatgen.entries.computed_entries import ComputedEntry
from pymatgen.analysis import phase_diagram, interface_reactions
from pymatgen.core import Element, Composition

# Configuration
class Config:
    PUBLIC_LINK = "https://storage.googleapis.com/"
    BUCKET_NAME = "gdm_materials_discovery"
    FOLDER_NAME = "gnome_data"
    FILES = ["stable_materials_summary.csv"]
    EXTERNAL_FOLDER_NAME = "external_data"
    EXTERNAL_FILES = ["external_materials_summary.csv"]
    
    # Physical constants
    ROOM_TEMP = 300  # K
    ATMOSPHERIC_PRESSURE = 21200  # Pa
    OXYGEN_REFERENCE_ENERGY = -4.95  # eV
    
    # ML Parameters
    BATCH_SIZE = 32
    LEARNING_RATE = 0.001
    EMBEDDING_DIM = 64

@dataclass
class MaterialEntry:
    """Enhanced material entry with additional analysis capabilities"""
    composition: str
    material_id: str
    elements: List[str]
    nsites: int
    energy: float
    formation_energy: float
    chemical_system: Tuple[str, ...]
    
    @property
    def element_fraction(self) -> Dict[str, float]:
        comp = Composition(self.composition)
        return {str(el): frac for el, frac in comp.get_el_amt_dict().items()}

class MaterialDataset:
    """Enhanced dataset handler with ML capabilities"""
    def __init__(self):
        self.config = Config()
        self.element_encoder = None
        self.composition_model = None
        
    def download_data(self) -> None:
        """Download required datasets"""
        def download_from_link(link: str, output_dir: str):
            os.system(f"wget {link} -P {output_dir}")

        parent_directory = os.path.join(self.config.PUBLIC_LINK, self.config.BUCKET_NAME)
        
        for filename in self.config.FILES:
            public_link = os.path.join(parent_directory, self.config.FOLDER_NAME, filename)
            download_from_link(public_link, '.')
            
        for filename in self.config.EXTERNAL_FILES:
            public_link = os.path.join(parent_directory, self.config.EXTERNAL_FOLDER_NAME, filename)
            download_from_link(public_link, '.')
    
    def load_and_preprocess(self) -> Tuple[pd.DataFrame, pd.DataFrame]:
        """Load and preprocess the datasets with enhanced features"""
        gnome_crystals = pd.read_csv('stable_materials_summary.csv', index_col=0)
        reference_crystals = pd.read_csv('external_materials_summary.csv')
        
        # Enhanced chemical system annotation
        gnome_crystals = self._annotate_chemical_system(gnome_crystals)
        reference_crystals = self._annotate_chemical_system(reference_crystals)
        
        # Add additional features
        gnome_crystals = self._add_advanced_features(gnome_crystals)
        reference_crystals = self._add_advanced_features(reference_crystals)
        
        return gnome_crystals, reference_crystals
    
    def _annotate_chemical_system(self, df: pd.DataFrame) -> pd.DataFrame:
        """Enhanced chemical system annotation with validation"""
        chemical_systems = []
        for elements in df['Elements']:
            try:
                chemsys = json.loads(elements.replace("'", '"'))
                chemical_systems.append(tuple(sorted(chemsys)))
            except json.JSONDecodeError:
                print(f"Error processing elements: {elements}")
                chemical_systems.append(None)
        
        df['Chemical System'] = chemical_systems
        return df
    
    def _add_advanced_features(self, df: pd.DataFrame) -> pd.DataFrame:
        """Add advanced materials science features"""
        # Calculate electronegativity differences
        df['Mean Electronegativity'] = df['Elements'].apply(
            lambda x: np.mean([Element(el.strip("'[]")).X for el in x.split(', ')])
        )
        
        # Calculate atomic radius ratios
        df['Max Radius Ratio'] = df['Elements'].apply(
            lambda x: self._calculate_radius_ratio(x)
        )
        
        return df
    
    def _calculate_radius_ratio(self, elements_str: str) -> float:
        """Calculate atomic radius ratio for stability prediction"""
        elements = [el.strip("'[]") for el in elements_str.split(', ')]
        radii = [Element(el).atomic_radius for el in elements if Element(el).atomic_radius]
        if len(radii) < 2:
            return 1.0
        return max(radii) / min(radii)

class StabilityAnalyzer:
    """Enhanced stability analysis with ML predictions"""
    def __init__(self, temperature: float = Config.ROOM_TEMP, pressure: float = Config.ATMOSPHERIC_PRESSURE):
        self.temperature = temperature
        self.pressure = pressure
        self.stability_model = self._build_stability_model()
        
    def _build_stability_model(self) -> tf.keras.Model:
        """Build ML model for stability prediction"""
        inputs = tf.keras.Input(shape=(None,))
        x = tf.keras.layers.Embedding(100, Config.EMBEDDING_DIM)(inputs)
        x = tf.keras.layers.GlobalAveragePooling1D()(x)
        x = tf.keras.layers.Dense(32, activation='relu')(x)
        outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)
        
        model = tf.keras.Model(inputs=inputs, outputs=outputs)
        model.compile(
            optimizer=tf.keras.optimizers.Adam(Config.LEARNING_RATE),
            loss='binary_crossentropy',
            metrics=['accuracy']
        )
        return model

    def analyze_oxygen_stability(self, material: MaterialEntry, phase_diagram_entries: List[ComputedEntry]) -> float:
        """Analyze stability with respect to oxygen"""
        u_o = self._calculate_oxygen_chemical_potential()
        
        chempots = {Element('O'): u_o}
        oxygen_chemsys = material.chemical_system + ('O',)
        
        grand_diagram = phase_diagram.GrandPotentialPhaseDiagram(
            phase_diagram_entries,
            chempots=chempots
        )
        
        material_entry = ComputedEntry(
            material.composition,
            material.energy
        )
        
        grand_entry = phase_diagram.GrandPotPDEntry(
            material_entry,
            chempots=chempots
        )
        
        decomp, e_above_hull = grand_diagram.get_decomp_and_e_above_hull(grand_entry)
        return e_above_hull
    
    def analyze_environmental_stability(
        self, 
        material: MaterialEntry,
        phase_diagram_entries: List[ComputedEntry]
    ) -> Dict[str, float]:
        """Analyze stability against environmental factors (CO2, H2O)"""
        pd_entries = self._prepare_phase_diagram_entries(
            material.chemical_system + ('H', 'C', 'O'),
            phase_diagram_entries
        )
        
        material_pd = phase_diagram.PhaseDiagram(pd_entries)
        material_entry = ComputedEntry(material.composition, material.energy)
        
        # Analyze CO2 stability
        co2_reactivity = interface_reactions.InterfacialReactivity(
            material_entry.composition,
            Composition('CO2'),
            material_pd,
            use_hull_energy=True
        )
        
        # Analyze H2O stability
        h2o_reactivity = interface_reactions.InterfacialReactivity(
            material_entry.composition,
            Composition('H2O'),
            material_pd,
            use_hull_energy=True
        )
        
        return {
            'co2_stability': co2_reactivity.minimum[1],
            'h2o_stability': h2o_reactivity.minimum[1],
            'predicted_lifetime': self._predict_environmental_lifetime(
                co2_reactivity.minimum[1],
                h2o_reactivity.minimum[1]
            )
        }
    
    def _calculate_oxygen_chemical_potential(self) -> float:
        """Calculate oxygen chemical potential with temperature and pressure corrections"""
        correction = interface_reactions.InterfacialReactivity.get_chempot_correction(
            'O',
            self.temperature,
            self.pressure
        )
        return Config.OXYGEN_REFERENCE_ENERGY + correction
    
    def _predict_environmental_lifetime(self, co2_stability: float, h2o_stability: float) -> float:
        """Predict material lifetime based on stability analysis"""
        # Implementation of ML-based lifetime prediction
        features = tf.constant([[co2_stability, h2o_stability]])
        return float(self.stability_model(features))

def main():
    # Initialize dataset handler
    dataset = MaterialDataset()
    dataset.download_data()
    gnome_crystals, reference_crystals = dataset.load_and_preprocess()
    
    # Combine datasets
    all_crystals = pd.concat([gnome_crystals, reference_crystals], ignore_index=True)
    required_columns = ['Composition', 'NSites', 'Corrected Energy', 'Formation Energy Per Atom', 'Chemical System']
    minimal_entries = all_crystals[required_columns]
    grouped_entries = minimal_entries.groupby('Chemical System')
    
    # Initialize stability analyzer
    analyzer = StabilityAnalyzer()
    
    # Select sample material (binary compound)
    binaries = gnome_crystals[gnome_crystals['Chemical System'].map(len) == 2]
    sample = binaries.sample(n=1).iloc[0]
    
    # Create material entry
    material = MaterialEntry(
        composition=sample['Composition'],
        material_id=sample['MaterialId'],
        elements=json.loads(sample['Elements'].replace("'", '"')),
        nsites=sample['NSites'],
        energy=sample['Corrected Energy'],
        formation_energy=sample['Formation Energy Per Atom'],
        chemical_system=sample['Chemical System']
    )
    
    # Analyze stability
    phase_diagram_entries = collect_phase_diagram_entries(
        material.chemical_system,
        grouped_entries,
        minimal_entries
    )
    
    oxygen_stability = analyzer.analyze_oxygen_stability(material, phase_diagram_entries)
    environmental_stability = analyzer.analyze_environmental_stability(material, phase_diagram_entries)
    
    # Print results
    print(f"Material: {material.composition}")
    print(f"Oxygen stability: {oxygen_stability:.4f} eV/atom")
    print(f"CO2 stability: {environmental_stability['co2_stability']:.4f} eV/atom")
    print(f"H2O stability: {environmental_stability['h2o_stability']:.4f} eV/atom")
    print(f"Predicted environmental lifetime: {environmental_stability['predicted_lifetime']:.2f} years")

if __name__ == "__main__":
    main()