In [None]:
!pip install pymatgen mp-api megnet tensorflow scikit-learn pandas numpy matplotlib seaborn

Collecting pymatgen
  Downloading pymatgen-2025.2.18-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting mp-api
  Downloading mp_api-0.45.3-py3-none-any.whl.metadata (2.3 kB)
Collecting megnet
  Downloading megnet-1.3.2-py3-none-any.whl.metadata (23 kB)
Collecting monty>=2025.1.9 (from pymatgen)
  Downloading monty-2025.1.9-py3-none-any.whl.metadata (3.6 kB)
Collecting palettable>=3.3.3 (from pymatgen)
  Downloading palettable-3.3.3-py2.py3-none-any.whl.metadata (3.3 kB)
Collecting pybtex>=0.24.0 (from pymatgen)
  Downloading pybtex-0.24.0-py2.py3-none-any.whl.metadata (2.0 kB)
Collecting ruamel.yaml>=0.17.0 (from pymatgen)
  Downloading ruamel.yaml-0.18.10-py3-none-any.whl.metadata (23 kB)
Collecting spglib>=2.5 (from pymatgen)
  Downloading spglib-2.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Collecting uncertainties>=3.1.4 (from pymatgen)
  Downloading uncertainties-3.2.2-py3-none-any.whl.metadata (6.9 kB)
Colle

In [None]:
# Import necessary libraries
import os
import pickle
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.neighbors import NearestNeighbors
import warnings
warnings.filterwarnings('ignore')

# Pymatgen imports
from pymatgen.core import Structure
from mp_api.client import MPRester

# Try to import MEGNet
try:
    from megnet.models import MEGNetModel
    from megnet.utils.models import load_model
    from megnet.utils.descriptor import MEGNetDescriptor
    MEGNET_AVAILABLE = True
except ImportError:
    print("MEGNet not available, will use fallback feature extraction")
    MEGNET_AVAILABLE = False

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

print("Libraries imported successfully!")

# Materials Project API key
MP_API_KEY = os.getenv("API_KEY")

ModuleNotFoundError: No module named 'pymatgen'

