In [None]:
import os, torch
from sklearn.model_selection import train_test_split
import pickle
import torch_geometric.transforms as T
import numpy as np
from torch_geometric.nn.models import Node2Vec
from torch_geometric.data import DataLoader
from torch_geometric.nn import MessagePassing
from torch_geometric.data import Data
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
GCNConv._orig_propagate = GCNConv.propagate

import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from torch_geometric.explain import GNNExplainer, Explainer


epochs = int(os.getenv("EPOCHS", 10))  # Default to 10 if not provided
learning_rate = float(os.getenv("LEARNING_RATE", 0.001))  # Default to 0.001
hidden_c = int(os.getenv("HIDDEN_C", 16))  # Default to 16
random_seed = int(os.getenv("RANDOM_SEED", 42))  # Default to 42
bins = [int(i) for i in os.getenv("BINS", "400 800 1300 2100 3000 3700 4700 7020 9660").split(' ')]  # Default to [1000, 3000, 5000]
num_layers = int(os.getenv("NUM_LAYERS", 5))  # Default to 5
nh = int(os.getenv("NUM_HEADS", 10))
gat = int(os.getenv("GAT", 0))
api_key = os.getenv("API_KEY", None)
graph_num = os.getenv("GRAPH_NUM", 2)
dropout_p = float(os.getenv("DROPOUT", 0.5))

bins = torch.tensor(bins, device='cuda' if torch.cuda.is_available() else 'cpu')

graph_num = 17

model_name = 'deft-plant-303'  # Replace with your model name
weight_prefix = 'best_accuracy'  # Replace with your weight prefix

if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"Using CUDA device: {torch.cuda.get_device_name(0)}", flush = True)
else:
    device = torch.device('cpu')
    print("Using CPU", flush = True)

with open(f'../data/graphs/{graph_num}/linegraph_tg.pkl', 'rb') as f:
    data = pickle.load(f)

data.edge_index = data.edge_index.contiguous()
data.x = data.x.contiguous()
data.y = data.y.contiguous()

# Define or import the GCN class
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        torch.manual_seed(random_seed)
        self.conv1 = GCNConv(data.num_features, hidden_channels, improved = True, cached = True)
        conv2_list = []
        hc = hidden_channels
        # for _ in range(num_layers):
        #     conv2_list.append(
        #         GCNConv(hc, hc)
        #     )
            # hc //= 2
        # self.conv2 = torch.nn.ModuleList(conv2_list)
        self.conv3 = GCNConv(hc, len(bins) + 1, cached = True)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=dropout_p, training=self.training)
        # for conv in self.conv2:
        #     x = conv(x, edge_index)
        #     x = F.relu(x)
        #     x = F.dropout(x, p=dropout_p, training=self.training)
        x = self.conv3(x, edge_index)
        return x


# Load the model with the GCN class
model = torch.load(f'../data/graphs/{graph_num}/models/{model_name}.pt', map_location=device)
model = model.to(device)

model.load_state_dict(torch.load(f'../data/graphs/{graph_num}/models/{model_name}_{weight_prefix}.pt', map_location=device))


In [None]:
def stratified_split(data, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    """Splits data into train, validation, and test sets, stratifying by y > 0."""

    # Create a boolean mask for nodes where y > 0
    positive_mask = data.y > 0

    # Get indices of positive and negative nodes
    positive_indices = positive_mask.nonzero(as_tuple=False).squeeze()
    negative_indices = (~positive_mask).nonzero(as_tuple=False).squeeze()

    # Split positive indices
    pos_train_idx, pos_temp_idx = train_test_split(positive_indices, train_size=train_ratio, random_state=random_seed)  # Adjust random_state for consistent splits
    pos_val_idx, pos_test_idx = train_test_split(pos_temp_idx, test_size=(test_ratio / (val_ratio + test_ratio)), random_state=random_seed)

    # Split negative indices
    neg_train_idx, neg_temp_idx = train_test_split(negative_indices, train_size=train_ratio, random_state=random_seed)
    neg_val_idx, neg_test_idx = train_test_split(neg_temp_idx, test_size=(test_ratio / (val_ratio + test_ratio)), random_state=random_seed)

    # Combine indices
    train_idx = torch.cat([pos_train_idx, neg_train_idx])
    val_idx = torch.cat([pos_val_idx, neg_val_idx])
    test_idx = torch.cat([pos_test_idx, neg_test_idx])

    # Create masks
    train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)

    train_mask[train_idx] = True
    val_mask[val_idx] = True
    test_mask[test_idx] = True

    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask

    return data

data.edge_index = data.edge_index.contiguous()
data.x = data.x.contiguous()
data.y = data.y.contiguous()

print(data.x.shape, data.edge_index.shape, data.y.shape, flush = True)

data = stratified_split(data)


In [None]:
from torch_geometric.explain import GNNExplainer, Explainer

explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=1),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type=None,
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='log_probs',
    ),
)


In [None]:
mask = data.val_mask.squeeze() & (data.y > 0).squeeze()

node_idx = torch.nonzero(mask, as_tuple=True)[0]


In [None]:
out = model(data.x.to(device), data.edge_index.to(device))
pred = out.argmax(dim=1)


In [None]:
used_feats = []
scores = {}


for i in node_idx:
    # Input data must include x and edge_index, and optionally y
    explanation = explainer(data.x.to(device), data.edge_index.to(device), index=i)
    # Ensure node_mask is 2D
    node_mask = explanation.node_mask
    if node_mask.dim() == 1:
        node_mask = node_mask.unsqueeze(0)
    elif node_mask.dim() == 3:
        node_mask = node_mask.squeeze(0)
    

    curr_pred = pred[i].item()
    target = data.y[i].item()
    target = int(torch.bucketize(target, bins))

    # Sum across nodes (or use first if only one node)
    score = node_mask.sum(dim=0).detach().cpu().numpy()
    score = score.flatten()  # ensure 1D

    # Ensure labels are native Python list (not np array or tensor)
    feat_labels = [f"feat_{i}" for i in range(score.shape[0])]
    top10 = np.argsort(score)[::-1][:10]  # Get indices of top 10 features
    top10_score = score[top10]  # Get scores of top 10 features
    print(top10[top10_score > 0], top10_score[top10_score > 0], '\n PRedicted class:' ,curr_pred, 'actual class:', target, '\n---------------------------------------', flush = True)
    for i in top10:
        used_feats.append(i)
        if i not in scores:
            scores[i] = []
        scores[i].append(score[i])


In [None]:
from collections import Counter
Counter(used_feats).most_common(100)


In [None]:
for i in scores:
    scores[i] = np.mean(scores[i])


In [None]:
scores = dict(sorted(scores.items(), key=lambda item: item[1], reverse=True))


In [None]:
scores
