In [8]:
import pandas as pd
import numpy as np
from pymatgen.core import Composition
from pymatgen.ext.matproj import MPRester
import warnings

# Suppress minor warnings for clean output
warnings.filterwarnings('ignore', category=UserWarning)

class PseudohalideAnalyzer:
    def __init__(self, mp_csv_path, shannon_csv_path):
        # Initialize datasets with validation
        try:
            self.mp_df = pd.read_csv(mp_csv_path)
            self.shannon_df = pd.read_csv(shannon_csv_path)
            self._validate_datasets()
        except Exception as e:
            raise ValueError(f"Data loading failed: {str(e)}")

        # Pseudohalide definition with charge information
        self.pseudohalides = {
            'CN': {'elements': ['C', 'N'], 'charge': -1, 'radius': 1.91},
            'SCN': {'elements': ['S', 'C', 'N'], 'charge': -1, 'radius': 2.13},
            'N3': {'elements': ['N'], 'charge': -1, 'radius': 1.82},
            'OCN': {'elements': ['O', 'C', 'N'], 'charge': -1, 'radius': 2.05},
            'SeCN': {'elements': ['Se', 'C', 'N'], 'charge': -1, 'radius': 2.20},
            'NCS': {'elements': ['N', 'C', 'S'], 'charge': -1, 'radius': 2.13}
        }

    def _validate_datasets(self):
        """Ensure required columns exist in datasets"""
        required_mp_cols = {'material_id', 'formula_pretty', 'elements', 
                           'band_gap', 'is_metal', 'energy_above_hull'}
        required_shannon_cols = {'element', 'oxidation_state', 'r_ionic'}
        
        if not required_mp_cols.issubset(self.mp_df.columns):
            missing = required_mp_cols - set(self.mp_df.columns)
            raise ValueError(f"MP dataset missing columns: {missing}")
            
        if not required_shannon_cols.issubset(self.shannon_df.columns):
            missing = required_shannon_cols - set(self.shannon_df.columns)
            raise ValueError(f"Shannon dataset missing columns: {missing}")

    def filter_pv_candidates(self, min_gap=1.0, max_gap=1.8, max_hull=0.1):
        """Filter materials for photovoltaic potential"""
        return self.mp_df[
            (~self.mp_df['is_metal']) &
            (self.mp_df['band_gap'].between(min_gap, max_gap)) &
            (self.mp_df['energy_above_hull'] <= max_hull)
        ].copy()

    def identify_pseudohalides(self, df):
        """Identify compounds containing pseudohalide groups"""
        results = []
        for _, row in df.iterrows():
            comp = Composition(row['formula_pretty'])
            # Get element symbols as strings for quick lookup
            elements = {e.symbol for e in comp.elements}
            
            for p_id, p_data in self.pseudohalides.items():
                # Create pseudohalide composition
                pseudo_comp = Composition("".join(p_data['elements']))
                
                # Check if all pseudohalide elements exist in compound
                if all(e.symbol in elements for e in pseudo_comp.elements):
                    results.append({
                        'material_id': row['material_id'],
                        'formula': row['formula_pretty'],
                        'pseudohalide': p_id,
                        'band_gap': row['band_gap'],
                        'formation_energy': row.get('formation_energy_per_atom', np.nan),
                        'energy_above_hull': row['energy_above_hull']
                    })
        return pd.DataFrame(results)

    def get_ionic_radius(self, element, oxidation_state):
        """Get average ionic radius with validation"""
        radii = self.shannon_df[
            (self.shannon_df['element'] == element) &
            (self.shannon_df['oxidation_state'] == oxidation_state)
        ]['r_ionic']
        
        if radii.empty:
            raise ValueError(f"No radius found for {element}+{oxidation_state}")
        return radii.mean()

    def generate_abx3_candidates(self, a_charges=(1, 2, 3), b_charges=(2, 3, 4)):
        """Generate new ABX3 candidates with charge balance validation"""
        candidates = []
        a_elements = self.shannon_df[
            self.shannon_df['oxidation_state'].isin(a_charges)
        ]['element'].unique()
        
        b_elements = self.shannon_df[
            self.shannon_df['oxidation_state'].isin(b_charges)
        ]['element'].unique()

        for a in a_elements:
            for a_charge in self.shannon_df[self.shannon_df['element'] == a]['oxidation_state'].unique():
                for b in b_elements:
                    for b_charge in self.shannon_df[self.shannon_df['element'] == b]['oxidation_state'].unique():
                        if (a_charge + b_charge) != 3:  # ABX3 charge balance
                            continue
                            
                        for x_id, x_data in self.pseudohalides.items():
                            try:
                                a_radius = self.get_ionic_radius(a, a_charge)
                                b_radius = self.get_ionic_radius(b, b_charge)
                                x_radius = x_data['radius']
                                
                                # Structural compatibility
                                tolerance = (a_radius + x_radius) / (np.sqrt(2) * (b_radius + x_radius))
                                octahedral = b_radius / x_radius
                                
                                if 0.8 < tolerance < 1.1 and 0.4 < octahedral < 0.9:
                                    candidates.append({
                                        'A': a,
                                        'A_charge': a_charge,
                                        'B': b,
                                        'B_charge': b_charge,
                                        'X': x_id,
                                        'tolerance_factor': tolerance,
                                        'octahedral_factor': octahedral,
                                        'formula': f"{a}{b}({x_id})3"
                                    })
                            except ValueError:
                                continue
                                
        return pd.DataFrame(candidates)

    def full_pipeline(self):
        """Complete analysis pipeline"""
        # Step 1: Filter PV candidates
        pv_candidates = self.filter_pv_candidates()
        print(f"Step 1: Found {len(pv_candidates)} PV candidates")
        
        # Step 2: Identify existing pseudohalides
        existing = self.identify_pseudohalides(pv_candidates)
        print(f"Step 2: Identified {len(existing)} existing pseudohalides")
        
        # Step 3: Generate new ABX3 candidates
        new_candidates = self.generate_abx3_candidates()
        print(f"Step 3: Generated {len(new_candidates)} new ABX3 candidates")
        
        # Step 4: Prepare DataFrames for merging
        new_candidates['energy_above_hull'] = np.nan
        existing_df = existing.assign(type='existing')
        new_df = new_candidates.assign(
            type='new',
            material_id=None,
            band_gap=np.nan
        )
        
        # Step 5: Combine and filter
        combined = pd.concat([existing_df, new_df], ignore_index=True)
        filtered = combined[
            (combined['tolerance_factor'].between(0.9, 1.0)) &
            (combined['octahedral_factor'].between(0.4, 0.7)) &
            (~combined['B'].isin(['Hg', 'Cd'])) &
            (~combined['X'].isin(['CN', 'SCN']))
        ]
        
        # Step 6: Dynamic sorting
        sort_columns = []
        ascending = []
        
        if 'energy_above_hull' in filtered.columns:
            sort_columns.append('energy_above_hull')
            ascending.append(True)
            
        sort_columns.append('tolerance_factor')
        ascending.append(False)
        
        return filtered.sort_values(
            by=sort_columns,
            ascending=ascending
        )
# Initialize analyzer with validated datasets
analyzer = PseudohalideAnalyzer(
    mp_csv_path="MPDataset.csv",
    shannon_csv_path="ShannonDataset.csv"
)

# Run complete pipeline
results = analyzer.full_pipeline()

# Show final output
print("Top 5 Candidates:")
print(results[['formula', 'type', 'energy_above_hull', 'tolerance_factor']].head(5))

# Export results
results.to_csv("final_pseudohalide_candidates.csv", index=False)

Step 1: Found 19 PV candidates
Step 2: Identified 0 existing pseudohalides
Step 3: Generated 511 new ABX3 candidates
Top 5 Candidates:
       formula type  energy_above_hull  tolerance_factor
220  FrPd(N3)3  new                NaN          0.996003
131  BrCf(N3)3  new                NaN          0.993445
323  CsPd(N3)3  new                NaN          0.992793
287  FrCr(N3)3  new                NaN          0.990223
86   BrBk(N3)3  new                NaN          0.989949


In [58]:

import pandas as pd
import numpy as np
from mp_api.client import MPRester

# Replace with your API key (register at materialsproject.org)
API_KEY = "DGBlr4BVU3ukySoP3uILlx4rd8YTb3BC"

with MPRester(API_KEY) as mpr:
    oxide_data = mpr.materials.summary.search(
        formula="ABO3",
        fields=["formula_pretty", "elements", "nsites", "band_gap", 
                "formation_energy_per_atom", "energy_above_hull", "is_stable"]
    )
    halide_data = mpr.materials.summary.search(
        formula="ABX3",
        chemsys=["Cs-Pb-I", "Cs-Sn-I", "Rb-Pb-I", "Rb-Sn-I", 
                 "Cs-Pb-Br", "Cs-Sn-Br", "Rb-Pb-Br", "Rb-Sn-Br",
                 "Cs-Pb-Cl", "Cs-Sn-Cl", "Rb-Pb-Cl", "Rb-Sn-Cl"],
        fields=["formula_pretty", "elements", "nsites", "band_gap", 
                "formation_energy_per_atom", "energy_above_hull", "is_stable"]
    )
    perovskite_data = oxide_data + halide_data

mp_df = pd.DataFrame([{
    "formula_pretty": entry.formula_pretty,
    "elements": [str(e) for e in entry.elements],
    "nsites": entry.nsites,
    "band_gap": entry.band_gap if entry.band_gap is not None else np.nan,
    "formation_energy_per_atom": entry.formation_energy_per_atom if entry.formation_energy_per_atom is not None else np.nan,
    "energy_above_hull": entry.energy_above_hull if entry.energy_above_hull is not None else np.nan,
    "is_stable": entry.is_stable
} for entry in perovskite_data])
mp_df = mp_df[mp_df['nsites'] == 5]
print("Fetched Perovskites:\n", mp_df.head())
print(f"Total entries: {len(mp_df)}")

# Load and expand Shannon Dataset
shannon_df = pd.read_csv("ShannonDataset.csv")
pseudohalides = pd.DataFrame({
    'element': ['SCN', 'CN', 'N3'],
    'oxidation_state': [-1, -1, -1],
    'r_ionic': [1.95, 1.9, 2.1]
})
extra_elements = pd.DataFrame({
    'element': ['Cs', 'Rb', 'Pb', 'Sn', 'I', 'Br', 'Cl', 'O', 'Al', 'B', 'Cr', 'Cu', 'Fe', 'Ac', 
                'Ga', 'Tb', 'Ge', 'Mg', 'Ti', 'Tl', 'Zn'],
    'oxidation_state': [+1, +1, +2, +2, -1, -1, -1, -2, +3, +3, +3, +2, +3, +3, 
                        +3, +3, +4, +2, +4, +3, +2],
    'r_ionic': [1.67, 1.52, 1.19, 0.69, 1.33, 1.17, 0.99, 1.40, 0.535, 0.27, 0.615, 0.73, 0.645, 1.12, 
                0.47, 0.86, 0.53, 0.72, 0.605, 0.885, 0.74]
})
shannon_df = pd.concat([shannon_df, pseudohalides, extra_elements], ignore_index=True)
shannon_df = shannon_df.drop_duplicates(subset=['element', 'oxidation_state'], keep='last')
print("Shannon Tail:\n", shannon_df.tail())

# Assign oxidation states
def assign_oxidation_states(formula, elements):
    if 'O' in formula:
        return {'A': +3, 'B': +3, 'X': -2}  # Oxides: ABOâ‚ƒ
    elif any(x in formula for x in ['I', 'Br', 'Cl']):
        # Halides: ABXâ‚ƒ (e.g., CsPbIâ‚ƒ)
        a, b, x = elements
        if a in ['Cs', 'Rb']:  # Common A-site cations
            return {'A': +1, 'B': +2, 'X': -1}
        elif b in ['Pb', 'Sn']:  # Common B-site cations
            return {'A': +1, 'B': +2, 'X': -1}
    return {'A': +1, 'B': +2, 'X': -1}  # Default for pseudohalides

