# Classification of Cora dataset with Noderformer. Trainformer is training with drop in loss. Complete with testing loop ✅ ✅ ✅ 

In [1]:

import sys
import os, random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import  remove_self_loops, add_self_loops
from sklearn.neighbors import kneighbors_graph
from utils  import *
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.data import Data
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import dropout_adj, to_undirected, to_networkx
import time
from ortools.sat.python import cp_model
import pandas as pd


import warnings
warnings.filterwarnings('ignore')


# NOTE: for consistent data splits, see data_utils.rand_train_test_idx
def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


fix_seed(42)

device = torch.device("cpu")

# Function to count violations

In [2]:
def count_same_color_edges(graph, colors, num_classes):
    count = 0
    if colors != 'no_solution':
        for u, v in graph.edges():
            if colors[u] >= num_classes - 1:
                count += 1
            elif colors[v] >= num_classes - 1:
                count += 1
            elif colors[u] == colors[v]:
                count += 1
    else:
        count = np.nan
    return count

# Function to make graph using ORTools

In [3]:
def color_graph_with_ortools(G, chromatic_number):
    # Create the model
    model = cp_model.CpModel()

    # Create variables
    # For graph coloring, each node is a variable and the colors are their domains
    color_vars = {
        node: model.NewIntVar(0, chromatic_number - 1, f'node_{node}')
        for node in G.nodes
    }

    # Create constraints
    # Adjacent nodes must have different colors
    for edge in G.edges:
        model.Add(color_vars[edge[0]] != color_vars[edge[1]])

    # Create a solver and solve the model
    solver = cp_model.CpSolver()
    status = solver.Solve(model)

    if status == cp_model.OPTIMAL or status == cp_model.FEASIBLE:
        # If a solution exists, extract the colors
        solution = {node: solver.Value(var) for node, var in color_vars.items()}
 
        return solution
    else:
        # If the problem could not be solved, raise an error
        solution =  'no_solution'
        return solution
    


# Function to color Graphs using Greedy method

In [4]:
def color_graph(G, k):
    # Initialize the result dictionary with all vertices uncolored initially
    coloring = {node: None for node in G.nodes()}

    # Go through each node and assign the first available color
    for node in G.nodes():
        neighbor_colors = {coloring[neighbor] for neighbor in G.neighbors(node)}
        available_colors = set(range(len(G))) - neighbor_colors

        # Assign the first available color
        for color in range(len(G)):
            if color in available_colors:
                coloring[node] = color
                break

    return coloring

# Function to color graphs using DSATUR

In [5]:
def dsatur(graph):
    # Initialize color assignment and saturation degrees
    colors = {node: None for node in graph.nodes()}
    saturation_degrees = {node: 0 for node in graph.nodes()}

    # Utility function to update saturation degrees
    def update_saturation(node):
        adjacent_colors = set(colors[neighbor] for neighbor in graph.neighbors(node) if colors[neighbor] is not None)
        saturation_degrees[node] = len(adjacent_colors)

    # Color assignment loop
    while None in colors.values():
        # Select the uncolored node with the highest saturation degree, breaking ties by degree
        node_to_color = max(
            (node for node in graph.nodes() if colors[node] is None),
            key=lambda n: (saturation_degrees[n], len(list(graph.neighbors(n))))
        )

        # Find the first available color
        used_colors = set(colors[neighbor] for neighbor in graph.neighbors(node_to_color) if colors[neighbor] is not None)
        color = 0
        while color in used_colors:
            color += 1

        colors[node_to_color] = color

        # Update saturation degrees of adjacent nodes
        for neighbor in graph.neighbors(node_to_color):
            if colors[neighbor] is None:
                update_saturation(neighbor)

    return colors

# Decoder

In [6]:
def decoder(x_final):

    max_args = torch.argmax(x_final, dim = 1)
    gnn_sol = {index: value.item() for index, value in enumerate(max_args)}

    return gnn_sol

# Loading and splitting data

In [7]:
def read_graph(filepath):
    with open(filepath,"r") as f:

        line = ''

        # Parse number of vertices
        while 'DIMENSION' not in line: line = f.readline();
        n = int(line.split()[1])
        
        Ma = np.zeros((n,n),dtype=int)
        
        # Parse edges
        while 'EDGE_DATA_SECTION' not in line: line = f.readline();
        line = f.readline()
        while '-1' not in line:
            i,j = [ int(x) for x in line.split() ]
            Ma[i,j] = 1
            line = f.readline()
        #end while

        # Parse diff edge
        while 'DIFF_EDGE' not in line: line = f.readline();
        diff_edge = [ int(x) for x in f.readline().split() ]

        # Parse target cost
        while 'CHROM_NUMBER' not in line: line = f.readline();
        chrom_number = int(f.readline().strip())

    #end
    return Ma,chrom_number,diff_edge


