In [None]:
"""
# Materials Selection for Specific Applications
# ============================================
#
# This project teaches students to filter materials based on desired properties for 
# specific applications. Students will learn to:
# 1. Define property criteria for specific applications
# 2. Query the Materials Project database with appropriate filters
# 3. Rank materials based on performance metrics
# 4. Visualize the top candidates and their properties
"""

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mp_api.client import MPRester
import plotly.express as px
import plotly.graph_objects as go
from pymatgen.core import Structure
import crystal_toolkit
from IPython.display import display
from sklearn.preprocessing import MinMaxScaler
import warnings
warnings.filterwarnings('ignore')

# Set your API key here
API_KEY = "qV9zEVIo7ibhcTvvWMa5ddlYxIlDWRqV"  # Students will replace with their actual API key

def fetch_materials_data(criteria, properties):
    """
    Fetch materials data from the Materials Project API based on specified criteria.
    
    Args:
        criteria (dict): Dictionary containing search criteria
        properties (list): List of properties to retrieve
        
    Returns:
        pd.DataFrame: DataFrame containing materials data
    """
    print(f"Fetching materials with criteria: {criteria}")
    
    with MPRester(API_KEY) as mpr:
        # Query for materials based on provided criteria
        docs = mpr.materials.summary.search(
            **criteria,
            fields=properties
        )
        
        # Extract data to a list of dictionaries
        data = []
        for doc in docs:
            material_dict = {}
            
            # Extract available properties from document
            for prop in properties:
                if hasattr(doc, prop):
                    material_dict[prop] = getattr(doc, prop)
                else:
                    material_dict[prop] = None
            
            # Get crystal system from symmetry object if available and requested
            if 'symmetry' in properties and hasattr(doc, 'symmetry'):
                try:
                    material_dict['crystal_system'] = doc.symmetry.crystal_system
                except:
                    material_dict['crystal_system'] = "Unknown"
            
            # Calculate nelements from formula if needed
            if 'formula_pretty' in material_dict:
                # Extract elements from formula
                import re
                elements = re.findall(r'[A-Z][a-z]?', material_dict.get('formula_pretty', ''))
                material_dict['nelements'] = len(set(elements))
            
            data.append(material_dict)
        
        # Convert to DataFrame
        df = pd.DataFrame(data)
        
        print(f"Found {len(df)} materials matching criteria")
        
        return df

def rank_materials(df, performance_metrics, weights=None, higher_is_better=None):
    """
    Rank materials based on performance metrics.
    
    Args:
        df (pd.DataFrame): DataFrame containing materials data
        performance_metrics (list): List of metrics to use for ranking
        weights (list, optional): Weights for each metric. Defaults to equal weights.
        higher_is_better (list, optional): Boolean list indicating if higher values are better.
                                           Defaults to True for all metrics.
    
    Returns:
        pd.DataFrame: DataFrame with added ranking columns
    """
    print(f"Ranking materials based on: {performance_metrics}")
    
    # Make a copy to avoid modifying the original
    ranked_df = df.copy()
    
    # Check if all performance metrics exist in the DataFrame
    for metric in performance_metrics:
        if metric not in ranked_df.columns:
            print(f"Warning: {metric} not found in data, skipping in ranking")
            performance_metrics.remove(metric)
    
    if not performance_metrics:
        print("No valid performance metrics found for ranking")
        return ranked_df
    
    # Set default weights if not provided
    if weights is None:
        weights = [1.0] * len(performance_metrics)
    
    # Set default higher_is_better if not provided
    if higher_is_better is None:
        higher_is_better = [True] * len(performance_metrics)
    
    # Create a normalized score for each metric
    scaler = MinMaxScaler()
    
    for metric, is_higher_better in zip(performance_metrics, higher_is_better):
        # Skip metrics with too many missing values
        if ranked_df[metric].isna().sum() > 0.5 * len(ranked_df):
            print(f"Skipping {metric} due to too many missing values")
            continue
            
        # Fill remaining missing values with median
        ranked_df[metric] = ranked_df[metric].fillna(ranked_df[metric].median())
        
        # Extract the values for normalization
        values = ranked_df[metric].values.reshape(-1, 1)
        
        # Normalize the values
        normalized_values = scaler.fit_transform(values).flatten()
        
        # If lower values are better, invert the normalized values
        if not is_higher_better:
            normalized_values = 1 - normalized_values
        
        # Add normalized values to DataFrame
        ranked_df[f"{metric}_normalized"] = normalized_values
    
    # Calculate the weighted sum of normalized values
    normalized_metrics = [f"{metric}_normalized" for metric in performance_metrics 
                          if f"{metric}_normalized" in ranked_df.columns]
    
    if not normalized_metrics:
        print("No metrics could be normalized for ranking")
        return ranked_df
    
    # Adjust weights to match available normalized metrics
    adjusted_weights = weights[:len(normalized_metrics)]
    
    # Normalize weights to sum to 1
    adjusted_weights = [w / sum(adjusted_weights) for w in adjusted_weights]
    
    # Calculate performance score
    ranked_df['performance_score'] = 0
    for metric, weight in zip(normalized_metrics, adjusted_weights):
        ranked_df['performance_score'] += weight * ranked_df[metric]
    
    # Sort by performance score
    ranked_df = ranked_df.sort_values('performance_score', ascending=False)
    
    # Add rank column
    ranked_df['rank'] = range(1, len(ranked_df) + 1)
    
    return ranked_df

