In [1]:
import torch
from torch import nn
from torch_geometric.nn import GCNConv
from torch.nn import Linear
import torch.nn.functional as F
from GNNNestedCVEvaluation import GNNNestedCVEvaluation
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.utils import add_self_loops
from hyperopt import hp

  _torch_pytree._register_pytree_node(


In [2]:
class GNN(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout = .2, normalize = False, add_self_loops = True):
        super(GNN, self).__init__()
        
        self.conv1 = GCNConv(in_dim, hidden_dim, normalize = normalize, add_self_loops=add_self_loops)
        self.conv2 = GCNConv(hidden_dim, out_dim, normalize = normalize, add_self_loops=add_self_loops)
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, x, edge_index):
        x = self.dropout(x)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        return x

In [3]:
dataset_name = 'Cora'
split = "public"
dataset = Planetoid(root='data/', name=dataset_name, split=split)
dataset.transform = T.NormalizeFeatures()

In [4]:
data = dataset[0]

In [5]:
device = torch.device("cuda:0")

In [6]:
gnn_nestedCV_evaluation = GNNNestedCVEvaluation(device, GNN,data)

In [7]:
hidden_dim = [2**i for i in range(3, 8)]
out_dim = [dataset.num_classes]
normalize = [True]
add_self_loops = [True, False]

gnn_choices = {
    'hidden_dim': hidden_dim,
    'out_dim': out_dim, 
    'normalize': normalize, 
    'add_self_loops': add_self_loops, 
}
 
gnn_space = {
    **{key: hp.choice(key, value) for key, value in gnn_choices.items()},
    'lr': hp.loguniform('lr',-8, -4),
    'weight_decay': hp.loguniform('weight_decay',-11, -9),
    'dropout': hp.uniform('dropout', 0, .6),
}

In [8]:
gnn_nestedCV_evaluation.nested_cross_validate(2, 2, gnn_space, 5)

0it [00:00, ?it/s]

START HYPERPARAM SEARCH
{'hidden_dim': <hyperopt.pyll.base.Apply object at 0x7f4041b977c0>, 'out_dim': <hyperopt.pyll.base.Apply object at 0x7f4041b946d0>, 'normalize': <hyperopt.pyll.base.Apply object at 0x7f4041b955a0>, 'add_self_loops': <hyperopt.pyll.base.Apply object at 0x7f4041b956c0>, 'lr': <hyperopt.pyll.base.Apply object at 0x7f4041b95840>, 'weight_decay': <hyperopt.pyll.base.Apply object at 0x7f4041b95960>, 'dropout': <hyperopt.pyll.base.Apply object at 0x7f4041b96f80>}


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/07/04 14:47:48 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/07/04 14:47:49 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
  _torch_pytree._register_pytree_node(
[0.31166911 0.33382571]
  _torch_pytree._register_pytree_node(
[0.3028065  0.31314623]
  _torch_pytree._register_pytree_node(
[0.20384046 0.31610045]
  _torch_pytree._register_pytree_node(
[0.30428359 0.30576071]
  _torch_pytree._register_pytree_node(
[0.86410636 0.3028065 ]
Total Trials: 5: 5 succeeded, 0 failed, 0 cancelled.                            


[[0. 0.]
 [0. 0.]]
START HYPERPARAM SEARCH
{'hidden_dim': <hyperopt.pyll.base.Apply object at 0x7f4041b977c0>, 'out_dim': <hyperopt.pyll.base.Apply object at 0x7f4041b946d0>, 'normalize': <hyperopt.pyll.base.Apply object at 0x7f4041b955a0>, 'add_self_loops': <hyperopt.pyll.base.Apply object at 0x7f4041b956c0>, 'lr': <hyperopt.pyll.base.Apply object at 0x7f4041b95840>, 'weight_decay': <hyperopt.pyll.base.Apply object at 0x7f4041b95960>, 'dropout': <hyperopt.pyll.base.Apply object at 0x7f4041b96f80>}


  _torch_pytree._register_pytree_node(
[0.30428359 0.80059081]
  _torch_pytree._register_pytree_node(
[0.41949779 0.32496306]
  _torch_pytree._register_pytree_node(
[0.31314623 0.31610045]
  _torch_pytree._register_pytree_node(
[0.77548003 0.31166911]
  _torch_pytree._register_pytree_node(
[0.85672081 0.30871493]
Total Trials: 5: 5 succeeded, 0 failed, 0 cancelled.                            


[[0. 0.]
 [0. 0.]]


<NestedCV.NestedTransductiveCV at 0x7f4041b95c30>

In [9]:
gnn_nestedCV_evaluation.nested_transd_cv.outer_scores

array([0.30354506, 0.88183159])