In [1]:
import torch
from sklearn.base import BaseEstimator
from typing import TypedDict
import numpy as np
import numpy
from sklearn.base import clone
from sklearn.model_selection import GridSearchCV

USER_FUNCTIONS = {
    'sum': lambda origin_features, updated_features, sum_neighbors, mul_neighbors, num_neighbors: sum_neighbors,
    'mean': lambda origin_features, updated_features, sum_neighbors, mul_neighbors, num_neighbors: sum_neighbors / num_neighbors,
    'diff_of_origin_mean': lambda origin_features, updated_features, sum_neighbors, mul_neighbors, num_neighbors: origin_features - sum_neighbors / num_neighbors,
    'diff_of_updated_mean': lambda origin_features, updated_features, sum_neighbors, mul_neighbors, num_neighbors: updated_features - sum_neighbors / num_neighbors,
    'sum_of_origin_mean': lambda origin_features, updated_features, sum_neighbors, mul_neighbors, num_neighbors: origin_features + sum_neighbors / num_neighbors,
    'sum_of_updated_mean': lambda origin_features, updated_features, sum_neighbors, mul_neighbors, num_neighbors: updated_features + sum_neighbors / num_neighbors,
}
## Assumption: the overall prediction perf improved when the performance of inidividual predictiors improves
##TODO More input_validation, grid search method whoch accepts the same params
class Framework:    
    
    def __init__(self, user_functions, 
                 hops_list:list[int],
                 clfs:list,
                 gpu_idx:int|None=None,
                 handle_nan:float|None=None,
                attention_configs:list=[]) -> None:
        self.user_functions = user_functions
        self.hops_list:list[int] = hops_list
        self.clfs:list[int] = clfs
        self.trained_clfs = None
        self.gpu_idx:int|None = gpu_idx
        self.handle_nan:float|int|None = handle_nan
        self.attention_configs = attention_configs
        self.device:torch.DeviceObjType = torch.device(f"cuda:{str(self.gpu_idx)}") if self.gpu_idx is not None and torch.cuda.is_available() else torch.device("cpu")
    
    def update_user_function(self):
        if self.user_function in USER_FUNCTIONS:
            self.user_function = USER_FUNCTIONS[self.user_function]
        else:
            raise Exception(f"Only the following string values are valid inputs for the user function: {[key for key in USER_FUNCTIONS]}. You can also specify your own function for aggregatioon.")
            
    def get_features(self,
                     X:torch.FloatTensor|numpy._typing.NDArray,
                     edge_index:torch.LongTensor|numpy._typing.NDArray,
                     mask:torch.BoolTensor|numpy._typing.NDArray,
                    is_training:bool = False) -> tuple[torch.FloatTensor, torch.FloatTensor]:
        if mask is None:
            mask = torch.ones(X.shape[0]).type(torch.bool)