def visualize_top_materials(df, n=10, properties=None, title="Top Materials"):
    """
    Visualize the top n materials and their properties.
    
    Args:
        df (pd.DataFrame): DataFrame containing ranked materials
        n (int, optional): Number of top materials to visualize. Defaults to 10.
        properties (list, optional): Properties to visualize. Defaults to None.
        title (str, optional): Title for the visualization. Defaults to "Top Materials".
    """
    # Ensure there are enough materials to visualize
    n = min(n, len(df))
    
    if n == 0:
        print("No materials available to visualize")
        return
    
    # Select top n materials
    top_df = df.head(n).copy()
    
    print(f"\nTop {n} materials for {title}:")
    display_cols = ['material_id', 'formula_pretty', 'performance_score', 'rank']
    if properties:
        display_cols.extend([p for p in properties if p in top_df.columns])
    
    # Display the top materials
    print(top_df[display_cols].to_string(index=False))
    
    # Create a bar chart of performance scores
    plt.figure(figsize=(12, 6))
    bars = plt.barh(top_df['formula_pretty'], top_df['performance_score'], color='skyblue')
    plt.xlabel('Performance Score')
    plt.ylabel('Material')
    plt.title(f'Top {n} Materials for {title}')
    plt.gca().invert_yaxis()  # Display highest score at the top
    
    # Add material_id as text on bars
    for i, bar in enumerate(bars):
        plt.text(bar.get_width() * 0.02, bar.get_y() + bar.get_height()/2, 
                f"{top_df.iloc[i]['material_id']}", 
                va='center', color='black', fontsize=8)
    
    plt.tight_layout()
    plt.savefig(f'top_materials_{title.replace(" ", "_").lower()}.png')
    plt.show()
    
    # Create interactive visualization with Plotly if properties are provided
    if properties and len(properties) >= 2:
        # Select first two properties for scatter plot
        x_prop = properties[0]
        y_prop = properties[1]
        
        if x_prop in top_df.columns and y_prop in top_df.columns:
            # Create scatter plot
            fig = px.scatter(
                top_df, 
                x=x_prop, 
                y=y_prop,
                text='formula_pretty',
                size='performance_score',
                color='performance_score',
                hover_name='formula_pretty',
                hover_data=['material_id', 'performance_score', 'rank'],
                labels={
                    x_prop: x_prop.replace('_', ' ').title(),
                    y_prop: y_prop.replace('_', ' ').title(),
                    'performance_score': 'Performance Score'
                },
                title=f'{y_prop.replace("_", " ").title()} vs {x_prop.replace("_", " ").title()} for Top Materials'
            )
            fig.update_traces(textposition='top center')
            fig.update_layout(template='plotly_white')
            fig.show()

