In [None]:
# Install required packages
!pip install -q dgl torch numpy matplotlib seaborn scikit-learn pandas pymatgen plotly tqdm umap-learn jarvis-tools ase

# Mount Google Drive to access your files
from google.colab import drive
drive.mount('/content/drive')

# Import common libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import dgl
from dgl.nn import GraphConv, EdgeConv
from pymatgen.core.structure import Structure
from pymatgen.analysis.local_env import CrystalNN
from sklearn.neighbors import KNeighborsRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error, r2_score
import warnings
warnings.filterwarnings('ignore')

# Set the base directory for your CIF files
base_dir = '/content/drive/My Drive/btp_cif'
heusler_types = ['full', 'half', 'inverse', 'quaternary']

In [None]:
# Class to handle loading and processing of CIF files
class HeuslerCIFDataset:
    def __init__(self, base_dir, heusler_types):
        self.base_dir = base_dir
        self.heusler_types = heusler_types
        self.structures = []
        self.labels = []
        self.formulas = []
        self.filenames = []
        self.load_structures()

    def load_structures(self):
        print("Loading CIF files...")
        for i, ht in enumerate(self.heusler_types):
            folder_path = os.path.join(self.base_dir, ht)
            if not os.path.exists(folder_path):
                print(f"Warning: {folder_path} not found, skipping...")
                continue

            cif_files = [f for f in os.listdir(folder_path) if f.endswith('.cif')]

            for cif_file in tqdm(cif_files, desc=f"Loading {ht} structures"):
                try:
                    file_path = os.path.join(folder_path, cif_file)
                    structure = Structure.from_file(file_path)
                    self.structures.append(structure)
                    self.labels.append(i)  # Use index as label for Heusler type
                    self.formulas.append(structure.composition.reduced_formula)
                    self.filenames.append(cif_file)
                except Exception as e:
                    print(f"Error loading {cif_file}: {e}")

        print(f"Loaded {len(self.structures)} structures in total")

    def get_structure_info(self):
        """Extract basic structural information from the loaded CIFs"""
        data = []
        for i, structure in enumerate(self.structures):
            # Extract basic structural properties
            lattice = structure.lattice
            volume = lattice.volume
            density = structure.density
            num_sites = len(structure)
            elements = [str(site.specie) for site in structure]
            unique_elements = list(set(elements))
            num_elements = len(unique_elements)

            # Calculate average atomic radius and electronegativity
            atomic_radii = [site.specie.atomic_radius for site in structure if site.specie.atomic_radius]
            avg_radius = np.mean(atomic_radii) if atomic_radii else np.nan

            electronegativities = [site.specie.X for site in structure if site.specie.X]
            avg_electronegativity = np.mean(electronegativities) if electronegativities else np.nan

            data.append({
                'formula': self.formulas[i],
                'heusler_type': self.heusler_types[self.labels[i]],
                'volume': volume,
                'density': density,
                'num_sites': num_sites,
                'num_elements': num_elements,
                'avg_atomic_radius': avg_radius,
                'avg_electronegativity': avg_electronegativity,
                'filename': self.filenames[i]
            })

        return pd.DataFrame(data)

# Load the dataset
dataset = HeuslerCIFDataset(base_dir, heusler_types)

# Get and display structure info
structure_info = dataset.get_structure_info()
print("\nStructure Info Sample:")
display(structure_info.head())

In [None]:
# Graph construction functions
def one_hot_encoding(atomic_number, max_z=94):
    """Create a one-hot encoding vector for atomic number"""
    encoding = np.zeros(max_z)
    if atomic_number <= max_z:
        encoding[atomic_number-1] = 1.0
    return encoding