#         if isinstance(self.user_function, str):
#             self.update_user_function()
        ## To tensor
        X = Framework.get_feature_tensor(X)
        edge_index = Framework.get_edge_index_tensor(edge_index)
        mask = Framework.get_mask_tensor(mask)
        
        ## To device
        X = self.shift_tensor_to_device(X)
        edge_index = self.shift_tensor_to_device(edge_index)
        mask = self.shift_tensor_to_device(mask)
        
        aggregated_train_features_list = []
        ## Aggregate
        for hop_idx in range(len(self.hops_list)):
            neighbor_features = self.aggregate(X, edge_index, hop_idx, is_training)
            aggregated_train_features_list.append(neighbor_features[mask])
        
        return aggregated_train_features_list
    
    def aggregate(self, X:torch.FloatTensor, edge_index:torch.LongTensor,hop_idx, is_training:bool=False) -> torch.FloatTensor: 
        original_features = X
        features_for_aggregation:torch.FloatTensor = torch.clone(X)
        hops_list = self.hops_list[hop_idx]
        for i, hop in enumerate(range(hops_list)):
            if self.attention_configs[hop_idx] and self.attention_configs[hop_idx]["inter_layer_normalize"]:
                features_for_aggregation = torch.nn.functional.normalize(features_for_aggregation, dim = 0)
            source_lift = features_for_aggregation.index_select(0, edge_index[0])
            target = edge_index[1]
            
            if self.attention_configs[hop_idx] and self.attention_configs[hop_idx]["use_pseudo_attention"]:
                source_lift = self.apply_attention_mechanism(source_lift, features_for_aggregation, target,self.attention_configs[hop_idx], is_training)
            
            summed_neighbors = torch.zeros_like(features_for_aggregation, device=self.device).scatter_reduce(0, target.unsqueeze(0).repeat(features_for_aggregation.shape[1], 1).t(), source_lift, reduce="sum", include_self = False)
            summed_neighbors = torch.zeros_like(features_for_aggregation, device=self.device).scatter_(0, target.unsqueeze(0).repeat(features_for_aggregation.shape[1], 1).t(), source_lift, reduce="add")
            multiplied_neighbors = torch.ones_like(features_for_aggregation, device=self.device).scatter_reduce(0, target.unsqueeze(0).repeat(features_for_aggregation.shape[1], 1).t(), source_lift, reduce="prod", include_self = False)
            mean_neighbors = torch.zeros_like(features_for_aggregation, device=self.device).scatter_reduce(0, target.unsqueeze(0).repeat(features_for_aggregation.shape[1], 1).t(), source_lift, reduce="mean", include_self = False)
            max_neighbors = torch.zeros_like(features_for_aggregation, device=self.device).scatter_reduce(0, target.unsqueeze(0).repeat(features_for_aggregation.shape[1], 1).t(), source_lift, reduce="amax", include_self = False)
            min_neighbors = torch.zeros_like(features_for_aggregation, device=self.device).scatter_reduce(0, target.unsqueeze(0).repeat(features_for_aggregation.shape[1], 1).t(), source_lift, reduce="amin", include_self = False)

            num_source_neighbors = torch.zeros(features_for_aggregation.shape[0], dtype=torch.float, device=self.device)
            num_source_neighbors.scatter_reduce(0, target, torch.ones_like(target, dtype=torch.float, device=self.device), reduce="sum", include_self = False)
            num_source_neighbors = num_source_neighbors.unsqueeze(-1)

            user_function = self.user_functions[hop_idx]
            updated_features = features_for_aggregation ## just renaming so that the key in the user function is clear
            user_function_kwargs = {
                                'original_features':original_features,
                                'updated_features':updated_features,
                                'summed_neighbors':summed_neighbors,
                                'multiplied_neighbors':multiplied_neighbors,
                                'mean_neighbors':mean_neighbors,
                                'max_neighbors':max_neighbors,
                                'min_neighbors':min_neighbors,
                                'num_source_neighbors':num_source_neighbors,
                                'hop':hop}
            out = user_function(user_function_kwargs)
            
            if self.handle_nan is not None:
                out = torch.nan_to_num(out, nan=self.handle_nan)
            features_for_aggregation = out
        return features_for_aggregation
    
    def apply_attention_mechanism(self, source_lift:torch.FloatTensor,
                                  features_for_aggregation:torch.FloatTensor,
                                  target:torch.LongTensor,
                                  attention_config,
                                 is_training:bool = False) -> torch.FloatTensor:
        cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
        score = cos(source_lift, features_for_aggregation.index_select(0, target))
        dropout_tens = None
        
        origin_scores = torch.clone(score)
        if attention_config["cosine_eps"]:
            score[score < attention_config["cosine_eps"]] = -torch.inf
        if attention_config["dropout_attn"] is not None and is_training:
            dropout_tens = torch.FloatTensor(score.shape[0]).uniform_(0, 1)
            score[dropout_tens < attention_config["dropout_attn"]] = -torch.inf
        exp_score = torch.exp(score)
        summed_exp_score = torch.zeros_like(exp_score).scatter(0, target,exp_score, reduce="add")
        target_lifted_summed_exp_score = summed_exp_score.index_select(0, target)
        normalized_scores = exp_score / target_lifted_summed_exp_score
        source_lift = normalized_scores.unsqueeze(1) * source_lift
        return source_lift
    
    def fit(self,
            X_train:torch.FloatTensor|numpy._typing.NDArray,
            edge_index:torch.LongTensor|numpy._typing.NDArray,
            y_train:torch.LongTensor|numpy._typing.NDArray,
            train_mask:torch.BoolTensor|numpy._typing.NDArray|None,
            kwargs_list = None
            ) -> BaseEstimator:   
        if train_mask is None:
            train_mask = torch.ones(X_train.shape[0]).type(torch.bool)
            
        y_train = Framework.get_label_tensor(y_train)
        y_train = y_train[train_mask]
        
        self.validate_input()
        
        aggregated_train_features_list = self.get_features(X_train, edge_index, train_mask, True)  
        
        trained_clfs = []
        for i, aggregated_train_features in enumerate(aggregated_train_features_list):
            clf = clone(self.clfs[i])
            kwargs = kwargs_list[i] if kwargs_list and len(kwargs_list)>i is not None else {}
            clf.fit(aggregated_train_features.cpu().numpy(), y_train,**kwargs)
            trained_clfs.append(clf)
        self.trained_clfs = trained_clfs
        return trained_clfs    
    
    def predict_proba(self, X_test:torch.FloatTensor|numpy._typing.NDArray,
                      edge_index:torch.LongTensor|numpy._typing.NDArray,
                      test_mask:torch.BoolTensor|numpy._typing.NDArray|None,
                      weights=None,
                     kwargs_list = None):  
        if test_mask is None:
            test_mask = torch.ones(X_test.shape[0]).type(torch.bool)
        aggregated_test_features_list = self.get_features(X_test, edge_index, test_mask)
        
        pred_probas = []
        for i, clf in enumerate(self.trained_clfs):
            aggregated_test_features = aggregated_test_features_list[i]
            kwargs = kwargs_list[i] if kwargs_list is not None else {}
            pred_proba = clf.predict_proba(aggregated_test_features.cpu().numpy(),**kwargs) if kwargs else clf.predict_proba(aggregated_test_features.cpu().numpy())
            pred_probas.append(pred_proba)
        final_pred_proba = np.average(np.asarray(pred_probas), weights=weights, axis=0)
        return final_pred_proba
        
    
    def predict(self,
                X_test:torch.FloatTensor|numpy._typing.NDArray,
                edge_index:torch.LongTensor|numpy._typing.NDArray,
                test_mask:torch.BoolTensor|numpy._typing.NDArray|None,
                 weights=None,
                     kwargs_list = None):
        return self.predict_proba(X_test, edge_index, test_mask, weights, kwargs_list).argmax(1)
        

    def validate_input(self):
        pass
            
    @staticmethod
    def get_feature_tensor(X:torch.FloatTensor|numpy._typing.NDArray) -> torch.FloatTensor|None:
        if not torch.is_tensor(X):
            try:
                return torch.from_numpy(X).type(torch.float)
            except:
                raise Exception("Features input X must be numpy array or torch tensor!")
                return None 
        return X
    
    @staticmethod
    def get_label_tensor(y:torch.LongTensor|numpy._typing.NDArray) -> torch.LongTensor|None:
        if not torch.is_tensor(y):
            try:
                return torch.from_numpy(y).type(torch.long)
            except:
                raise Exception("Label input y must be numpy array or torch tensor!")
                return None
        return y
    
    @staticmethod
    def get_mask_tensor(mask:torch.BoolTensor|numpy._typing.NDArray) -> torch.BoolTensor|None:
        if not torch.is_tensor(mask):
            try:
                return torch.from_numpy(mask).type(torch.bool)
            except:
                raise Exception("Input mask must be numpy array or torch tensor!")
                return None
        return mask
            
    @staticmethod
    def get_edge_index_tensor(edge_index:torch.LongTensor|numpy._typing.NDArray) -> torch.LongTensor|None:
        if not torch.is_tensor(edge_index):
            try:
                edge_index =  torch.from_numpy(edge_index).type(torch.long)
                Framework.validate_edge_index(edge_index)
                return edge_index
            except:
                raise Exception("Edge index must be numpy array or torch tensor")
                return None
        return edge_index
    
    @staticmethod
    def validate_edge_index(edge_index:torch.LongTensor) -> None:
        if edge_index.shape[0] != 2:
            raise Exception("Edge index must have the shape 2 x NumberOfEdges")
            # TODO: check max edge index and shape of features
    
    def shift_tensor_to_device(self,
                               t:torch.FloatTensor) -> torch.FloatTensor:
        if self.gpu_idx is not None:
            return t.to(self.device) 
        return t
    
    def validate_grid_input(self, grid_params):
        if len(grid_params) != 1 and self.use_feature_based_aggregation:
            raise Exception("You need to provide grid parameter for the classifier!")
        if len(grid_params) != 2 and not self.use_feature_based_aggregation:
            raise Exception("You need to provide two grid parameter, one for each classifier!")
        return
    
    def hyper_param_tuning(spaces, objectives, n_iter, X_train, y_train, X_val, y_val):
        ## bayes optim
        pass