def visualize_structure(structure, title=None):
    """
    Visualize a crystal structure using crystal-toolkit
    
    Args:
        structure: Pymatgen Structure object
        title: Title for the visualization
    """
    # Display the structure directly - crystal-toolkit will render it
    if title:
        print(f"\n{title} Structure:")
    
    # Display the structure
    display(structure)
    
    # Print additional structure information
    print(f"Formula: {structure.composition.reduced_formula}")
    try:
        print(f"Space Group: {structure.get_space_group_info()}")
    except:
        print("Space Group: Unable to determine")
    print(f"Lattice Parameters: a={structure.lattice.a:.4f}, b={structure.lattice.b:.4f}, c={structure.lattice.c:.4f}")
    print(f"Lattice Angles: α={structure.lattice.alpha:.2f}°, β={structure.lattice.beta:.2f}°, γ={structure.lattice.gamma:.2f}°")
    print(f"Volume: {structure.volume:.2f} Å³")
    print(f"Density: {structure.density:.4f} g/cm³")
    
    # Save structure to CIF file if needed
    filename = f"{structure.composition.reduced_formula}.cif"
    structure.to(filename)
    print(f"Saved structure to {filename}")

def create_radar_chart(df, properties, title="Material Properties Comparison"):
    """
    Create a radar chart comparing multiple properties across top materials.
    
    Args:
        df (pd.DataFrame): DataFrame containing materials data
        properties (list): List of properties to compare
        title (str, optional): Title for the chart. Defaults to "Material Properties Comparison".
    """
    # Ensure all properties exist in the DataFrame
    properties = [p for p in properties if p in df.columns]
    
    if len(properties) < 3:
        print("Not enough valid properties for radar chart (need at least 3)")
        return
    
    # Create a figure
    fig = go.Figure()
    
    # Prepare the radar chart
    for i, row in df.iterrows():
        # Normalize values to 0-1 scale for each property
        values = []
        for prop in properties:
            # Skip if property value is missing
            if pd.isna(row[prop]):
                values.append(0)
                continue
                
            # Get min and max for normalization
            prop_min = df[prop].min()
            prop_max = df[prop].max()
            
            # Avoid division by zero
            if prop_min == prop_max:
                values.append(0.5)
                continue
                
            # Normalize
            norm_value = (row[prop] - prop_min) / (prop_max - prop_min)
            values.append(norm_value)
        
        # Add trace for this material
        fig.add_trace(go.Scatterpolar(
            r=values + [values[0]],  # Close the polygon
            theta=properties + [properties[0]],  # Close the polygon
            fill='toself',
            name=f"{row['formula_pretty']} ({row['material_id']})"
        ))
    
    # Update layout
    fig.update_layout(
        title=title,
        polar=dict(
            radialaxis=dict(
                visible=True,
                range=[0, 1]
            )
        ),
        showlegend=True
    )
    
    fig.show()