# Get radii with fallback
def get_radii(element, ox_state, shannon_df):
    match = shannon_df[(shannon_df['element'] == element) & 
                       (shannon_df['oxidation_state'] == ox_state)]
    radius = match['r_ionic'].values[0] if not match.empty else None
    if radius is None:
        try:
            el = Element(element)
            radius = el.ionic_radius if el.ionic_radius else np.nan
        except:
            radius = np.nan
    # print(f"Element: {element}, Ox State: {ox_state}, Radius: {radius}")
    return radius

# Get electronegativity with default
def get_electronegativity(element):
    try:
        return Element(element).X if Element(element).X else 2.5
    except:
        return 2.5

# Expand MP data
mp_df_expanded = mp_df.copy()
mp_df_expanded['A'] = mp_df_expanded['elements'].apply(lambda x: x[0])
mp_df_expanded['B'] = mp_df_expanded['elements'].apply(lambda x: x[1])
mp_df_expanded['X'] = mp_df_expanded['elements'].apply(lambda x: x[2])

for idx, row in mp_df_expanded.iterrows():
    ox_states = assign_oxidation_states(row['formula_pretty'], row['elements'])
    mp_df_expanded.loc[idx, 'r_A'] = get_radii(row['A'], ox_states['A'], shannon_df)
    mp_df_expanded.loc[idx, 'r_B'] = get_radii(row['B'], ox_states['B'], shannon_df)
    mp_df_expanded.loc[idx, 'r_X'] = get_radii(row['X'], ox_states['X'], shannon_df)

mp_df_expanded['X_A'] = mp_df_expanded['A'].apply(get_electronegativity)
mp_df_expanded['X_B'] = mp_df_expanded['B'].apply(get_electronegativity)
mp_df_expanded['X_X'] = mp_df_expanded['X'].apply(get_electronegativity)

print("Merged Dataset:\n", mp_df_expanded[['formula_pretty', 'r_A', 'r_B', 'r_X', 'X_A', 'X_B', 'X_X']].head(10))

# Feature engineering
def calc_tolerance_factor(r_A, r_B, r_X):
    return (r_A + r_X) / (np.sqrt(2) * (r_B + r_X)) if all([r_A, r_B, r_X]) else np.nan

def calc_octahedral_factor(r_B, r_X):
    return r_B / r_X if r_B and r_X else np.nan

mp_df_expanded['tolerance_factor'] = mp_df_expanded.apply(
    lambda row: calc_tolerance_factor(row['r_A'], row['r_B'], row['r_X']), axis=1)
mp_df_expanded['octahedral_factor'] = mp_df_expanded.apply(
    lambda row: calc_octahedral_factor(row['r_B'], row['r_X']), axis=1)

# Include formation_energy_per_atom as a feature
mp_df_expanded.dropna(subset=['r_A', 'r_B', 'r_X', 'tolerance_factor', 'octahedral_factor', 'band_gap'], inplace=True)
print("Engineered Features:\n", mp_df_expanded[['formula_pretty', 'tolerance_factor', 'octahedral_factor', 'formation_energy_per_atom']].head(10))
print(f"Entries after dropna: {len(mp_df_expanded)}")

# Screen pseudohalide candidates
candidates = {'A': ['Cs', 'Rb'], 'B': ['Pb', 'Sn'], 'X': ['SCN']}
screened_candidates = []
for a in candidates['A']:
    for b in candidates['B']:
        for x in candidates['X']:
            r_A = get_radii(a, +1, shannon_df)
            r_B = get_radii(b, +2, shannon_df)
            r_X = get_radii(x, -1, shannon_df)
            if pd.isna([r_A, r_B, r_X]).any():
                continue
            t = calc_tolerance_factor(r_A, r_B, r_X)
            mu = calc_octahedral_factor(r_B, r_X)
            print(f"{a}{b}{x}3: t = {t:.3f}, mu = {mu:.3f}")
            if 0.8 <= t <= 1.0 and 0.4 <= mu <= 0.7:
                screened_candidates.append({
                    'formula': f'{a}{b}{x}3',
                    'tolerance_factor': t,
                    'octahedral_factor': mu,
                    'r_A': r_A,
                    'r_B': r_B,
                    'r_X': r_X,
                    'X_A': get_electronegativity(a),
                    'X_B': get_electronegativity(b),
                    'X_X': get_electronegativity(x) if x in ['I', 'Br', 'Cl'] else 2.5,
                    'formation_energy_per_atom': np.nan  # Placeholder, not predicted here
                })
screened_df = pd.DataFrame(screened_candidates)
print("Screened Candidates (Strict):\n", screened_df)

screened_candidates_relaxed = []
for a in candidates['A']:
    for b in candidates['B']:
        for x in candidates['X']:
            r_A = get_radii(a, +1, shannon_df)
            r_B = get_radii(b, +2, shannon_df)
            r_X = get_radii(x, -1, shannon_df)
            if pd.isna([r_A, r_B, r_X]).any():
                continue
            t = calc_tolerance_factor(r_A, r_B, r_X)
            mu = calc_octahedral_factor(r_B, r_X)
            if 0.75 <= t <= 1.05 and 0.35 <= mu <= 0.75:
                screened_candidates_relaxed.append({
                    'formula': f'{a}{b}{x}3',
                    'tolerance_factor': t,
                    'octahedral_factor': mu,
                    'r_A': r_A,
                    'r_B': r_B,
                    'r_X': r_X,
                    'X_A': get_electronegativity(a),
                    'X_B': get_electronegativity(b),
                    'X_X': get_electronegativity(x) if x in ['I', 'Br', 'Cl'] else 2.5,
                    'formation_energy_per_atom': np.nan
                })
screened_df_relaxed = pd.DataFrame(screened_candidates_relaxed)
print("Screened Candidates (Relaxed):\n", screened_df_relaxed)

# Machine Learning Prediction
features = ['r_A', 'r_B', 'r_X', 'tolerance_factor', 'octahedral_factor', 'X_A', 'X_B', 'X_X', 'formation_energy_per_atom']
X = mp_df_expanded[features]
y = mp_df_expanded['band_gap']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model = XGBRegressor(n_estimators=300, learning_rate=0.03, max_depth=4, random_state=42)
model.fit(X_train, y_train)
print("Model RÂ² Score:", model.score(X_test, y_test))

for df, label in [(screened_df, "Strict"), (screened_df_relaxed, "Relaxed")]:
    if not df.empty:
        X_new = df[features]
        df['predicted_band_gap'] = model.predict(X_new)
        df['predicted_band_gap'] = df['predicted_band_gap'].clip(lower=0)
        print(f"Predictions ({label}):\n", df[['formula', 'predicted_band_gap']])

Retrieving SummaryDoc documents:   0%|          | 0/4696 [00:00<?, ?it/s]

Retrieving SummaryDoc documents:   0%|          | 0/32 [00:00<?, ?it/s]

Fetched Perovskites:
   formula_pretty     elements  nsites  band_gap  formation_energy_per_atom  \
0         AcAlO3  [Ac, Al, O]       5    4.1024                  -3.690019   
1          AcBO3   [Ac, B, O]       5    0.8071                  -2.475390   
2         AcCrO3  [Ac, Cr, O]       5    2.0031                  -3.138972   
3         AcCuO3  [Ac, Cu, O]       5    0.0000                  -2.422892   
4         AcFeO3  [Ac, Fe, O]       5    0.9888                  -2.771539   

   energy_above_hull  is_stable  
0           0.000000       True  
1           0.792473      False  
2           0.000000       True  
3           0.000000       True  
4           0.000000       True  
Total entries: 1837
Shannon Tail:
      Unnamed: 0 element  oxidation_state element_oxidation_state  \
516         NaN      Ge                4                     NaN   
517         NaN      Mg                2                     NaN   
518         NaN      Ti                4                     NaN  

In [141]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from mp_api.client import MPRester
from pymatgen.core import Element, Composition
from sklearn.model_selection import train_test_split, KFold, cross_validate
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.impute import SimpleImputer
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error, make_scorer
from xgboost import XGBRegressor
import matplotlib.pyplot as plt
from pathlib import Path
import joblib
import logging
from dataclasses import dataclass, field
import warnings

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
np.random.seed(42)
torch.manual_seed(42)

@dataclass
class PipelineConfig:
    api_key: str = ""
    shannon_file: str = "ShannonDataset.csv"
    test_size: float = 0.2
    random_state: int = 42
    nn_epochs: int = 200
    nn_patience: int = 20
    candidate_ranges: dict = field(default_factory=lambda: {
        'tolerance_factor': (0.75, 1.05),
        'octahedral_factor': (0.35, 0.75)
    })