In [57]:
from xgboost import XGBClassifier
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.utils import add_self_loops
from sklearn.tree import DecisionTreeClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.svm import SVC

dataset = Planetoid(root='/tmp/Cora', name='Cora', split="public")
dataset.transform = T.NormalizeFeatures()

X =  dataset[0].x 
y =  dataset[0].y 

test =  dataset[0].test_mask
train = dataset[0].train_mask 
val =  dataset[0].val_mask

edge_index = dataset[0].edge_index 
edge_index = add_self_loops(edge_index)[0]


clf_1 = XGBClassifier( tree_method='hist',
                      device="cuda",
                           n_estimators=1100,
                           max_depth=2,
                    random_state=42,
                    eta=0.3,
                    reg_lambda=0.001,
                           min_child_weight = 1,
                           max_delta_step= 3,
                           sampling_method= "uniform")
   
clf_2 = XGBClassifier( tree_method='hist',
                      device="cuda",
                           n_estimators=900,
                           max_depth=2,
                       random_state=42,
                       reg_lambda=0.2953684210526316,
                       eta=0.2733333333333333,
                           min_child_weight = 2,
                           max_delta_step= 4,
                           sampling_method= "uniform",
                      subsample=0.5)

clf_3 = SVC(probability=True, C=100, kernel="linear", degree=1)