def get_data(num_classes, directory):
    Mas = []
    chrom_numbers = []
    diff_edges = []
    torch_dats = []
    nx_graphs = []

    for root, dirs, files in os.walk(directory):
        for filename in files:
            filepath = os.path.join(root, filename)
            Ma,chrom_number,diff_edge = read_graph(filepath)

            adj_matrix = np.array(Ma)  # Replace with your n*n adjacency matrix

            # Convert adjacency matrix to edge list
            edge_index = np.array(adj_matrix.nonzero(), dtype=np.int64)
            edge_index = torch.tensor(edge_index, dtype=torch.long)

            n = Ma.shape[0]
            train_mask = torch.cat((torch.ones(n // 3, dtype=torch.bool), torch.zeros(2 * n // 3, dtype=torch.bool)))
            val_mask = torch.cat((torch.zeros(n // 3, dtype=torch.bool), torch.ones(n // 3, dtype=torch.bool), torch.zeros(n // 3, dtype=torch.bool)))
            test_mask = torch.cat((torch.zeros(2 * n // 3, dtype=torch.bool), torch.ones(n // 3, dtype=torch.bool)))


            # Create dummy node features (e.g., one-hot encoding)
            num_nodes = adj_matrix.shape[0]
            node_features = torch.ones(num_nodes, num_classes)  

            # Create PyTorch Geometric data object
            data = Data(x=node_features, edge_index=edge_index)
            data.y = torch.ones(n)
            data.train_mask = train_mask
            data.val_mask = val_mask
            data.test_mask = test_mask
            

            # Creating networkx graphs
            nx_graph = nx.from_numpy_array(adj_matrix)


            Mas.append(Ma)
            chrom_numbers.append(chrom_number)
            diff_edges.append(diff_edge)
            torch_dats.append(data)
            nx_graphs.append(nx_graph)

    chrom_list = []
    for data,k, nx_graph in zip(torch_dats, chrom_numbers, nx_graphs):
        chrom_list.append((data, k, nx_graph))

    k_list = [item for item in chrom_list if item[1] == num_classes]

    # Keep only the first 250 elements of the filtered list
    k_list = k_list[:170]

    
    data = [item[0] for item in k_list]

    return data


In [8]:
class NCDataset(object):
    def __init__(self, num_nodes, num_classes):

        self.name = str(num_nodes) + '_' +str(num_classes)  # original name, e.g., ogbn-proteins
        self.graph = {}
        self.label = None

    def get_idx_split(self,  train_prop=.5, valid_prop=.25):
        """
        split_type: 'random' for random splitting, 'class' for splitting with equal node num per class
        train_prop: The proportion of dataset for train split. Between 0 and 1.
        valid_prop: The proportion of dataset for validation split. Between 0 and 1.
        label_num_per_class: num of nodes per class
        """

        ignore_negative = False if self.name == 'ogbn-proteins' else True
        train_idx, valid_idx, test_idx = rand_train_test_idx(
            self.label, train_prop=train_prop, valid_prop=valid_prop, ignore_negative=ignore_negative)
        split_idx = {'train': train_idx,
                        'valid': valid_idx,
                        'test': test_idx}

        return split_idx

    def __getitem__(self, idx):
        assert idx == 0, 'This dataset has only one graph'
        return self.graph, self.label

    def __len__(self):
        return 1

    def __repr__(self):
        return '{}({})'.format(self.__class__.__name__, len(self))
    
def load_dataset( num_nodes, num_classes, ind):

    if num_nodes == 50:
        directory = 'chromatic_graphs/n50'
    elif num_nodes == 60:
        directory = 'chromatic_graphs/n60'
    elif num_nodes == 40:
        directory = 'chromatic_graphs/n40'
    elif num_nodes == 30:
        directory = 'chromatic_graphs/n30'

    data = get_data(num_classes, directory)
    data = data[ind]


    edge_index = data.edge_index
    node_feat = data.x
    label = data.y
    num_nodes = data.num_nodes

    dataset = NCDataset(num_nodes, num_classes)


    dataset.graph = {'edge_index': edge_index,
                     'node_feat': node_feat,
                     'edge_feat': None,
                     'num_nodes': num_nodes}
    dataset.label = label

    return dataset, data
    

# Making the model

In [9]:
class NodeFormer(nn.Module):

    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, num_heads=4, dropout=0.9,
                 kernel_transformation=softmax_kernel_transformation, nb_random_features=30, use_bn=True, use_gumbel=True,
                 use_residual=True, use_act=False, use_jk=False, nb_gumbel_sample=10, rb_order=0, rb_trans='sigmoid', use_edge_loss=True):
        super(NodeFormer, self).__init__()

        self.convs = nn.ModuleList()
        self.fcs = nn.ModuleList()
        self.fcs.append(nn.Linear(in_channels, hidden_channels))
        self.bns = nn.ModuleList()
        self.bns.append(nn.LayerNorm(hidden_channels))
        for i in range(num_layers):
            self.convs.append(
                NodeFormerConv(hidden_channels, hidden_channels, num_heads=num_heads, kernel_transformation=kernel_transformation,
                              nb_random_features=nb_random_features, use_gumbel=use_gumbel, nb_gumbel_sample=nb_gumbel_sample,
                               rb_order=rb_order, rb_trans=rb_trans, use_edge_loss=use_edge_loss))
            self.bns.append(nn.LayerNorm(hidden_channels))

        if use_jk:
            self.fcs.append(nn.Linear(hidden_channels * num_layers + hidden_channels, out_channels))
        else:
            self.fcs.append(nn.Linear(hidden_channels, out_channels))

        self.dropout = dropout
        self.activation = F.relu
        self.use_bn = use_bn
        self.use_residual = use_residual
        self.use_act = use_act
        self.use_jk = use_jk
        self.use_edge_loss = use_edge_loss

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()
        for fc in self.fcs:
            fc.reset_parameters()



    def forward(self, x, edge_index, adjs, tau=1.0):

        row, col = edge_index

        x = x.unsqueeze(0) # [B, N, H, D], B=1 denotes number of graph
        layer_ = []
        link_loss_ = []
  
        z = self.fcs[0](x)     
        if self.use_bn:
            z = self.bns[0](z)
        z = self.activation(z)
        z = F.dropout(z, p=self.dropout, training=self.training)

        layer_.append(z)

        for i, conv in enumerate(self.convs):
            if self.use_edge_loss:
                z, link_loss = conv(z, adjs, tau)
                link_loss_.append(link_loss)
            else:
                z = conv(z, adjs, tau)
            if self.use_residual:
                z += layer_[i]
            if self.use_bn:
                z = self.bns[i+1](z)
            if self.use_act:
                z = self.activation(z)
            z = F.dropout(z, p=self.dropout, training=self.training)
            layer_.append(z)

        if self.use_jk: # use jk connection for each layer
            z = torch.cat(layer_, dim=-1)
             
        x_out = self.fcs[-1](z).squeeze(0)       
        x_softmax = F.softmax(100*x_out, dim=1)

        # Vectorized indexing
        pi = x_softmax[col.long()]  # Assuming no_loop_col contains integer values
        pj = x_softmax[row.long()]  # Assuming no_loop_row contains integer values  
        prod = torch.mul(pi, pj)
        loss = torch.sum(prod)

        return x_out, loss


In [10]:
### Load method ###
hidden_channels = 32
num_layers = 2
dropout = 0.9
num_heads = 64
use_bn = True
M = 30
use_gumbel = True
use_residual = True
use_act = True
use_jk = True
K = 10
rb_order = 0
rb_trans = 'sigmoid'

lr = 0.005
num_epochs = 350
num_nodes = 60 # 30, 40, 50, 60
num_classes = 4 # 4 or 5


In [11]:
num_nodes_list  = []
num_classes_list = []
indices = []
gnn_violations = []
cp_violations = []
greedy_violations = []
dsatur_violations = []

for num_nodes in [30, 40, 50, 60]:
    for num_classes in [4, 5]:
        for i in range(25):

            dataset, raw_data = load_dataset(num_nodes, num_classes, i)

            ### Basic information of datasets ###
            n = dataset.graph['num_nodes']
            e = dataset.graph['edge_index'].shape[1]
            # infer the number of classes for non one-hot and one-hot labels
            c = num_classes
            d = dataset.graph['node_feat'].shape[1]

            dataset.graph['edge_index'], dataset.graph['node_feat'] = \
                dataset.graph['edge_index'].to(device), dataset.graph['node_feat'].to(device)


            model=NodeFormer(d, hidden_channels, c, num_layers=num_layers, dropout=dropout,
                        num_heads=num_heads, use_bn=use_bn, nb_random_features=M,
                        use_gumbel=use_gumbel, use_residual=use_residual, use_act=use_act, use_jk=use_jk,
                        nb_gumbel_sample=K, rb_order=rb_order, rb_trans=rb_trans).to(device)

            model.train()

            ### Adj storage for relational bias ###
            adjs = []
            adj, _ = remove_self_loops(dataset.graph['edge_index'])
            adj, _ = add_self_loops(adj, num_nodes=n)
            adjs.append(adj)
            for i in range(0 - 1): # edge_index of high order adjacency
                adj = adj_mul(adj, adj, n)
                adjs.append(adj)
            dataset.graph['adjs'] = adjs

            model.reset_parameters()
            optimizer = torch.optim.Adam(model.parameters(),weight_decay=5e-3, lr = lr)

            losses = []
            node_feats = []

            for epoch in range(num_epochs):
                model.train()
                optimizer.zero_grad()

                out, loss = model(dataset.graph['node_feat'],dataset.graph['edge_index'] ,dataset.graph['adjs'], 0.25)

                losses.append(loss.item())

                loss.backward()
                optimizer.step()

                #print('epoch ', epoch, ' loss ', loss.item())

            graph = to_networkx(raw_data, to_undirected= True)
            x_final, loss = model(dataset.graph['node_feat'],dataset.graph['edge_index'] ,dataset.graph['adjs'], 0.25)
            colors = decoder(x_final)
            same_color_edge_count = count_same_color_edges(graph, colors, num_classes)
            edges = graph.number_of_edges()
            #print(same_color_edge_count, ' out of ', edges)
            gnn_violation = 100*same_color_edge_count/edges

            ortools_colors = color_graph_with_ortools(graph, num_classes)
            same_color_edge_count = count_same_color_edges(graph, ortools_colors, num_classes)
            edges = graph.number_of_edges()
            #print(same_color_edge_count, ' out of ', edges)
            cp_violation = 100*same_color_edge_count/edges

            greedy_colors = color_graph(graph, num_classes)
            same_color_edge_count = count_same_color_edges(graph, greedy_colors, num_classes)
            edges = graph.number_of_edges()
            #print(same_color_edge_count, ' out of ', edges)
            greedy_violation = 100*same_color_edge_count/edges

            dsatur_colors = dsatur(graph)
            same_color_edge_count = count_same_color_edges(graph, dsatur_colors, num_classes)
            edges = graph.number_of_edges()
            #print(same_color_edge_count, ' out of ', edges)
            dsatur_violation = 100*same_color_edge_count/edges

            num_nodes_list.append(num_nodes)
            num_classes_list.append(num_classes)
            indices.append(i)
            gnn_violations.append(gnn_violation)
            cp_violations.append(cp_violation)
            greedy_violations.append(greedy_violation)
            dsatur_violations.append(dsatur_violation)

            print('num nodes ', num_nodes,
                ' num classes ', num_classes, 
                ' index ', i,
                ' gnn violation ', round(gnn_violation,2),
                'cp violation ', round(cp_violation,2),
                'greedy violation ', round(greedy_violation,2),
                'dsatur violation ', round(dsatur_violation,2))


out_df = pd.DataFrame({'Num_nodes' : num_nodes_list,
                       'Num_classes' : num_classes_list,
                       'i': indices,
                       'GNN' : gnn_violations, 
                       'CP' : cp_violations,
                       'Greedy' : greedy_violations,
                       'DSATUR' : dsatur_violations
                        })

file_path = 'results.csv'

# Check if file is empty (or if it exists)
if os.path.exists(file_path) and os.path.getsize(file_path) > 0:
    header = False  # Don't write header if file already exists and is not empty
else:
    header = True  # Write header if file is empty or doesn't exist

# Append to CSV
out_df.to_csv(file_path, mode='a', header=header, index=False)


num nodes  30  num classes  4  index  0  gnn violation  55.56 cp violation  37.78 greedy violation  52.22 dsatur violation  44.44
num nodes  30  num classes  4  index  1  gnn violation  64.0 cp violation  45.33 greedy violation  54.67 dsatur violation  46.67
num nodes  30  num classes  4  index  2  gnn violation  59.38 cp violation  45.83 greedy violation  56.25 dsatur violation  59.38
num nodes  30  num classes  4  index  3  gnn violation  68.48 cp violation  54.35 greedy violation  41.3 dsatur violation  53.26
num nodes  30  num classes  4  index  4  gnn violation  61.18 cp violation  56.47 greedy violation  50.59 dsatur violation  47.06
num nodes  30  num classes  4  index  5  gnn violation  57.83 cp violation  43.37 greedy violation  60.24 dsatur violation  32.53
num nodes  30  num classes  4  index  6  gnn violation  65.17 cp violation  37.08 greedy violation  57.3 dsatur violation  46.07
num nodes  30  num classes  4  index  7  gnn violation  64.52 cp violation  47.31 greedy viol