def build_crystal_graph(structure, cutoff=8.0):
    """Build atom and bond graphs from pymatgen structure"""
    # Get all sites in structure
    all_sites = structure.sites
    num_atoms = len(all_sites)

    # Initialize arrays for graph construction
    src_atoms = []
    dst_atoms = []
    edge_features = []
    atom_features = []

    # Build atom features
    for i, site in enumerate(all_sites):
        # Get atomic properties
        atomic_num = site.specie.Z
        atomic_feat = one_hot_encoding(atomic_num)

        # Add other features like electronegativity, atomic radius
        try:
            electronegativity = site.specie.X if site.specie.X else 0.0
            atomic_radius = site.specie.atomic_radius if site.specie.atomic_radius else 0.0
            atomic_mass = site.specie.atomic_mass
            # Add these properties to feature vector
            additional_feat = np.array([electronegativity, atomic_radius, atomic_mass])
            atomic_feat = np.concatenate([atomic_feat, additional_feat])
        except:
            # If property not available, pad with zeros
            atomic_feat = np.concatenate([atomic_feat, np.zeros(3)])

        atom_features.append(atomic_feat)

    # Get bonds using CrystalNN
    cnn = CrystalNN(search_cutoff=cutoff, weighted_cn=True)

    # Get bonds
    for i, site in enumerate(all_sites):
        # Get nearest neighbors
        try:
            nn_info = cnn.get_nn_info(structure, i)

            for neighbor in nn_info:
                j = neighbor['site_index']

                # Skip self-loops
                if i == j:
                    continue

                # Get bond distance
                distance = neighbor['weight']  # Using weight from CrystalNN

                # Get displacement vector
                center_coords = all_sites[i].coords
                neigh_coords = neighbor['site'].coords

                # Calculate unit vector of displacement
                delta = neigh_coords - center_coords
                delta_norm = np.linalg.norm(delta)
                if delta_norm > 1e-6:  # Avoid division by zero
                    unit_vec = delta / delta_norm
                else:
                    unit_vec = np.zeros(3)

                # Create edge feature
                edge_feat = np.concatenate([
                    [distance],
                    unit_vec,
                ])

                # Add edges and features
                src_atoms.append(i)
                dst_atoms.append(j)
                edge_features.append(edge_feat)
        except Exception as e:
            print(f"Error building graph for site {i}: {e}")

    # Create atom graph
    atom_graph = dgl.graph((src_atoms, dst_atoms), num_nodes=num_atoms)

    # Convert lists to tensors
    atom_features = torch.tensor(np.array(atom_features), dtype=torch.float32)
    edge_features = torch.tensor(np.array(edge_features), dtype=torch.float32)

    # Create bond graph (line graph of atom graph)
    bond_graph = dgl.line_graph(atom_graph, backtracking=False)

    # Pad edge features to ensure constant dimension
    edge_feat_dim = 4  # distance + 3D unit vector
    pad_size = edge_feat_dim - edge_features.shape[1]
    if pad_size > 0:
        edge_features = torch.cat([edge_features, torch.zeros(edge_features.shape[0], pad_size)], dim=1)

    return atom_graph, bond_graph, atom_features, edge_features

# Define the ALIGNN model
class ALIGNNConv(nn.Module):
    def __init__(self, in_dim, out_dim, edge_dim):
        super().__init__()
        self.atom_conv = GraphConv(in_dim, out_dim)
        self.bond_conv = EdgeConv(edge_dim, out_dim)
        self.update_bond = nn.Sequential(
            nn.Linear(out_dim * 2 + edge_dim, out_dim),
            nn.ReLU(),
            nn.Linear(out_dim, out_dim)
        )
        self.update_atom = nn.Sequential(
            nn.Linear(out_dim * 2, out_dim),
            nn.ReLU(),
            nn.Linear(out_dim, out_dim)
        )

    def forward(self, atom_graph, bond_graph, atom_feat, bond_feat):
        # Update atom features using atom graph
        new_atom_feat = self.atom_conv(atom_graph, atom_feat)

        # Update bond features using bond graph
        new_bond_feat = self.bond_conv(bond_graph, bond_feat)

        # Update bond features using connected atoms
        atom_in = atom_feat[atom_graph.edges()[0]]
        atom_out = atom_feat[atom_graph.edges()[1]]
        edge_inputs = torch.cat([atom_in, atom_out, bond_feat], dim=1)
        bond_feat = bond_feat + self.update_bond(edge_inputs)

        # Update atom features using connected bonds
        atom_update = torch.zeros_like(atom_feat)
        bond_sum = dgl.sum_edges(atom_graph, bond_feat)
        atom_update = torch.cat([atom_feat, bond_sum], dim=1)
        atom_feat = atom_feat + self.update_atom(atom_update)

        return atom_feat, bond_feat

class ALIGNN(nn.Module):
    def __init__(self, node_feat_dim=97, edge_feat_dim=4, hidden_dim=128, num_layers=4, output_dim=64):
        super().__init__()
        self.node_embed = nn.Linear(node_feat_dim, hidden_dim)
        self.edge_embed = nn.Linear(edge_feat_dim, hidden_dim)

        self.alignn_layers = nn.ModuleList()
        for _ in range(num_layers):
            self.alignn_layers.append(ALIGNNConv(hidden_dim, hidden_dim, hidden_dim))

        self.readout = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, atom_graph, bond_graph, atom_feat, bond_feat):
        # Initial embedding
        atom_feat = self.node_embed(atom_feat)
        bond_feat = self.edge_embed(bond_feat)

        # Apply ALIGNN layers
        for layer in self.alignn_layers:
            atom_feat, bond_feat = layer(atom_graph, bond_graph, atom_feat, bond_feat)

        # Global pooling
        global_feat = dgl.mean_nodes(atom_graph, atom_feat)

        # Final readout
        output = self.readout(global_feat)

        return output, atom_feat, bond_feat