def user_function(kwargs):
    return  kwargs["updated_features"] + kwargs["summed_neighbors"]

user_functions = [user_function, user_function,  user_function]
clfs = [clf_1, clf_2, clf_3]
hops_list = [0, 3,  8]
attention_configs = [ {'inter_layer_normalize': True,
                     'use_pseudo_attention':True,
                     'cosine_eps':.01,
                     'dropout_attn': None}, 
                     {'inter_layer_normalize': True,
                     'use_pseudo_attention':True,
                     'cosine_eps':.01,
                     'dropout_attn': None},
                     {'inter_layer_normalize': True,
                     'use_pseudo_attention':True,
                     'cosine_eps':.01,
                     'dropout_attn': None}
                    ]

In [58]:
import time
start = time.time()
framework = Framework(user_functions, 
                     hops_list=hops_list, ## to obtain best for local neighborhood
                     clfs=clfs,
                     gpu_idx=0,
                     handle_nan=0.0,
                    attention_configs=attention_configs)
val_0, val_3, val_8 = framework.get_features(X, edge_index,val)
val_0, val_3, val_8 = val_0.cpu(), val_3.cpu(), val_8.cpu()
kwargs_list=[{"eval_set":[(val_0, y[val])], "early_stopping_rounds":5}, {"eval_set":[(val_3, y[val])], "early_stopping_rounds":5}, {}]
framework.fit(X, edge_index, y, train, kwargs_list)
print(time.time() - start )