def transparent_conductors():
    """
    Find and rank materials suitable for transparent conductor applications.
    
    Transparent conductors need:
    1. Wide band gap (>3.0 eV) for optical transparency
    2. Low electrical resistivity
    3. Good stability (low energy above hull)
    """
    print("\nFinding materials for transparent conductor applications...")
    
    # Define search criteria
    criteria = {
        "band_gap": (3.0, 10.0),  # Wide band gap for transparency
        "is_stable": True  # Ensure stability
    }
    
    # Define properties to retrieve
    properties = [
        "material_id",
        "formula_pretty",
        "band_gap",
        "energy_above_hull",
        "formation_energy_per_atom",
        "density",
        "volume",
        "symmetry",
        "structure",
        "total_magnetization"
    ]
    
    # Fetch data
    tc_df = fetch_materials_data(criteria, properties)
    
    if len(tc_df) == 0:
        print("No suitable materials found for transparent conductors")
        return None
    
    # Define performance metrics for transparent conductors
    # For transparent conductors:
    # - Higher band gap is better (more transparent)
    # - Lower energy_above_hull is better (more stable)
    # - Lower formation_energy_per_atom is better (more formable)
    performance_metrics = ['band_gap', 'energy_above_hull', 'formation_energy_per_atom']
    weights = [0.6, 0.3, 0.1]  # Band gap is most important, then stability
    higher_is_better = [True, False, False]  # Higher band gap is better, lower energy and formation energy is better
    
    # Rank materials
    ranked_tc_df = rank_materials(tc_df, performance_metrics, weights, higher_is_better)
    
    # Visualize top materials
    visualize_top_materials(
        ranked_tc_df, 
        n=5, 
        properties=['band_gap', 'energy_above_hull', 'formation_energy_per_atom'],
        title="Transparent Conductors"
    )
    
    # Create radar chart for top 5 materials
    create_radar_chart(
        ranked_tc_df.head(5), 
        ['band_gap', 'energy_above_hull', 'formation_energy_per_atom', 'density'],
        title="Transparent Conductor Properties Comparison"
    )
    
    # Visualize structure of the top material
    if len(ranked_tc_df) > 0:
        top_material = ranked_tc_df.iloc[0]
        print(f"\nVisualizing structure of top transparent conductor material:")
        print(f"{top_material['formula_pretty']} ({top_material['material_id']})")
        print(f"Band gap: {top_material['band_gap']:.2f} eV, Stability: {top_material['energy_above_hull']:.3f} eV")
        visualize_structure(top_material['structure'], top_material['formula_pretty'])
    
    return ranked_tc_df

def thermoelectric_materials():
    """
    Find and rank materials suitable for thermoelectric applications.
    
    Good thermoelectric materials need:
    1. Narrow band gap (0.1-1.5 eV)
    2. Complex crystal structure (correlation with lower thermal conductivity)
    3. Good stability (low energy above hull)
    """
    print("\nFinding materials for thermoelectric applications...")
    
    # Define search criteria
    criteria = {
        "band_gap": (0.1, 1.5),  # Narrow band gap for electronic conductivity
        "is_stable": True         # Ensure stability
    }
    
    # Note: nelements filtering will be done after retrieval since it's not directly
    # supported in the search criteria
    
    # Define properties to retrieve
    properties = [
        "material_id",
        "formula_pretty",
        "band_gap",
        "energy_above_hull",
        "formation_energy_per_atom",
        "density",
        "volume",
        "symmetry",
        "structure",
        "nelements",
        "total_magnetization"
    ]
    
    # Fetch data
    te_df = fetch_materials_data(criteria, properties)
    
    if len(te_df) == 0:
        print("No suitable materials found for thermoelectric applications")
        return None
        
    # Apply nelements filter after retrieval
    te_df = te_df[(te_df['nelements'] >= 3) & (te_df['nelements'] <= 10)]
    
    if len(te_df) == 0:
        print("No materials found after filtering for complex compositions (3-10 elements)")
        return None
        
    print(f"Found {len(te_df)} materials after filtering for complex compositions")
    
    # Define performance metrics for thermoelectrics
    # For thermoelectrics:
    # - Moderate band gap is best (0.3-0.8 eV), adjust normalized score
    # - Higher nelements is better (proxy for lower thermal conductivity)
    # - Lower energy_above_hull is better (more stable)
    
    # Adjust band gap score to prefer values in the optimal range
    te_df['band_gap_score'] = te_df['band_gap'].apply(
        lambda x: 1.0 - abs((x - 0.5) / 0.5) if x <= 1.0 else 0.0
    )
    
    performance_metrics = ['band_gap_score', 'nelements', 'energy_above_hull']
    weights = [0.4, 0.4, 0.2]  # Band gap and complexity equally important, then stability
    higher_is_better = [True, True, False]  # Higher band gap score and nelements better, lower energy better
    
    # Rank materials
    ranked_te_df = rank_materials(te_df, performance_metrics, weights, higher_is_better)
    
    # Visualize top materials
    visualize_top_materials(
        ranked_te_df, 
        n=5, 
        properties=['band_gap', 'nelements', 'energy_above_hull', 'band_gap_score'],
        title="Thermoelectric Materials"
    )
    
    # Create scatter plot of band gap vs nelements
    fig = px.scatter(
        ranked_te_df.head(20),
        x='band_gap',
        y='nelements',
        color='performance_score',
        size='volume',
        hover_name='formula_pretty',
        hover_data=['material_id', 'energy_above_hull'],
        labels={
            'band_gap': 'Band Gap (eV)',
            'nelements': 'Number of Elements',
            'performance_score': 'Performance Score',
            'volume': 'Unit Cell Volume (Å³)'
        },
        title='Band Gap vs Complexity for Top Thermoelectric Materials'
    )
    fig.update_layout(template='plotly_white')
    fig.show()
    
    # Visualize structure of the top material
    if len(ranked_te_df) > 0:
        top_material = ranked_te_df.iloc[0]
        print(f"\nVisualizing structure of top thermoelectric material:")
        print(f"{top_material['formula_pretty']} ({top_material['material_id']})")
        print(f"Band gap: {top_material['band_gap']:.2f} eV, Elements: {top_material['nelements']}")
        visualize_structure(top_material['structure'], top_material['formula_pretty'])
    
    return ranked_te_df