class PerovskiteNN(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        
    def forward(self, x):
        return self.network(x)

class PerovskitePipeline:
    def __init__(self, config):
        self.config = config
        self.shannon_df = self._prepare_shannon_data()
        self._initialize_components()
        self._validate_config()
        
    def _initialize_components(self):
        self.candidates = {'A': ['Cs', 'Rb', 'Sr', 'Ba'], 
                          'B': ['Pb', 'Sn', 'Ge'], 
                          'X': ['SCN', 'CN', 'N3']}
        self.models = {}
        self.scaler = StandardScaler()
        self.imputer = SimpleImputer(strategy='mean')
        self.output_dir = Path("perovskite_models")
        self.output_dir.mkdir(exist_ok=True)
        self.valid_elements = {e.symbol for e in Element}
        self.reference_df = None
        self.X_train = self.X_test = None
        self.y_train = self.y_test = None
        self.test_indices = None

    def _prepare_shannon_data(self):
        shannon_df = pd.read_csv(self.config.shannon_file)
        pseudohalides = pd.DataFrame({
            'element': ['SCN', 'CN', 'N3'],
            'oxidation_state': [-1, -1, -1],
            'r_ionic': [2.13, 1.98, 2.25],
            'electronegativity': [2.82, 3.05, 3.12],
            'polarizability': [4.1, 3.8, 4.3]
        })
        return pd.concat([shannon_df, pseudohalides]).drop_duplicates(
            subset=['element', 'oxidation_state'], keep='last')

    def _validate_config(self):
        if not 0 < self.config.test_size < 1:
            raise ValueError("Test size must be between 0 and 1")
        if self.config.nn_patience <= 0:
            raise ValueError("Patience must be positive integer")

    def fetch_data(self):
        logging.info("Fetching data from Materials Project...")
        try:
            with MPRester(self.config.api_key) as mpr:
                queries = [
                    ("ABO3", None, ["Cs-Pb-I", "Cs-Sn-I", "Rb-Pb-I"]),  # Example systems
                    ("ABX3", ["Cs-Pb-I", "Cs-Sn-I", "Rb-Pb-I"], None)
                ]
                data = []
                for formula, chemsys, fields in queries:
                    results = mpr.materials.summary.search(
                        formula=formula,
                        chemsys=chemsys,
                        fields=["formula_pretty", "elements", "nsites", 
                               "formation_energy_per_atom", "energy_above_hull",
                               "band_gap", "is_stable"]
                    )
                    data.extend(results)
                return self._process_raw_data(data)
        except Exception as e:
            logging.error(f"Data fetch failed: {str(e)}")
            raise

    def _process_raw_data(self, raw_data):
        processed = []
        for entry in raw_data:
            try:
                processed.append({
                    "formula": entry.formula_pretty,
                    "elements": [str(e) for e in entry.elements],
                    "nsites": entry.nsites,
                    "formation_energy_per_atom": entry.formation_energy_per_atom,
                    "energy_above_hull": entry.energy_above_hull,
                    "band_gap": entry.band_gap,
                    "is_stable": entry.is_stable
                })
            except AttributeError as ae:
                logging.warning(f"Skipping invalid entry: {str(ae)}")
        df = pd.DataFrame(processed)
        return df[(df['nsites'] == 5) & 
                (df['formation_energy_per_atom'].notna())].drop_duplicates('formula')

    def preprocess_data(self, df):
        logging.info("Preprocessing data...")
        # Ensure exactly 3 elements per entry
        df = df[df['elements'].apply(lambda x: len(x) == 3)].copy()
        df['elements'] = df['elements'].apply(
            lambda x: [e for e in x if Element.is_valid_symbol(e)])
        df = df.dropna(subset=['elements'])
        logging.info(f"Remaining entries after preprocessing: {len(df)}")
        return df

    def engineer_features(self, df):
        logging.info("Engineering features...")
        
        # Properly extract elements from list-type column
        df = df.assign(
            A=df['elements'].str.get(0),  # Use .str.get() for list elements
            B=df['elements'].str.get(1),
            X=df['elements'].str.get(2)
        )
        
        # Filter out any rows with missing elements
        df = df.dropna(subset=['A', 'B', 'X'])

        invalid = df[['A', 'B', 'X']].isna().any(axis=1)
        if invalid.any():
            logging.warning(f"Dropping {invalid.sum()} rows with invalid elements")
            df = df[~invalid].copy()
            
        # Oxidation states with composition-based guessing
        df['ox_states'] = df.apply(lambda row: self._get_oxidation_states(row), axis=1)
        
        # Ionic radii
        df[['r_A', 'r_B', 'r_X']] = df.apply(
            lambda row: self._get_ionic_radii(row), axis=1, result_type='expand')
        
        # Electronic properties
        electronegativity = lambda el: Element(el).X if el in self.valid_elements else np.nan
        for el, col in [('A', 'X_A'), ('B', 'X_B'), ('X', 'X_X')]:
            df[col] = df[el].apply(electronegativity)
        
        # Structure-property relationships
        df['tolerance_factor'] = (df['r_A'] + df['r_X']) / (np.sqrt(2) * (df['r_B'] + df['r_X']))
        df['octahedral_factor'] = df['r_B'] / df['r_X']
        df['delta_X'] = abs(df['X_B'] - df['X_X'])
        
        # Stability indicators
        df['goldschmidt_ok'] = df['tolerance_factor'].between(0.8, 1.1).astype(int)
        df['octahedral_ok'] = df['octahedral_factor'].between(0.4, 0.7).astype(int)
        
        return df.dropna(subset=['r_A', 'r_B', 'r_X', 'tolerance_factor', 'octahedral_factor'])

    def _get_oxidation_states(self, row):
        try:
            comp = Composition(row['formula'])
            oxi_states = comp.oxi_state_guesses()[0]
            # Map to A/B/X positions
            return {
                'A': oxi_states.get(row['A']),  # Added missing parenthesis
                'B': oxi_states.get(row['B']),
                'X': oxi_states.get(row['X'])
            }
        except:
            return {'A': +1, 'B': +2, 'X': -1}

    def _get_ionic_radii(self, row):
        try:
            return (
                self.shannon_df[
                    (self.shannon_df.element == row['A']) & 
                    (self.shannon_df.oxidation_state == row['ox_states']['A'])
                ].r_ionic.values[0],
                self.shannon_df[
                    (self.shannon_df.element == row['B']) & 
                    (self.shannon_df.oxidation_state == row['ox_states']['B'])
                ].r_ionic.values[0],
                self.shannon_df[
                    (self.shannon_df.element == row['X']) & 
                    (self.shannon_df.oxidation_state == row['ox_states']['X'])
                ].r_ionic.values[0]
            )
        except (IndexError, KeyError):  # Add KeyError catch
            return (np.nan, np.nan, np.nan)

    def train_models(self, df, target='formation_energy_per_atom'):
        logging.info(f"Training models for {target}...")
        X = pd.get_dummies(df[['A', 'B', 'X']]).join(
            df[['r_A', 'r_B', 'r_X', 'tolerance_factor', 'octahedral_factor', 
               'X_A', 'X_B', 'X_X', 'energy_above_hull', 'delta_X']])
        y = df[target]
        
        # Train-test split
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=self.config.test_size, 
            random_state=self.config.random_state)
        self.test_indices = X_test.index
        
        # Imputation and scaling
        X_train = self.imputer.fit_transform(X_train)
        X_test = self.imputer.transform(X_test)
        self.X_train = self.scaler.fit_transform(X_train)
        self.X_test = self.scaler.transform(X_test)
        self.y_train, self.y_test = y_train, y_test
        
        nn_model, train_losses, test_losses = self._train_neural_network()
    
        # Store all models with consistent interface
        self.models = {
            'xgb': XGBRegressor().fit(self.X_train, self.y_train),
            'rf': RandomForestRegressor().fit(self.X_train, self.y_train),
            'gbr': GradientBoostingRegressor().fit(self.X_train, self.y_train),
            'nn': {
                'model': nn_model,  # The actual PyTorch model
                'train_losses': train_losses,
                'test_losses': test_losses
            }
        }
        # Model evaluation
        results = self._evaluate_models()
        self._save_best_model(results)
        return results

    def _plot_learning_curve(self):
        """Neural network training progress"""
        if 'nn' not in self.models:
            return
        
        plt.figure(figsize=(10, 6))
        plt.plot(self.models['nn']['train_losses'], label='Train Loss')
        plt.plot(self.models['nn']['test_losses'], label='Validation Loss')
        
        plt.title("Neural Network Learning Curve")
        plt.xlabel("Epoch")
        plt.ylabel("MSE Loss")
        plt.legend()
        plt.grid(True)
        
        plt.savefig(self.output_dir/'nn_learning_curve.png', dpi=300)
        plt.close()

    

    def _train_single_fold(self, X_train, y_train):
        """Train neural network on a single CV fold"""
        model = PerovskiteNN(X_train.shape[1])
        criterion = nn.MSELoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        
        X_tensor = torch.FloatTensor(X_train)
        y_tensor = torch.FloatTensor(y_train.values).view(-1, 1)
        
        best_loss = float('inf')
        patience_counter = 0
        
        model.train()
        for epoch in range(self.config.nn_epochs):
            optimizer.zero_grad()
            outputs = model(X_tensor)
            loss = criterion(outputs, y_tensor)
            loss.backward()
            optimizer.step()
            
            # Early stopping check
            with torch.no_grad():
                model.eval()
                val_loss = criterion(model(X_tensor), y_tensor)
                model.train()
                
            if val_loss < best_loss:
                best_loss = val_loss
                patience_counter = 0
            else:
                patience_counter += 1
                
            if patience_counter >= self.config.nn_patience:
                break
                
        return model

    def _train_neural_network(self):
        model = PerovskiteNN(self.X_train.shape[1])
        criterion = nn.MSELoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        
        train_tensor = torch.FloatTensor(self.X_train)
        train_target = torch.FloatTensor(self.y_train.values).view(-1, 1)
        test_tensor = torch.FloatTensor(self.X_test)
        test_target = torch.FloatTensor(self.y_test.values).view(-1, 1)
        
        best_loss = float('inf')
        patience_counter = 0
        train_losses, test_losses = [], []
        
        model.train()
        for epoch in range(self.config.nn_epochs):
            optimizer.zero_grad()
            outputs = model(train_tensor)
            loss = criterion(outputs, train_target)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
            
            with torch.no_grad():
                model.eval()
                test_loss = criterion(model(test_tensor), test_target).item()
                test_losses.append(test_loss)
                
            if test_loss < best_loss:
                best_loss = test_loss
                patience_counter = 0
                torch.save(model.state_dict(), self.output_dir/'nn_best.pth')
            else:
                patience_counter += 1
                
            if patience_counter >= self.config.nn_patience:
                logging.info(f"Early stopping at epoch {epoch}")
                break
                
        model.load_state_dict(torch.load(self.output_dir/'nn_best.pth'))
        return model, train_losses, test_losses


    def _evaluate_models(self):
        """Evaluate all models on test set"""
        metrics = {}
        for name, model in self.models.items():
            if name == 'nn':
                # Access the actual model from the dictionary
                with torch.no_grad():
                    preds = model['model'](torch.FloatTensor(self.X_test)).numpy().flatten()
            else:
                preds = model.predict(self.X_test)
                
            metrics[name] = {
                'r2': r2_score(self.y_test, preds),
                'mae': mean_absolute_error(self.y_test, preds),
                'rmse': np.sqrt(mean_squared_error(self.y_test, preds))
            }
        return metrics

    def _save_best_model(self, results):
        best_model = max(results.items(), key=lambda x: x[1]['r2'])
        logging.info(f"Best model: {best_model[0]} (RÂ²={best_model[1]['r2']:.3f})")
        joblib.dump(self.models[best_model[0]], self.output_dir/f"{best_model[0]}_model.pkl")
        joblib.dump(self.scaler, self.output_dir/'scaler.pkl')
        joblib.dump(self.imputer, self.output_dir/'imputer.pkl')

    def validate_models(self, cv=5):
        logging.info("Running cross-validation...")
        X = np.vstack([self.X_train, self.X_test])
        y = pd.concat([self.y_train, self.y_test])
        
        scorers = {
            'r2': make_scorer(r2_score),
            'mae': make_scorer(mean_absolute_error),
            'rmse': make_scorer(lambda y, p: np.sqrt(mean_squared_error(y, p)))
        }
        
        results = {}
        for name, model in self.models.items():
            if name == 'nn':
                results[name] = self._cross_validate_nn(X, y, cv)
            else:
                cv_results = cross_validate(model, X, y, cv=cv, scoring=scorers)
                results[name] = {
                    'r2': cv_results['test_r2'].mean(),
                    'mae': cv_results['test_mae'].mean(),
                    'rmse': cv_results['test_rmse'].mean()
                }
        return results

    def _cross_validate_nn(self, X, y, cv):
        kf = KFold(cv)
        scores = {'r2': [], 'mae': [], 'rmse': []}
        
        for train_idx, test_idx in kf.split(X):
            X_train, X_test = X[train_idx], X[test_idx]
            y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]
            
            model = self._train_single_fold(X_train, y_train)
            with torch.no_grad():
                preds = model(torch.FloatTensor(X_test)).numpy().flatten()
            
            scores['r2'].append(r2_score(y_test, preds))
            scores['mae'].append(mean_absolute_error(y_test, preds))
            scores['rmse'].append(np.sqrt(mean_squared_error(y_test, preds)))
        
        return {k: np.mean(v) for k, v in scores.items()}

    def chemical_sanity_check(self, df):
        logging.info("Performing chemical sanity checks...")
        valid = (
            df['tolerance_factor'].between(0.8, 1.1) &
            df['octahedral_factor'].between(0.4, 0.7)
        )
        validity_rate = valid.mean()
        if validity_rate < 0.9:
            warnings.warn(f"Low validity rate: {validity_rate:.1%}")
        return validity_rate

    def run_pipeline(self):
        df = self.fetch_data()
        df = self.preprocess_data(df)
        df = self.engineer_features(df)
        self.chemical_sanity_check(df)
        
        results = self.train_models(df)
        validation = self.validate_models()
        
        # This line invokes the visualizations
        self.visualize_results(df)  # Make sure this line is present
        
        logging.info("\nFinal Results:")
        for model, scores in results.items():
            print(f"{model.upper():<5} RÂ²: {scores['r2']:.3f}  MAE: {scores['mae']:.3f}  RMSE: {scores['rmse']:.3f}")
        
        return results, validation

    ## _________VISUALIZATIONS______________

    def visualize_results(self, df):
        """Generate comprehensive validation visualizations"""
        logging.info("Generating model validation visualizations...")
        
        try:
            # 1. Parity Plots with Chemical Validity
            self._plot_parity_with_chemicals(df)
            
            # 2. Learning Curves for Neural Network
            if 'nn' in self.models:
                self._plot_learning_curve()
            
            # 3. Residual Analysis
            self._plot_residuals(df)
            
            # 4. Feature Importance
            self._plot_feature_importance(df)
            
            # 5. Prediction Distributions
            self._plot_prediction_distributions(df)
            
            # 6. Top Candidate Analysis
            self._plot_top_candidates(df)
            
        except Exception as e:
            logging.error(f"Visualization failed: {str(e)}")
            raise
    
    def _plot_parity_with_chemicals(self, df):
        # Update model prediction access
        test_df = df.loc[self.test_indices].copy()
        
        for name in self.models.keys():
            if name == 'nn':
                with torch.no_grad():
                    test_df[f'pred_{name}'] = self.models[name]['model'](
                        torch.FloatTensor(self.X_test)
                    ).numpy().flatten()
            else:
                test_df[f'pred_{name}'] = self.models[name].predict(self.X_test)
        
        # Create subplots
        fig, axes = plt.subplots(2, 2, figsize=(20, 18))
        
        # Plot each model's parity plot
        for idx, (name, ax) in enumerate(zip(self.models.keys(), axes.flatten())):
            if idx >= 4:  # Only plot first 4 models
                break
                
            # Calculate metrics
            y_true = test_df['formation_energy_per_atom']
            y_pred = test_df[f'pred_{name}']
            r2 = r2_score(y_true, y_pred)
            mae = mean_absolute_error(y_true, y_pred)
            
            # Create scatter plot with chemical validity
            sc = ax.scatter(
                y_true, y_pred, 
                c=test_df['tolerance_factor'],
                cmap='viridis', 
                alpha=0.7,
                vmin=0.7, vmax=1.1
            )
            
            # Formatting
            ax.plot([y_true.min(), y_true.max()], 
                    [y_true.min(), y_true.max()], 'r--')
            ax.set_title(f"{name.upper()} Parity Plot\n(RÂ²={r2:.3f}, MAE={mae:.3f})")
            ax.set_xlabel("True Formation Energy (eV/atom)")
            ax.set_ylabel("Predicted Formation Energy (eV/atom)")
            plt.colorbar(sc, ax=ax, label='Tolerance Factor')
            
        plt.tight_layout()
        plt.savefig(self.output_dir/'parity_plots.png', dpi=300)
        plt.close()
    
    def _plot_learning_curve(self):
        """Neural network training progress"""
        plt.figure(figsize=(10, 6))
        plt.plot(self.models['nn']['train_losses'], label='Train Loss')
        plt.plot(self.models['nn']['test_losses'], label='Validation Loss')
        
        plt.title("Neural Network Learning Curve")
        plt.xlabel("Epoch")
        plt.ylabel("MSE Loss")
        plt.legend()
        plt.grid(True)
        
        plt.savefig(self.output_dir/'nn_learning_curve.png', dpi=300)
        plt.close()
    
    def _plot_residuals(self, df):
        """Residual analysis across models"""
        test_df = df.loc[self.test_indices].copy()
        
        plt.figure(figsize=(15, 10))
        for idx, name in enumerate(self.models.keys(), 1):
            # Get predictions based on model type
            if name == 'nn':
                with torch.no_grad():
                    preds = self.models[name]['model'](  # Access the actual model
                        torch.FloatTensor(self.X_test)
                    ).numpy().flatten()
            else:
                preds = self.models[name].predict(self.X_test)
                
            residuals = test_df['formation_energy_per_atom'] - preds
            
            # Plotting code remains the same
            plt.subplot(2, 2, idx)
            plt.scatter(preds, residuals, alpha=0.5)
            plt.axhline(0, color='red', linestyle='--')
            plt.title(f"{name.upper()} Residuals")
            plt.xlabel("Predicted Values")
            plt.ylabel("Residuals")
            plt.grid(True)
            
        plt.tight_layout()
        plt.savefig(self.output_dir/'residual_analysis.png', dpi=300)
        plt.close()
    
    def _plot_feature_importance(self, df):
        """Feature importance for tree-based models"""
        plt.figure(figsize=(15, 10))
        
        # Get feature names
        features = pd.get_dummies(df[['A', 'B', 'X']]).columns.tolist() + [
            'r_A', 'r_B', 'r_X', 'tolerance_factor', 'octahedral_factor',
            'X_A', 'X_B', 'X_X', 'energy_above_hull', 'delta_X'
        ]
        
        for idx, name in enumerate(self.models.keys(), 1):
            if name == 'nn':
                continue  # Skip NN
                
            model = self.models[name]
            if hasattr(model, 'feature_importances_'):
                importances = model.feature_importances_
            else:
                continue
                
            indices = np.argsort(importances)[::-1]
            
            plt.subplot(2, 2, idx)
            plt.title(f"{name.upper()} Feature Importance")
            plt.barh(range(10), importances[indices][:10], align='center')
            plt.yticks(range(10), [features[i] for i in indices[:10]])
            plt.xlabel('Relative Importance')
            plt.gca().invert_yaxis()
            
        plt.tight_layout()
        plt.savefig(self.output_dir/'feature_importance.png', dpi=300)
        plt.close()
    
    def _plot_prediction_distributions(self, df):
        """Distribution of predicted vs actual values"""
        test_df = df.loc[self.test_indices].copy()
        
        plt.figure(figsize=(12, 8))
        for name in self.models.keys():
            if name == 'nn':
                with torch.no_grad():
                    preds = self.models[name]['model'](
                        torch.FloatTensor(self.X_test)
                    ).numpy().flatten()
            else:
                preds = self.models[name].predict(self.X_test)
                
            plt.hist(preds, alpha=0.5, bins=30, label=name.upper())
        
        plt.hist(test_df['formation_energy_per_atom'], 
                bins=30, alpha=0.3, label='Actual', color='black')
        
        plt.title("Prediction Distributions vs Actual Values")
        plt.xlabel("Formation Energy (eV/atom)")
        plt.ylabel("Frequency")
        plt.legend()
        plt.savefig(self.output_dir/'prediction_distributions.png', dpi=300)
        plt.close()
    
    def _plot_top_candidates(self, df):
        """Visualize top predicted candidates with annotations"""
        test_df = df.loc[self.test_indices].copy()
        
        # Get predictions from all models
        for name in self.models.keys():
            if name == 'nn':
                # Access the actual model from the dictionary
                with torch.no_grad():
                    test_df[f'pred_{name}'] = self.models[name]['model'](  # Changed here
                        torch.FloatTensor(self.X_test)
                    ).numpy().flatten()
            else:
                test_df[f'pred_{name}'] = self.models[name].predict(self.X_test)
        
        # Rest of the method remains the same
        plt.figure(figsize=(15, 10))
        
        # Identify top 10 most stable candidates across all models
        top_candidates = pd.concat([
            test_df.nsmallest(10, f'pred_{name}') for name in self.models.keys()
        ]).drop_duplicates('formula')
        
        # Create parallel coordinates plot
        plt.figure(figsize=(15, 10))
        pd.plotting.parallel_coordinates(
            top_candidates[['formula', 'tolerance_factor', 'octahedral_factor'] + 
                          [f'pred_{name}' for name in self.models.keys()]],
            'formula',
            colormap='viridis'
        )
        
        plt.title("Top Candidates - Multi-Model Comparison")
        plt.xlabel("Features and Model Predictions")
        plt.ylabel("Normalized Values")
        plt.xticks(rotation=45)
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(self.output_dir/'top_candidates.png', dpi=300)
        plt.close()