[0]	validation_0-mlogloss:1.73400
[1]	validation_0-mlogloss:1.63995
[2]	validation_0-mlogloss:1.54837
[3]	validation_0-mlogloss:1.48037
[4]	validation_0-mlogloss:1.45511
[5]	validation_0-mlogloss:1.42616
[6]	validation_0-mlogloss:1.40100
[7]	validation_0-mlogloss:1.37746
[8]	validation_0-mlogloss:1.37575
[9]	validation_0-mlogloss:1.36987
[10]	validation_0-mlogloss:1.36355
[11]	validation_0-mlogloss:1.35918
[12]	validation_0-mlogloss:1.34453
[13]	validation_0-mlogloss:1.34202
[14]	validation_0-mlogloss:1.33372
[15]	validation_0-mlogloss:1.33217
[16]	validation_0-mlogloss:1.32909
[17]	validation_0-mlogloss:1.32382
[18]	validation_0-mlogloss:1.32752
[19]	validation_0-mlogloss:1.32067
[20]	validation_0-mlogloss:1.31594
[21]	validation_0-mlogloss:1.32297
[22]	validation_0-mlogloss:1.32562
[23]	validation_0-mlogloss:1.32744
[24]	validation_0-mlogloss:1.32578
[25]	validation_0-mlogloss:1.32492




[0]	validation_0-mlogloss:1.66496
[1]	validation_0-mlogloss:1.47006
[2]	validation_0-mlogloss:1.35592
[3]	validation_0-mlogloss:1.25656
[4]	validation_0-mlogloss:1.16575
[5]	validation_0-mlogloss:1.10393
[6]	validation_0-mlogloss:1.04138
[7]	validation_0-mlogloss:0.99219
[8]	validation_0-mlogloss:0.94844
[9]	validation_0-mlogloss:0.89773
[10]	validation_0-mlogloss:0.87204
[11]	validation_0-mlogloss:0.85190
[12]	validation_0-mlogloss:0.82722
[13]	validation_0-mlogloss:0.82055
[14]	validation_0-mlogloss:0.80748
[15]	validation_0-mlogloss:0.79850
[16]	validation_0-mlogloss:0.79455
[17]	validation_0-mlogloss:0.79733
[18]	validation_0-mlogloss:0.78743
[19]	validation_0-mlogloss:0.78669
[20]	validation_0-mlogloss:0.76802
[21]	validation_0-mlogloss:0.76707
[22]	validation_0-mlogloss:0.76692
[23]	validation_0-mlogloss:0.76274
[24]	validation_0-mlogloss:0.76288
[25]	validation_0-mlogloss:0.75665
[26]	validation_0-mlogloss:0.75304
[27]	validation_0-mlogloss:0.74388
[28]	validation_0-mlogloss:0.7

In [59]:
import numpy as np
from sklearn.metrics import accuracy_score
pred = framework.predict(X, edge_index, test) 
pred_val = framework.predict(X, edge_index, val) 
y_test = y[test]
y_val = y[val]
print(accuracy_score(y_val, pred_val))
print(accuracy_score(y_test, pred))

0.806
0.834


In [30]:
from tqdm.notebook import tqdm
max_val = 0
max_test = 0
best_weights = None
for weight_0 in tqdm(np.linspace(0,1, 5)):
    for weight_1 in tqdm(np.linspace(0,1,5)):
        pred = framework.predict(X, edge_index, test, weights=[weight_0, weight_1, 1-weight_0-weight_1]) 
        pred_val = framework.predict(X, edge_index, val, weights=[weight_0, weight_1, 1-weight_0-weight_1]) 
        y_test = y[test]
        y_val = y[val]
        acc_val = accuracy_score(y_val, pred_val)
        acc_test = accuracy_score(y_test, pred)
        
        if acc_val >= max_val:
            max_val = acc_val
            max_test = acc_test
            best_weights = [weight_0, weight_1, 1-weight_0-weight_1]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

In [31]:
print(max_val)
print(max_test)
print(best_weights)

0.804
0.831
[0.25, 0.25, 0.5]
