In [None]:
import os, torch
from sklearn.model_selection import train_test_split
import pickle
import torch_geometric.transforms as T
import numpy as np
from torch_geometric.nn.models import Node2Vec
from torch_geometric.data import DataLoader
from torch_geometric.nn import MessagePassing
from torch_geometric.data import Data
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv


model_name = 'old-womprat-13'  # Replace with your model name
# model_name = 'holographic-master-15'  # Replace with your model name
graph_num = 17 
weights_prefix = 'best_loss'
random_seed =  100
bins = [int(i) for i in "400 800 1300 2100 3000 3700 4700 7020 9660".split(' ')] 
dropout_p =  0.5
epochs = 100


if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"Using CUDA device: {torch.cuda.get_device_name(0)}", flush = True)
else:
    device = torch.device('cpu')
    print("Using CPU", flush = True)

bins = torch.tensor(bins, device=device)

### load graph data

with open(f'../data/graphs/{graph_num}/linegraph_tg.pkl', 'rb') as f:
    data = pickle.load(f)

def stratified_split(data, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    """Splits data into train, validation, and test sets, stratifying by y > 0."""

    # Create a boolean mask for nodes where y > 0
    positive_mask = data.y > 0

    # Get indices of positive and negative nodes
    positive_indices = positive_mask.nonzero(as_tuple=False).squeeze()
    negative_indices = (~positive_mask).nonzero(as_tuple=False).squeeze()

    # Split positive indices
    pos_train_idx, pos_temp_idx = train_test_split(positive_indices, train_size=train_ratio, random_state=random_seed)  # Adjust random_state for consistent splits
    pos_val_idx, pos_test_idx = train_test_split(pos_temp_idx, test_size=(test_ratio / (val_ratio + test_ratio)), random_state=random_seed)

    # Split negative indices
    neg_train_idx, neg_temp_idx = train_test_split(negative_indices, train_size=train_ratio, random_state=random_seed)
    neg_val_idx, neg_test_idx = train_test_split(neg_temp_idx, test_size=(test_ratio / (val_ratio + test_ratio)), random_state=random_seed)

    # Combine indices
    train_idx = torch.cat([pos_train_idx, neg_train_idx])
    val_idx = torch.cat([pos_val_idx, neg_val_idx])
    test_idx = torch.cat([pos_test_idx, neg_test_idx])

    # Create masks
    train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)

    train_mask[train_idx] = True
    val_mask[val_idx] = True
    test_mask[test_idx] = True

    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask

    return data

data.edge_index = data.edge_index.contiguous()
data.x = data.x.contiguous()
data.y = data.y.contiguous()

print(data.x.shape, data.edge_index.shape, data.y.shape, flush = True)

data = stratified_split(data)


In [None]:
# class GCN(torch.nn.Module):
#     def __init__(self, hidden_channels):
#         super().__init__()
#         torch.manual_seed(random_seed)
#         self.conv1 = GCNConv(data.num_features, hidden_channels, improved = True, cached = True)
#         conv2_list = []
#         hc = hidden_channels
#         # for _ in range(num_layers):
#         #     conv2_list.append(
#         #         GCNConv(hc, hc)
#         #     )
#             # hc //= 2
#         # self.conv2 = torch.nn.ModuleList(conv2_list)
#         self.conv3 = GCNConv(hc, len(bins) + 1, cached = True)

#     def forward(self, x, edge_index):
#         x = self.conv1(x, edge_index)
#         x = F.relu(x)
#         x = F.dropout(x, p=dropout_p, training=self.training)
#         # for conv in self.conv2:
#         #     x = conv(x, edge_index)
#         #     x = F.relu(x)
#         #     x = F.dropout(x, p=dropout_p, training=self.training)
#         x = self.conv3(x, edge_index)
#         return x


In [None]:
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers):
        super().__init__()
        torch.manual_seed(random_seed)

        self.input_layer = GCNConv(data.num_features, hidden_channels, improved=True, cached=True)

        # Create intermediate hidden layers (optional)
        self.hidden_layers = torch.nn.ModuleList()
        for _ in range(num_layers):
            self.hidden_layers.append(GCNConv(hidden_channels, hidden_channels, improved=True, cached=True))

        self.output_layer = GCNConv(hidden_channels, len(bins) + 1, cached=True)

    def forward(self, x, edge_index):
        x = self.input_layer(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=dropout_p, training=self.training)

        for layer in self.hidden_layers:
            x = layer(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=dropout_p, training=self.training)

        x = self.output_layer(x, edge_index)
        return x


In [None]:

weights_path = f'../data/graphs/{graph_num}/models/{model_name}_{weights_prefix}.pt'

model = torch.load(f'../data/graphs/{graph_num}/models/{model_name}.pt', map_location=device)
### load weights onto model
model.load_state_dict(torch.load(weights_path, map_location=device))
model = model.to(device)
model.eval()


In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-3)
criterion = torch.nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=epochs//100)


In [None]:

def test():
    model.eval()
    data.x = data.x.to(device)
    data.edge_index = data.edge_index.to(device)
    data.val_mask = data.val_mask.to(device)
    data.y = data.y.to(device)
    mask = data.val_mask.squeeze() & (data.y > 0).squeeze()
    out = model(data.x, data.edge_index)
    target = torch.bucketize(data.y[mask], bins).squeeze()
    loss = criterion(out[mask], target.long())  # Ensure target is 1D and long
    correct_preds = out[mask].argmax(dim=1)
    correct = (correct_preds == target).sum()
    accuracy = correct.item() / mask.sum().item()
    return accuracy, out, loss


In [None]:

# def test():
#     model.eval()
#     data.x = data.x.to(device)
#     data.edge_index = data.edge_index.to(device)
#     data.test_mask = data.test_mask.to(device)
#     data.y = data.y.to(device)
#     mask = data.test_mask.squeeze() & (data.y > 0).squeeze()
#     out = model(data.x, data.edge_index)
#     target = torch.bucketize(data.y[mask], bins).squeeze()
#     loss = criterion(out[mask], target.long())  # Ensure target is 1D and long
#     correct_preds = out[mask].argmax(dim=1)
#     correct = (correct_preds == target).sum()
#     accuracy = correct.item() / mask.sum().item()
#     return accuracy, out, loss


In [None]:
acc, out, loss = test()


In [None]:
mask = data.val_mask.squeeze() & (data.y > 0).squeeze()
target = torch.bucketize(data.y[mask], bins).squeeze()

pred = out[mask].argmax(dim=1)
pred = pred.cpu().numpy()


In [None]:
with open('predictions.csv', 'w') as f:
    f.write('predicted,target\n')
    for p, t in zip(pred, target):
        f.write(f'{p},{t}\n')
print(f"Accuracy: {acc:.4f}", flush = True)