if __name__ == "__main__":
    config = PipelineConfig(
        api_key="0IdI9tURdQnNXU2ROBizUG4UiTye6B73",
        nn_epochs=150,
        nn_patience=20
    )
    pipeline = PerovskitePipeline(config)
    results, validation = pipeline.run_pipeline()

2025-03-13 18:26:04,988 - INFO - Fetching data from Materials Project...


Retrieving SummaryDoc documents:   0%|          | 0/4696 [00:00<?, ?it/s]

Retrieving SummaryDoc documents:   0%|          | 0/8 [00:00<?, ?it/s]

2025-03-13 18:26:08,245 - INFO - Preprocessing data...
2025-03-13 18:26:08,250 - INFO - Remaining entries after preprocessing: 1640
2025-03-13 18:26:08,250 - INFO - Engineering features...
2025-03-13 18:26:08,957 - INFO - Performing chemical sanity checks...
2025-03-13 18:26:08,958 - INFO - Training models for formation_energy_per_atom...
2025-03-13 18:26:09,847 - INFO - Best model: xgb (RÂ²=0.877)
2025-03-13 18:26:09,853 - INFO - Running cross-validation...
2025-03-13 18:26:14,099 - INFO - Generating model validation visualizations...
2025-03-13 18:26:18,459 - INFO - 
Final Results:


XGB   RÂ²: 0.877  MAE: 0.206  RMSE: 0.280
RF    RÂ²: 0.814  MAE: 0.267  RMSE: 0.344
GBR   RÂ²: 0.864  MAE: 0.230  RMSE: 0.294
NN    RÂ²: 0.816  MAE: 0.247  RMSE: 0.342


<Figure size 1500x1000 with 0 Axes>

In [152]:
from mp_api.client import MPRester

with MPRester("0IdI9tURdQnNXU2ROBizUG4UiTye6B73") as mpr:
    try:
        docs = mpr.materials.summary.search(
            elements=["Cs", "Pb", "I"],  # Materials containing Cs, Pb, I
            fields=["material_id", "formula_pretty", "band_gap", "energy_above_hull", "volume", "symmetry"],
            chunk_size=10,
            num_chunks=1
        )
        for doc in docs:
            print(f"Material ID: {doc.material_id}, "
                  f"Formula: {doc.formula_pretty}, "
                  f"Band Gap: {doc.band_gap} eV, "
                  f"Energy Above Hull: {doc.energy_above_hull} eV/atom, "
                  f"Volume: {doc.volume} Ã…Â³, "
                  f"Spacegroup: {doc.symmetry.symbol} (Number: {doc.symmetry.number})")
    except Exception as e:
        print(f"ðŸš¨ Query failed: {str(e)}")

Retrieving SummaryDoc documents:   0%|          | 0/9 [00:00<?, ?it/s]

Material ID: mp-1238804, Formula: Cs4PbI6, Band Gap: 3.4095 eV, Energy Above Hull: 0.0 eV/atom, Volume: 1163.7499601357163 Ã…Â³, Spacegroup: R-3c (Number: 167)
Material ID: mp-1069538, Formula: CsPbI3, Band Gap: 1.4785 eV, Energy Above Hull: 0.025049344000002003 eV/atom, Volume: 247.0988555291859 Ã…Â³, Spacegroup: Pm-3m (Number: 221)
Material ID: mp-1120768, Formula: CsPbI3, Band Gap: 1.6421000000000001 eV, Energy Above Hull: 0.015823317000002002 eV/atom, Volume: 1008.6977999446959 Ã…Â³, Spacegroup: Pnma (Number: 62)
Material ID: mp-540839, Formula: CsPbI3, Band Gap: 2.5181 eV, Energy Above Hull: 0.0 eV/atom, Volume: 930.846658290349 Ã…Â³, Spacegroup: Pnma (Number: 62)
Material ID: mp-2646981, Formula: CsPbI2Br, Band Gap: 1.4384000000000001 eV, Energy Above Hull: 0.035953406166667 eV/atom, Volume: 242.33084912816582 Ã…Â³, Spacegroup: P4/mmm (Number: 123)
Material ID: mp-2647097, Formula: Cs2Pb(ICl)2, Band Gap: 2.9646 eV, Energy Above Hull: 0.0 eV/atom, Volume: 648.2370343292783 Ã…Â³, S

In [154]:
import pandas as pd
import numpy as np
import logging
from dataclasses import dataclass
from pathlib import Path
import os
from pymatgen.core import Element, Lattice, Structure
from pymatgen.io.cif import CifWriter
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from mp_api.client import MPRester

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
np.random.seed(42)

@dataclass
class PipelineConfig:
    shannon_file: str = "ShannonDataset.csv"
    test_size: float = 0.2
    val_size: float = 0.15
    random_state: int = 42
    mp_api_key: str = "0IdI9tURdQnNXU2ROBizUG4UiTye6B73"

