In [None]:
# Import required libraries
import tensorflow as tf
import numpy as np
import pandas as pd
import json
import os
import re
import itertools
from typing import List, Tuple, Dict, Optional, Union
from dataclasses import dataclass
import pymatgen as mg
from pymatgen.entries.computed_entries import ComputedEntry
from pymatgen.analysis import phase_diagram
from scipy.spatial import ConvexHull
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
import logging
import warnings
from tqdm import tqdm

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class ChemicalSystemConfig:
    """Configuration for chemical system analysis"""
    public_link: str = "https://storage.googleapis.com/"
    bucket_name: str = "gdm_materials_discovery"
    folder_name: str = "gnome_data"
    external_folder_name: str = "external_data"
    files: tuple = ("stable_materials_summary.csv",)
    external_files: tuple = ("mp_snapshot_summary.csv", "external_materials_summary.csv")

class DataLoader:
    """Handles data loading and preprocessing for chemical systems"""
    def __init__(self, config: ChemicalSystemConfig):
        self.config = config
        
    def download_files(self) -> None:
        """Download required files from storage"""
        parent_directory = os.path.join(self.config.public_link, self.config.bucket_name)
        
        for filename in self.config.files:
            public_link = os.path.join(parent_directory, self.config.folder_name, filename)
            self._download_from_link(public_link)
            
        for filename in self.config.external_files:
            public_link = os.path.join(parent_directory, self.config.external_folder_name, filename)
            self._download_from_link(public_link)
    
    def _download_from_link(self, link: str) -> None:
        """Download a single file from a public link"""
        try:
            os.system(f"wget {link} -P .")
            logger.info(f"Successfully downloaded: {link}")
        except Exception as e:
            logger.error(f"Failed to download {link}: {str(e)}")

class ChemicalSystemAnalyzer:
    """Advanced analysis of chemical systems using TensorFlow and Pymatgen"""
    
    def __init__(self):
        self.scaler = StandardScaler()
        self._initialize_tf_model()
    
    def _initialize_tf_model(self):
        """Initialize TensorFlow model for energy predictions"""
        self.model = tf.keras.Sequential([
            tf.keras.layers.Dense(128, activation='relu'),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(32, activation='relu'),
            tf.keras.layers.Dense(1)
        ])
        
        self.model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
            loss='mse',
            metrics=['mae']
        )
    
    def preprocess_data(self, crystals: pd.DataFrame) -> pd.DataFrame:
        """Enhanced preprocessing of crystal data"""
        processed_df = crystals.copy()
        
        # Convert chemical systems to sorted tuples
        chemical_systems = []
        for elements in crystals['Elements']:
            try:
                chemsys = json.loads(elements.replace("'", '"'))
                chemical_systems.append(tuple(sorted(chemsys)))
            except json.JSONDecodeError as e:
                logger.warning(f"Failed to parse elements: {elements}")
                chemical_systems.append(tuple())
        
        processed_df['Chemical System'] = chemical_systems
        
        # Add advanced features
        processed_df['System Size'] = processed_df['Chemical System'].map(len)
        processed_df['Energy Density'] = processed_df['Corrected Energy'] / processed_df['NSites']
        
        return processed_df

    def analyze_chemical_system(self, 
                              chemsys: Tuple[str, ...],
                              grouped_entries: pd.core.groupby.generic.DataFrameGroupBy,
                              minimal_entries: pd.DataFrame) -> List[ComputedEntry]:
        """Analyze chemical system with enhanced features"""
        phase_diagram_entries = []
        
        # Generate all possible subsystems
        for length in range(len(chemsys) + 1):
            for subsystem in itertools.combinations(chemsys, length):
                subsystem_key = tuple(sorted(subsystem))
                subsystem_entries = grouped_entries.groups.get(subsystem_key, [])
                
                if len(subsystem_entries):
                    entries_df = minimal_entries.iloc[subsystem_entries]
                    # Add ML-enhanced prediction
                    entries_df = self._enhance_predictions(entries_df)
                    phase_diagram_entries.append(entries_df)
        
        if not phase_diagram_entries:
            logger.warning(f"No entries found for chemical system: {chemsys}")
            return []
            
        combined_entries = pd.concat(phase_diagram_entries)
        return self._convert_to_pymatgen_entries(combined_entries)

    def _enhance_predictions(self, entries_df: pd.DataFrame) -> pd.DataFrame:
        """Enhance energy predictions using TensorFlow model"""
        # Create feature matrix
        features = self._extract_features(entries_df)
        scaled_features = self.scaler.fit_transform(features)
        
        # Generate predictions
        predictions = self.model.predict(scaled_features)
        entries_df['ML Enhanced Energy'] = entries_df['Corrected Energy'] * 0.9 + predictions.flatten() * 0.1
        
        return entries_df

    def _extract_features(self, df: pd.DataFrame) -> np.ndarray:
        """Extract numerical features for ML model"""
        feature_cols = ['NSites', 'Formation Energy Per Atom']
        return df[feature_cols].values

    def _convert_to_pymatgen_entries(self, df: pd.DataFrame) -> List[ComputedEntry]:
        """Convert DataFrame entries to Pymatgen ComputedEntry objects"""
        mg_entries = []
        
        for _, row in df.iterrows():
            try:
                composition = row['Composition']
                # Use ML-enhanced energy if available, otherwise use corrected energy
                formation_energy = row.get('ML Enhanced Energy', row['Corrected Energy'])
                entry = ComputedEntry(composition, formation_energy)
                mg_entries.append(entry)
            except Exception as e:
                logger.error(f"Failed to convert entry: {row['Composition']}, Error: {str(e)}")
                
        return mg_entries

