In [1]:
from torch_geometric.datasets import IMDB, DBLP
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import balanced_accuracy_score
import torch
import numpy as np
from sklearn.metrics import normalized_mutual_info_score, f1_score, accuracy_score
from HeteroNestedCV import NestedTransductiveCV
from tqdm.notebook import tqdm
import time
import copy
from hyperopt import hp

  _torch_pytree._register_pytree_node(


In [2]:
imdb = IMDB(root= "./data/IMDB")
dblp = DBLP(root= "./data/DBLP")
data = dblp[0]

In [3]:
data["conference"].x = torch.arange(data["conference"].num_nodes).unsqueeze(-1)
data["conference"].x
data = data.cpu()

In [4]:
import torch
from torch_geometric.nn import HGTConv, HANConv
from typing import Callable

ACTIVATION_FUNS = {
    "RELU": torch.nn.ReLU,
    "LEAKY_RELU": torch.nn.LeakyReLU,
    "ELU":  torch.nn.ELU
}


class TransductiveHeteroGNN(torch.nn.Module):
    def __init__(self, layers, target_node_type, dropout, weight_decay,patience, lr, act = "RELU", device = None):
        assert act in ACTIVATION_FUNS.keys(), f"You can only pick one of the following activation funcations: {ACTIVATION_FUNS.keys()}"
        super(TransductiveHeteroGNN, self).__init__()
        self.device = device if device is not None else torch.device("cpu")
        out_dim  = data[target_node_type].y.unique().numel()
        self.layers = layers
        self.patience = patience
        
        self.dropout = torch.nn.Dropout(p = dropout)
        self.activate = ACTIVATION_FUNS[act]()
        self.target_node_type = target_node_type
        self.lr = lr
        self.weight_decay = weight_decay
        self.loss_values = []
        self.val_loss_values = []

    @staticmethod
    def split_mask(mask, split_ratio):
        true_indices = torch.nonzero(mask, as_tuple=False)
        num_to_switch = int(split_ratio * true_indices.size(0))
        indices_to_switch = torch.randperm(true_indices.size(0))[:num_to_switch]
        true_indices[indices_to_switch].split(1, dim=1)
        train_mask = torch.clone(mask)
        val_mask = torch.zeros_like(mask)
        train_mask[true_indices[indices_to_switch].split(1, dim=1)] = False
        val_mask[true_indices[indices_to_switch].split(1, dim=1)] = True
        return train_mask, val_mask        

    def forward(self, x_dict, edge_index_dict):
        ## TODO CHeck droput input layer
        for key in x_dict:
            x_dict[key] = self.activate(x_dict[key])
        for layer in self.layers[:-1]:
            x_dict = layer(x_dict, edge_index_dict)
            for key in x_dict:
                x_dict[key] = self.activate(x_dict[key])
                x_dict[key] = self.dropout(x_dict[key])
        x_dict = self.layers[-1](x_dict, edge_index_dict)
        return x_dict

    def shift_data_to_device(self, data, mask):
        data.x_dict = {key: data.x_dict[key].to(self.device).type(torch.float) for key in data.x_dict}
        data.edge_index_dict = {key: data.edge_index_dict[key].to(self.device) for key in data.edge_index_dict}
        mask = mask.to(self.device)
        data[self.target_node_type].y = data[self.target_node_type].y.to(self.device)
        return data, mask

    def fit(self, data, mask):
        optim = torch.optim.Adam(params=self.parameters(), lr = self.lr, weight_decay= self.weight_decay)
        loss_fun = torch.nn.CrossEntropyLoss() 
        for epoch in range(10_000):
            self.train()
            data, mask = self.shift_data_to_device(data, mask)
            train_mask, val_mask = TransductiveHeteroGNN.split_mask(mask, .2)
            
            out = self(data.x_dict, data.edge_index_dict)
            loss = loss_fun(out[self.target_node_type][train_mask], data[self.target_node_type].y[train_mask])
            val_loss = self.evaluate_loss(data, val_mask, loss_fun)

            # if len(self.val_loss_values) > self.patience and val_loss.item() > torch.tensor(self.val_loss_values[-(self.patience + 1):-1]).mean().item():
            
            if len(self.val_loss_values) > self.patience and  ((val_loss.item() >= torch.tensor(self.val_loss_values[-(self.patience + 1):-1])).sum() == self.patience):
                break
                
            self.loss_values.append(loss.item())
            self.val_loss_values.append(val_loss.item())
            optim.zero_grad()
            loss.backward()
            optim.step()       
            
    def evaluate_loss(self, data, mask, loss_fun):
        with torch.no_grad():
            self.eval()
            data.x_dict = {key: data.x_dict[key].to(self.device).type(torch.float) for key in data.x_dict}
            data.edge_index_dict = {key: data.edge_index_dict[key].to(self.device) for key in data.edge_index_dict}
            mask = mask.to(self.device)
            out = self(data.x_dict, data.edge_index_dict)
            loss = loss_fun(out[self.target_node_type][mask], data[self.target_node_type].y[mask])
            return loss

    def predict_proba(self, data, mask):
        with torch.no_grad():
            self.eval()
            data.x_dict = {key: data.x_dict[key].to(self.device).type(torch.float) for key in data.x_dict}
            data.edge_index_dict = {key: data.edge_index_dict[key].to(self.device) for key in data.edge_index_dict}
            mask = mask.to(self.device)
            out = self(data.x_dict, data.edge_index_dict)
            return torch.nn.functional.softmax(out[self.target_node_type][mask], dim = -1).cpu()

    def pred(self, data, mask):
        return self.predict_proba(data, mask).argmax(-1)

    def evaluate(self, data, mask, metric_fun, metric_kwargs):
        assert isinstance(metric_fun, Callable), f"Your evaluation metric (metric_fun) is not callable! {metric_fun}"
        pred = model.pred(data, mask) 
        return metric_fun(data[self.target_node_type].y[mask].cpu(), pred, **metric_kwargs)

class HGT(TransductiveHeteroGNN):
    def __init__(self, target_node_type, data, num_layers, hidden_dim, heads, dropout, weight_decay,patience, lr, act = "RELU", device = None):
        heads = int(heads)
        hidden_dim = int(hidden_dim)
        num_layers = int(num_layers)
        
        out_dim  = data[target_node_type].y.unique().numel()
        
        layers = torch.nn.ModuleList()
        first_layer = HGTConv({key: data.x_dict[key].shape[-1] for key in data.x_dict}, hidden_dim, heads = heads, metadata= data.metadata())
        layers.append(first_layer)
        for i in range(1, num_layers - 1):
            layer = HGTConv({key: hidden_dim for key in data.x_dict}, hidden_dim, heads = heads, metadata= data.metadata())
            layers.append(layer)
        last_layer = HGTConv(hidden_dim, out_dim, metadata= data.metadata())
        layers.append(last_layer)
        
        
        super().__init__(layers=layers, target_node_type=target_node_type, dropout=dropout, weight_decay=weight_decay,\
                         patience = patience,lr=lr, act = act, device = device)

class HAN(TransductiveHeteroGNN):
    def __init__(self, target_node_type, data, num_layers, hidden_dim, heads, dropout, weight_decay,patience, lr, act = "RELU", device = None):
        heads = int(heads)
        hidden_dim = int(hidden_dim)
        num_layers = int(num_layers)
        
        out_dim  = data[target_node_type].y.unique().numel()
        
        layers = torch.nn.ModuleList()
        first_layer = HANConv({key: data.x_dict[key].shape[-1] for key in data.x_dict}, hidden_dim, heads = heads, metadata= data.metadata())
        layers.append(first_layer)
        for i in range(1, num_layers - 1):
            layer = HANConv({key: hidden_dim for key in data.x_dict}, hidden_dim, heads = heads, metadata= data.metadata())
            layers.append(layer)
        last_layer = HANConv(hidden_dim, out_dim, metadata= data.metadata())
        layers.append(last_layer)
        
        
        super().__init__(layers=layers, target_node_type=target_node_type, dropout=dropout, weight_decay=weight_decay,\
                         patience = patience,lr=lr, act = act, device = device)

In [5]:
# device = torch.device("cuda:0")
# model = HGT("author", data, num_layers = 3, hidden_dim=64, heads=4, dropout=.1, weight_decay=1e-5,patience =100, lr=5e-4, act = "RELU", device = device)
# model = model.to(device)
# # model.fit(data["author"].train_mask)
# start_train_time = time.time()
# model.fit(data, data["author"].train_mask.to(device))
# end_train_time = time.time() - start_train_time
# model.evaluate(data, data["author"].test_mask, f1_score, {"average":"micro"})

In [6]:
def train_val_masks(train_mask, manual_seed = None, train_size = 0.8):
    if manual_seed:
        torch.manual_seed(manual_seed)
    train_index = train_mask.nonzero().squeeze()
    min = int(train_size*train_index.shape[0])
    rand_train_index = torch.randperm(train_index.shape[0])
    rand_train_index_train_index = rand_train_index[:min]
    rand_train_index_val_index = rand_train_index[min:]

    train_mask = torch.zeros_like(train_mask)
    val_mask = torch.zeros_like(train_mask)
    
    new_train_idx = train_index[rand_train_index_train_index]
    new_val_idx = train_index[rand_train_index_val_index]

    train_mask[new_train_idx] = 1
    val_mask[new_val_idx] = 1
    return train_mask, val_mask

def space_to_spaces(space, hops):
    spaces = []
    for hop in hops:
        spaces.append(copy.deepcopy(space))
    return spaces

class GNNNestedCVEvaluation:

    def __init__(self,target_node_type, Model, device_id, data, minimize = False, max_evals = 100, parallelism = 1):
        self.target_node_type = target_node_type
        self.device = torch.device(f"cuda:{device_id}") if device_id is not None else torch.device("cpu")
        self.training_times = []
        self.Model = Model
        self.minimize = minimize
        self.data = data
        self.nested_transd_cv = None
        self.max_evals = max_evals
        self.parallelism = parallelism

    def nested_cross_validate(self, k_outer, k_inner, space):  

        # spaces = space_to_spaces()        
        def evaluate_fun(fitted_model, data, mask):
            pred_proba = fitted_model.predict_proba(data, mask)
            return f1_score(data[self.target_node_type].y[mask].cpu(), pred_proba.argmax(1).cpu(), average="micro")

        def train_fun(data, inner_train_mask, hyperparameters):  
            torch.manual_seed(0)
            model = self.Model(self.target_node_type, data, num_layers=hyperparameters["num_layers"], hidden_dim=hyperparameters["hidden_dim"], heads=hyperparameters["heads"], dropout=hyperparameters["dropout"],
                        weight_decay=hyperparameters["weight_decay"],patience = int(hyperparameters["patience"]), lr=hyperparameters["lr"], act = hyperparameters["act"],
                        device = self.device)
            model = model.to(self.device)
            data = data.to(self.device)
            start_train_time = time.time()
            model.fit(data, data[self.target_node_type].train_mask.to(self.device))
            end_train_time = time.time() - start_train_time
            data = data.cpu()
            return model #copy.deepcopy(model) #, end_train_time
            
        self.nested_transd_cv = NestedTransductiveCV(self.data, self.target_node_type, k_outer, k_inner, train_fun, evaluate_fun,max_evals = self.max_evals, parallelism = self.parallelism, minimalize = self.minimize)
        self.nested_transd_cv.outer_cv(space)
        return self.nested_transd_cv

In [7]:
class GNNModelSpace():
    def __init__(self):
        self.space = None
        self.initialize_space()

    def initialize_space(self):
        framework_choices = {
            'act': ["RELU", "LEAKY_RELU", "ELU"]
        }
         
        self.space = {
            **{key: hp.choice(key, value) for key, value in framework_choices.items()},
        }
        self.add_choice("num_layers", [2,3,4])
        # self.add_qloguniform("patience", [10, 100], 10)
        self.add_choice("patience", [50])
        self.add_loguniform("lr", (1e-5, 1e-1))
        self.add_loguniform("weight_decay", (1e-6, 1e-2))
        self.add_qloguniform("hidden_dim", (16, 256), 16)
        self.add_uniform("dropout", (0.0, 0.8))
        
    def add_choice(self, key, items):
        self.space[key] = hp.choice(key, items)
        
    def add_uniform(self, key, limits: tuple):
        self.space[key] = hp.uniform(key, limits[0], limits[1])
        
    def add_loguniform(self, key, limits: tuple):
        self.space[key] = hp.loguniform(key, np.log(limits[0]), np.log(limits[1]))
        
    def add_qloguniform(self, key, limits, q):
        self.space[key] = hp.qloguniform(key, low=np.log(limits[0]), high=np.log(limits[1]), q=q)

class HGTSpace(GNNModelSpace):
    def __init__(self):
        super().__init__()

    def get_space(self):
        # self.add_qloguniform("heads", (2, 8), 4)
        self.add_choice("heads", [1,2,4])
        return self.space    

class HANSpace(GNNModelSpace):
    def __init__(self):
        super().__init__()

    def get_space(self):
        # self.add_qloguniform("heads", (2, 8), 4)
        self.add_choice("heads", [1,2,4])
        return self.space    

In [8]:
hgt_space = HGTSpace()
gnn_nested_cv_eval = GNNNestedCVEvaluation("author", HAN, 0, data, max_evals = len(hgt_space.get_space())*40, parallelism=4) #4

In [None]:
gnn_nested_cv_eval.nested_cross_validate(3,3, hgt_space.get_space())

0it [00:00, ?it/s]

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/08/23 07:42:03 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/08/23 07:42:04 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _t

In [None]:
gnn_nested_cv_eval.nested_transd_cv