def magnetic_materials():
    """
    Find and rank materials suitable for magnetic applications.
    
    Good magnetic materials need:
    1. High total magnetization
    2. Good stability (low energy above hull)
    3. Reasonable density and formability
    """
    print("\nFinding materials for magnetic applications...")
    
    # Define search criteria
    criteria = {
        "total_magnetization": (1.0, float('inf')),  # Must have non-zero magnetization
        "is_stable": True                            # Ensure stability
    }
    
    # Define properties to retrieve
    properties = [
        "material_id",
        "formula_pretty",
        "total_magnetization",
        "energy_above_hull",
        "formation_energy_per_atom",
        "density",
        "volume",
        "symmetry",
        "structure"
    ]
    
    # Fetch data
    mag_df = fetch_materials_data(criteria, properties)
    
    if len(mag_df) == 0:
        print("No suitable materials found for magnetic applications")
        return None
    
    # Define performance metrics for magnetic materials
    # For magnetic materials:
    # - Higher total_magnetization is better
    # - Lower energy_above_hull is better (more stable)
    # - Lower formation_energy_per_atom is better (more formable)
    performance_metrics = ['total_magnetization', 'energy_above_hull', 'formation_energy_per_atom']
    weights = [0.7, 0.2, 0.1]  # Magnetization is most important, then stability
    higher_is_better = [True, False, False]  # Higher magnetization is better, lower energies are better
    
    # Rank materials
    ranked_mag_df = rank_materials(mag_df, performance_metrics, weights, higher_is_better)
    
    # Visualize top materials
    visualize_top_materials(
        ranked_mag_df, 
        n=5, 
        properties=['total_magnetization', 'energy_above_hull', 'formation_energy_per_atom'],
        title="Magnetic Materials"
    )
    
    # Create scatter plot of magnetization vs stability
    fig = px.scatter(
        ranked_mag_df.head(20),
        x='total_magnetization',
        y='energy_above_hull',
        color='performance_score',
        size='total_magnetization',
        hover_name='formula_pretty',
        hover_data=['material_id', 'formation_energy_per_atom'],
        labels={
            'total_magnetization': 'Total Magnetization (μB)',
            'energy_above_hull': 'Energy Above Hull (eV)',
            'performance_score': 'Performance Score'
        },
        title='Magnetization vs Stability for Top Magnetic Materials'
    )
    fig.update_layout(template='plotly_white')
    fig.show()
    
    # Visualize structure of the top material
    if len(ranked_mag_df) > 0:
        top_material = ranked_mag_df.iloc[0]
        print(f"\nVisualizing structure of top magnetic material:")
        print(f"{top_material['formula_pretty']} ({top_material['material_id']})")
        print(f"Magnetization: {top_material['total_magnetization']:.2f} μB, Stability: {top_material['energy_above_hull']:.3f} eV")
        visualize_structure(top_material['structure'], top_material['formula_pretty'])
    
    return ranked_mag_df

