In [None]:
import os, torch
from sklearn.model_selection import train_test_split
import pickle
import torch_geometric.transforms as T
import numpy as np
from torch_geometric.nn.models import Node2Vec
from torch_geometric.data import DataLoader
from torch_geometric.nn import MessagePassing
from torch_geometric.data import Data
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
GCNConv._orig_propagate = GCNConv.propagate

import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from torch_geometric.explain import GNNExplainer, Explainer

epochs = int(os.getenv("EPOCHS", 5000))  # Default to 10 if not provided
learning_rate = float(os.getenv("LEARNING_RATE", 0.0001))  # Default to 0.001
hidden_c = int(os.getenv("HIDDEN_C", 350))  # Default to 16
random_seed = int(os.getenv("RANDOM_SEED", 100))  # Default to 42
bins = [int(i) for i in os.getenv("BINS", "400 800 1300 2100 3000 3700 4700 7020 9660").split(' ')]  # Default to [1000, 3000, 5000]
# bins = [int(i) for i in os.getenv("BINS", "3000").split(' ')]  # Default to [1000, 3000, 5000]
num_layers = int(os.getenv("NUM_LAYERS", 0))  # Default to 5
nh = int(os.getenv("NUM_HEADS", 1))
gat = int(os.getenv("GAT", 1))
api_key = os.getenv("API_KEY", None)
graph_num = os.getenv("GRAPH_NUM", 26)
dropout_p = float(os.getenv("DROPOUT", 0.5))

bins = torch.tensor(bins, device='cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"Using CUDA device: {torch.cuda.get_device_name(0)}", flush = True)
else:
    device = torch.device('cpu')
    print("Using CPU", flush = True)


with open(f'../data/graphs/{graph_num}/linegraph_tg.pkl', 'rb') as f:
    data = pickle.load(f)

data.edge_index = data.edge_index.contiguous()
data.x = data.x.contiguous()
data.y = data.y.contiguous()


# --- Model Definitions ---
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers):
        super().__init__()
        torch.manual_seed(random_seed)

        self.input_layer = GCNConv(data.num_features, hidden_channels, improved=True, cached=True)

        # Create intermediate hidden layers (optional)
        self.hidden_layers = torch.nn.ModuleList()
        for _ in range(num_layers):
            self.hidden_layers.append(GCNConv(hidden_channels, hidden_channels, improved=True, cached=True))

        self.output_layer = GCNConv(hidden_channels, len(bins) + 1, cached=True)

    def forward(self, x, edge_index):
        x = self.input_layer(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=dropout_p, training=self.training)

        for layer in self.hidden_layers:
            x = layer(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=dropout_p, training=self.training)

        x = self.output_layer(x, edge_index)
        return x

class GAT(torch.nn.Module):
    def __init__(self,hidden_channels, num_layers, num_heads):
        super().__init__()
        torch.manual_seed(random_seed)  # Replace with your desired seed

        self.convs = torch.nn.ModuleList()

        # Input layer
        self.convs.append(GATConv(data.num_features, hidden_channels, heads=num_heads, concat=True))

        # Hidden layers
        for _ in range(num_layers):
            self.convs.append(GATConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, concat=True))

        # Output layer
        if bins != 'regression':
            self.convs.append(GATConv(hidden_channels * num_heads, len(bins) + 1, heads=1, concat=False))
        else:
            self.convs.append(GATConv(hidden_channels * num_heads, 1, heads=1, concat=False))

    def forward(self, x, edge_index):
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.elu(x)
            x = F.dropout(x, p=dropout_p, training=self.training)  # Adjust dropout probability as needed

        x = self.convs[-1](x, edge_index)
        return x

In [None]:
def stratified_split(data, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    positive_mask = data.y > 0

    # Convert indices to numpy arrays
    positive_indices = positive_mask.nonzero(as_tuple=False).squeeze().cpu().numpy()
    negative_indices = (~positive_mask).nonzero(as_tuple=False).squeeze().cpu().numpy()

    # Split positive indices
    pos_train_idx, pos_temp_idx = train_test_split(positive_indices, train_size=train_ratio, random_state=random_seed)
    pos_val_idx, pos_test_idx = train_test_split(pos_temp_idx, test_size=(test_ratio / (val_ratio + test_ratio)), random_state=random_seed)

    # Split negative indices
    neg_train_idx, neg_temp_idx = train_test_split(negative_indices, train_size=train_ratio, random_state=random_seed)
    neg_val_idx, neg_test_idx = train_test_split(neg_temp_idx, test_size=(test_ratio / (val_ratio + test_ratio)), random_state=random_seed)

    # Concatenate and convert back to torch tensors
    train_idx = torch.from_numpy(np.concatenate([pos_train_idx, neg_train_idx])).long()
    val_idx = torch.from_numpy(np.concatenate([pos_val_idx, neg_val_idx])).long()
    test_idx = torch.from_numpy(np.concatenate([pos_test_idx, neg_test_idx])).long()

    # Create masks
    data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    data.val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)

    data.train_mask[train_idx] = True
    data.val_mask[val_idx] = True
    data.test_mask[test_idx] = True

    return data


data.edge_index = data.edge_index.contiguous()
data.x = data.x.contiguous()
data.y = data.y.contiguous()

print(data.x.shape, data.edge_index.shape, data.y.shape, flush = True)

data = stratified_split(data)