class PhaseDiagramVisualizer:
    """Enhanced visualization of phase diagrams"""
    
    def __init__(self):
        self.figure_size = (12, 8)
        self.dpi = 300
        plt.style.use('seaborn-darkgrid')
    
    def plot_phase_diagram(self, 
                          pd_entries: List[ComputedEntry],
                          title: str,
                          show_stability_regions: bool = True,
                          show_energy_scale: bool = True) -> plt.Figure:
        """Create enhanced phase diagram plot"""
        phase_diagram_obj = phase_diagram.PhaseDiagram(pd_entries)
        plotter = phase_diagram.PDPlotter(phase_diagram_obj)
        
        fig = plotter.get_plot(
            show_unstable=True,
            label_unstable=True,
            show_stability_window=show_stability_regions
        )
        
        if show_energy_scale:
            self._add_energy_scale(fig, phase_diagram_obj)
        
        fig.set_size_inches(self.figure_size)
        fig.suptitle(title, fontsize=14, y=0.95)
        
        return fig
    
    def _add_energy_scale(self, fig: plt.Figure, pd: phase_diagram.PhaseDiagram):
        """Add energy scale colorbar to phase diagram"""
        energies = [entry.energy_per_atom for entry in pd.entries]
        norm = plt.Normalize(min(energies), max(energies))
        
        ax = fig.add_axes([0.92, 0.1, 0.02, 0.6])
        cb = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap='viridis'),
                         cax=ax, label='Energy per atom (eV)')
        return cb

def main():
    """Main execution function"""
    # Initialize configuration and components
    config = ChemicalSystemConfig()
    loader = DataLoader(config)
    analyzer = ChemicalSystemAnalyzer()
    visualizer = PhaseDiagramVisualizer()
    
    # Download and load data
    loader.download_files()
    
    # Load and preprocess datasets
    gnome_crystals = pd.read_csv('stable_materials_summary.csv', index_col=0)
    reference_crystals = pd.read_csv('external_materials_summary.csv')
    mp_crystals = pd.read_csv('mp_snapshot_summary.csv')
    
    # Preprocess all datasets
    gnome_crystals = analyzer.preprocess_data(gnome_crystals)
    reference_crystals = analyzer.preprocess_data(reference_crystals)
    mp_crystals = analyzer.preprocess_data(mp_crystals)
    
    # Combine datasets
    all_crystals = pd.concat([gnome_crystals, reference_crystals], ignore_index=True)
    required_columns = ['Composition', 'NSites', 'Corrected Energy', 
                       'Formation Energy Per Atom', 'Chemical System']
    minimal_entries = all_crystals[required_columns]
    
    # Group entries
    grouped_entries = minimal_entries.groupby('Chemical System')
    mp_entries = mp_crystals.groupby('Chemical System')
    
    # Example usage with Na-Zn-P system
    chemsys = tuple(sorted(['Na', 'Zn', 'P']))
    
    # Generate and analyze phase diagrams
    gnome_entries = analyzer.analyze_chemical_system(chemsys, grouped_entries, all_crystals)
    mp_entries = analyzer.analyze_chemical_system(chemsys, mp_entries, mp_crystals)
    
    # Create visualizations
    gnome_plot = visualizer.plot_phase_diagram(gnome_entries, "GNoME Phase Diagram")
    mp_plot = visualizer.plot_phase_diagram(mp_entries, "Materials Project Phase Diagram")
    
    # Create joint phase diagram
    joint_entries = gnome_entries + mp_entries
    joint_plot = visualizer.plot_phase_diagram(
        joint_entries,
        "Joint Convex Hull Analysis",
        show_stability_regions=True,
        show_energy_scale=True
    )
    
    return gnome_plot, mp_plot, joint_plot

if __name__ == "__main__":
    main()