class PerovskitePipeline:
    def __init__(self, config):
        self.config = config
        self._initialize_components()
        self.shannon_df = self._prepare_shannon_data()
        self.reference_df = self._prepare_mp_data()  # Use real MP data
        self._validate_config()

    def _initialize_components(self):
        self.candidates = {'A': ['Cs', 'Rb', 'Sr', 'Ba'], 
                          'B': ['Pb', 'Sn', 'Ge'], 
                          'X': ['I', 'SCN', 'CN', 'N3']}
        self.valid_elements = {e.symbol for e in Element}
        self.structures_dir = "structures/"
        os.makedirs(self.structures_dir, exist_ok=True)

    def _validate_config(self):
        if not 0 < self.config.test_size < 1 or not 0 < self.config.val_size < 1:
            raise ValueError("Test and validation sizes must be between 0 and 1")

    def _prepare_shannon_data(self):
        """Load and prepare the Shannon ionic radii dataset."""
        try:
            shannon_df = pd.read_csv(self.config.shannon_file)
            logging.info(f"Loaded Shannon dataset from {self.config.shannon_file} with shape: {shannon_df.shape}")
            logging.info(f"Shannon dataset columns: {shannon_df.columns.tolist()}")

            # Define required and expected columns
            required_columns = ['element', 'oxidation_state', 'r_ionic']
            final_columns = ['element', 'oxidation_state', 'r_ionic', 'electronegativity', 'polarizability']
            available_columns = shannon_df.columns.tolist()

            # Check for required columns
            missing_required = [col for col in required_columns if col not in available_columns]
            if missing_required:
                raise ValueError(f"Shannon dataset is missing required columns: {missing_required}")

            # Drop extra columns
            shannon_df = shannon_df[required_columns]

            # Add electronegativity and polarizability with default values if missing
            if 'electronegativity' not in shannon_df.columns:
                logging.warning("Electronegativity column missing. Adding default value of 0.")
                shannon_df['electronegativity'] = 0
            if 'polarizability' not in shannon_df.columns:
                logging.warning("Polarizability column missing. Adding default value of 0.")
                shannon_df['polarizability'] = 0

            # Check for required elements and provide defaults if missing
            required_elements = {
                'Cs': (1, 1.67, 0.79, 0),
                'Rb': (1, 1.52, 0.82, 0),
                'Pb': (2, 0.84, 2.33, 0),
                'Sn': (2, 0.69, 1.96, 0),
                'I': (-1, 2.20, 2.66, 0)
            }
            missing_elements = [el for el in required_elements if el not in shannon_df['element'].values]
            if missing_elements:
                logging.warning(f"Missing ionic radii for elements: {missing_elements}. Using defaults.")
                defaults = pd.DataFrame({
                    'element': list(required_elements.keys()),
                    'oxidation_state': [data[0] for data in required_elements.values()],
                    'r_ionic': [data[1] for data in required_elements.values()],
                    'electronegativity': [data[2] for data in required_elements.values()],
                    'polarizability': [data[3] for data in required_elements.values()]
                })
                shannon_df = pd.concat([shannon_df, defaults]).drop_duplicates(subset=['element', 'oxidation_state'], keep='last')

            # Ensure pseudohalides are included
            pseudohalides = pd.DataFrame({
                'element': ['SCN', 'CN', 'N3'],
                'oxidation_state': [-1, -1, -1],
                'r_ionic': [2.13, 1.98, 2.25],
                'electronegativity': [2.82, 3.05, 3.12],
                'polarizability': [4.1, 3.8, 4.3]
            })
            shannon_df = pd.concat([shannon_df, pseudohalides]).drop_duplicates(
                subset=['element', 'oxidation_state'], keep='last')

            # Ensure final DataFrame has exactly the expected columns
            shannon_df = shannon_df[final_columns]
            logging.info(f"Processed Shannon dataset shape: {shannon_df.shape}")
            logging.info(f"Processed Shannon dataset columns: {shannon_df.columns.tolist()}")
            return shannon_df
        except FileNotFoundError:
            logging.error(f"ShannonDataset.csv not found at {self.config.shannon_file}. Using defaults.")
            defaults = pd.DataFrame({
                'element': ['Cs', 'Rb', 'Pb', 'Sn', 'I'],
                'oxidation_state': [1, 1, 2, 2, -1],
                'r_ionic': [1.67, 1.52, 0.84, 0.69, 2.20],
                'electronegativity': [0.79, 0.82, 2.33, 1.96, 2.66],
                'polarizability': [0, 0, 0, 0, 0]
            })
            pseudohalides = pd.DataFrame({
                'element': ['SCN', 'CN', 'N3'],
                'oxidation_state': [-1, -1, -1],
                'r_ionic': [2.13, 1.98, 2.25],
                'electronegativity': [2.82, 3.05, 3.12],
                'polarizability': [4.1, 3.8, 4.3]
            })
            shannon_df = pd.concat([defaults, pseudohalides]).drop_duplicates(
                subset=['element', 'oxidation_state'], keep='last')
            shannon_df = shannon_df[final_columns]
            logging.info(f"Default Shannon dataset shape: {shannon_df.shape}")
            logging.info(f"Default Shannon dataset columns: {shannon_df.columns.tolist()}")
            return shannon_df

    def _prepare_mp_data(self):
        """Retrieve data from Materials Project API."""
        try:
            with MPRester(self.config.mp_api_key) as mpr:
                docs = mpr.materials.summary.search(
                    elements=["Cs", "Pb", "I"],
                    fields=["material_id", "formula_pretty", "elements", "nsites", "band_gap", 
                            "energy_above_hull", "volume", "symmetry", "structure"],
                    chunk_size=50,
                    num_chunks=1
                )
                mp_data_list = []
                for doc in docs:
                    mp_data_list.append({
                        "material_id": doc.material_id,
                        "formula_pretty": doc.formula_pretty,
                        "elements": [str(elem) for elem in doc.elements],  # Convert Element objects to strings
                        "nsites": doc.nsites,
                        "band_gap": doc.band_gap,
                        "energy_above_hull": doc.energy_above_hull,
                        "volume": doc.volume,
                        "symmetry": doc.symmetry,
                        "structure": doc.structure
                    })
                if mp_data_list:
                    mp_data = pd.DataFrame(mp_data_list)
                    logging.info(f"Retrieved MP dataset shape: {mp_data.shape}")
                    logging.info(f"MP dataset features: {mp_data.columns.tolist()}")

                    # Save structures as CIF files
                    for idx, row in mp_data.iterrows():
                        if row['structure'] is not None:
                            CifWriter(row["structure"]).write_file(f"{self.structures_dir}/{row['material_id']}.cif")

                    return mp_data
                else:
                    logging.warning("No data retrieved from MP API.")
                    return pd.DataFrame()  # Return empty DataFrame if no data
        except Exception as e:
            logging.error(f"Failed to retrieve MP data: {str(e)}")
            return pd.DataFrame()  # Return empty DataFrame on failure

    def prepare_data(self):
        """Load and prepare the initial datasets, returning their shapes and features."""
        logging.info("Preparing initial datasets...")
        shannon_shape = self.shannon_df.shape
        mp_shape = self.reference_df.shape
        mp_features = self.reference_df.columns.tolist()
        logging.info(f"Shannon dataset shape: {shannon_shape}")
        logging.info(f"MP dataset shape: {mp_shape}")
        logging.info(f"MP dataset features: {mp_features}")
        return shannon_shape, mp_shape, mp_features

if __name__ == "__main__":
    # Initialize the pipeline with the configuration
    config = PipelineConfig(shannon_file="./ShannonDataset.csv")
    pipeline = PerovskitePipeline(config)
    
    # Prepare the data and get shapes and features
    shannon_shape, mp_shape, mp_features = pipeline.prepare_data()
    print(f"Shannon dataset shape: {shannon_shape}")
    print(f"MP dataset shape: {mp_shape}")
    print(f"MP dataset features: {mp_features}")

2025-03-13 19:38:17,253 - INFO - Loaded Shannon dataset from ./ShannonDataset.csv with shape: (497, 8)
2025-03-13 19:38:17,253 - INFO - Shannon dataset columns: ['Unnamed: 0', 'element', 'oxidation_state', 'element_oxidation_state', 'oxidation_type', 'r_crystal', 'remark', 'r_ionic']
2025-03-13 19:38:17,259 - INFO - Processed Shannon dataset shape: (209, 5)
2025-03-13 19:38:17,260 - INFO - Processed Shannon dataset columns: ['element', 'oxidation_state', 'r_ionic', 'electronegativity', 'polarizability']


Retrieving SummaryDoc documents:   0%|          | 0/9 [00:00<?, ?it/s]

2025-03-13 19:38:17,809 - INFO - Retrieved MP dataset shape: (9, 9)
2025-03-13 19:38:17,810 - INFO - MP dataset features: ['material_id', 'formula_pretty', 'elements', 'nsites', 'band_gap', 'energy_above_hull', 'volume', 'symmetry', 'structure']
2025-03-13 19:38:17,821 - INFO - Preparing initial datasets...
2025-03-13 19:38:17,822 - INFO - Shannon dataset shape: (209, 5)
2025-03-13 19:38:17,822 - INFO - MP dataset shape: (9, 9)
2025-03-13 19:38:17,822 - INFO - MP dataset features: ['material_id', 'formula_pretty', 'elements', 'nsites', 'band_gap', 'energy_above_hull', 'volume', 'symmetry', 'structure']


Shannon dataset shape: (209, 5)
MP dataset shape: (9, 9)
MP dataset features: ['material_id', 'formula_pretty', 'elements', 'nsites', 'band_gap', 'energy_above_hull', 'volume', 'symmetry', 'structure']


In [1]:
# Increasing Data sources:

In [None]:
from mp_api.client import MPRester  # Use mp_api for both MP and OQMD
import joblib
from dataclasses import dataclass, field
from chemdataextractor import Document
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import logging
from pathlib import Path
from pymatgen.core.composition import Composition, Element
from sklearn.model_selection import train_test_split, GridSearchCV, StratifiedKFold, cross_validate  # Corrected typo
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error, make_scorer
from xgboost import XGBRegressor
from tqdm import tqdm
import warnings


# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
np.random.seed(42)
torch.manual_seed(42)

@dataclass
class PipelineConfig:
    api_key: str = "0IdI9tURdQnNXU2ROBizUG4UiTye6B73"
    shannon_file: str = "ShannonDataset.csv"
    test_size: float = 0.2
    val_size: float = 0.15
    random_state: int = 42
    nn_epochs: int = 150
    nn_patience: int = 20
    candidate_ranges: dict = field(default_factory=lambda: {
        'tolerance_factor': (0.75, 1.05),
        'octahedral_factor': (0.35, 0.75)
    })

class PerovskiteNN(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )
        
    def forward(self, x):
        return self.network(x)

