In [4]:
import argparse
import json
import os
import random
import sys

import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import mean_squared_error as MSE
from sklearn.metrics import r2_score as R2
from torch_geometric.nn import Linear, SAGEConv, global_mean_pool
import time
import snntorch as snn
import pickle
# from utils import *


def seed_all(seed):
    '''
    Set random seeds for reproducability
    '''
    if not seed:
        seed = 42
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

import torch
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.utils import spmm
from torch_geometric.typing import Adj, OptPairTensor, Size, SparseTensor
from typing import Union, Tuple, Optional

class SAGEConv_new(MessagePassing):
    def __init__(
        self,
        in_channels: Union[int, Tuple[int, int]],
        out_channels: int,
        aggr: str = "mean",
        normalize: bool = False,
        root_weight: bool = True,
        project: bool = False,
        bias: bool = True,
    ):
        super().__init__(aggr)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize
        self.root_weight = root_weight
        self.project = project

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        if self.project:
            self.lin = Linear(in_channels[0], in_channels[0], bias=True)

        self.lin_l = Linear(in_channels[0], out_channels, bias=bias)
        if self.root_weight:
            self.lin_r = Linear(in_channels[1], out_channels, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        if self.project:
            self.lin.reset_parameters()
        self.lin_l.reset_parameters()
        if self.root_weight:
            self.lin_r.reset_parameters()

    def forward_without_activation(
        self, x: Union[torch.Tensor, OptPairTensor], edge_index: Adj, size: Size = None
    ) -> torch.Tensor:
        """Performs all operations before activation."""
        if isinstance(x, torch.Tensor):
            x = (x, x)

        if self.project and hasattr(self, "lin"):
            x = (self.lin(x[0]), x[1])  # No activation here

        out = self.propagate(edge_index, x=x, size=size)
        out = self.lin_l(out)

        x_r = x[1]
        if self.root_weight and x_r is not None:
            out = out + self.lin_r(x_r)

        return out  # Output before activation

    def apply_activation(self, x: torch.Tensor, activation_fn=torch.relu) -> torch.Tensor:
        """Applies the activation function separately."""
        return activation_fn(x)

    def normalize_output(self, x: torch.Tensor) -> torch.Tensor:
        """Applies normalization separately if needed."""
        if self.normalize:
            return F.normalize(x, p=2.0, dim=-1)
        return x

    def forward(self, x: torch.Tensor, edge_index: Adj, activation_fn=torch.relu):
        """Complete forward pass with separate activation."""
        out = self.forward_without_activation(x, edge_index)
        out = self.apply_activation(out, activation_fn)
        out = self.normalize_output(out)
        return out

    def message(self, x_j: torch.Tensor) -> torch.Tensor:
        return x_j

    def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> torch.Tensor:
        if isinstance(adj_t, SparseTensor):
            adj_t = adj_t.set_value(None, layout=None)
        return spmm(adj_t, x[0], reduce=self.aggr)


# Initialize model
class GNN(torch.nn.Module):
    '''
    Graph Neural Network
    '''
    def __init__(self, N_fl1, N_mpl, N_fl2, N_fl3):
        super(GNN, self).__init__()
        self.pre = Linear(5, N_fl1)
        self.conv1 = SAGEConv_new(N_fl1, N_mpl, normalize=True)
        self.conv2 = SAGEConv_new(N_mpl, N_mpl, normalize=True)
        self.post1 = Linear(N_mpl, N_fl2)
        self.post2 = Linear(N_fl2, N_fl3)
        self.out = Linear(N_fl3, 1)

        # Spiking Neurons
        # Neuron 1
        beta_1 = torch.rand(32)
        thr_1 = torch.rand(32)*0.001
        self.lif_1 = snn.Leaky(beta = beta_1, learn_beta = True, threshold = thr_1, learn_threshold=True, reset_mechanism='zero')

        beta_2 = torch.rand(64)
        thr_2 = torch.rand(64)*0.001
        self.lif_2 = snn.Leaky(beta = beta_2, learn_beta = True, threshold = thr_2, learn_threshold=True, reset_mechanism='zero')

        beta_2a = torch.rand(64)
        thr_2a = torch.rand(64)*0.001
        self.lif_2a = snn.Leaky(beta = beta_2a, learn_beta = True, threshold = thr_2a, learn_threshold=True, reset_mechanism='zero')

        beta_3 = torch.rand(64)
        thr_3 = torch.rand(64)*0.001
        self.lif_3 = snn.Leaky(beta = beta_3, learn_beta = True, threshold = thr_3, learn_threshold=True, reset_mechanism='zero')

        beta_3a = torch.rand(64)
        thr_3a = torch.rand(64)*0.001
        self.lif_3a = snn.Leaky(beta = beta_3a, learn_beta = True, threshold = thr_3a, learn_threshold=True, reset_mechanism='zero')

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        mem_1 = self.lif_1.init_leaky()
        mem_2 = self.lif_2.init_leaky()
        mem_2a = self.lif_2a.init_leaky()
        mem_3 = self.lif_3.init_leaky()
        mem_3a = self.lif_3a.init_leaky()

        s1_sum = torch.zeros([1]).to(device)
        s2_sum = torch.zeros([1]).to(device)
        s2a_sum = torch.zeros([1]).to(device)
        s3_sum = torch.zeros([1]).to(device)
        s3a_sum = torch.zeros([1]).to(device)

        # Pre Processing Linear Layer
        # Replacing ReLU with spiking
        x1 = self.pre(x)
        spk_in1, mem_1 = self.lif_1(x1, mem_1)
        x = spk_in1*x1
        s1_sum[0] += torch.sum(spk_in1)/spk_in1.numel()

        # 1. Obtain node embeddings
        # Replacing ReLU with spiking
        x2a = self.conv1.forward_without_activation(x, edge_index)
        spk_in2a, mem_2a = self.lif_2a(x2a, mem_2a)
        x = spk_in2a*x2a
        s2a_sum[0] += torch.sum(spk_in2a)/spk_in2a.numel()
        
        x3a = self.conv2.forward_without_activation(x, edge_index)
        spk_in3a, mem_3a = self.lif_3a(x3a, mem_3a)
        x = spk_in3a*x3a
        s3a_sum[0] += torch.sum(spk_in3a)/spk_in3a.numel()

        # 2. Readout layer
        x = global_mean_pool(x, batch)
        
        x = F.relu(self.post1(x))

        x = F.relu(self.post2(x))

        x = self.out(x)
        return x, s1_sum, s2a_sum, s3a_sum


def init_model():
    '''
    Initialize model
    '''
    seed_all(seed)
    device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
    model = GNN(N_fl1, N_mpl, N_fl2, N_fl3).to(device)
    optimizer = torch.optim.Adam(
        model.parameters(), lr=l_rate, weight_decay=w_decay)
    return model, optimizer


def train(model, optimizer, train_loader, val_loader, n_epoch, prop, config, fold):
    '''
    Train GNN
    '''
    filename = f'{output_dir}/eval-{eval}_config-{config}_fold-{fold}_loss_history.txt'
    output = open(filename, "w")

    print('Epoch Training_MSE Validation_MSE', file=output, flush=True)

    seed_all(seed)
    for epoch in range(n_epoch):
        model.train()
        # Train batches
        for train_batch in train_loader:
            train_batch = train_batch.to(device)
            train_pred, s1_t, s2_t, s3_t = model(train_batch)
            train_true = getattr(train_batch, prop)
            train_loss = F.mse_loss(train_pred, train_true)
            train_loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        # Evaluate
        val_pred, val_true, s1_test, s2_test, s3_test = test(model, val_loader, prop)
        val_loss = F.mse_loss(val_pred, val_true)
        print(f'{epoch:d}, {train_loss:e}, {val_loss:e}', file=output, flush=True)
    return


def test(model, data_loader, prop):
    '''
    Test GNN
    '''
    seed_all(seed)
    model.eval()
    data = next(iter(data_loader)).to(device)
    pred, s1_test, s2_test, s3_test = model(data)
    true = getattr(data, prop)
    return pred, true, s1_test, s2_test, s3_test


if __name__ == '__main__':
    
    eval = 2
    prop = 'strength'
    config_dir = './'
    config = 0
    output_dir = './out/'
    seed = 42

    if not os.path.exists(config_dir):
        os.makedirs(config_dir)

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    config_name = config_dir + str(config) + '.json'
    with open(config_name, 'r') as h:
        params = json.load(h)

    l_rate = params['l_rate']
    w_decay = params['w_decay']
    n_epoch = params['n_epoch']
    b_size = params['b_size']
    N_fl1 = params['N_fl1']
    N_mpl = params['N_mpl']
    N_fl2 = params['N_fl2']
    N_fl3 = params['N_fl3']

    # Set seeds for complete reproducability
    seed_all(seed)

    # Define the model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    cases = ['Train (A-G) / Test (A-G)',
             'Train (A-G) / Test (H-L)']

    print('\n====== Configuration ======')
    print(f'Evaluation #{eval}:\t\t{cases[eval-2]}')
    print(f'Regression task:\t{prop}')
    print(f'Hyper-parameters :\t{config}.json')