In [None]:
# Create a PyTorch dataset for the ALIGNN model
class ALIGNNHeuslerDataset(Dataset):
    def __init__(self, structures):
        self.structures = structures

    def __len__(self):
        return len(self.structures)

    def __getitem__(self, idx):
        structure = self.structures[idx]
        atom_graph, bond_graph, atom_feat, edge_feat = build_crystal_graph(structure)
        return atom_graph, bond_graph, atom_feat, edge_feat

# Collate function for batching
def collate_fn(batch):
    atom_graphs = [item[0] for item in batch]
    bond_graphs = [item[1] for item in batch]
    atom_feats = [item[2] for item in batch]
    edge_feats = [item[3] for item in batch]

    batched_atom_graph = dgl.batch(atom_graphs)
    batched_bond_graph = dgl.batch(bond_graphs)
    batched_atom_feat = torch.cat(atom_feats, dim=0)
    batched_edge_feat = torch.cat(edge_feats, dim=0)

    return batched_atom_graph, batched_bond_graph, batched_atom_feat, batched_edge_feat

In [None]:
# Function to predict properties based on structural embeddings
def predict_properties(embeddings, structure_info):
    """Predict properties using k-nearest neighbors in the embedding space"""

    # Properties to predict
    properties = ['volume', 'density', 'avg_atomic_radius', 'avg_electronegativity']

    # Create a clean dataframe with only numeric columns
    clean_df = structure_info[properties].copy()

    # Check for and handle missing values
    for prop in properties:
        if clean_df[prop].isna().any():
            print(f"Warning: {clean_df[prop].isna().sum()} missing values in {prop}")
            # Fill missing values with median
            clean_df[prop] = clean_df[prop].fillna(clean_df[prop].median())

    # Results dictionary
    results = {}

    # Prediction for each property
    for prop in properties:
        print(f"\nPredicting {prop} based on structural similarity...")

        # Split data
        X_train, X_test, y_train, y_test = train_test_split(
            embeddings, clean_df[prop], test_size=0.2, random_state=42
        )

        # Try different k values
        k_values = [3, 5, 7, 9]
        mae_scores = []
        r2_scores = []

        for k in k_values:
            # Train KNN model
            knn = KNeighborsRegressor(n_neighbors=k)
            knn.fit(X_train, y_train)

            # Predict
            y_pred = knn.predict(X_test)

            # Calculate metrics
            mae = mean_absolute_error(y_test, y_pred)
            r2 = r2_score(y_test, y_pred)

            # Store scores
            mae_scores.append(mae)
            r2_scores.append(r2)

            print(f"k={k}, MAE={mae:.4f}, R²={r2:.4f}")

        # Find best k
        best_k_idx = r2_scores.index(max(r2_scores))
        best_k = k_values[best_k_idx]
        print(f"Best k for {prop}: {best_k} (R²={r2_scores[best_k_idx]:.4f})")

        # Train final model with best k
        final_model = KNeighborsRegressor(n_neighbors=best_k)
        final_model.fit(X_train, y_train)

        # Predict on test set
        y_pred = final_model.predict(X_test)

        # Plot actual vs predicted
        plt.figure(figsize=(10, 6))
        plt.scatter(y_test, y_pred, alpha=0.7)
        plt.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], 'r--')
        plt.xlabel(f'Actual {prop}')
        plt.ylabel(f'Predicted {prop}')
        plt.title(f'Actual vs Predicted {prop} (k={best_k}, R²={r2_scores[best_k_idx]:.4f})')
        plt.grid(True, alpha=0.3)
        plt.show()

        # Store results
        results[prop] = {
            'best_k': best_k,
            'mae': mae_scores[best_k_idx],
            'r2': r2_scores[best_k_idx],
            'model': final_model
        }

    # Summary
    print("\nProperty Prediction Summary:")
    for prop, res in results.items():
        print(f"{prop}: R²={res['r2']:.4f}, MAE={res['mae']:.4f}, best k={res['best_k']}")

    return results

# Predict properties
print("Predicting properties based on structural embeddings...")
property_predictions = predict_properties(embeddings, structure_info)