# For pt / site specific

In [12]:
import torch
import numpy as np
from glob import glob
import os

graphs_to_investigate = 'C:/Users/Tseh/Documents/Files/HEA/hea_project/train_data_graphs/graphs_site_specific/fives'


def examine_graph_details(graph_dir=graphs_to_investigate, num_examples=1):
    
    pt_files = glob(os.path.join(graph_dir, "*.pt"))
    
    if len(pt_files) == 0:
        print("No .pt files found!")
        return
    
    print(f"=== EXAMINING {num_examples} EXAMPLE GRAPHS ===\n")
    
    for i, pt_file in enumerate(pt_files[:num_examples]):
        graph = torch.load(pt_file)
        filename = os.path.basename(pt_file)
        
        print(f"Example {i+1}: {filename}")
        print(f"Composition: {graph.composition if hasattr(graph, 'composition') else 'Unknown'}")
        print(f"Adsorbate: {graph.ads if hasattr(graph, 'ads') else 'Unknown'}")
        print(f"Target Energy: {float(graph.y[0]):.4f} eV" if hasattr(graph, 'y') else "No target")
        print(f"One-hot labels: {graph.onehot_labels}")
        
        print(f"\n--- NODE FEATURES (graph.x) ---")
        print(f"Shape: {graph.x.shape} (nodes × features)")
        print(f"Node feature matrix:")
        node_matrix = graph.x.numpy()
        
        for node_idx, features in enumerate(node_matrix):
            element_idx = np.where(features[:-2] == 1)[0]
            if len(element_idx) > 0:
                element = graph.onehot_labels[element_idx[0]]
            else:
                element = "Unknown"
            
            tag = int(features[-2])
            aoi = int(features[-1])
            
            print(f"  Node {node_idx:2d}: {features} -> {element} (tag={tag}, aoi={aoi})")
        
        print(f"\n--- EDGE CONNECTIONS (graph.edge_index) ---")
        print(f"Shape: {graph.edge_index.shape} (2 × edges)")
        print(f"Edge connections:")
        edges = graph.edge_index.numpy()
        
        for edge_idx in range(edges.shape[1]):
            src, dst = edges[0, edge_idx], edges[1, edge_idx]
            src_element = get_element_from_node(node_matrix[src], graph.onehot_labels)
            dst_element = get_element_from_node(node_matrix[dst], graph.onehot_labels)
            print(f"  Edge {edge_idx:2d}: Node {src}({src_element}) <-> Node {dst}({dst_element})")
        
        
        print(f"\n--- SUMMARY FOR THIS GRAPH ---")
        elements_in_graph = []
        adsorbate_nodes = []
        surface_nodes = []
        
        for node_idx, features in enumerate(node_matrix):
            element = get_element_from_node(features, graph.onehot_labels)
            tag = int(features[-2])
            elements_in_graph.append(element)
            
            if tag == 0:  # adsorbate
                adsorbate_nodes.append(f"Node {node_idx}({element})")
            elif tag == 1:  # surface
                surface_nodes.append(f"Node {node_idx}({element})")
        
        print(f"  Elements present: {list(set(elements_in_graph))}")
        print(f"  Adsorbate nodes: {adsorbate_nodes}")
        print(f"  Surface nodes: {surface_nodes}")
        print(f"  Total connectivity: {edges.shape[1]} edges between {len(node_matrix)} nodes")
        
        print("\n" + "="*80 + "\n")

def get_element_from_node(node_features, onehot_labels):
    element_idx = np.where(node_features[:-2] == 1)[0]
    if len(element_idx) > 0:
        return onehot_labels[element_idx[0]]
    return "Unknown"

def show_training_data_format(graph_dir=graphs_to_investigate):
    
    pt_files = glob(os.path.join(graph_dir, "*.pt"))
    if len(pt_files) == 0:
        return
    
    print("=== TRAINING DATA FORMAT ===\n")
    
    graph = torch.load(pt_files[0])
    
    print("What the model receives for each training sample:")
    print(f"1. Node features (X): tensor of shape {graph.x.shape}")
    print(f"   - Each row is one atom with {graph.x.shape[1]} features")
    print(f"   - Features: one-hot element encoding + tag + AOI flag")
    
    print(f"\n2. Edge indices: tensor of shape {graph.edge_index.shape}")
    print(f"   - Defines which atoms are connected")
    print(f"   - Format: [[source_nodes], [target_nodes]]")
    
    print(f"\n3. Target value (y): {graph.y}")
    print(f"   - Single number: averaged adsorption energy")
    
    print(f"\n4. Batch information gets added during training")
    print(f"   - Batch tensor indicates which nodes belong to which graph")
    
    print("\nExample of what goes into model.forward():")
    print(f"  data.x shape: {graph.x.shape}")
    print(f"  data.edge_index shape: {graph.edge_index.shape}")
    print(f"  data.y: {float(graph.y[0]):.4f}")
    
    print(f"\nActual arrays that would be processed:")
    print(f"Node features (first 3 nodes):")
    for i in range(min(3, graph.x.shape[0])):
        print(f"  Node {i}: {graph.x[i].numpy()}")
    
    print(f"\nEdge connections (first 5 edges):")
    for i in range(min(5, graph.edge_index.shape[1])):
        src, dst = graph.edge_index[0, i], graph.edge_index[1, i]
        print(f"  Edge {i}: {src} -> {dst}")