# *************************************************************************** #
    print('\n====== Training / Testing ======')
    start = time.time()
    
    def seed_worker(worker_id):
        '''Seeding for DataLoaders'''
        worker_seed = torch.initial_seed() % 2**32
        np.random.seed(42)
        random.seed(42)

    # Load data
    train_loader = torch.load("data/train_dataset.pt",weights_only=False)
    
    #### Eval = 2 if you're running Evaluation-1 and Eval = 3 if you're running Evaluation-2
    if eval==2:
        test_loader = torch.load("data/test_dataset_2.pt",weights_only=False)
    else:
        test_loader = torch.load("data/test_dataset_3.pt",weights_only=False)

    # Define model and optimizer
    model, optimizer = init_model()
    
    #### Load the checkpoint file if you directly want results else train
#     model.load_state_dict(torch.load("out/ALL_LIF_eval-2_config-0_prop-strength_fold-NA_checkpoint_AFTER_REVISION_NO_NORM.pth",map_location=torch.device('cpu')))
    model.eval()

    #Train model
    train(model, optimizer, train_loader, test_loader,
          n_epoch, prop, config, 'NA')

    # Test Model
    preds, trues, s1_test, s2_test, s3_test = test(model, test_loader, prop)

    # Save model
    # torch.save(model.state_dict(), f"{output_dir}/ALL_LIF_eval-{eval}_config-{config}_prop-{prop}_fold-{'NA'}_checkpoint_AFTER_REVISION_NO_NORM.pth")

    print(f'Processing time: {time.time()-start:.2f} seconds')