class PerovskitePipeline:
    def __init__(self, config):
        self.config = config
        self.shannon_df = self._prepare_shannon_data()
        self._initialize_components()
        self._validate_config()
        
    def _initialize_components(self):
        self.candidates = {'A': ['Cs', 'Rb', 'Sr', 'Ba'], 
                          'B': ['Pb', 'Sn', 'Ge'], 
                          'X': ['SCN', 'CN', 'N3']}
        self.models = {}
        self.scaler = StandardScaler()
        self.imputer = SimpleImputer(strategy='mean')
        self.output_dir = Path("perovskite_models")
        self.output_dir.mkdir(exist_ok=True)
        self.valid_elements = {e.symbol for e in Element}
        self.reference_df = None
        self.X_train = self.X_val = self.X_test = None
        self.y_train = self.y_val = self.y_test = None
        self.test_indices = None

    def _prepare_shannon_data(self):
        shannon_df = pd.read_csv(self.config.shannon_file)
        pseudohalides = pd.DataFrame({
            'element': ['SCN', 'CN', 'N3'],
            'oxidation_state': [-1, -1, -1],
            'r_ionic': [2.13, 1.98, 2.25],
            'electronegativity': [2.82, 3.05, 3.12],
            'polarizability': [4.1, 3.8, 4.3]
        })
        return pd.concat([shannon_df, pseudohalides]).drop_duplicates(
            subset=['element', 'oxidation_state'], keep='last')

    def _validate_config(self):
        if not 0 < self.config.test_size < 1 or not 0 < self.config.val_size < 1:
            raise ValueError("Test and validation sizes must be between 0 and 1")
        if self.config.nn_patience <= 0:
            raise ValueError("Patience must be positive integer")

    def fetch_icsd_data(self):
        try:
            icsd_data = pd.read_csv("icsd_data.csv")  # Replace with actual file path
            icsd_data['source'] = 'ICSD'
            return icsd_data
        except Exception as e:
            logging.error(f"ICSD data fetch failed: {str(e)}")
            return pd.DataFrame({
                'formula': ['CsPbI3', 'CsSnI3'],
                'elements': [['Cs', 'Pb', 'I'], ['Cs', 'Sn', 'I']],
                'nsites': [5, 5],
                'formation_energy_per_atom': [-0.5, -0.6],
                'is_stable': [True, True],
                'volume': [200.0, 210.0],
                'source': ['ICSD', 'ICSD']
            })

    # def fetch_aflow_data(self):
    #     try:
    #         aflow_data = []
    #         target_formulas = ["CsPbI3", "CsSnI3", "RbPbI3"]
    #         for formula in target_formulas:
    #             # Fetch entries containing the elements of the formula
    #             elements = set(Composition(formula).elements)  # e.g., {Cs, Pb, I}
    #             element_symbols = [str(el) for el in elements]
    #             query = search(catalog="icsd").filter(K.species.contains(element_symbols[0]))
    #             for entry in query:
    #                 # Check if the entry matches the target formula
    #                 if getattr(entry, 'compound', '') == formula:
    #                     aflow_data.append({
    #                         "formula": getattr(entry, 'compound', formula),
    #                         "elements": getattr(entry, 'species', element_symbols),
    #                         "nsites": getattr(entry, 'natoms', np.nan),
    #                         "formation_energy_per_atom": getattr(entry, 'enthalpy_formation_atom', np.nan),
    #                         "energy_above_hull": getattr(entry, 'delta_electronic_energy', np.nan),
    #                         "band_gap": getattr(entry, 'egap', np.nan),
    #                         "is_stable": True if getattr(entry, 'delta_electronic_energy', 0) < 0.1 else False,
    #                         "volume": getattr(entry, 'volume_cell', np.nan),
    #                         "source": "AFLOW"
    #                     })
    #         return pd.DataFrame(aflow_data)
    #     except Exception as e:
    #         logging.error(f"AFLOW data fetch failed: {str(e)}")
    #         return pd.DataFrame()

    def fetch_aflow_data(self):
        logging.warning("AFLOW data fetch disabled due to library incompatibility. Returning empty DataFrame.")
        return pd.DataFrame()

    def fetch_literature_data(self):
        try:
            literature_data = []
            with open("perovskite_paper.pdf", "rb") as f:
                doc = Document(f)
                for compound in doc.cems:
                    if "formation_energy" in doc.records:
                        for record in doc.records:
                            if hasattr(record, 'formation_energy'):
                                literature_data.append({
                                    "formula": compound.names[0],
                                    "elements": compound.elements,
                                    "nsites": np.nan,
                                    "formation_energy_per_atom": record.formation_energy.value if record.formation_energy else np.nan,
                                    "is_stable": True,
                                    "volume": np.nan,
                                    "source": "Literature"
                                })
            return pd.DataFrame(literature_data)
        except Exception as e:
            logging.error(f"Literature data fetch failed: {str(e)}")
            return pd.DataFrame()

    def fetch_data(self):
        logging.info("Fetching data from Materials Project, ICSD, Perovskite Database, AFLOW, and Literature...")
        original_level = logging.getLogger().getEffectiveLevel()
        logging.getLogger().setLevel(logging.ERROR)
        
        try:
            with MPRester(self.config.api_key) as mpr:
                queries = [("ABO3", None, ["Cs-Pb-I", "Cs-Sn-I", "Rb-Pb-I"]),
                           ("ABX3", ["Cs-Pb-I", "Cs-Sn-I", "Rb-Pb-I"], None)]
                mp_data = []
                for formula, chemsys, fields in queries:
                    results = mpr.materials.summary.search(
                        formula=formula, chemsys=chemsys,
                        fields=["formula_pretty", "elements", "nsites", 
                                "formation_energy_per_atom", "energy_above_hull",
                                "band_gap", "is_stable", "volume"])
                    mp_data.extend(results)
                logging.debug(f"Fetched {len(mp_data)} entries from Materials Project")
    
            icsd_data = self.fetch_icsd_data()
            logging.debug(f"Fetched {len(icsd_data)} entries from ICSD")
    
            try:
                perovskite_db_data = pd.read_csv("perovskite_dataset.csv")
                perovskite_db_data = perovskite_db_data.rename(columns={
                    'Formula': 'formula', 'FormationEnergy': 'formation_energy_per_atom',
                    'Elements': 'elements', 'NSites': 'nsites', 'Volume': 'volume', 'IsStable': 'is_stable'})
                perovskite_db_data['source'] = 'PerovskiteDB'
                perovskite_db_data['elements'] = perovskite_db_data['elements'].apply(
                    lambda x: eval(x) if isinstance(x, str) else x)
                logging.debug(f"Fetched {len(perovskite_db_data)} entries from Perovskite Database")
            except Exception as e:
                logging.error(f"Perovskite Database data fetch failed: {str(e)}")
                perovskite_db_data = pd.DataFrame()
                logging.debug("Perovskite Database data set to empty DataFrame")
    
            aflow_data = self.fetch_aflow_data()
            logging.debug(f"Fetched {len(aflow_data)} entries from AFLOW")
    
            literature_data = self.fetch_literature_data()
            logging.debug(f"Fetched {len(literature_data)} entries from Literature")

            def standardize_dataframe(df, required_cols):
                for col in required_cols:
                    if col not in df.columns:
                        df[col] = np.nan
                return df[required_cols]
    
            required_cols = ["formula", "elements", "nsites", "formation_energy_per_atom", 
                            "energy_above_hull", "band_gap", "is_stable", "volume", "source"]
            mp_df = standardize_dataframe(self._process_raw_data(mp_data, source="Materials Project"), required_cols)
            icsd_data = standardize_dataframe(icsd_data, required_cols)
            perovskite_db_data = standardize_dataframe(perovskite_db_data, required_cols)
            aflow_data = standardize_dataframe(aflow_data, required_cols)
            literature_data = standardize_dataframe(literature_data, required_cols)
    
            combined_df = pd.concat([mp_df, icsd_data, perovskite_db_data, aflow_data, literature_data]).drop_duplicates('formula')
            
            logging.getLogger().setLevel(original_level)
            logging.info(f"Combined dataset size: {len(combined_df)} entries")
            return combined_df
        except Exception as e:
            logging.getLogger().setLevel(original_level)
            logging.error(f"Data fetch failed: {str(e)}")
            raise

    def _process_raw_data(self, raw_data, source):
        processed = []
        for entry in raw_data:
            try:
                processed.append({
                    "formula": entry.formula_pretty,
                    "elements": [str(e) for e in entry.elements],
                    "nsites": entry.nsites,
                    "formation_energy_per_atom": entry.formation_energy_per_atom,
                    "energy_above_hull": getattr(entry, 'energy_above_hull', np.nan),
                    "band_gap": entry.band_gap,
                    "is_stable": entry.is_stable,
                    "volume": entry.volume,
                    "source": source
                })
            except AttributeError as ae:
                logging.warning(f"Skipping invalid entry in {source}: {str(ae)}")
        return pd.DataFrame(processed)

    def _process_oqmd_data(self, raw_data):
        processed = []
        for entry in raw_data:
            try:
                comp = entry.composition
                formula = comp.get_reduced_formula()
                elements = [str(el) for el in comp.elements]
                num_atoms = int(comp.num_atoms)
                processed.append({
                    "formula": formula,
                    "elements": elements,
                    "nsites": num_atoms,
                    "formation_energy_per_atom": entry.data.get("formation_energy_per_atom", np.nan),
                    "energy_above_hull": entry.data.get("energy_above_hull", np.nan),
                    "band_gap": np.nan,
                    "is_stable": entry.data.get("is_stable", False),
                    "volume": entry.data.get("volume", np.nan),
                    "source": "OQMD"
                })
            except AttributeError as ae:
                logging.warning(f"Skipping invalid OQMD entry: {str(ae)}")
        return pd.DataFrame(processed)

    def preprocess_data(self, df):
        logging.info("Preprocessing data...")
        df = df[df['nsites'] == 5].dropna(subset=['formation_energy_per_atom'])
        df['elements'] = df['elements'].apply(lambda x: [e for e in x if Element.is_valid_symbol(e)])
        logging.info(f"Remaining entries after preprocessing: {len(df)}")
        return df

    def engineer_features(self, df):
        logging.info("Engineering features...")
        df = df.assign(
            A=df['elements'].str.get(0),
            B=df['elements'].str.get(1),
            X=df['elements'].str.get(2)
        ).dropna(subset=['A', 'B', 'X'])

        df['ox_states'] = df.apply(self._get_oxidation_states, axis=1)
        df[['r_A', 'r_B', 'r_X']] = df.apply(self._get_ionic_radii, axis=1, result_type='expand')

        electronegativity = lambda el: Element(el).X if el in self.valid_elements else np.nan
        df['X_A'] = df['A'].map(electronegativity)
        df['X_B'] = df['B'].map(electronegativity)
        df['X_X'] = df['X'].map(electronegativity)

        df['delta_X'] = abs(df['X_B'] - df['X_X'])
        df['volume_per_atom'] = df['volume'] / df['nsites']
        df['tolerance_factor'] = (df['r_A'] + df['r_X']) / (np.sqrt(2) * (df['r_B'] + df['r_X']))
        df['octahedral_factor'] = df['r_B'] / df['r_X']

        df['goldschmidt_ok'] = df['tolerance_factor'].between(0.75, 1.1).astype(int)
        df['octahedral_ok'] = df['octahedral_factor'].between(0.35, 0.75).astype(int)
        df['stable'] = df['energy_above_hull'] < 0.1

        return df.dropna(subset=['r_A', 'r_B', 'r_X', 'tolerance_factor', 'octahedral_factor'])

    def _get_oxidation_states(self, row):
        try:
            comp = Composition(row['formula'])
            oxi_states = comp.oxi_state_guesses()
            if oxi_states:
                return {'A': oxi_states[0].get(row['A'], +1),
                        'B': oxi_states[0].get(row['B'], +2),
                        'X': oxi_states[0].get(row['X'], -1)}
            return {'A': +1, 'B': +2, 'X': -1}
        except Exception:
            return {'A': +1, 'B': +2, 'X': -1}

    def _get_ionic_radii(self, row):
        try:
            radii = []
            for site, oxi in [('A', row['ox_states']['A']), ('B', row['ox_states']['B']), ('X', row['ox_states']['X'])]:
                radius = self.shannon_df[
                    (self.shannon_df['element'] == row[site]) & 
                    (self.shannon_df['oxidation_state'] == oxi)
                ]['r_ionic'].values
                radii.append(radius[0] if radius.size else np.nan)
            return tuple(radii)
        except (IndexError, KeyError):
            return (np.nan, np.nan, np.nan)

    def train_models(self, df, target='formation_energy_per_atom'):
        logging.info(f"Training models for {target}...")
        computational_df = df[df['source'].isin(['Materials Project', 'OQMD', 'AFLOW'])].dropna(subset=[target])
        experimental_df = df[df['source'].isin(['ICSD', 'PerovskiteDB', 'Literature'])].dropna(subset=[target])
        
        if experimental_df.empty:
            logging.warning("No experimental data available (ICSD, PerovskiteDB, Literature). Skipping experimental evaluation.")
        else:
            logging.info(f"Found {len(experimental_df)} experimental entries.")
        
        # Ensure indices are reset to avoid mismatch
        computational_df = computational_df.reset_index(drop=True)
        X_comp = pd.get_dummies(computational_df[['A', 'B', 'X', 'source']]).join(
            computational_df[['r_A', 'r_B', 'r_X', 'tolerance_factor', 'octahedral_factor', 
                             'X_A', 'X_B', 'X_X', 'energy_above_hull', 'delta_X', 'volume_per_atom']])
        y_comp = computational_df[target].reset_index(drop=True)
        
        if not experimental_df.empty:
            experimental_df = experimental_df.reset_index(drop=True)
            X_exp = pd.get_dummies(experimental_df[['A', 'B', 'X', 'source']]).join(
                experimental_df[['r_A', 'r_B', 'r_X', 'tolerance_factor', 'octahedral_factor', 
                                'X_A', 'X_B', 'X_X', 'energy_above_hull', 'delta_X', 'volume_per_atom']])
            y_exp = experimental_df[target].reset_index(drop=True)
        
            missing_cols = set(X_comp.columns) - set(X_exp.columns)
            if missing_cols:
                X_exp = pd.concat([X_exp, pd.DataFrame(0, index=X_exp.index, columns=missing_cols)], axis=1)
            X_exp = X_exp[X_comp.columns]
        else:
            X_exp = np.array([])
            y_exp = pd.Series(dtype=float)
        
        if X_comp.empty:
            raise ValueError("No computational data available for training.")
        
        # Create binned target for stratification after preprocessing
        y_bins = pd.qcut(y_comp, q=5, labels=False, duplicates='drop')
        skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=self.config.random_state)
        train_val_indices, test_indices = next(skf.split(X_comp, y_bins))
        
        X_temp = X_comp.iloc[train_val_indices].reset_index(drop=True)
        X_test = X_comp.iloc[test_indices].reset_index(drop=True)
        y_temp = y_comp.iloc[train_val_indices].reset_index(drop=True)
        y_test = y_comp.iloc[test_indices].reset_index(drop=True)
        
        # Reindex y_bins to match X_temp
        y_bins_temp = y_bins[train_val_indices].reset_index(drop=True)
        
        train_indices, val_indices = train_test_split(
            range(len(X_temp)),
            test_size=self.config.val_size/(1-self.config.test_size),
            stratify=y_bins_temp,
            random_state=self.config.random_state
        )
        X_train = X_temp.iloc[train_indices]
        X_val = X_temp.iloc[val_indices]
        y_train = y_temp.iloc[train_indices]
        y_val = y_temp.iloc[val_indices]
        self.test_indices = test_indices
        
        X_train = self.imputer.fit_transform(X_train)
        X_val = self.imputer.transform(X_val)
        X_test = self.imputer.transform(X_test)
        if isinstance(X_exp, pd.DataFrame) and not X_exp.empty:
            X_exp = self.imputer.transform(X_exp)
        elif isinstance(X_exp, np.ndarray) and X_exp.size > 0:
            X_exp = self.imputer.transform(X_exp)
        self.X_train = self.scaler.fit_transform(X_train)
        self.X_val = self.scaler.transform(X_val)
        self.X_test = self.scaler.transform(X_test)
        self.X_exp = X_exp
        self.y_train, self.y_val, self.y_test, self.y_exp = y_train, y_val, y_test, y_exp
        
        param_grid = {
            'max_depth': [3, 5, 7],
            'learning_rate': [0.01, 0.1],
            'n_estimators': [100, 200]
        }
        xgb_grid = GridSearchCV(XGBRegressor(), param_grid, cv=5, scoring='r2')
        xgb_grid.fit(self.X_train, self.y_train)
        self.models['xgb'] = xgb_grid.best_estimator_
        
        self.models['rf'] = RandomForestRegressor().fit(self.X_train, self.y_train)
        self.models['gbr'] = GradientBoostingRegressor().fit(self.X_train, self.y_train)
        nn_model, train_losses, val_losses = self._train_neural_network()
        
        self.models['nn'] = {
            'model': nn_model,
            'train_losses': train_losses,
            'val_losses': val_losses
        }
        
        results = self._evaluate_models()
        if (isinstance(self.X_exp, pd.DataFrame) and not X_exp.empty) or \
           (isinstance(self.X_exp, np.ndarray) and X_exp.size > 0):
            self._evaluate_experimental()
        else:
            logging.warning("Skipping experimental evaluation due to empty experimental dataset.")
        self._physical_sanity_check(df)
        self._save_best_model(results)
        return results

    def _train_neural_network(self):
        model = PerovskiteNN(self.X_train.shape[1])
        criterion = nn.MSELoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.3)
    
        train_tensor = torch.FloatTensor(self.X_train)
        train_target = torch.FloatTensor(self.y_train.values).view(-1, 1)
        val_tensor = torch.FloatTensor(self.X_val)
        val_target = torch.FloatTensor(self.y_val.values).view(-1, 1)
    
        # Create a DataLoader for mini-batching
        train_dataset = torch.utils.data.TensorDataset(train_tensor, train_target)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    
        best_loss = float('inf')
        patience_counter = 0
        train_losses, val_losses = [], []
    
        model.train()
        for epoch in range(self.config.nn_epochs):
            epoch_train_loss = 0
            for batch_x, batch_y in train_loader:
                optimizer.zero_grad()
                outputs = model(batch_x)
                loss = criterion(outputs, batch_y)
                loss.backward()
                optimizer.step()
                epoch_train_loss += loss.item() * batch_x.size(0)
            epoch_train_loss /= len(train_dataset)
            train_losses.append(epoch_train_loss)
    
            with torch.no_grad():
                model.eval()
                val_loss = criterion(model(val_tensor), val_target).item()
                val_losses.append(val_loss)
    
            if val_loss < best_loss:
                best_loss = val_loss
                patience_counter = 0
                torch.save(model.state_dict(), self.output_dir/'nn_best.pth')
            else:
                patience_counter += 1
    
            scheduler.step(val_loss)
            if patience_counter >= 10:
                logging.info(f"Early stopping at epoch {epoch}")
                break
    
        model.load_state_dict(torch.load(self.output_dir/'nn_best.pth', weights_only=True))
        return model, train_losses, val_losses

    def _evaluate_models(self):
        metrics = {}
        for name, model in self.models.items():
            if name == 'nn':
                with torch.no_grad():
                    preds = model['model'](torch.FloatTensor(self.X_test)).numpy().flatten()
            else:
                preds = model.predict(self.X_test)
            metrics[name] = {
                'r2': r2_score(self.y_test, preds),
                'mae': mean_absolute_error(self.y_test, preds),
                'rmse': np.sqrt(mean_squared_error(self.y_test, preds)),
                'mape': np.mean(np.abs((self.y_test - preds) / self.y_test)) * 100
            }
        return metrics

    def _evaluate_experimental(self):
        logging.info("Evaluating models on experimental data...")
        metrics = {}
        for name, model in self.models.items():
            if name == 'nn':
                with torch.no_grad():
                    preds = model['model'](torch.FloatTensor(self.X_exp)).numpy().flatten()
            else:
                preds = model.predict(self.X_exp)
            metrics[name] = {
                'r2': r2_score(self.y_exp, preds),
                'mae': mean_absolute_error(self.y_exp, preds),
                'rmse': np.sqrt(mean_squared_error(self.y_exp, preds)),
                'mape': np.mean(np.abs((self.y_exp - preds) / self.y_exp)) * 100
            }
        logging.info("\nExperimental Data Results:")
        for model, scores in metrics.items():
            print(f"{model.upper():<5} RÂ²: {scores['r2']:.3f}  MAE: {scores['mae']:.3f}  "
                  f"RMSE: {scores['rmse']:.3f}  MAPE: {scores['mape']:.1f}%")

    def _physical_sanity_check(self, df):
        logging.info("Performing physical sanity checks on predictions...")
        test_df = df.loc[self.test_indices].copy()
        for name in self.models.keys():
            if name == 'nn':
                with torch.no_grad():
                    test_df[f'pred_{name}'] = self.models[name]['model'](
                        torch.FloatTensor(self.X_test)).numpy().flatten()
            else:
                test_df[f'pred_{name}'] = self.models[name].predict(self.X_test)

            stable_df = test_df[test_df['stable']]
            if not stable_df.empty:
                mean_pred_energy = stable_df[f'pred_{name}'].mean()
                if mean_pred_energy > 0:
                    warnings.warn(f"{name.upper()}: Stable compounds have positive mean predicted formation energy ({mean_pred_energy:.3f} eV/atom)")

            corr = test_df['tolerance_factor'].corr(test_df[f'pred_{name}'])
            if abs(corr) < 0.1:
                warnings.warn(f"{name.upper()}: Low correlation between tolerance factor and predicted formation energy ({corr:.3f})")

    def _cross_validate_nn(self, X, y, cv):
        kf = KFold(n_splits=cv, shuffle=True, random_state=self.config.random_state)
        scores = {'r2': [], 'mae': [], 'rmse': [], 'mape': []}
        all_train_losses, all_val_losses = [], []
    
        for fold, (train_idx, test_idx) in enumerate(kf.split(X)):
            X_train, X_test = X[train_idx], X[test_idx]
            y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]
            X_train, X_val, y_train, y_val = train_test_split(
                X_train, y_train, test_size=0.15, random_state=self.config.random_state)
    
            model, train_losses, val_losses = self._train_neural_network(X_train, X_val, y_train, y_val)
            all_train_losses.append(train_losses)
            all_val_losses.append(val_losses)
    
            with torch.no_grad():
                preds = model(torch.FloatTensor(X_test)).numpy().flatten()
    
            scores['r2'].append(r2_score(y_test, preds))
            scores['mae'].append(mean_absolute_error(y_test, preds))
            scores['rmse'].append(np.sqrt(mean_squared_error(y_test, preds)))
            scores['mape'].append(np.mean(np.abs((y_test - preds) / y_test)) * 100)
    
        # Average the learning curves across folds
        max_len = max(len(lst) for lst in all_train_losses)
        avg_train_losses = np.zeros(max_len)
        avg_val_losses = np.zeros(max_len)
        for i in range(max_len):
            train_vals = [lst[i] for lst in all_train_losses if i < len(lst)]
            val_vals = [lst[i] for lst in all_val_losses if i < len(lst)]
            avg_train_losses[i] = np.mean(train_vals)
            avg_val_losses[i] = np.mean(val_vals)
    
        self.models['nn']['train_losses'] = avg_train_losses
        self.models['nn']['val_losses'] = avg_val_losses
        return {k: np.mean(v) for k, v in scores.items()}

    
    def _save_best_model(self, results):
        best_model = max(results.items(), key=lambda x: x[1]['r2'])
        logging.info(f"Best model: {best_model[0]} (RÂ²={best_model[1]['r2']:.3f})")
        joblib.dump(self.models[best_model[0]], self.output_dir/f"{best_model[0]}_model.pkl")
        joblib.dump(self.scaler, self.output_dir/'scaler.pkl')
        joblib.dump(self.imputer, self.output_dir/'imputer.pkl')

    def validate_models(self, cv=5):
        logging.info("Running cross-validation...")
        X = np.vstack([self.X_train, self.X_val, self.X_test])
        y = pd.concat([self.y_train, self.y_val, self.y_test])

        scorers = {
            'r2': make_scorer(r2_score),
            'mae': make_scorer(mean_absolute_error),
            'rmse': make_scorer(lambda y, p: np.sqrt(mean_squared_error(y, p))),
            'mape': make_scorer(lambda y, p: np.mean(np.abs((y - p) / y)) * 100)
        }

        results = {}
        for name, model in self.models.items():
            if name == 'nn':
                results[name] = self._cross_validate_nn(X, y, cv)
            else:
                cv_results = cross_validate(model, X, y, cv=cv, scoring=scorers)
                results[name] = {k: v.mean() for k, v in cv_results.items() if k.startswith('test_')}
        return results

    def _cross_validate_nn(self, X, y, cv):
        kf = KFold(cv)
        scores = {'r2': [], 'mae': [], 'rmse': [], 'mape': []}

        for train_idx, test_idx in kf.split(X):
            X_train, X_test = X[train_idx], X[test_idx]
            y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]
            X_train, X_val, y_train, y_val = train_test_split(
                X_train, y_train, test_size=0.15, random_state=self.config.random_state)

            model = self._train_single_fold(X_train, X_val, y_train, y_val)
            with torch.no_grad():
                preds = model(torch.FloatTensor(X_test)).numpy().flatten()

            scores['r2'].append(r2_score(y_test, preds))
            scores['mae'].append(mean_absolute_error(y_test, preds))
            scores['rmse'].append(np.sqrt(mean_squared_error(y_test, preds)))
            scores['mape'].append(np.mean(np.abs((y_test - preds) / y_test)) * 100)

        return {k: np.mean(v) for k, v in scores.items()}

    def _train_single_fold(self, X_train, X_val, y_train, y_val):
        model = PerovskiteNN(X_train.shape[1])
        criterion = nn.MSELoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)

        train_tensor = torch.FloatTensor(X_train)
        train_target = torch.FloatTensor(y_train.values).view(-1, 1)
        val_tensor = torch.FloatTensor(X_val)
        val_target = torch.FloatTensor(y_val.values).view(-1, 1)

        best_loss = float('inf')
        patience_counter = 0

        model.train()
        for epoch in range(self.config.nn_epochs):
            optimizer.zero_grad()
            outputs = model(train_tensor)
            loss = criterion(outputs, train_target)
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                model.eval()
                val_loss = criterion(model(val_tensor), val_target).item()
                model.train()

            if val_loss < best_loss:
                best_loss = val_loss
                patience_counter = 0
                torch.save(model.state_dict(), self.output_dir/'nn_fold_best.pth')
            else:
                patience_counter += 1

            scheduler.step(val_loss)
            if patience_counter >= self.config.nn_patience:
                break

        model.load_state_dict(torch.load(self.output_dir/'nn_fold_best.pth', weights_only=True))
        return model

    def chemical_sanity_check(self, df):
        logging.info("Performing chemical sanity checks...")
        valid = (
            df['tolerance_factor'].between(0.75, 1.1) &
            df['octahedral_factor'].between(0.35, 0.75)
        )
        validity_rate = valid.mean()
        if validity_rate < 0.8:
            warnings.warn(f"Low validity rate: {validity_rate:.1%}")
        return validity_rate

    def run_pipeline(self):
        df = self.fetch_data()
        df = self.preprocess_data(df)
        df = self.engineer_features(df)
        self.chemical_sanity_check(df)

        results = self.train_models(df)
        validation = self.validate_models()

        self.visualize_results(df)

        logging.info("\nFinal Results (Computational Test Set):")
        for model, scores in results.items():
            print(f"{model.upper():<5} RÂ²: {scores['r2']:.3f}  MAE: {scores['mae']:.3f}  "
                  f"RMSE: {scores['rmse']:.3f}  MAPE: {scores['mape']:.1f}%")
        return results, validation

    def visualize_results(self, df):
        logging.info("Generating model validation visualizations...")
        try:
            self._plot_parity_with_chemicals(df)
            if 'nn' in self.models:
                self._plot_learning_curve()
            self._plot_residuals(df)
            self._plot_feature_importance(df)
            self._plot_prediction_distributions(df)
            self._plot_top_candidates(df)
            self._plot_feature_correlations(df)
            self._plot_error_by_composition(df)
            self._plot_formation_energy_vs_tolerance(df)
        except Exception as e:
            logging.error(f"Visualization failed: {str(e)}")
            raise

    def _plot_parity_with_chemicals(self, df):
        test_df = df.loc[self.test_indices].copy()
        for name in self.models.keys():
            if name == 'nn':
                with torch.no_grad():
                    test_df[f'pred_{name}'] = self.models[name]['model'](
                        torch.FloatTensor(self.X_test)
                    ).numpy().flatten()
            else:
                test_df[f'pred_{name}'] = self.models[name].predict(self.X_test)

        fig, axes = plt.subplots(2, 2, figsize=(20, 18))
        for idx, (name, ax) in enumerate(zip(self.models.keys(), axes.flatten())):
            if idx >= 4:
                break
            y_true = test_df['formation_energy_per_atom']
            y_pred = test_df[f'pred_{name}']
            r2 = r2_score(y_true, y_pred)
            mae = mean_absolute_error(y_true, y_pred)

            sc = ax.scatter(y_true, y_pred, c=test_df['tolerance_factor'],
                           cmap='viridis', alpha=0.7, vmin=0.7, vmax=1.1)
            ax.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'r--')
            ax.set_title(f"{name.upper()} Parity Plot\n(RÂ²={r2:.3f}, MAE={mae:.3f})")
            ax.set_xlabel("True Formation Energy (eV/atom)")
            ax.set_ylabel("Predicted Formation Energy (eV/atom)")
            plt.colorbar(sc, ax=ax, label='Tolerance Factor')

        plt.tight_layout()
        plt.savefig(self.output_dir/'parity_plots.png', dpi=300)
        plt.close()

    def _plot_learning_curve(self):
        plt.figure(figsize=(10, 6))
        plt.plot(self.models['nn']['train_losses'], label='Train Loss')
        plt.plot(self.models['nn']['val_losses'], label='Validation Loss')
        plt.title("Neural Network Learning Curve")
        plt.xlabel("Epoch")
        plt.ylabel("MSE Loss")
        # Use a log scale to better visualize the loss trends
        plt.yscale('log')
        plt.legend()
        plt.grid(True, which="both", ls="--")
        plt.savefig(self.output_dir/'nn_learning_curve.png', dpi=300)
        plt.close()

    def _plot_residuals(self, df):
        test_df = df.loc[self.test_indices].copy()
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        for idx, (name, ax) in enumerate(zip(self.models.keys(), axes.flatten())):
            if idx >= 4:
                break
            if name == 'nn':
                with torch.no_grad():
                    preds = self.models[name]['model'](torch.FloatTensor(self.X_test)).numpy().flatten()
            else:
                preds = self.models[name].predict(self.X_test)
            residuals = test_df['formation_energy_per_atom'] - preds
            ax.scatter(preds, residuals, alpha=0.5)
            ax.axhline(0, color='red', linestyle='--')
            ax.set_title(f"{name.upper()} Residuals")
            ax.set_xlabel("Predicted Values")
            ax.set_ylabel("Residuals")
            ax.grid(True)
        plt.tight_layout()
        plt.savefig(self.output_dir/'residual_analysis.png', dpi=300)
        plt.close()

    def _plot_feature_importance(self, df):
        plt.figure(figsize=(15, 10))
        features = pd.get_dummies(df[['A', 'B', 'X', 'source']]).columns.tolist() + [
            'r_A', 'r_B', 'r_X', 'tolerance_factor', 'octahedral_factor',
            'X_A', 'X_B', 'X_X', 'energy_above_hull', 'delta_X', 'volume_per_atom']
        for idx, name in enumerate(['xgb', 'rf', 'gbr'], 1):
            if name in self.models and hasattr(self.models[name], 'feature_importances_'):
                importances = self.models[name].feature_importances_
                indices = np.argsort(importances)[::-1]
                plt.subplot(1, 3, idx)
                plt.title(f"{name.upper()} Feature Importance")
                plt.barh(range(10), importances[indices][:10], align='center')
                plt.yticks(range(10), [features[i] for i in indices[:10]])
                plt.xlabel('Relative Importance')
                plt.gca().invert_yaxis()
        plt.tight_layout()
        plt.savefig(self.output_dir/'feature_importance.png', dpi=300)
        plt.close()

    def _plot_prediction_distributions(self, df):
        test_df = df.loc[self.test_indices].copy()
        plt.figure(figsize=(12, 8))
        for name in self.models.keys():
            if name == 'nn':
                with torch.no_grad():
                    preds = self.models[name]['model'](torch.FloatTensor(self.X_test)).numpy().flatten()
            else:
                preds = self.models[name].predict(self.X_test)
            plt.hist(preds, alpha=0.5, bins=30, label=name.upper())
        plt.hist(test_df['formation_energy_per_atom'], bins=30, alpha=0.3, label='Actual', color='black')
        plt.title("Prediction Distributions vs Actual Values")
        plt.xlabel("Formation Energy (eV/atom)")
        plt.ylabel("Frequency")
        plt.legend()
        plt.savefig(self.output_dir/'prediction_distributions.png', dpi=300)
        plt.close()

    def _plot_top_candidates(self, df):
        test_df = df.loc[self.test_indices].copy()
        for name in self.models.keys():
            if name == 'nn':
                with torch.no_grad():
                    test_df[f'pred_{name}'] = self.models[name]['model'](
                        torch.FloatTensor(self.X_test)).numpy().flatten()
            else:
                test_df[f'pred_{name}'] = self.models[name].predict(self.X_test)

        top_candidates = pd.concat([
            test_df.nsmallest(10, f'pred_{name}') for name in self.models.keys()
        ]).drop_duplicates('formula')
        plt.figure(figsize=(15, 10))
        pd.plotting.parallel_coordinates(
            top_candidates[['formula', 'tolerance_factor', 'octahedral_factor'] + 
                          [f'pred_{name}' for name in self.models.keys()]],
            'formula',
            colormap='viridis'
        )
        plt.title("Top Candidates - Multi-Model Comparison")
        plt.xlabel("Features and Model Predictions")
        plt.ylabel("Normalized Values")
        plt.xticks(rotation=45)
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(self.output_dir/'top_candidates.png', dpi=300)
        plt.close()

    def _plot_feature_correlations(self, df):
        plt.figure(figsize=(10, 8))
        features = df[['r_A', 'r_B', 'r_X', 'tolerance_factor', 'octahedral_factor',
                      'X_A', 'X_B', 'X_X', 'energy_above_hull', 'delta_X', 'volume_per_atom']]
        sns.heatmap(features.corr(), annot=True, cmap='coolwarm', vmin=-1, vmax=1)
        plt.title("Feature Correlations")
        plt.savefig(self.output_dir/'feature_correlations.png', dpi=300)
        plt.close()

    def _plot_error_by_composition(self, df):
        test_df = df.loc[self.test_indices].copy()
        for name in self.models.keys():
            if name == 'nn':
                with torch.no_grad():
                    preds = self.models[name]['model'](torch.FloatTensor(self.X_test)).numpy().flatten()
            else:
                preds = self.models[name].predict(self.X_test)
            test_df[f'error_{name}'] = np.abs(test_df['formation_energy_per_atom'] - preds)
        plt.figure(figsize=(12, 6))
        sns.boxplot(x='A', y=f'error_xgb', data=test_df)
        plt.title(f"Prediction Errors by A-site Element (XGB)")
        plt.savefig(self.output_dir/'error_by_A_xgb.png', dpi=300)
        plt.close()

    def _plot_formation_energy_vs_tolerance(self, df):
        test_df = df.loc[self.test_indices].copy()
        for name in self.models.keys():
            if name == 'nn':
                with torch.no_grad():
                    test_df[f'pred_{name}'] = self.models[name]['model'](
                        torch.FloatTensor(self.X_test)).numpy().flatten()
            else:
                test_df[f'pred_{name}'] = self.models[name].predict(self.X_test)

        plt.figure(figsize=(12, 6))
        handles = {}
        for name in self.models.keys():
            scatter = sns.scatterplot(
                x='tolerance_factor', 
                y=f'pred_{name}', 
                hue='stable', 
                size='energy_above_hull',
                data=test_df, 
                alpha=0.7
            )
            handles[name.upper()] = scatter

        plt.axvline(x=1.0, color='red', linestyle='--', label='Ideal Tolerance Factor')
        plt.title("Predicted Formation Energy vs Tolerance Factor")
        plt.xlabel("Tolerance Factor")
        plt.ylabel("Predicted Formation Energy (eV/atom)")

        ax = plt.gca()
        hue_handles, hue_labels = ax.get_legend_handles_labels()
        custom_handles = [plt.scatter([], [], label=name.upper(), color='gray') for name in self.models.keys()]
        custom_handles.append(plt.axvline(x=0, color='red', linestyle='--', label='Ideal Tolerance Factor'))
        
        final_handles = custom_handles + hue_handles
        final_labels = [h.get_label() for h in custom_handles] + hue_labels
        
        plt.legend(handles=final_handles, labels=final_labels, title="Legend")
        plt.savefig(self.output_dir/'formation_energy_vs_tolerance.png', dpi=300)
        plt.close()

    def ensemble_predict(self, X):
        weights = {'xgb': 0.3, 'rf': 0.2, 'gbr': 0.2, 'nn': 0.3}
        preds = np.zeros(len(X))
        for name, weight in weights.items():
            if name == 'nn':
                with torch.no_grad():
                    pred = self.models[name]['model'](torch.FloatTensor(X)).numpy().flatten()
            else:
                pred = self.models[name].predict(X)
            preds += weight * pred
        return preds

if __name__ == "__main__":
    config = PipelineConfig(val_size=0.15)
    pipeline = PerovskitePipeline(config)
    results, validation = pipeline.run_pipeline()

Retrieving SummaryDoc documents:   0%|          | 0/4700 [00:00<?, ?it/s]

Retrieving SummaryDoc documents:   0%|          | 0/8 [00:00<?, ?it/s]