if __name__ == "__main__":
    print("CHECKING NEW CODE FROM train_data_graphs") 
    examine_graph_details(num_examples=2)
    print("\n")
    show_training_data_format()

CHECKING NEW CODE FROM train_data_graphs
=== EXAMINING 2 EXAMPLE GRAPHS ===

Example 1: CoNiCuZnAg_H_site0.pt
Composition: ['Co', 'Ni', 'Cu', 'Zn', 'Ag']
Adsorbate: H
Target Energy: -1.6650 eV
One-hot labels: ['Ag', 'Co', 'Cr', 'Cu', 'Fe', 'H', 'Mn', 'Ni', 'S', 'Zn', 'Zr']

--- NODE FEATURES (graph.x) ---
Shape: torch.Size([19, 13]) (nodes × features)
Node feature matrix:
  Node  0: [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 3. 0.] -> Ni (tag=3, aoi=0)
  Node  1: [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 3. 0.] -> Co (tag=3, aoi=0)
  Node  2: [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 3. 0.] -> Ag (tag=3, aoi=0)
  Node  3: [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 3. 0.] -> Co (tag=3, aoi=0)
  Node  4: [0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 3. 0.] -> Cu (tag=3, aoi=0)
  Node  5: [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 3. 0.] -> Ni (tag=3, aoi=0)
  Node  6: [0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 3. 0.] -> Cu (tag=3, aoi=0)
  Node  7: [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 3. 0.] -> Ni (tag=3, aoi=0)
  Node  8: [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.

  graph = torch.load(pt_file)
  graph = torch.load(pt_files[0])


# For graph extentions OURs

In [None]:
import torch
import numpy as np
from glob import glob
import os
import pickle 

# graphs_to_investigate = 'C:/Users/Tseh/Documents/Files/HEA/hea_project/cheat_related/graphs'
# graphs_to_investigate = 'C:/Users/Tseh/Documents/Files/HEA/hea_project/train_data_graphs/graphs_site_specific'

def get_element_from_node(node_features, onehot_labels):
    element_idx = np.where(node_features[:-2] == 1)[0]
    if len(element_idx) > 0:
        return onehot_labels[element_idx[0]]
    return "Unknown"


def examine_graph_details_by_file(graph_dir=graphs_to_investigate, examples_per_file=2):
    """
    Examine multiple examples from each .graphs file separately
    """
    graph_files = glob(os.path.join(graph_dir, "*.graphs"))
    
    if len(graph_files) == 0:
        print("No .graphs files found!")
        return
    
    print(f"=== EXAMINING {examples_per_file} EXAMPLES FROM EACH FILE ===\n")
    
    for file_idx, graph_file in enumerate(graph_files):
        filename = os.path.basename(graph_file)
        print(f"{'='*20} FILE: {filename} {'='*20}")
        
        with open(graph_file, 'rb') as f:
            data = pickle.load(f)
        
        if isinstance(data, list):
            graphs_in_file = data
            print(f"This file contains {len(graphs_in_file)} graphs")
        else:
            graphs_in_file = [data]
            print(f"This file contains 1 graph")
        
        # Sample examples_per_file graphs from this file
        num_to_sample = min(examples_per_file, len(graphs_in_file))
        
        # Take evenly spaced samples to get variety
        if len(graphs_in_file) > examples_per_file:
            step = len(graphs_in_file) // examples_per_file
            sample_indices = [i * step for i in range(examples_per_file)]
        else:
            sample_indices = list(range(len(graphs_in_file)))
        
        for sample_idx, graph_idx in enumerate(sample_indices):
            graph = graphs_in_file[graph_idx]
            
            print(f"\n--- Example {sample_idx+1} (Graph {graph_idx} from {filename}) ---")
            
            # Try to access adsorbate information
            try:
                if hasattr(graph, 'ads'):
                    ads = graph.ads
                    print(f"Adsorbate: {ads}")
                else:
                    print("Adsorbate: Not found in graph attributes")
            except Exception as e:
                print(f"Adsorbate: Error accessing - {e}")
            
            # Try to access target energy
            try:
                if hasattr(graph, 'y'):
                    y_value = graph.y
                    if torch.is_tensor(y_value):
                        print(f"Target Energy: {float(y_value.item()):.4f} eV")
                    else:
                        print(f"Target Energy: {float(y_value):.4f} eV")
                else:
                    print("Target Energy: No 'y' attribute found")
            except Exception as e:
                print(f"Target Energy: Error accessing - {e}")
            
            # Examine node features to see metal composition
            try:
                if hasattr(graph, 'x') and hasattr(graph, 'onehot_labels'):
                    x = graph.x
                    print(f"Graph has {x.shape[0]} nodes")
                    
                    # Convert to numpy if it's a tensor
                    if torch.is_tensor(x):
                        node_matrix = x.detach().numpy()
                    else:
                        node_matrix = np.array(x)
                    
                    # Count elements in this graph
                    element_counts = {}
                    for node_idx in range(len(node_matrix)):
                        features = node_matrix[node_idx]
                        # Element features are all except last 2 (tag and aoi)
                        element_features = features[:-2]
                        element_idx = np.where(element_features == 1)[0]
                        if len(element_idx) > 0:
                            element = graph.onehot_labels[element_idx[0]]
                            element_counts[element] = element_counts.get(element, 0) + 1
                    
                    print(f"Metal composition: {element_counts}")
                    
                    # Show a few node examples
                    print("Sample nodes:")
                    for node_idx in range(len(node_matrix)):
                        features = node_matrix[node_idx]
                        element_features = features[:-2]
                        element_idx = np.where(element_features == 1)[0]
                        if len(element_idx) > 0:
                            element = graph.onehot_labels[element_idx[0]]
                            tag = int(features[-2])
                            aoi = int(features[-1])
                            print(f"  Node {node_idx}: {element} (tag={tag}, aoi={aoi})")
                
                else:
                    print("Cannot examine composition - missing 'x' or 'onehot_labels'")
            except Exception as e:
                print(f"Error examining composition: {e}")
            
            # Show edge information
            try:
                if hasattr(graph, 'edge_index'):
                    edge_index = graph.edge_index
                    num_edges = edge_index.shape[1] if len(edge_index.shape) > 1 else 0
                    print(f"Number of edges: {num_edges}")
                else:
                    print("No edge information found")
            except Exception as e:
                print(f"Error examining edges: {e}")

            try:
                print(f"\n--- EDGE CONNECTIONS (graph.edge_index) ---")
                print(f"Shape: {graph.edge_index.shape} (2 × edges)")
                print(f"Edge connections:")
                edges = graph.edge_index.numpy()
            
                for edge_idx in range(edges.shape[1]):
                    src, dst = edges[0, edge_idx], edges[1, edge_idx]
                    src_element = get_element_from_node(node_matrix[src], graph.onehot_labels)
                    dst_element = get_element_from_node(node_matrix[dst], graph.onehot_labels)
                    print(f"  Edge {edge_idx:2d}: Node {src}({src_element}) <-> Node {dst}({dst_element})")
            except Exception as e:
                print(f"Error examining edges: {e}")
        
        print(f"\n{'='*60}\n")

def show_dataset_summary(graph_dir=graphs_to_investigate):
    """
    Show a summary of all files and their contents
    """
    graph_files = glob(os.path.join(graph_dir, "*.graphs"))
    
    if len(graph_files) == 0:
        print("No .graphs files found!")
        return
    
    print("=== DATASET SUMMARY ===\n")
    
    total_graphs = 0
    all_adsorbates = set()
    all_elements = set()
    
    for graph_file in graph_files:
        filename = os.path.basename(graph_file)
        
        with open(graph_file, 'rb') as f:
            data = pickle.load(f)
        
        if isinstance(data, list):
            graphs = data
        else:
            graphs = [data]
        
        print(f"{filename}: {len(graphs)} graphs")
        total_graphs += len(graphs)
        
        # Sample a few graphs to get variety info
        sample_size = len(graphs)
        step = len(graphs) // sample_size if len(graphs) > sample_size else 1
        sample_indices = [i * step for i in range(sample_size)]
        
        file_adsorbates = set()
        file_elements = set()
        
        for idx in sample_indices:
            graph = graphs[idx]
            
            # Get adsorbate
            if hasattr(graph, 'ads'):
                file_adsorbates.add(graph.ads)
                all_adsorbates.add(graph.ads)
            
            # Get elements
            if hasattr(graph, 'x') and hasattr(graph, 'onehot_labels'):
                try:
                    x = graph.x
                    if torch.is_tensor(x):
                        node_matrix = x.detach().numpy()
                    else:
                        node_matrix = np.array(x)
                    
                    for node_idx in range(len(node_matrix)):
                        features = node_matrix[node_idx]
                        element_features = features[:-2]
                        element_idx = np.where(element_features == 1)[0]
                        if len(element_idx) > 0:
                            element = graph.onehot_labels[element_idx[0]]
                            file_elements.add(element)
                            all_elements.add(element)
                except:
                    pass
        
        print(f"  Sample adsorbates: {sorted(file_adsorbates)}")
        print(f"  Sample elements: {sorted(file_elements)}")
        print()


    print(f"TOTAL DATASET:")
    print(f"  Total graphs: {total_graphs}")
    print(f"  All adsorbates found: {sorted(all_adsorbates)}")
    print(f"  All elements found: {sorted(all_elements)}")
            


if __name__ == "__main__":
    print("EXAMINING PICKLED GRAPH FILES") 
    print(f"Looking in directory: {graphs_to_investigate}\n")
    
    # Show overall dataset summary
    show_dataset_summary()
    
    print("\n" + "="*80 + "\n")
    
    # Examine 5 examples from each file
    examine_graph_details_by_file(examples_per_file=2)

EXAMINING PICKLED GRAPH FILES
Looking in directory: C:/Users/Tseh/Documents/Files/HEA/hea_project/cheat_related/graphs

=== DATASET SUMMARY ===

test.graphs: 479 graphs
  Sample adsorbates: ['O', 'OH']
  Sample elements: ['Ag', 'H', 'Ir', 'O', 'Pd', 'Pt', 'Ru']

train.graphs: 3926 graphs
  Sample adsorbates: ['O', 'OH']
  Sample elements: ['Ag', 'H', 'Ir', 'O', 'Pd', 'Pt', 'Ru']

val.graphs: 497 graphs
  Sample adsorbates: ['O', 'OH']
  Sample elements: ['Ag', 'H', 'Ir', 'O', 'Pd', 'Pt', 'Ru']

TOTAL DATASET:
  Total graphs: 4902
  All adsorbates found: ['O', 'OH']
  All elements found: ['Ag', 'H', 'Ir', 'O', 'Pd', 'Pt', 'Ru']


=== EXAMINING 2 EXAMPLES FROM EACH FILE ===

This file contains 479 graphs

--- Example 1 (Graph 0 from test.graphs) ---
Adsorbate: OH
Target Energy: 0.2817 eV
Graph has 18 nodes
Metal composition: {'Pd': 2, 'Pt': 3, 'Ru': 8, 'Ir': 2, 'Ag': 1, 'O': 1, 'H': 1}
Sample nodes:
  Node 0: Pd (tag=3, aoi=1)
  Node 1: Pt (tag=3, aoi=0)
  Node 2: Ru (tag=3, aoi=1)
  Nod

# For graph extensions CHEAT

In [10]:
import torch
import numpy as np
from glob import glob
import os
import pickle 

# graphs_to_investigate = 'C:/Users/Tseh/Documents/Files/HEA/hea_project/cheat_related/graphs'
graphs_type = 'train.graphs'
graphs_to_investigate = 'C:/Users/Tseh/Documents/Files/HEA/hea_project/train_data_graphs/graphs'

def examine_all_train_graphs(graph_dir=graphs_to_investigate):
    """
    Examine ALL graphs in the train.graphs file
    """
    train_file = os.path.join(graph_dir, graphs_type)
    
    if not os.path.exists(train_file):
        print("train.graphs file not found!")
        return
    
    print("=== EXAMINING ALL GRAPHS IN TRAIN.GRAPHS ===\n")
    
    with open(train_file, 'rb') as f:
        data = pickle.load(f)
    
    if isinstance(data, list):
        graphs = data
    else:
        graphs = [data]
    
    print(f"Total graphs in train.graphs: {len(graphs)}\n")
    
    # Track statistics
    adsorbate_counts = {}
    element_sets = {}
    energy_ranges = {}
    
    for graph_idx, graph in enumerate(graphs):
        print(f"--- Graph {graph_idx + 1}/{len(graphs)} ---")
        
        # Get adsorbate
        adsorbate = "Unknown"
        try:
            if hasattr(graph, 'ads'):
                adsorbate = graph.ads
                print(f"Adsorbate: {adsorbate}")
                adsorbate_counts[adsorbate] = adsorbate_counts.get(adsorbate, 0) + 1
            else:
                print("Adsorbate: Not found")
        except Exception as e:
            print(f"Adsorbate: Error - {e}")
        
        # Get target energy
        energy = None
        try:
            if hasattr(graph, 'y'):
                y_value = graph.y
                if torch.is_tensor(y_value):
                    energy = float(y_value.item())
                else:
                    energy = float(y_value)
                print(f"Target Energy: {energy:.4f} eV")
                
                # Track energy ranges per adsorbate
                if adsorbate not in energy_ranges:
                    energy_ranges[adsorbate] = []
                energy_ranges[adsorbate].append(energy)
            else:
                print("Target Energy: Not found")
        except Exception as e:
            print(f"Target Energy: Error - {e}")
        
        # Get composition
        try:
            if hasattr(graph, 'x') and hasattr(graph, 'onehot_labels'):
                x = graph.x
                print(f"Number of nodes: {x.shape[0]}")
                
                # Convert to numpy if it's a tensor
                if torch.is_tensor(x):
                    node_matrix = x.detach().numpy()
                else:
                    node_matrix = np.array(x)
                
                # Count elements
                element_counts = {}
                for node_idx in range(len(node_matrix)):
                    features = node_matrix[node_idx]
                    element_features = features[:-2]  # All except last 2 (tag and aoi)
                    element_idx = np.where(element_features == 1)[0]
                    if len(element_idx) > 0:
                        element = graph.onehot_labels[element_idx[0]]
                        element_counts[element] = element_counts.get(element, 0) + 1
                
                composition_str = ", ".join([f"{elem}:{count}" for elem, count in sorted(element_counts.items())])
                print(f"Composition: {composition_str}")
                
                # Track unique element combinations per adsorbate
                element_set = frozenset(element_counts.keys())
                if adsorbate not in element_sets:
                    element_sets[adsorbate] = set()
                element_sets[adsorbate].add(element_set)
                
            else:
                print("Composition: Cannot determine")
        except Exception as e:
            print(f"Composition: Error - {e}")
        
        # Get edge count
        try:
            if hasattr(graph, 'edge_index'):
                edge_index = graph.edge_index
                num_edges = edge_index.shape[1] if len(edge_index.shape) > 1 else 0
                print(f"Number of edges: {num_edges}")
            else:
                print("Edges: Not found")
        except Exception as e:
            print(f"Edges: Error - {e}")
        
        print()  # Empty line between graphs
    
    # Print summary statistics
    print("\n" + "="*80)
    print("SUMMARY STATISTICS")
    print("="*80)
    
    print(f"\nAdsorbate distribution:")
    for ads, count in sorted(adsorbate_counts.items()):
        print(f"  {ads}: {count} graphs")
    
    print(f"\nEnergy ranges by adsorbate:")
    for ads, energies in sorted(energy_ranges.items()):
        if energies:
            min_e, max_e = min(energies), max(energies)
            avg_e = sum(energies) / len(energies)
            print(f"  {ads}: {min_e:.4f} to {max_e:.4f} eV (avg: {avg_e:.4f} eV)")
    
    print(f"\nUnique element combinations by adsorbate:")
    for ads, element_combos in sorted(element_sets.items()):
        print(f"  {ads}: {len(element_combos)} unique combinations")
        for combo in sorted(element_combos):
            combo_str = ", ".join(sorted(combo))
            print(f"    - {combo_str}")

def examine_specific_graphs(graph_dir=graphs_to_investigate, graph_indices=None):
    """
    Examine specific graphs by index from train.graphs
    Usage: examine_specific_graphs(graph_indices=[0, 10, 50, 100])
    """
    train_file = os.path.join(graph_dir, graphs_type)
    
    if not os.path.exists(train_file):
        print("train.graphs file not found!")
        return
    
    with open(train_file, 'rb') as f:
        data = pickle.load(f)
    
    if isinstance(data, list):
        graphs = data
    else:
        graphs = [data]
    
    if graph_indices is None:
        graph_indices = list(range(len(graphs)))
    
    print(f"=== EXAMINING SPECIFIC GRAPHS FROM TRAIN.GRAPHS ===")
    print(f"Total graphs available: {len(graphs)}")
    print(f"Examining indices: {graph_indices}\n")
    
    for i, graph_idx in enumerate(graph_indices):
        if graph_idx >= len(graphs):
            print(f"Index {graph_idx} is out of range (max: {len(graphs)-1})")
            continue
            
        graph = graphs[graph_idx]
        print(f"--- Graph {graph_idx} (#{i+1} of requested) ---")
        
        # Same examination code as above but for specific graphs
        # [Include the same examination logic here]
        try:
            if hasattr(graph, 'ads'):
                print(f"Adsorbate: {graph.ads}")
            
            if hasattr(graph, 'y'):
                y_value = graph.y
                if torch.is_tensor(y_value):
                    energy = float(y_value.item())
                else:
                    energy = float(y_value)
                print(f"Target Energy: {energy:.4f} eV")
            
            if hasattr(graph, 'x') and hasattr(graph, 'onehot_labels'):
                x = graph.x
                print(f"Number of nodes: {x.shape[0]}")
                
                if torch.is_tensor(x):
                    node_matrix = x.detach().numpy()
                else:
                    node_matrix = np.array(x)
                
                element_counts = {}
                for node_idx in range(len(node_matrix)):
                    features = node_matrix[node_idx]
                    element_features = features[:-2]
                    element_idx = np.where(element_features == 1)[0]
                    if len(element_idx) > 0:
                        element = graph.onehot_labels[element_idx[0]]
                        element_counts[element] = element_counts.get(element, 0) + 1
                
                composition_str = ", ".join([f"{elem}:{count}" for elem, count in sorted(element_counts.items())])
                print(f"Composition: {composition_str}")
            
            if hasattr(graph, 'edge_index'):
                edge_index = graph.edge_index
                num_edges = edge_index.shape[1] if len(edge_index.shape) > 1 else 0
                print(f"Number of edges: {num_edges}")
                
        except Exception as e:
            print(f"Error examining graph {graph_idx}: {e}")
        
        print()

if __name__ == "__main__":
    print("EXAMINING ALL TRAINING GRAPHS") 
    print(f"Looking in directory: {graphs_to_investigate}\n")
    
    # Examine all graphs in train.graphs
    examine_all_train_graphs()
    
    # Alternative: If you want to examine just specific graphs by index, uncomment below:
    # examine_specific_graphs(graph_indices=[0, 10, 50, 100, 200])

EXAMINING ALL TRAINING GRAPHS
Looking in directory: C:/Users/Tseh/Documents/Files/HEA/hea_project/train_data_graphs/graphs

=== EXAMINING ALL GRAPHS IN TRAIN.GRAPHS ===

Total graphs in train.graphs: 1584

--- Graph 1/1584 ---
Adsorbate: H
Target Energy: -2.5139 eV
Number of nodes: 10
Composition: Ag:1, Cr:4, H:1, Ni:1, Zn:2, Zr:1
Number of edges: 50

--- Graph 2/1584 ---
Adsorbate: H
Target Energy: -2.4486 eV
Number of nodes: 10
Composition: Ag:1, Cr:4, H:1, Ni:1, Zn:2, Zr:1
Number of edges: 50

--- Graph 3/1584 ---
Adsorbate: H
Target Energy: -2.4778 eV
Number of nodes: 10
Composition: Ag:1, Cr:4, H:1, Ni:1, Zn:2, Zr:1
Number of edges: 50

--- Graph 4/1584 ---
Adsorbate: H
Target Energy: -2.4341 eV
Number of nodes: 10
Composition: Ag:1, Cr:4, H:1, Ni:1, Zn:2, Zr:1
Number of edges: 50

--- Graph 5/1584 ---
Adsorbate: H
Target Energy: -2.4979 eV
Number of nodes: 10
Composition: Ag:1, Cr:4, H:1, Ni:1, Zn:2, Zr:1
Number of edges: 50

--- Graph 6/1584 ---
Adsorbate: H
Target Energy: -2.53

# For pt / averaged

In [9]:
import torch
import numpy as np
from glob import glob
import os

graphs_to_investigate = 'C:/Users/Tseh/Documents/Files/HEA/hea_project/train_data_graphs/graphs_averaged/triplets'


def examine_graph_details(graph_dir=graphs_to_investigate, num_examples=1):
    
    pt_files = glob(os.path.join(graph_dir, "*.pt"))
    
    if len(pt_files) == 0:
        print("No .pt files found!")
        return
    
    print(f"=== EXAMINING {num_examples} EXAMPLE GRAPHS ===\n")
    
    for i, pt_file in enumerate(pt_files[:num_examples]):
        graph = torch.load(pt_file)
        filename = os.path.basename(pt_file)
        
        print(f"Example {i+1}: {filename}")
        print(f"Composition: {graph.composition if hasattr(graph, 'composition') else 'Unknown'}")
        print(f"Adsorbate: {graph.ads if hasattr(graph, 'ads') else 'Unknown'}")
        print(f"Target Energy: {float(graph.y[0]):.4f} eV" if hasattr(graph, 'y') else "No target")
        print(f"One-hot labels: {graph.onehot_labels}")
        
        print(f"\n--- NODE FEATURES (graph.x) ---")
        print(f"Shape: {graph.x.shape} (nodes × features)")
        print(f"Node feature matrix:")
        node_matrix = graph.x.numpy()
        
        for node_idx, features in enumerate(node_matrix):
            element_idx = np.where(features[:-2] == 1)[0]
            if len(element_idx) > 0:
                element = graph.onehot_labels[element_idx[0]]
            else:
                element = "Unknown"
            
            tag = int(features[-2])
            aoi = int(features[-1])
            
            print(f"  Node {node_idx:2d}: {features} -> {element} (tag={tag}, aoi={aoi})")
        
        print(f"\n--- EDGE CONNECTIONS (graph.edge_index) ---")
        print(f"Shape: {graph.edge_index.shape} (2 × edges)")
        print(f"Edge connections:")
        edges = graph.edge_index.numpy()
        
        for edge_idx in range(edges.shape[1]):
            src, dst = edges[0, edge_idx], edges[1, edge_idx]
            src_element = get_element_from_node(node_matrix[src], graph.onehot_labels)
            dst_element = get_element_from_node(node_matrix[dst], graph.onehot_labels)
            print(f"  Edge {edge_idx:2d}: Node {src}({src_element}) <-> Node {dst}({dst_element})")
        
        
        print(f"\n--- SUMMARY FOR THIS GRAPH ---")
        elements_in_graph = []
        adsorbate_nodes = []
        surface_nodes = []
        
        for node_idx, features in enumerate(node_matrix):
            element = get_element_from_node(features, graph.onehot_labels)
            tag = int(features[-2])
            elements_in_graph.append(element)
            
            if tag == 0:  # adsorbate
                adsorbate_nodes.append(f"Node {node_idx}({element})")
            elif tag == 1:  # surface
                surface_nodes.append(f"Node {node_idx}({element})")
        
        print(f"  Elements present: {list(set(elements_in_graph))}")
        print(f"  Adsorbate nodes: {adsorbate_nodes}")
        print(f"  Surface nodes: {surface_nodes}")
        print(f"  Total connectivity: {edges.shape[1]} edges between {len(node_matrix)} nodes")
        
        print("\n" + "="*80 + "\n")

def get_element_from_node(node_features, onehot_labels):
    element_idx = np.where(node_features[:-2] == 1)[0]
    if len(element_idx) > 0:
        return onehot_labels[element_idx[0]]
    return "Unknown"

def show_training_data_format(graph_dir=graphs_to_investigate):
    
    pt_files = glob(os.path.join(graph_dir, "*.pt"))
    if len(pt_files) == 0:
        return
    
    print("=== TRAINING DATA FORMAT ===\n")
    
    graph = torch.load(pt_files[0])
    
    print("What the model receives for each training sample:")
    print(f"1. Node features (X): tensor of shape {graph.x.shape}")
    print(f"   - Each row is one atom with {graph.x.shape[1]} features")
    print(f"   - Features: one-hot element encoding + tag + AOI flag")
    
    print(f"\n2. Edge indices: tensor of shape {graph.edge_index.shape}")
    print(f"   - Defines which atoms are connected")
    print(f"   - Format: [[source_nodes], [target_nodes]]")
    
    print(f"\n3. Target value (y): {graph.y}")
    print(f"   - Single number: averaged adsorption energy")
    
    print(f"\n4. Batch information gets added during training")
    print(f"   - Batch tensor indicates which nodes belong to which graph")
    
    print("\nExample of what goes into model.forward():")
    print(f"  data.x shape: {graph.x.shape}")
    print(f"  data.edge_index shape: {graph.edge_index.shape}")
    print(f"  data.y: {float(graph.y[0]):.4f}")
    
    print(f"\nActual arrays that would be processed:")
    print(f"Node features (first 3 nodes):")
    for i in range(min(3, graph.x.shape[0])):
        print(f"  Node {i}: {graph.x[i].numpy()}")
    
    print(f"\nEdge connections (first 5 edges):")
    for i in range(min(5, graph.edge_index.shape[1])):
        src, dst = graph.edge_index[0, i], graph.edge_index[1, i]
        print(f"  Edge {i}: {src} -> {dst}")

if __name__ == "__main__":
    print("CHECKING NEW CODE FROM train_data_graphs") 
    examine_graph_details(num_examples=2)
    print("\n")
    show_training_data_format()

CHECKING NEW CODE FROM train_data_graphs
=== EXAMINING 2 EXAMPLE GRAPHS ===

Example 1: CoCuAg_H.pt
Composition: ['Co', 'Cu', 'Ag']
Adsorbate: H
Target Energy: -1.9003 eV
One-hot labels: ['Ag', 'Co', 'Cr', 'Cu', 'Fe', 'H', 'Mn', 'Ni', 'S', 'Zn', 'Zr']

--- NODE FEATURES (graph.x) ---
Shape: torch.Size([27, 13]) (nodes × features)
Node feature matrix:
  Node  0: [0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.] -> H (tag=0, aoi=0)
  Node  1: [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1.] -> Co (tag=1, aoi=1)
  Node  2: [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 2. 0.] -> Ag (tag=2, aoi=0)
  Node  3: [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 2. 0.] -> Ag (tag=2, aoi=0)
  Node  4: [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 2. 0.] -> Ag (tag=2, aoi=0)
  Node  5: [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 2. 0.] -> Co (tag=2, aoi=0)
  Node  6: [0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1.] -> Cu (tag=1, aoi=1)
  Node  7: [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1.] -> Ag (tag=1, aoi=1)
  Node  8: [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1.] -> Co (tag=1

  graph = torch.load(pt_file)
  graph = torch.load(pt_files[0])


# XYZ to cif

In [2]:
import numpy as np
from pymatgen.core import Structure, Lattice
from pymatgen.io.cif import CifWriter
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

def read_xyz_file_custom(xyz_file):
    """
    Custom XYZ file reader that can handle non-standard formats
    """
    atoms = []
    coords = []
    
    with open(xyz_file, 'r') as f:
        lines = f.readlines()
    
    # Skip first two lines (number of atoms and comment)
    for line in lines[2:]:
        line = line.strip()
        if not line or line.startswith('!'):  # Skip empty lines and lines starting with !
            continue
            
        parts = line.split()
        if len(parts) >= 4:
            element = parts[0]
            x, y, z = float(parts[1]), float(parts[2]), float(parts[3])
            atoms.append(element)
            coords.append([x, y, z])
    
    return atoms, np.array(coords)

def xyz_to_cif_robust(xyz_file, cif_file, lattice_params=None):
    """
    Convert XYZ file to CIF format with robust parsing
    
    Args:
        xyz_file: Path to input XYZ file
        cif_file: Path to output CIF file
        lattice_params: Optional tuple of (a, b, c, alpha, beta, gamma)
                       If None, will attempt to determine automatically
    """
    
    # Read XYZ file with custom parser
    atoms, coords = read_xyz_file_custom(xyz_file)
    
    print(f"Read {len(atoms)} atoms from {xyz_file}")
    
    if lattice_params is None:
        # Attempt to determine unit cell automatically
        min_coords = np.min(coords, axis=0)
        max_coords = np.max(coords, axis=0)
        
        # Add padding (adjust as needed)
        padding = 3.0  # Angstroms
        a = max_coords[0] - min_coords[0] + padding
        b = max_coords[1] - min_coords[1] + padding  
        c = max_coords[2] - min_coords[2] + padding
        
        lattice_params = (a, b, c, 90, 90, 90)  # Assume orthogonal
        print(f"Auto-determined lattice parameters: a={a:.3f}, b={b:.3f}, c={c:.3f}")
    
    # Create lattice
    lattice = Lattice.from_parameters(*lattice_params)
    
    # Create structure
    structure = Structure(
        lattice=lattice,
        species=atoms,
        coords=coords,
        coords_are_cartesian=True
    )
    
    print(f"Created structure with {len(structure)} sites")
    
    try:
        # Analyze symmetry and get primitive cell
        sga = SpacegroupAnalyzer(structure, symprec=0.1)
        primitive_structure = sga.get_primitive_standard_structure()
        
        print(f"Space group: {sga.get_space_group_symbol()}")
        print(f"Space group number: {sga.get_space_group_number()}")
    except Exception as e:
        print(f"Symmetry analysis failed: {e}")
        print("Using original structure without symmetry reduction")
        primitive_structure = structure
    
    # Write CIF
    writer = CifWriter(primitive_structure)
    writer.write_file(cif_file)
    
    print(f"Successfully converted {xyz_file} to {cif_file}")
    print(f"Final lattice parameters: {primitive_structure.lattice.parameters}")
    
    return primitive_structure

# Alternative function with manual lattice specification
def xyz_to_cif_manual_lattice(xyz_file, cif_file, a, b, c, alpha=90, beta=90, gamma=90):
    """
    Convert XYZ to CIF with manually specified lattice parameters
    """
    atoms, coords = read_xyz_file_custom(xyz_file)
    
    lattice = Lattice.from_parameters(a, b, c, alpha, beta, gamma)
    structure = Structure(lattice=lattice, species=atoms, coords=coords, coords_are_cartesian=True)
    
    writer = CifWriter(structure)
    writer.write_file(cif_file)
    
    print(f"Converted with manual lattice: a={a}, b={b}, c={c}")
    return structure

# Usage with automatic lattice determination
try:
    structure = xyz_to_cif_robust(
        "C:/Users/Tseh/Documents/Files/HEA/hea_project/train_data_dft/HEA_results_fives/CoNiCuZnAg/geometry.xyz", 
        "C:/Users/Tseh/Documents/Files/HEA/hea_project/train_data_dft/HEA_results_fives/CoNiCuZnAg/hea.cif"
    )
except Exception as e:
    print(f"Error with automatic conversion: {e}")
    print("Trying with manual lattice parameters...")
    
    # If automatic fails, try with manual lattice parameters
    # Based on your coordinates, these seem reasonable:
    structure = xyz_to_cif_manual_lattice(
        "C:/Users/Tseh/Documents/Files/HEA/hea_project/train_data_dft/HEA_results_fives/CoNiCuZnAg/geometry.xyz",
        "C:/Users/Tseh/Documents/Files/HEA/hea_project/train_data_dft/HEA_results_fives/CoNiCuZnAg/hea_manual.cif",
        a=12.0,  # Adjust based on your structure
        b=8.0,   # Adjust based on your structure  
        c=12.0   # Adjust based on your structure
    )

Read 45 atoms from C:/Users/Tseh/Documents/Files/HEA/hea_project/train_data_dft/HEA_results_fives/CoNiCuZnAg/geometry.xyz
Auto-determined lattice parameters: a=13.027, b=8.789, c=11.187
Created structure with 45 sites
Space group: P1
Space group number: 1
Successfully converted C:/Users/Tseh/Documents/Files/HEA/hea_project/train_data_dft/HEA_results_fives/CoNiCuZnAg/geometry.xyz to C:/Users/Tseh/Documents/Files/HEA/hea_project/train_data_dft/HEA_results_fives/CoNiCuZnAg/hea.cif
Final lattice parameters: (8.789124488999999, 11.187057495, 13.02705574, 90.0, 90.0, 90.0)