# *************************************************************************** #
    # Report and Visualize predictions
    
    with open("out/stiffness_scaler.pickle", "rb") as f:
        stiffness_scaler = pickle.load(f)

    with open("out/strength_scaler.pickle", "rb") as f:
        strength_scaler = pickle.load(f)

    with open("out/x_scaler.pickle", "rb") as f:
        x_scaler = pickle.load(f)
    
    scaler = {}
    scaler['stiffness'] = stiffness_scaler
    scaler['strength'] = strength_scaler
    scaler['x'] = x_scaler

    print('\n====== RESULTS ======')
    preds = scaler[prop].inverse_transform(
        preds.detach().detach().cpu().numpy())
    trues = scaler[prop].inverse_transform(
        trues.detach().detach().cpu().numpy())
    meanARE, maxARE = mean_maxARE(preds, trues)

    print("Spiking activity is :" )
    print(s1_test)
    print(s2_test)
    print(s3_test)
    print(s4_test)
    print(s5_test)

    print(f'(MeanARE, MaxARE):\t({meanARE}, {maxARE})')

    def plot_results(preds, trues, output_dir, eval, config, prop):

        if prop == 'strength':
            preds = preds*1000
            trues = trues*1000

        
        '''Plot evaluation results
        '''
        sns.set(font_scale=1.75)
        sns.set_style("ticks")
        # fig, ax = plt.subplots(figsize=(8.5, 5.5), dpi=300)
        fig, ax = plt.subplots()
    
        minColor = 0.4
        maxColor = 1.00
        if prop == 'strength':
            cmap = truncate_colormap(plt.get_cmap("Greens"), minColor, maxColor)
        else:
            cmap = truncate_colormap(plt.get_cmap("Blues"), minColor, maxColor)
        col = mcolors.to_hex(cmap(0.5))
    
        if eval != 2:
            x = np.squeeze(trues)
            y = np.squeeze(preds)
            xy = np.vstack([x, y])
            z = gaussian_kde(xy)(xy)
            # Sort the points by density, so that the densest points are plotted last
            idx = z.argsort()
            x, y, z = x[idx], y[idx], z[idx]
    
            plt.scatter(x,
                        y,
                        c=z,
                        s=20,
                        cmap=cmap)
        else:
            plt.scatter(trues,
                        preds,
                        s=20,
                        ec='k',
                        lw=0.5,
                        color=col)
        if prop == 'strength':
            plt.xlabel('True strength (MPa)')
            plt.ylabel('Predicted strength (MPa)')
            plt.xlim([700, 1220])
            plt.ylim([700, 1220])
            plt.plot([700, 1220], [700, 1220], '-k', linewidth=2)
        else:
            plt.xlabel('True modulus (GPa)')
            plt.ylabel('Predicted modulus (GPa)')
            plt.xlim([110, 152])
            plt.ylim([110, 152])
            plt.plot([110, 152], [110, 152], '-k', linewidth=2)
    
        ax.set_aspect(1.0/ax.get_data_ratio(), adjustable='box')
    
        plt.savefig(f'/mnt/sdb1/graphspiking/graph_spiking/PolyGRAPH-main/REVISION_RESULTS/ALL_LIF_eval-{eval}_prop-{prop}_config-{config}.parity.png', dpi=300, bbox_inches="tight")
        
    plot_results(preds/1000, trues/1000, output_dir, eval, config, prop)


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.2 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "C:\ProgramData\Anaconda3\lib\runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "C:\ProgramData\Anaconda3\lib\runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "C:\ProgramData\Anaconda3\lib\site-packages\ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "C:\ProgramData\Anaconda3\lib\site-packages\traitlets\config\application.py", line 846, in launch_instance
    app.start()
  File "C:\ProgramData\Anaconda3\lib\site-pack

AttributeError: _ARRAY_API not found

ImportError: numpy.core.multiarray failed to import

MSE Values

In [12]:
pred_norm = scaler[prop].transform(
        preds)
true_norm = scaler[prop].transform(
        trues)

np.mean((preds-trues)**2), np.mean((pred_norm-true_norm)**2)*1000, pred_norm.shape, true_norm.shape

(170.39403, 10.246600955724716, (70, 1), (70, 1))