In [None]:
def collect_materials_data(api_key, num_materials=100):
    """
    Collect materials data from the Materials Project API with improved diversity.

    Args:
        api_key (str): Materials Project API key
        num_materials (int): Maximum number of materials to retrieve

    Returns:
        list: List of materials with their structures and properties
    """
    # Define fields to retrieve
    fields = [
        "material_id",
        "structure",
        "elements",
        "band_gap",
        "formation_energy_per_atom",
        "energy_above_hull"
    ]

    # Connect to the Materials Project API
    with MPRester(api_key) as mpr:
        all_materials = []

        # Binary and ternary combinations for more diverse compounds
        element_combinations = [
            ["Fe", "O"],       # Iron oxide
            ["Al", "O"],       # Aluminum oxide
            ["Ti", "O"],       # Titanium oxide
            ["Ni", "O"],       # Nickel oxide
            ["Cu", "O"],       # Copper oxide
            ["Si", "O"],       # Silicon oxide
            ["Li", "O"],       # Lithium oxide
            ["Fe", "Si"],      # Iron silicide
            ["Al", "Ni"],      # Aluminum nickel
            ["Ti", "Si", "O"], # Titanium silicate
            ["Fe", "Ti", "O"], # Iron titanate
            ["Li", "Fe", "O"], # Lithium ferrite
            ["Li", "Ni", "O"], # Lithium nickelate
            ["Fe", "Al", "O"]  # Iron aluminate
        ]

        # First collect compound materials (better properties for applications)
        materials_per_combo = max(2, num_materials // (len(element_combinations) + 8))

        for elements in element_combinations:
            try:
                print(f"Searching for materials with {', '.join(elements)}...")

                # Search for stable materials with these elements
                docs = mpr.materials.summary.search(
                    elements=elements,
                    energy_above_hull=(0, 0.1),  # Relatively stable materials
                    fields=fields
                )

                # Get materials and add to our collection
                element_materials = list(docs)
                materials_to_add = min(len(element_materials), materials_per_combo)

                print(f"Found {len(element_materials)} materials with {', '.join(elements)}, adding {materials_to_add}")

                # Convert MPDataDoc objects to dictionaries and add to materials list
                for mat in element_materials[:materials_to_add]:
                    material_dict = {
                        "material_id": mat.material_id,
                        "structure": mat.structure,
                        "band_gap": mat.band_gap,
                        "formation_energy_per_atom": mat.formation_energy_per_atom,
                        "energy_above_hull": mat.energy_above_hull,
                        "elements": mat.elements
                    }
                    all_materials.append(material_dict)

            except Exception as e:
                print(f"Error searching for {', '.join(elements)} materials: {e}")
                continue

        # Then add some single-element materials (for comparison)
        single_elements = ["Fe", "Al", "Ti", "Ni", "Cu", "Si", "O", "Li"]

        for element in single_elements:
            try:
                print(f"Searching for materials with {element}...")

                # Search for stable materials with this element
                docs = mpr.materials.summary.search(
                    elements=[element],
                    energy_above_hull=(0, 0.1),  # Relatively stable materials
                    fields=fields
                )

                # Get materials and add to our collection
                element_materials = list(docs)
                materials_to_add = min(len(element_materials), materials_per_combo // 2)  # Fewer single element materials

                print(f"Found {len(element_materials)} materials with {element}, adding {materials_to_add}")

                # Convert MPDataDoc objects to dictionaries
                for mat in element_materials[:materials_to_add]:
                    material_dict = {
                        "material_id": mat.material_id,
                        "structure": mat.structure,
                        "band_gap": mat.band_gap,
                        "formation_energy_per_atom": mat.formation_energy_per_atom,
                        "energy_above_hull": mat.energy_above_hull,
                        "elements": mat.elements
                    }
                    all_materials.append(material_dict)

            except Exception as e:
                print(f"Error searching for {element} materials: {e}")
                continue

        print(f"Retrieved a total of {len(all_materials)} materials")
        return all_materials

# Example usage
materials = collect_materials_data(MP_API_KEY, num_materials=100)

Searching for materials with Fe, O...


Retrieving SummaryDoc documents:   0%|          | 0/5591 [00:00<?, ?it/s]

Found 5591 materials with Fe, O, adding 4
Searching for materials with Al, O...


Retrieving SummaryDoc documents:   0%|          | 0/2043 [00:00<?, ?it/s]

Found 2043 materials with Al, O, adding 4
Searching for materials with Ti, O...


Retrieving SummaryDoc documents:   0%|          | 0/3632 [00:00<?, ?it/s]

Found 3632 materials with Ti, O, adding 4
Searching for materials with Ni, O...


Retrieving SummaryDoc documents:   0%|          | 0/2635 [00:00<?, ?it/s]

Found 2635 materials with Ni, O, adding 4
Searching for materials with Cu, O...


Retrieving SummaryDoc documents:   0%|          | 0/3125 [00:00<?, ?it/s]

Found 3125 materials with Cu, O, adding 4
Searching for materials with Si, O...


Retrieving SummaryDoc documents:   0%|          | 0/5206 [00:00<?, ?it/s]

Found 5206 materials with Si, O, adding 4
Searching for materials with Li, O...


Retrieving SummaryDoc documents:   0%|          | 0/13693 [00:00<?, ?it/s]

Found 13693 materials with Li, O, adding 4
Searching for materials with Fe, Si...


Retrieving SummaryDoc documents:   0%|          | 0/694 [00:00<?, ?it/s]

Found 694 materials with Fe, Si, adding 4
Searching for materials with Al, Ni...


Retrieving SummaryDoc documents:   0%|          | 0/319 [00:00<?, ?it/s]

Found 319 materials with Al, Ni, adding 4
Searching for materials with Ti, Si, O...


Retrieving SummaryDoc documents:   0%|          | 0/395 [00:00<?, ?it/s]

Found 395 materials with Ti, Si, O, adding 4
Searching for materials with Fe, Ti, O...


Retrieving SummaryDoc documents:   0%|          | 0/305 [00:00<?, ?it/s]

Found 305 materials with Fe, Ti, O, adding 4
Searching for materials with Li, Fe, O...


Retrieving SummaryDoc documents:   0%|          | 0/2453 [00:00<?, ?it/s]

Found 2453 materials with Li, Fe, O, adding 4
Searching for materials with Li, Ni, O...


Retrieving SummaryDoc documents:   0%|          | 0/1195 [00:00<?, ?it/s]

Found 1195 materials with Li, Ni, O, adding 4
Searching for materials with Fe, Al, O...


Retrieving SummaryDoc documents:   0%|          | 0/115 [00:00<?, ?it/s]

Found 115 materials with Fe, Al, O, adding 4
Searching for materials with Fe...


Retrieving SummaryDoc documents:   0%|          | 0/8286 [00:00<?, ?it/s]

Found 8286 materials with Fe, adding 2
Searching for materials with Al...


Retrieving SummaryDoc documents:   0%|          | 0/5440 [00:00<?, ?it/s]

Found 5440 materials with Al, adding 2
Searching for materials with Ti...


Retrieving SummaryDoc documents:   0%|          | 0/5134 [00:00<?, ?it/s]

Found 5134 materials with Ti, adding 2
Searching for materials with Ni...


Retrieving SummaryDoc documents:   0%|          | 0/6102 [00:00<?, ?it/s]

Found 6102 materials with Ni, adding 2
Searching for materials with Cu...


Retrieving SummaryDoc documents:   0%|          | 0/6670 [00:00<?, ?it/s]

Found 6670 materials with Cu, adding 2
Searching for materials with Si...


Retrieving SummaryDoc documents:   0%|          | 0/8605 [00:00<?, ?it/s]

Found 8605 materials with Si, adding 2
Searching for materials with O...


Retrieving SummaryDoc documents:   0%|          | 0/52688 [00:00<?, ?it/s]

Found 52688 materials with O, adding 2
Searching for materials with Li...


Retrieving SummaryDoc documents:   0%|          | 0/17089 [00:00<?, ?it/s]

Found 17089 materials with Li, adding 2
Retrieved a total of 72 materials


In [None]:
def load_megnet_models():
    """
    Load pre-trained MEGNet models for feature extraction.
    Returns model names instead of actual models to avoid serialization issues.

    Returns:
        dict: Dictionary of model names or None if MEGNet is not available
    """
    if not MEGNET_AVAILABLE:
        return None

    try:
        # The models dictionary is correct, but the descriptor creation has issues
        models = {
            'formation_energy': 'Eform_MP_2019',
            'band_gap': 'Egap_MP_2019',
            'bulk_modulus': 'logK_MP_2019'
        }
        print("Successfully prepared MEGNet model names")
        return models
    except Exception as e:
        print(f"Failed to prepare MEGNet model names: {e}")
        return None

def extract_megnet_features(structure, model_names):
    """
    Extract features using pre-trained MEGNet models.

    Args:
        structure (pymatgen.Structure): Crystal structure
        model_names (dict): Dictionary of model names

    Returns:
        np.ndarray: Feature vector
    """
    if model_names is None:
        return None

    try:
        # Load models and create descriptors one at a time to avoid serialization issues
        all_features = []

        for k, model_name in model_names.items():
            try:
                # Load the model
                model = load_model(model_name)

                # Create a descriptor
                descriptor = MEGNetDescriptor(model)

                # Get features directly from the structure without intermediate serialization
                feature = descriptor.get_structure_features(structure)

                # Make sure we get a numpy array back
                if feature is not None:
                    all_features.append(feature)
            except Exception as e:
                print(f"MEGNet feature extraction failed for {model_name}: {e}")
                # Continue with next model instead of failing entirely
                continue

        # If we extracted any features, concatenate them
        if all_features:
            return np.concatenate(all_features)
        else:
            # If all models failed, return None to trigger fallback
            print("All MEGNet models failed, using fallback features")
            return None

    except Exception as e:
        print(f"MEGNet feature extraction failed: {e}")
        return None

def extract_fallback_features(structure):
    """
    Extract simple composition and structural features when MEGNet fails.

    Args:
        structure (pymatgen.Structure): Crystal structure

    Returns:
        np.ndarray: Feature vector
    """
    # Get lattice parameters
    a, b, c = structure.lattice.abc
    alpha, beta, gamma = structure.lattice.angles
    volume = structure.volume
    density = structure.density

    # Get composition-based features
    composition = structure.composition
    elements = [str(el) for el in composition.elements]
    atomic_nums = [el.Z for el in composition.elements]
    weights = [composition[el] for el in composition.elements]

    # Basic statistics about element properties
    avg_atomic_num = sum(n * w for n, w in zip(atomic_nums, weights)) / sum(weights)
    min_atomic_num = min(atomic_nums)
    max_atomic_num = max(atomic_nums)

    # Number of atoms
    num_atoms = len(structure)
    num_elements = len(composition.elements)

    # Create a fixed-size feature vector with zeros for elements not present
    # Using a simplified periodic table representation (up to 100 elements)
    element_vec = np.zeros(100)
    for el in composition.elements:
        if el.Z < 100:
            element_vec[el.Z] = composition[el]

    # Combine all features
    features = np.concatenate([
        [a, b, c, alpha, beta, gamma, volume, density],
        [avg_atomic_num, min_atomic_num, max_atomic_num],
        [num_atoms, num_elements],
        element_vec
    ])

    return features

def extract_features(materials_data):
    """
    Extract features for a list of materials using only fallback features.
    This modified version skips MEGNet completely to avoid serialization errors.

    Args:
        materials_data (list): List of materials with structures

    Returns:
        tuple: (feature_matrix, property_matrix, materials_with_features)
    """
    features = []
    properties = []
    materials_with_features = []

    for i, material in enumerate(materials_data):
        if i % 10 == 0:
            print(f"Processing material {i}/{len(materials_data)}")

        structure = material["structure"]

        # Skip MEGNet extraction and use fallback directly
        feature_vector = extract_fallback_features(structure)

        # Ensure feature vector is not None and has the right shape
        if feature_vector is not None and len(feature_vector) > 0:
            # Collect property values
            property_vector = [
                material.get("band_gap", 0),
                material.get("formation_energy_per_atom", 0),
                0  # Placeholder for bulk modulus which is often missing
            ]

            features.append(feature_vector)
            properties.append(property_vector)
            materials_with_features.append(material)

    # Convert to numpy arrays
    feature_matrix = np.array(features)
    property_matrix = np.array(properties)

    print(f"Extracted features for {len(features)} materials")
    print(f"Feature matrix shape: {feature_matrix.shape}")
    print(f"Property matrix shape: {property_matrix.shape}")

    return feature_matrix, property_matrix, materials_with_features

# Example usage
feature_matrix, property_matrix, filtered_materials = extract_features(materials)

Processing material 0/72
Processing material 10/72
Processing material 20/72
Processing material 30/72
Processing material 40/72
Processing material 50/72
Processing material 60/72
Processing material 70/72
Extracted features for 72 materials
Feature matrix shape: (72, 113)
Property matrix shape: (72, 3)


In [None]:
def preprocess_data(feature_matrix, property_matrix):
    """
    Preprocess the feature and property matrices for VAE training.

    Args:
        feature_matrix (np.ndarray): Matrix of material features
        property_matrix (np.ndarray): Matrix of material properties

    Returns:
        tuple: (scaled_features, scaled_properties, feature_scaler, property_scaler, train_val_test_splits)
    """
    # Handle missing values
    feature_matrix = np.nan_to_num(feature_matrix, nan=0.0)
    property_matrix = np.nan_to_num(property_matrix, nan=0.0)

    # Scale features to [0, 1] range
    feature_scaler = MinMaxScaler()
    scaled_features = feature_scaler.fit_transform(feature_matrix)

    # Scale properties to [0, 1] range
    property_scaler = MinMaxScaler()
    scaled_properties = property_scaler.fit_transform(property_matrix)

    # Split data into train, validation, and test sets
    X_train_val, X_test, y_train_val, y_test = train_test_split(
        scaled_features, scaled_properties, test_size=0.15, random_state=42
    )

    X_train, X_val, y_train, y_val = train_test_split(
        X_train_val, y_train_val, test_size=0.15, random_state=42
    )

    splits = {
        'train': (X_train, y_train),
        'val': (X_val, y_val),
        'test': (X_test, y_test)
    }

    print(f"Training set: {X_train.shape[0]} samples")
    print(f"Validation set: {X_val.shape[0]} samples")
    print(f"Test set: {X_test.shape[0]} samples")

    return scaled_features, scaled_properties, feature_scaler, property_scaler, splits

# Example usage
# scaled_features, scaled_properties, feature_scaler, property_scaler, splits = preprocess_data(feature_matrix, property_matrix)

In [None]:
class KLAnnealing(tf.keras.callbacks.Callback):
    def __init__(self, vae_model, start=0.0, end=1.0, warmup_epochs=10):
        super().__init__()
        self.vae_model = vae_model
        self.start = start
        self.end = end
        self.warmup_epochs = warmup_epochs

    def on_epoch_begin(self, epoch, logs=None):
        if epoch <= self.warmup_epochs:
            weight = self.start + (self.end - self.start) * (epoch / self.warmup_epochs)
        else:
            weight = self.end

        self.vae_model.kl_weight = weight
        print(f"Epoch {epoch}: KL weight = {weight:.4f}")

In [None]:
class VAEModel(models.Model):
    """Custom VAE model class that properly handles the KL loss"""

    def __init__(self, encoder, decoder, latent_dim, **kwargs):
        super(VAEModel, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.latent_dim = latent_dim
        self.kl_weight = 1.0  # Default weight for KL loss

    def call(self, inputs):
        # This defines the forward pass
        feature_input, property_input = inputs

        # Encode
        z_mean, z_log_var = self.encoder([feature_input, property_input])

        # Sample
        batch_size = tf.shape(z_mean)[0]
        epsilon = tf.random.normal(shape=(batch_size, self.latent_dim))
        z = z_mean + tf.exp(0.5 * z_log_var) * epsilon

        # Decode
        reconstructed = self.decoder([z, property_input])

        # Add KL loss
        kl_loss = -0.5 * tf.reduce_mean(
            tf.reduce_sum(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=1)
        )
        self.add_loss(self.kl_weight * kl_loss)

        return reconstructed


In [None]:
class MaterialVAE:
    def __init__(self, input_dim, property_dim=3, latent_dim=16, hidden_dims=[64, 32]):
        """
        Initialize the VAE model for material generation.

        Args:
            input_dim (int): Dimension of input features
            property_dim (int): Dimension of material properties
            latent_dim (int): Dimension of latent space
            hidden_dims (list): Dimensions of hidden layers
        """
        self.input_dim = input_dim
        self.property_dim = property_dim
        self.latent_dim = latent_dim
        self.hidden_dims = hidden_dims
        self.kl_weight = 1.0  # Will be used for KL annealing

        # Build the encoder, decoder, and full VAE model
        self.encoder = self._build_encoder()
        self.decoder = self._build_decoder()
        self.vae_model = self._build_vae()

    def _build_encoder(self):
        """Build the encoder network."""
        # Input layers
        feature_input = layers.Input(shape=(self.input_dim,), name='feature_input')
        property_input = layers.Input(shape=(self.property_dim,), name='property_input')

        # Combine inputs
        x = feature_input

        # Add hidden layers
        for i, dim in enumerate(self.hidden_dims):
            x = layers.Dense(dim, activation='relu', name=f'encoder_dense_{i}')(x)

        # Output layers for mean and log variance
        z_mean = layers.Dense(self.latent_dim, name='z_mean')(x)
        z_log_var = layers.Dense(self.latent_dim, name='z_log_var')(x)

        # Create encoder model
        encoder = models.Model([feature_input, property_input], [z_mean, z_log_var], name='encoder')
        return encoder

    def _build_decoder(self):
        """Build the decoder network."""
        # Input layers
        latent_input = layers.Input(shape=(self.latent_dim,), name='latent_input')
        property_input = layers.Input(shape=(self.property_dim,), name='property_input')

        # Concatenate latent vector with property conditioning
        x = layers.Concatenate()([latent_input, property_input])

        # Add hidden layers in reverse order
        for i, dim in enumerate(reversed(self.hidden_dims)):
            x = layers.Dense(dim, activation='relu', name=f'decoder_dense_{i}')(x)

        # Output layer for reconstructed features
        outputs = layers.Dense(self.input_dim, activation='sigmoid', name='decoder_output')(x)

        # Create decoder model
        decoder = models.Model([latent_input, property_input], outputs, name='decoder')
        return decoder

    def _build_vae(self):
        """Build the VAE model using the custom VAEModel class."""
        # Create the custom VAE model
        vae = VAEModel(
            encoder=self.encoder,
            decoder=self.decoder,
            latent_dim=self.latent_dim,
            name='vae'
        )

        # Define inputs for compilation (needed to build the model)
        dummy_features = tf.keras.Input(shape=(self.input_dim,))
        dummy_properties = tf.keras.Input(shape=(self.property_dim,))

        # Build the model
        vae([dummy_features, dummy_properties])

        # Compile the model
        vae.compile(optimizer=optimizers.Adam(learning_rate=0.001), loss='mse')

        # Set the KL weight
        vae.kl_weight = self.kl_weight

        return vae

    def train(self, splits, epochs=50, batch_size=32, use_kl_annealing=True):
        """
        Train the VAE model.

        Args:
            splits (dict): Dictionary containing training and validation data
            epochs (int): Number of training epochs
            batch_size (int): Batch size
            use_kl_annealing (bool): Whether to use KL annealing

        Returns:
            dict: Training history
        """
        X_train, y_train = splits['train']
        X_val, y_val = splits['val']

        # Define callbacks
        callbacks = [
            EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),
            ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6)
        ]

        # KL annealing schedule
        if use_kl_annealing:
            callbacks.append(KLAnnealing(self.vae_model, warmup_epochs=10))

        # Train the model
        history = self.vae_model.fit(
            [X_train, y_train], X_train,
            validation_data=([X_val, y_val], X_val),
            epochs=epochs,
            batch_size=batch_size,
            callbacks=callbacks,
            verbose=1
        )

        return history.history

    def encode(self, features, properties):
        """Encode features to latent space."""
        z_mean, z_log_var = self.encoder.predict([features, properties])
        return z_mean

    def decode(self, latent_vectors, properties):
        """Decode latent vectors to feature space."""
        return self.decoder.predict([latent_vectors, properties])

    def generate(self, properties, n_samples=1, temperature=1.2):
        """
        Generate new materials conditioned on desired properties with temperature control.

        Args:
            properties (np.ndarray): Target properties
            n_samples (int): Number of samples to generate
            temperature (float): Higher values (>1.0) increase diversity but may reduce accuracy

        Returns:
            np.ndarray: Generated feature vectors
        """
        # Ensure properties have the right shape
        if properties.ndim == 1:
            properties = properties.reshape(1, -1)

        # Repeat properties for n_samples
        if n_samples > 1:
            properties = np.repeat(properties, n_samples, axis=0)

        # Sample from the latent space with temperature control
        random_latent_vectors = np.random.normal(0, temperature, size=(n_samples, self.latent_dim))

        # Decode the random latent vectors
        generated_features = self.decoder.predict([random_latent_vectors, properties])

        return generated_features

    def save(self, filepath):
        """Save the VAE model."""
        self.vae_model.save_weights(filepath)

    def load(self, filepath):
        """Load the VAE model."""
        self.vae_model.load_weights(filepath)

# Example usage
# vae = MaterialVAE(input_dim=scaled_features.shape[1], property_dim=scaled_properties.shape[1])
# history = vae.train(splits, epochs=20, batch_size=16)

In [None]:
class StructureRecovery:
    def __init__(self, feature_matrix, materials, feature_scaler=None, n_neighbors=5):
        """
        Initialize the structure recovery module.

        Args:
            feature_matrix (np.ndarray): Original feature matrix
            materials (list): List of materials with structures
            feature_scaler (sklearn.preprocessing.MinMaxScaler): Scaler used for features
            n_neighbors (int): Number of nearest neighbors to consider
        """
        self.feature_matrix = feature_matrix
        self.materials = materials
        self.feature_scaler = feature_scaler
        self.n_neighbors = n_neighbors

        # Build nearest neighbors model
        self.nn_model = self._build_nn_model()

        # Keep track of previously selected materials to promote diversity
        self.previously_selected = set()

    def _build_nn_model(self):
        """Build nearest neighbors model for structure lookup."""
        nn_model = NearestNeighbors(n_neighbors=self.n_neighbors)
        nn_model.fit(self.feature_matrix)
        return nn_model

    def recover_structures(self, generated_features, return_multiple=False, diversity_weight=0.7):
        """
        Recover crystal structures from generated features with improved diversity.

        Args:
            generated_features (np.ndarray): Generated feature vectors
            return_multiple (bool): Whether to return multiple candidate structures
            diversity_weight (float): Weight for diversity penalty (0-1)

        Returns:
            list: Recovered structures or list of candidate structures
        """
        # Inverse transform if scaler is provided
        if self.feature_scaler is not None:
            generated_features = self.feature_scaler.inverse_transform(generated_features)

        # Find nearest neighbors
        distances, indices = self.nn_model.kneighbors(generated_features)

        recovered_structures = []
        for i in range(len(generated_features)):
            if return_multiple:
                # Return multiple candidate structures, promoting diversity
                candidates = []
                already_added = set()  # Track formulas we've already added

                # First pass: try to find diverse candidates
                for j in range(min(15, self.n_neighbors)):  # Look at more neighbors
                    idx = indices[i, j]
                    structure = self.materials[idx]["structure"]
                    formula = structure.composition.reduced_formula
                    distance = distances[i, j]

                    # Only add if we haven't seen this formula yet in this batch
                    if formula not in already_added:
                        candidates.append({
                            "structure": structure,
                            "distance": distance,
                            "material_id": self.materials[idx]["material_id"]
                        })
                        already_added.add(formula)

                    # If we have enough candidates, stop
                    if len(candidates) >= 5:
                        break

                # Second pass: if we don't have enough candidates, add more
                if len(candidates) < 5:
                    for j in range(self.n_neighbors):
                        idx = indices[i, j]
                        structure = self.materials[idx]["structure"]
                        distance = distances[i, j]

                        # Check if we already added this material
                        if j < len(candidates):
                            continue

                        candidates.append({
                            "structure": structure,
                            "distance": distance,
                            "material_id": self.materials[idx]["material_id"]
                        })

                        # Stop when we have 5 candidates
                        if len(candidates) >= 5:
                            break

                recovered_structures.append(candidates)
            else:
                # Find the best match that promotes diversity
                for j in range(min(10, self.n_neighbors)):
                    idx = indices[i, j]
                    structure = self.materials[idx]["structure"]
                    formula = structure.composition.reduced_formula

                    # If formula hasn't been used recently or with probability based on diversity_weight
                    if formula not in self.previously_selected or np.random.random() > diversity_weight:
                        self.previously_selected.add(formula)
                        recovered_structures.append(structure)
                        break
                else:
                    # Fallback: use the closest match
                    idx = indices[i, 0]
                    structure = self.materials[idx]["structure"]
                    recovered_structures.append(structure)

                # Limit memory of previously selected formulas
                if len(self.previously_selected) > 20:
                    self.previously_selected = set(list(self.previously_selected)[-20:])

        return recovered_structures

In [None]:
def plot_training_history(history):
    """
    Plot the training history of the VAE model.

    Args:
        history (dict): Training history
    """
    plt.figure(figsize=(12, 5))

    # Plot loss
    plt.subplot(1, 2, 1)
    plt.plot(history['loss'], label='Training loss')
    plt.plot(history['val_loss'], label='Validation loss')
    plt.title('VAE Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Plot learning rate if available
    if 'lr' in history:
        plt.subplot(1, 2, 2)
        plt.plot(history['lr'])
        plt.title('Learning Rate')
        plt.xlabel('Epoch')
        plt.ylabel('Learning Rate')
        plt.yscale('log')

    plt.tight_layout()
    plt.show()

def plot_latent_space(encoder, X, y, property_names=None):
    """
    Visualize the latent space of the VAE.

    Args:
        encoder (tf.keras.Model): Encoder model
        X (np.ndarray): Input features
        y (np.ndarray): Target properties
        property_names (list): Names of the properties
    """
    # Get latent space representation
    z_mean, _ = encoder.predict([X, y])

    # Use t-SNE for dimensionality reduction if latent dim > 2
    if z_mean.shape[1] > 2:
        from sklearn.manifold import TSNE
        reducer = TSNE(n_components=2, random_state=42)
        z_tsne = reducer.fit_transform(z_mean)
        z_plot = z_tsne
    else:
        z_plot = z_mean

    # Set default property names if not provided
    if property_names is None:
        property_names = ['Band Gap', 'Formation Energy', 'Bulk Modulus']

    # Create scatter plots colored by each property
    plt.figure(figsize=(15, 5))
    for i in range(min(3, y.shape[1])):
        plt.subplot(1, 3, i+1)
        scatter = plt.scatter(z_plot[:, 0], z_plot[:, 1], c=y[:, i], cmap='viridis', alpha=0.8)
        plt.colorbar(scatter, label=property_names[i])
        plt.title(f'Latent Space colored by {property_names[i]}')
        plt.xlabel('Latent Dimension 1')
        plt.ylabel('Latent Dimension 2')

    plt.tight_layout()
    plt.show()

# Example usage
# plot_training_history(history)
# X_test, y_test = splits['test']
# plot_latent_space(vae.encoder, X_test, y_test)

In [None]:
def generate_materials_with_constraints(vae, property_scaler, recovery, n_samples=10, constraints=None, temperature=1.2):
    """
    Generate materials with specific property constraints and improved diversity.

    Args:
        vae (MaterialVAE): Trained VAE model
        property_scaler (MinMaxScaler): Scaler for properties
        recovery (StructureRecovery): Structure recovery module
        n_samples (int): Number of samples to generate
        constraints (dict): Property constraints (min and max values)
        temperature (float): Controls diversity in sampling (>1.0 = more diverse)

    Returns:
        list: Generated materials with properties and structures
    """
    if constraints is None:
        constraints = {
            'band_gap': {'min': 1.0, 'max': 3.0},
            'formation_energy': {'min': -2.0, 'max': -0.5},
            'bulk_modulus': {'min': 100, 'max': 200}
        }

    # Create target properties within the constraints
    target_props_real = np.zeros((n_samples, 3))

    # Add some variance to the targets for better diversity
    for i in range(n_samples):
        target_props_real[i, 0] = np.random.uniform(
            constraints['band_gap']['min'],
            constraints['band_gap']['max']
        )
        target_props_real[i, 1] = np.random.uniform(
            constraints['formation_energy']['min'],
            constraints['formation_energy']['max']
        )
        target_props_real[i, 2] = np.random.uniform(
            constraints['bulk_modulus']['min'],
            constraints['bulk_modulus']['max']
        )

    # Scale properties to normalized space
    target_props_norm = property_scaler.transform(target_props_real)

    # Generate materials for each target with increased temperature for diversity
    generated_materials = []

    for i, target_norm in enumerate(target_props_norm):
        print(f"Generating material {i+1}/{n_samples}")
        print(f"Target properties: Band Gap = {target_props_real[i, 0]:.2f} eV, "
              f"Formation Energy = {target_props_real[i, 1]:.2f} eV/atom, "
              f"Bulk Modulus = {target_props_real[i, 2]:.2f} GPa")

        # Generate multiple features per target for more diversity
        gen_features = vae.generate(target_norm.reshape(1, -1), n_samples=3, temperature=temperature)

        # Recover structures with diversity improvement
        candidates_list = recovery.recover_structures(gen_features, return_multiple=True, diversity_weight=0.7)

        # Flatten the candidates list for easier selection
        all_candidates = []
        for candidate_group in candidates_list:
            all_candidates.extend(candidate_group)

        # Select the best candidate that hasn't been used before, or the most diverse one
        selected_candidate = None
        used_formulas = [m['formula'] for m in generated_materials]

        for candidate in all_candidates:
            formula = candidate['structure'].composition.reduced_formula
            if formula not in used_formulas:
                selected_candidate = candidate
                break

        # If all formulas have been used, select the first one
        if selected_candidate is None:
            selected_candidate = all_candidates[0]

        # Store information
        generated_materials.append({
            'target_properties': target_props_real[i],
            'material_id': selected_candidate['material_id'],
            'structure': selected_candidate['structure'],
            'formula': selected_candidate['structure'].composition.reduced_formula,
            'distance': selected_candidate['distance']
        })

        print(f"Generated material: {selected_candidate['structure'].composition.reduced_formula}")
        print("-" * 50)

    return generated_materials

def evaluate_reconstruction(vae, splits):
    """
    Evaluate the reconstruction ability of the VAE.

    Args:
        vae (MaterialVAE): VAE model
        splits (dict): Data splits

    Returns:
        dict: Reconstruction metrics
    """
    X_test, y_test = splits['test']

    # Reconstruct test data
    X_reconstructed = vae.vae_model.predict([X_test, y_test])

    # Calculate reconstruction error
    mse = np.mean(np.square(X_test - X_reconstructed))
    mae = np.mean(np.abs(X_test - X_reconstructed))

    # Calculate feature-wise reconstruction error
    feature_mse = np.mean(np.square(X_test - X_reconstructed), axis=0)

    # Find worst and best reconstructed features
    worst_features = np.argsort(feature_mse)[-5:][::-1]
    best_features = np.argsort(feature_mse)[:5]

    return {
        'mse': mse,
        'mae': mae,
        'worst_features': worst_features,
        'best_features': best_features,
        'feature_mse': feature_mse
    }

def evaluate_diversity(materials):
    """
    Evaluate the diversity of generated materials.

    Args:
        materials (list): List of generated materials

    Returns:
        dict: Diversity metrics
    """
    # Extract structures
    structures = [mat['structure'] for mat in materials]

    # Calculate composition diversity
    formulas = [s.composition.reduced_formula for s in structures]
    unique_formulas = set(formulas)

    # Calculate structural diversity using lattice parameters
    lattice_params = np.array([list(s.lattice.abc) + list(s.lattice.angles) for s in structures])

    # Calculate pairwise distances
    from scipy.spatial.distance import pdist, squareform
    lattice_distances = pdist(lattice_params, metric='euclidean')
    lattice_distances = squareform(lattice_distances)

    # Calculate average distance to nearest neighbor
    nearest_distances = []
    for i in range(len(structures)):
        distances = [d for j, d in enumerate(lattice_distances[i]) if i != j]
        if distances:
            nearest_distances.append(min(distances))

    avg_nearest_distance = np.mean(nearest_distances) if nearest_distances else 0

    return {
        'num_structures': len(structures),
        'num_unique_formulas': len(unique_formulas),
        'formula_diversity': len(unique_formulas) / len(structures) if structures else 0,
        'avg_nearest_distance': avg_nearest_distance
    }

# Example usage
# semiconductor_constraints = {
#     'band_gap': {'min': 0.5, 'max': 2.5},       # Semiconducting range
#     'formation_energy': {'min': -2.0, 'max': -0.1},  # Relatively stable
#     'bulk_modulus': {'min': 50, 'max': 200}     # Moderate hardness
# }
#
# semiconductor_materials = generate_materials_with_constraints(
#     vae,
#     property_scaler,
#     recovery,
#     n_samples=5,
#     constraints=semiconductor_constraints
# )
#
# diversity_metrics = evaluate_diversity(semiconductor_materials)

In [None]:
def save_model_complete(save_dir, vae, feature_scaler, property_scaler,
                         filtered_materials=None, recovery=None):
    """
    Save the complete material generation model, including VAE model, scalers, and data.

    Args:
        save_dir (str): Directory to save the model
        vae (MaterialVAE): Trained VAE model
        feature_scaler (MinMaxScaler): Feature scaler
        property_scaler (MinMaxScaler): Property scaler
        filtered_materials (list, optional): Materials data
        recovery (StructureRecovery, optional): Structure recovery module
    """
    # Create directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    # 1. Save VAE model - Use .weights.h5 extension as required by Keras
    vae_weights_path = os.path.join(save_dir, "vae.weights.h5")
    vae.save(vae_weights_path)

    # 2. Save VAE configuration
    vae_config = {
        'input_dim': vae.input_dim,
        'property_dim': vae.property_dim,
        'latent_dim': vae.latent_dim,
        'hidden_dims': vae.hidden_dims
    }

    with open(os.path.join(save_dir, "vae_config.json"), 'w') as f:
        json.dump(vae_config, f)

    # 3. Save scalers
    with open(os.path.join(save_dir, "feature_scaler.pkl"), 'wb') as f:
        pickle.dump(feature_scaler, f)

    with open(os.path.join(save_dir, "property_scaler.pkl"), 'wb') as f:
        pickle.dump(property_scaler, f)

    # 4. Save materials data if provided
    if filtered_materials is not None:
        # Save only essential information to reduce file size
        materials_data = []
        for mat in filtered_materials:
            mat_dict = {
                'material_id': mat['material_id'],
                'structure': mat['structure'].as_dict(),
                'band_gap': mat.get('band_gap', 0),
                'formation_energy_per_atom': mat.get('formation_energy_per_atom', 0),
                'energy_above_hull': mat.get('energy_above_hull', 0)
            }
            materials_data.append(mat_dict)

        with open(os.path.join(save_dir, "materials_data.pkl"), 'wb') as f:
            pickle.dump(materials_data, f)

    # 5. Save feature matrix from recovery module if available
    if recovery is not None:
        np.save(os.path.join(save_dir, "feature_matrix.npy"), recovery.feature_matrix)

    print(f"Model successfully saved to {save_dir}")


def load_model_complete(load_dir):
    """
    Load a complete materials generation model, including VAE, scalers, and data.

    Args:
        load_dir (str): Directory containing the saved model

    Returns:
        dict: Dictionary containing all loaded components
    """
    # 1. Load VAE configuration
    with open(os.path.join(load_dir, "vae_config.json"), 'r') as f:
        vae_config = json.load(f)

    # 2. Create VAE model
    vae = MaterialVAE(
        input_dim=vae_config['input_dim'],
        property_dim=vae_config['property_dim'],
        latent_dim=vae_config['latent_dim'],
        hidden_dims=vae_config['hidden_dims']
    )

    # 3. Load weights - Match the .weights.h5 extension used in saving
    vae_weights_path = os.path.join(load_dir, "vae.weights.h5")
    vae.load(vae_weights_path)

    # 4. Load scalers
    with open(os.path.join(load_dir, "feature_scaler.pkl"), 'rb') as f:
        feature_scaler = pickle.load(f)

    with open(os.path.join(load_dir, "property_scaler.pkl"), 'rb') as f:
        property_scaler = pickle.load(f)

    # 5. Load materials data if available
    filtered_materials = None
    if os.path.exists(os.path.join(load_dir, "materials_data.pkl")):
        with open(os.path.join(load_dir, "materials_data.pkl"), 'rb') as f:
            materials_data = pickle.load(f)

        # Convert dictionaries back to Structure objects
        filtered_materials = []
        for mat in materials_data:
            structure = Structure.from_dict(mat['structure'])
            mat_dict = {
                'material_id': mat['material_id'],
                'structure': structure,
                'band_gap': mat['band_gap'],
                'formation_energy_per_atom': mat['formation_energy_per_atom'],
                'energy_above_hull': mat['energy_above_hull']
            }
            filtered_materials.append(mat_dict)

    # 6. Load feature matrix and create recovery module if available
    recovery = None
    if os.path.exists(os.path.join(load_dir, "feature_matrix.npy")):
        feature_matrix = np.load(os.path.join(load_dir, "feature_matrix.npy"))

        if filtered_materials is not None:
            recovery = StructureRecovery(feature_matrix, filtered_materials, feature_scaler)

    print(f"Model successfully loaded from {load_dir}")

    return {
        'vae': vae,
        'feature_scaler': feature_scaler,
        'property_scaler': property_scaler,
        'filtered_materials': filtered_materials,
        'recovery': recovery
    }

In [None]:
def run_materials_generation_pipeline(api_key, num_materials=100, latent_dim=16, save_model=True, model_dir="trained_material_vae"):
    """
    Run the complete materials generation pipeline with model saving option.

    Args:
        api_key (str): Materials Project API key
        num_materials (int): Number of materials to retrieve
        latent_dim (int): Dimension of the latent space
        save_model (bool): Whether to save the model after training
        model_dir (str): Directory to save the model to

    Returns:
        dict: Results of the pipeline
    """
    print("1. Collecting materials data...")
    materials = collect_materials_data(api_key, num_materials=num_materials)

    print("\n2. Extracting features...")
    feature_matrix, property_matrix, filtered_materials = extract_features(materials)

    print("\n3. Preprocessing data...")
    scaled_features, scaled_properties, feature_scaler, property_scaler, splits = preprocess_data(
        feature_matrix, property_matrix
    )

    print("\n4. Building and training VAE model...")
    input_dim = scaled_features.shape[1]
    property_dim = scaled_properties.shape[1]

    vae = MaterialVAE(
        input_dim=input_dim,
        property_dim=property_dim,
        latent_dim=latent_dim,
        hidden_dims=[64, 32]
    )

    history = vae.train(splits, epochs=20, batch_size=16, use_kl_annealing=True)

    print("\n5. Evaluating model reconstruction...")
    reconstruction_metrics = evaluate_reconstruction(vae, splits)
    print(f"Reconstruction MSE: {reconstruction_metrics['mse']:.6f}")
    print(f"Reconstruction MAE: {reconstruction_metrics['mae']:.6f}")

    print("\n6. Creating structure recovery module...")
    recovery = StructureRecovery(feature_matrix, filtered_materials, feature_scaler)

    # Save the model if requested
    if save_model:
        print("\n-- Saving the trained model...")
        save_model_complete(
            model_dir,
            vae,
            feature_scaler,
            property_scaler,
            filtered_materials,
            recovery
        )
        print(f"Model saved to {model_dir}")

    print("\n7. Generating materials with target properties...")
    # Define constraints for semiconductor-like materials
    semiconductor_constraints = {
        'band_gap': {'min': 0.5, 'max': 2.5},       # Semiconducting range
        'formation_energy': {'min': -2.0, 'max': -0.1},  # Relatively stable
        'bulk_modulus': {'min': 50, 'max': 200}     # Moderate hardness
    }

    # Generate materials with semiconductor-like properties
    semiconductor_materials = generate_materials_with_constraints(
        vae,
        property_scaler,
        recovery,
        n_samples=5,
        constraints=semiconductor_constraints
    )

    print("\n8. Evaluating diversity of generated materials...")
    diversity_metrics = evaluate_diversity(semiconductor_materials)

    print("\n9. Printing summary of generated materials...")
    print("-" * 60)
    print(f"{'Formula':<20} {'Band Gap (eV)':<15} {'Formation E. (eV/atom)':<25} {'Bulk Modulus (GPa)':<20}")
    print("-" * 60)

    for mat in semiconductor_materials:
        props = mat['target_properties']
        print(f"{mat['formula']:<20} {props[0]:<15.2f} {props[1]:<25.2f} {props[2]:<20.2f}")

    return {
        'materials': materials,
        'filtered_materials': filtered_materials,
        'feature_matrix': feature_matrix,
        'property_matrix': property_matrix,
        'vae': vae,
        'recovery': recovery,
        'history': history,
        'reconstruction_metrics': reconstruction_metrics,
        'generated_materials': semiconductor_materials,
        'diversity_metrics': diversity_metrics
    }

In [None]:
def generate_from_saved_model(model_dir="trained_material_vae", constraints=None, n_samples=5, temperature=1.0):
    """
    Generate materials using a previously saved model.

    Args:
        model_dir (str): Directory containing the saved model
        constraints (dict): Property constraints
        n_samples (int): Number of samples to generate
        temperature (float): Controls diversity in sampling (>1.0 = more diverse)

    Returns:
        list: Generated materials
    """
    print(f"Loading saved model from {model_dir}...")
    loaded = load_model_complete(model_dir)

    vae = loaded['vae']
    feature_scaler = loaded['feature_scaler']
    property_scaler = loaded['property_scaler']
    recovery = loaded['recovery']

    if constraints is None:
        constraints = {
            'band_gap': {'min': 0.5, 'max': 2.5},       # Semiconducting range
            'formation_energy': {'min': -2.0, 'max': -0.1},  # Relatively stable
            'bulk_modulus': {'min': 50, 'max': 200}     # Moderate hardness
        }

    print(f"\nGenerating {n_samples} materials with constraints:")
    print(f"Band gap: {constraints['band_gap']['min']}-{constraints['band_gap']['max']} eV")
    print(f"Formation energy: {constraints['formation_energy']['min']}-{constraints['formation_energy']['max']} eV/atom")
    print(f"Bulk modulus: {constraints['bulk_modulus']['min']}-{constraints['bulk_modulus']['max']} GPa")

    generated_materials = generate_materials_with_constraints(
        vae,
        property_scaler,
        recovery,
        n_samples=n_samples,
        constraints=constraints,
        temperature=temperature  # Use temperature parameter if using the improved version
    )

    # Print results
    print("\nGenerated materials:")
    print("-" * 60)
    print(f"{'Formula':<20} {'Band Gap (eV)':<15} {'Formation E. (eV/atom)':<25} {'Bulk Modulus (GPa)':<20}")
    print("-" * 60)

    for mat in generated_materials:
        props = mat['target_properties']
        print(f"{mat['formula']:<20} {props[0]:<15.2f} {props[1]:<25.2f} {props[2]:<20.2f}")

    return generated_materials


In [None]:
# To run the complete pipeline and save the model:
results = run_materials_generation_pipeline(MP_API_KEY, num_materials=5000, latent_dim=16, save_model=True)

1. Collecting materials data...
Searching for materials with Fe, O...


Retrieving SummaryDoc documents:   0%|          | 0/5591 [00:00<?, ?it/s]

Found 5591 materials with Fe, O, adding 227
Searching for materials with Al, O...


Retrieving SummaryDoc documents:   0%|          | 0/2043 [00:00<?, ?it/s]

Found 2043 materials with Al, O, adding 227
Searching for materials with Ti, O...


Retrieving SummaryDoc documents:   0%|          | 0/3632 [00:00<?, ?it/s]

Found 3632 materials with Ti, O, adding 227
Searching for materials with Ni, O...


Retrieving SummaryDoc documents:   0%|          | 0/2635 [00:00<?, ?it/s]

Found 2635 materials with Ni, O, adding 227
Searching for materials with Cu, O...


Retrieving SummaryDoc documents:   0%|          | 0/3125 [00:00<?, ?it/s]

Found 3125 materials with Cu, O, adding 227
Searching for materials with Si, O...


Retrieving SummaryDoc documents:   0%|          | 0/5206 [00:00<?, ?it/s]

Found 5206 materials with Si, O, adding 227
Searching for materials with Li, O...


Retrieving SummaryDoc documents:   0%|          | 0/13693 [00:00<?, ?it/s]

Found 13693 materials with Li, O, adding 227
Searching for materials with Fe, Si...


Retrieving SummaryDoc documents:   0%|          | 0/694 [00:00<?, ?it/s]

Found 694 materials with Fe, Si, adding 227
Searching for materials with Al, Ni...


Retrieving SummaryDoc documents:   0%|          | 0/319 [00:00<?, ?it/s]

Found 319 materials with Al, Ni, adding 227
Searching for materials with Ti, Si, O...


Retrieving SummaryDoc documents:   0%|          | 0/395 [00:00<?, ?it/s]

Found 395 materials with Ti, Si, O, adding 227
Searching for materials with Fe, Ti, O...


Retrieving SummaryDoc documents:   0%|          | 0/305 [00:00<?, ?it/s]

Found 305 materials with Fe, Ti, O, adding 227
Searching for materials with Li, Fe, O...


Retrieving SummaryDoc documents:   0%|          | 0/2453 [00:00<?, ?it/s]

Found 2453 materials with Li, Fe, O, adding 227
Searching for materials with Li, Ni, O...


Retrieving SummaryDoc documents:   0%|          | 0/1195 [00:00<?, ?it/s]

Found 1195 materials with Li, Ni, O, adding 227
Searching for materials with Fe, Al, O...


Retrieving SummaryDoc documents:   0%|          | 0/115 [00:00<?, ?it/s]

Found 115 materials with Fe, Al, O, adding 115
Searching for materials with Fe...


Retrieving SummaryDoc documents:   0%|          | 0/8286 [00:00<?, ?it/s]

Found 8286 materials with Fe, adding 113
Searching for materials with Al...


Retrieving SummaryDoc documents:   0%|          | 0/5440 [00:00<?, ?it/s]

Found 5440 materials with Al, adding 113
Searching for materials with Ti...


Retrieving SummaryDoc documents:   0%|          | 0/5134 [00:00<?, ?it/s]

Found 5134 materials with Ti, adding 113
Searching for materials with Ni...


Retrieving SummaryDoc documents:   0%|          | 0/6102 [00:00<?, ?it/s]

Found 6102 materials with Ni, adding 113
Searching for materials with Cu...


Retrieving SummaryDoc documents:   0%|          | 0/6670 [00:00<?, ?it/s]

Found 6670 materials with Cu, adding 113
Searching for materials with Si...


Retrieving SummaryDoc documents:   0%|          | 0/8605 [00:00<?, ?it/s]

Found 8605 materials with Si, adding 113
Searching for materials with O...


Retrieving SummaryDoc documents:   0%|          | 0/52688 [00:00<?, ?it/s]

Found 52688 materials with O, adding 113
Searching for materials with Li...


Retrieving SummaryDoc documents:   0%|          | 0/17089 [00:00<?, ?it/s]

Found 17089 materials with Li, adding 113
Retrieved a total of 3970 materials

2. Extracting features...
Processing material 0/3970
Processing material 10/3970
Processing material 20/3970
Processing material 30/3970
Processing material 40/3970
Processing material 50/3970
Processing material 60/3970
Processing material 70/3970
Processing material 80/3970
Processing material 90/3970
Processing material 100/3970
Processing material 110/3970
Processing material 120/3970
Processing material 130/3970
Processing material 140/3970
Processing material 150/3970
Processing material 160/3970
Processing material 170/3970
Processing material 180/3970
Processing material 190/3970
Processing material 200/3970
Processing material 210/3970
Processing material 220/3970
Processing material 230/3970
Processing material 240/3970
Processing material 250/3970
Processing material 260/3970
Processing material 270/3970
Processing material 280/3970
Processing material 290/3970
Processing material 300/3970
Process

In [None]:
# To load a saved model and generate materials:
battery_constraints = {
    'band_gap': {'min': 1.5, 'max': 4.0},          # Wide band gap for stability
    'formation_energy': {'min': -4.0, 'max': -1.0}, # Very stable materials
    'bulk_modulus': {'min': 80, 'max': 250}        # Mechanically robust
}
battery_materials = generate_from_saved_model(
    model_dir="trained_material_vae",
    constraints=battery_constraints,
    n_samples=10,
    temperature=1.2  # Higher diversity
)

Loading saved model from trained_material_vae...
Model successfully loaded from trained_material_vae

Generating 10 materials with constraints:
Band gap: 1.5-4.0 eV
Formation energy: -4.0--1.0 eV/atom
Bulk modulus: 80-250 GPa
Generating material 1/10
Target properties: Band Gap = 2.13 eV, Formation Energy = -2.35 eV/atom, Bulk Modulus = 201.48 GPa
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 145ms/step
Generated material: CuAuO2
--------------------------------------------------
Generating material 2/10
Target properties: Band Gap = 3.15 eV, Formation Energy = -3.16 eV/atom, Bulk Modulus = 242.33 GPa
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 49ms/step
Generated material: AlAgO2
--------------------------------------------------
Generating material 3/10
Target properties: Band Gap = 3.34 eV, Formation Energy = -2.34 eV/atom, Bulk Modulus = 183.99 GPa
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 47ms/step
Generated material: LiLuO2
--