In [8]:
import sys
import os
import pickle
import numpy as np

sys.path.append(os.path.abspath(os.path.join('..', 'src')))

In [9]:
from torch_geometric.utils import to_dense_adj
import torch.nn.functional as F
from utils import get_edge_index_and_theta
from model import FuzzyDirGCN
import torch

import importlib
import utils.data_loading

# Reload the entire module
importlib.reload(utils.data_loading)

# Now, re-import the specific function
from utils.data_loading import (
    get_classification_dataset,
    get_graph_ensemble_dataset,
)


In [10]:
def set_seed(seed):
    # Set the seed for Python's built-in random module
    # Set the seed for NumPy (if you're using it)
    np.random.seed(seed)
    # Set the seed for PyTorch
    torch.manual_seed(seed)
    # If using a GPU, ensure that all operations are deterministic
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

In [4]:
data, mask = get_classification_dataset('cora')

In [30]:
data = pickle.load(open('../datasets/perturb_seq/Replogle-gwps_training_data_with_three_splits.pkl', 'rb'))

In [32]:
data[0].keys()

dict_keys(['training_data', 'val_data', 'test_data', 'training_data_count', 'val_data_count', 'test_data_count', 'graph_info'])

In [34]:
data[0]['graph_info'][3]['edge_indices']

array([[   0,    0,    0, ..., 1992, 1993, 1995],
       [  21,  244,  473, ..., 1998, 1998, 1998]])

In [36]:
pickle.dump(
    {'train_data': data[0]['training_data_count'],
     'val_data': data[0]['val_data_count'],
     'test_data': data[0]['test_data_count'],
     'edge_index': data[0]['graph_info'][3]['edge_indices']},
    open('../datasets/perturb_seq/Replogle-gwps.pkl', 'wb'))

In [134]:
train_loader, val_loader, test_loader, pe = get_graph_ensemble_dataset(
    'lattice', undirected=False, pe_type='eigenvector', pe_dim=10)


In [143]:
train_loader, val_loader, test_loader = get_graph_ensemble_dataset(
    'perturb_seq')


In [144]:
next(iter(train_loader))

DataBatch(x=[32000, 1], edge_index=[2, 94928], y=[32000, 1], mask=[32000], batch=[32000], ptr=[17])

In [156]:
train_loader, val_loader, test_loader, meta_data = get_graph_ensemble_dataset(
    'power_grid', '/sahandlab/Team/directifying_graph/ICLR/datasets')

## Model loading

### 1. Table 1

In [11]:
# CoED
device = torch.device('cuda:1')
data, (train_mask, val_mask, test_mask) = get_classification_dataset('wisconsin', device='cuda:1')

adj = to_dense_adj(data.edge_index)[0]
print(adj[np.diag_indices(len(adj))].sum())
#adj.fill_diagonal_(0.0)

src_to_dst_edge, dst_to_src_edge, theta = get_edge_index_and_theta(adj)
theta = theta.float()
edge_index = src_to_dst_edge #.to(device)
edge_weight = torch.ones(edge_index.shape[1]) # .to(device)

num_nodes, num_edges = data.x.shape[0], edge_index.shape[1]

tensor(16., device='cuda:1')


In [12]:
def train(x, y, model, optimizer, edge_index, theta, edge_weight, mask, index=None):
    model.train()
    optimizer.zero_grad()
    if index is not None:
        loss = F.nll_loss( #F.cross_entropy(
            model(x, edge_index, theta, edge_weight)[mask[:, index]], 
            #model(x, edge_index)[mask[:, index]], 
            y[mask[:, index]])
    else:
        loss = F.nll_loss( #F.cross_entropy(
            model(x, edge_index, theta, edge_weight)[mask], 
            #model(x, edge_index)[mask], 
            y[mask])  
    loss.backward()
    optimizer.step()
    return loss.item()


@torch.no_grad()
def test(x, y, model, optimizer, edge_index, theta, edge_weight, masks, index):
    model.eval()
    log_probs, accs = model(x, edge_index, theta, edge_weight), []
    #log_probs, accs = model(x, edge_index), []
    for mask in masks:
        if index is not None:
            pred = log_probs[mask[:, index]].max(1)[1]
            acc = pred.eq(y[mask[:, index]]).sum().item() / mask[:, index].sum().item()
        else:
            pred = log_probs[mask].max(1)[1]
            acc = pred.eq(y[mask]).sum().item() / mask.sum().item()            
        accs.append(acc)
    return accs