def custom_application(name, criteria, properties, performance_metrics, weights=None, higher_is_better=None):
    """
    Find and rank materials for a custom application.
    
    Args:
        name (str): Name of the application
        criteria (dict): Search criteria for the Materials Project API
        properties (list): Properties to retrieve
        performance_metrics (list): Metrics to use for ranking
        weights (list, optional): Weights for each metric. Defaults to None.
        higher_is_better (list, optional): Whether higher values are better. Defaults to None.
        
    Returns:
        pd.DataFrame: DataFrame with ranked materials
    """
    print(f"\nFinding materials for {name} applications...")
    
    # Fetch data
    app_df = fetch_materials_data(criteria, properties)
    
    if len(app_df) == 0:
        print(f"No suitable materials found for {name} applications")
        return None
    
    # Rank materials
    ranked_app_df = rank_materials(app_df, performance_metrics, weights, higher_is_better)
    
    # Visualize top materials
    visualize_top_materials(
        ranked_app_df, 
        n=5, 
        properties=performance_metrics,
        title=f"{name} Materials"
    )
    
    # Create radar chart for top 5 materials if enough metrics
    if len(performance_metrics) >= 3:
        create_radar_chart(
            ranked_app_df.head(5), 
            performance_metrics,
            title=f"{name} Properties Comparison"
        )
    
    # Visualize structure of the top material
    if len(ranked_app_df) > 0:
        top_material = ranked_app_df.iloc[0]
        print(f"\nVisualizing structure of top {name} material:")
        print(f"{top_material['formula_pretty']} ({top_material['material_id']})")
        visualize_structure(top_material['structure'], top_material['formula_pretty'])
    
    return ranked_app_df

def main():
    """Main function to run the applications exploration"""
    # Set up for visualizations
    try:
        import crystal_toolkit.helpers.jupyter as ctj
        ctj.init_jupyter_mode()
    except:
        print("Note: For best visualization experience, ensure crystal-toolkit is properly initialized.")
    
    print("Materials Selection for Specific Applications")
    print("===========================================")
    
    # Run application-specific explorations
    tc_df = transparent_conductors()
    te_df = thermoelectric_materials()
    mag_df = magnetic_materials()
    
    # Example of a custom application - superconductors
    # Uncomment to run
    """
    sc_criteria = {
        "elements": ["Cu", "O"],  # Looking for copper oxide superconductors
        "is_stable": True
    }
    
    sc_properties = [
        "material_id",
        "formula_pretty",
        "band_gap",
        "energy_above_hull", 
        "formation_energy_per_atom",
        "density",
        "symmetry",
        "structure"
    ]
    
    sc_metrics = ["band_gap", "energy_above_hull"]
    sc_weights = [0.6, 0.4]
    sc_higher_is_better = [False, False]  # Lower band gap and energy above hull
    
    sc_df = custom_application(
        "Superconductor", 
        sc_criteria, 
        sc_properties, 
        sc_metrics, 
        sc_weights, 
        sc_higher_is_better
    )
    """
    
    print("\nAll material explorations completed successfully!")

if __name__ == "__main__":
    main()

# Assignment Tasks:
# 1. Obtain your Materials Project API key and replace "YOUR_API_KEY" in the code
# 2. Run the code to explore materials for different applications
# 3. Choose one of the applications and modify the criteria to find better candidates:
#    - Adjust the property filters (e.g., band gap range)
#    - Change the performance metrics or their weights
# 4. Add a new application of your choice:
#    - Define appropriate criteria for the application
#    - Choose relevant properties to query
#    - Set up appropriate performance metrics and weights
# 5. Add a new visualization to better analyze the materials
# 6. Write a brief report discussing your findings:
#    - Compare the top materials for your chosen application
#    - Discuss structure-property relationships
#    - Explain why your performance metrics are appropriate for the application