In [17]:
in_channels = data.x.shape[-1]
out_channels = data.y.max().item() + 1
hidden_channels = 128
num_layers = 2
lr = 2e-2
wd = 1e-3
dropout_rate = 0.5
alpha = 0.5
normalize = False
jumping_knolwedge = None
self_loop = False
self_feature_transform = True



seed = 42
set_seed(seed)
test_accs = []     

for index in range(10):

    model = FuzzyDirGCN(
        in_channels=data.x.shape[-1], 
        hidden_channels=hidden_channels, 
        out_channels=data.y.max().item() + 1, 
        num_layers=num_layers,
        num_nodes=num_nodes,
        num_edges=num_edges,
        alpha=alpha,
        normalize=normalize,
        self_feature_transform=self_feature_transform,
        self_loop=self_loop,
        layerwise_theta=False,
        regression=False,
        dropout_rate=dropout_rate,
        jumping_knowledge=jumping_knolwedge).to(device) 
    model.reset_parameters()

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

    
    best_val_acc = 0.0
    best_test_acc = 0.0
    n_non_decreasing_step = 0

    for epoch in range(1, 1000):
        tr_loss = train(
            data.x, data.y, model, optimizer, 
            edge_index, theta, edge_weight.to(device), 
            train_mask, index)
        train_acc, val_acc, test_acc = test(
            data.x, data.y, model, optimizer, 
            edge_index, theta, edge_weight.to(device), 
            (train_mask, val_mask, test_mask), index)
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc
            n_non_decreasing_step = 0
        else:
            n_non_decreasing_step += 1
        
        if n_non_decreasing_step > 200:
            break
    
        if epoch % 50 == 0:
            print(f'index: {index} | '
                  f'Epoch: {epoch:03d}, Loss: {tr_loss:.5f}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
                  f'Best: {best_test_acc:.4}, early_stopping: {n_non_decreasing_step}')

        
    
    # for n, p in model.named_parameters():
    #     print(n, p)
    test_accs.append(best_test_acc)
    # plt.figure(figsize=(4, 3))
    # plt.plot(theta.detach().cpu().numpy(), 'o')
    # plt.ylim(0, np.pi/2)
    # plt.show()
        
print(f'test acc: {np.mean(test_accs):.6f} +/- {np.std(test_accs):.6f}')
print(test_accs)


index: 0 | Epoch: 050, Loss: 0.05403, Train: 1.0000, Val: 0.8875, Best: 0.7647, early_stopping: 7
index: 0 | Epoch: 100, Loss: 0.04807, Train: 1.0000, Val: 0.8625, Best: 0.7843, early_stopping: 40
index: 0 | Epoch: 150, Loss: 0.04378, Train: 1.0000, Val: 0.8250, Best: 0.7843, early_stopping: 90
index: 0 | Epoch: 200, Loss: 0.03632, Train: 1.0000, Val: 0.8750, Best: 0.7843, early_stopping: 140
index: 0 | Epoch: 250, Loss: 0.05496, Train: 1.0000, Val: 0.8500, Best: 0.7843, early_stopping: 190
index: 1 | Epoch: 050, Loss: 0.07993, Train: 0.9917, Val: 0.7625, Best: 0.9216, early_stopping: 19
index: 1 | Epoch: 100, Loss: 0.06236, Train: 0.9917, Val: 0.8000, Best: 0.902, early_stopping: 18
index: 1 | Epoch: 150, Loss: 0.08245, Train: 1.0000, Val: 0.7625, Best: 0.902, early_stopping: 68
index: 1 | Epoch: 200, Loss: 0.04684, Train: 1.0000, Val: 0.8125, Best: 0.902, early_stopping: 118
index: 1 | Epoch: 250, Loss: 0.06326, Train: 1.0000, Val: 0.8000, Best: 0.902, early_stopping: 168
index: 2 | 