# Setup

In [1]:
import torch
import pandas as pd
import numpy as np
from torch_geometric.explain import Explainer, CaptumExplainer,ModelConfig,GNNExplainer

In [2]:
data_folder = "../../data/processed/graph_data_nohubs/"
models_folder = "../../data/models/"
experiments_folder = "../../data/experiments/design_space_experiment/"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import sys
sys.path.append("../..")
from src.models.base_model import base_model

# Utility

In [166]:
import torch
from torch_geometric.nn import SAGEConv,GATConv, to_hetero

class inner_product_decoder(torch.nn.Module):
    def forward(self,x_source,x_target,edge_index,apply_sigmoid=True):
        nodes_src = x_source[edge_index[0]]
        nodes_trg = x_target[edge_index[1]]
        pred = (nodes_src * nodes_trg).sum(dim=-1)

        if apply_sigmoid:
            pred = torch.sigmoid(pred)

        return pred

class base_message_layer(torch.nn.Module):

    def __init__(self, model_params,hidden_layer=True):
        super().__init__()

        # Currently SageConv or GATConv, might have to modify this to support other Convs
        conv_type = model_params["conv_type"]
        self.conv = layer_dict[conv_type]((-1,-1), model_params["hidden_channels"],aggr=model_params["micro_aggregation"],add_self_loops=False)
        self.normalize = model_params["L2_norm"]

        post_conv_modules = []
        if model_params["batch_norm"]:
            bn = torch.nn.BatchNorm1d(model_params["hidden_channels"])
            post_conv_modules.append(bn)
        
        if model_params["dropout"] > 0:    
            dropout = torch.nn.Dropout(p=model_params["dropout"])
            post_conv_modules.append(dropout)
        
        # No activation on final embedding layer
        if hidden_layer:
            activation = model_params["activation"]()
            post_conv_modules.append(activation)
        
        self.post_conv = torch.nn.Sequential(*post_conv_modules)

    def forward(self, x:dict, edge_index:dict) -> dict:
        x = self.conv(x,edge_index)
        x = self.post_conv(x)
        if self.normalize:
            x = torch.nn.functional.normalize(x,2,-1)
        return x

class multilayer_message_passing(torch.nn.Module):
    #TODO: consider input and output dims with skipcat. Currently the two supported convs auto-detect dimensions. Might have to modify this if i add more convs in the future.
    def __init__(self,num_layers,model_params,metadata):
        super().__init__()

        self.skip = model_params["layer_connectivity"]
        self.num_layers = num_layers

        for i in range(self.num_layers):
            hidden_layer = i != self.num_layers-1
            layer = to_hetero(base_message_layer(model_params,hidden_layer),metadata,model_params["macro_aggregation"])
            self.add_module(f"Layer_{i}",layer)
    
    def hetero_skipsum(self,x: dict, x_i:dict) -> dict:
        x_transformed = {}
        for key,x_val in x.items():
            x_i_val = x_i[key]
            transformed_val = x_val + x_i_val
            x_transformed[key] = transformed_val

        return x_transformed

    def hetero_skipcat(self,x: dict, x_i:dict) -> dict:
        x_transformed = {}
        for key,x_val in x.items():
            x_i_val = x_i[key]
            transformed_val = torch.cat([x_val,x_i_val],dim=-1)
            x_transformed[key] = transformed_val

        return x_transformed
    
    def forward(self, x:dict, edge_index:dict) -> dict:
        for i, layer in enumerate(self.children()):
            x_i = x
            x = layer(x,edge_index)
            if self.skip == "skipsum":
                x = self.hetero_skipsum(x,x_i)
            elif self.skip == "skipcat" and i < self.num_layers -1:
                x = self.hetero_skipcat(x,x_i)
        
        return x 

class MLP(torch.nn.Module):
    def __init__(self,num_layers,in_dim,out_dim,model_params,hidden_dim=None):
        super().__init__()

        hidden_dim = out_dim if hidden_dim is None else hidden_dim
        
        modules = []
        if num_layers == 1:
            modules.append(torch.nn.Linear(in_dim,out_dim))
        else:
            for i in range(num_layers):
                final_layer = i == num_layers-1
                first_layer = i == 0
                if first_layer:
                    modules.append(torch.nn.Linear(in_dim,hidden_dim))
                    modules.append(model_params["activation"]())
                elif final_layer:
                    modules.append(torch.nn.Linear(hidden_dim,out_dim))
                else:
                    modules.append(torch.nn.Linear(hidden_dim,hidden_dim))
                    modules.append(model_params["activation"]())
        
        self.model = torch.nn.Sequential(*modules)
    
    def forward(self,x):
        x = self.model(x)
        return x

class base_encoder(torch.nn.Module):
    def __init__(self,model_params,metadata):
        super().__init__()

        self.has_pre_mlp = model_params["pre_process_layers"] > 0
        self.has_post_mlp = model_params["post_process_layers"] > 0

        if self.has_pre_mlp:
            self.pre_mlp = to_hetero(MLP(model_params["pre_process_layers"],model_params["feature_dim"],model_params["hidden_channels"],model_params),metadata)
        
        self.message_passing = multilayer_message_passing(model_params["msg_passing_layers"],model_params,metadata)

        if self.has_post_mlp:
            self.post_mlp = to_hetero(MLP(model_params["post_process_layers"],model_params["hidden_channels"],model_params["hidden_channels"],model_params),metadata)
    
    def forward(self,x:dict,edge_index:dict) -> dict :
        if self.has_pre_mlp:
            x = self.pre_mlp(x)

        x = self.message_passing(x,edge_index)
        
        if self.has_post_mlp:
            x = self.post_mlp(x)

        return x

class base_explainable_model(torch.nn.Module):
    def __init__(self, model_params,metadata,supervision_types=[('gene_protein', 'gda', 'disease')]):
        super().__init__()

        default_model_params = {
            "hidden_channels":32,
            "conv_type":"SAGEConv",
            "batch_norm": True,
            "dropout":0,
            "activation":torch.nn.LeakyReLU,
            "micro_aggregation":"mean",
            "macro_aggregation":"mean",
            "layer_connectivity":None,
            "L2_norm":False,
            "feature_dim": 10,
            "pre_process_layers":0,
            "msg_passing_layers":2,
            "post_process_layers":0,
        }
        
        for arg in default_model_params:
            if arg not in model_params:
                model_params[arg] = default_model_params[arg]
        
        self.encoder = base_encoder(model_params,metadata)
        self.decoder = inner_product_decoder()
        self.loss_fn = torch.nn.BCELoss()
        self.supervision_types = supervision_types
    
    def decode(self,x:dict,edge_label_index):
        # pred_dict = {}
        # for edge_type in self.supervision_types:
        #     edge_index = edge_label_index[edge_type]

        #     src_type = edge_type[0]
        #     trg_type = edge_type[2]

        #     x_src = x[src_type]
        #     x_trg = x[trg_type]

        #     pred = self.decoder(x_src,x_trg,edge_index)

        #     pred_dict[edge_type] = pred
        
        # final_pred = pred_dict[self.supervision_types[0]]
        # # final_pred = final_pred.reshape([1,final_pred.shape[0]])
        edge_index = edge_label_index
        edge_type = self.supervision_types[0]

        src_type = edge_type[0]
        trg_type = edge_type[2]

        x_src = x[src_type]
        x_trg = x[trg_type]

        pred = self.decoder(x_src,x_trg,edge_index)
        
        final_pred = pred
        # final_pred = final_pred.reshape([1,final_pred.shape[0]])
        return final_pred
    
    def encode(self,data):
        x = data.x_dict
        adj_t = data.adj_t_dict

        encodings = self.encoder(x,adj_t)
        return encodings
    
    def forward(self,x:dict,adj_t:dict,edge_label_index:dict):
        # x = data.x_dict
        # adj_t = data.adj_t_dict
        # edge_label_index = data.edge_label_index_dict

        x = self.encoder(x,adj_t)
        pred = self.decode(x,edge_label_index)
        return pred
    
    def loss(self, prediction_dict, label_dict):
        loss = 0
        num_types = len(prediction_dict.keys())
        for edge_type,pred in prediction_dict.items():
            y = label_dict[edge_type]
            loss += self.loss_fn(pred, y.type(pred.dtype))
        return loss/num_types

layer_dict = {
    "GATConv":GATConv,
    "SAGEConv":SAGEConv
}

In [167]:
import copy

def load_data(folder_path,load_test = False):
    if load_test:
        names = ["train","validation","test"]
    else:
        names = ["train","validation"]
    datasets = []
    for name in names:
        path = folder_path+name+".pt"
        datasets.append(torch.load(path))
    
    return datasets

def initialize_features(data,feature,dim,inplace=False):
    if inplace:
        data_object = data
    else:
        data_object = copy.copy(data)
    for nodetype, store in data_object.node_items():
        if feature == "random":
            data_object[nodetype].x = torch.rand(store["num_nodes"],dim)
        if feature == "ones":
            data_object[nodetype].x = torch.ones(store["num_nodes"],dim)
    return data_object

def load_model(state_dict,params,metadata):
    model = base_model(params,metadata,supervision_types=[('gene_protein', 'gda', 'disease')])
    model.load_state_dict(state_dict)
    return model

def load_experiment(eid:int,date:str,metadata:tuple) -> tuple:
    """Returns tuple (model,params).
    date format: d_m_y"""
    df_path = f"{experiments_folder}experiment_{date}.parquet"
    weights_path = f"{experiments_folder}experiment_{eid}_{date}__.pth"

    df = pd.read_parquet(df_path)
    #TODO: this is only temporal, remove after fix
    df["conv_type"] = df.conv_type.apply(lambda x: x.split(".")[-1].rstrip("\'>"))
    df["activation"] = torch.nn.LeakyReLU
    params = df.loc[eid].to_dict()
    weights = torch.load(weights_path,map_location=torch.device(device))

    model = base_explainable_model(params,metadata)
    model.load_state_dict(weights)

    return model,params

def load_node_csv(path, index_col,type_col, **kwargs):
    """Returns node dataframe and a dict of mappings for each node type. 
    Each mapping maps from original df index to "heterodata index" { node_type : { dataframe_index : heterodata_index}}"""
    df = pd.read_csv(path, **kwargs,index_col=index_col)
    node_types = df[type_col].unique()
    mappings_dict = dict()
    for node_type in node_types:
        mapping = {index: i for i, index in enumerate(df[df[type_col] == node_type].index.unique())}
        mappings_dict[node_type] = mapping

    return df,mappings_dict

In [175]:
train_data,val_data = load_data(data_folder+"split_dataset/")

eid = 34
date = "18_04_23"
model,params = load_experiment(eid,date,train_data.metadata())

train_data = initialize_features(train_data,params["feature_type"],params["feature_dim"])
val_data = initialize_features(train_data,params["feature_type"],params["feature_dim"])

In [176]:
with torch.no_grad():
    model.eval()
    aver = model(train_data.x_dict,train_data.edge_index_dict,train_data["gene_protein","disease"].edge_label_index)

In [172]:
aver

tensor([0.3982, 0.7912, 0.8521,  ..., 0.4634, 0.0412, 0.0242])

In [178]:
x = train_data.x_dict
adj_t = train_data.adj_t_dict
edge_index = train_data.edge_index_dict
edge_label_index = train_data["gene_protein","disease"].edge_label_index

In [219]:
allow_unused = True
explainer = Explainer(
    model=model,
    algorithm=CaptumExplainer('IntegratedGradients'),
    explanation_type='model',
    model_config=dict(
        mode='binary_classification',
        task_level='edge',
        return_type='probs',
    ),
    node_mask_type='attributes',
    edge_mask_type='object'
)

In [162]:
train_data["gene_protein","disease"].edge_label_index

tensor([[ 2929,  6037,  4913,  ...,  3391, 15286,  7687],
        [  870,  4171,  4728,  ...,  7081,  7141, 11498]])

In [222]:
prediction = explainer.get_prediction(x,adj_t,edge_label_index)
target = explainer.get_target(prediction)

In [226]:
edge_label_index.shape

torch.Size([2, 26884])

In [230]:
explainer(x,edge_index,edge_label_index=edge_label_index)

TypeError: 'IntegratedGradients' object is not callable

In [160]:
explainer.algorithm.attribution_method

TypeError: 'IntegratedGradients' object is not callable

In [156]:
explainer.algorithm(model,x,edge_index,edge_label_index=edge_label_index,target=target)

TypeError: 'IntegratedGradients' object is not callable

In [138]:
from torch_geometric.explain.algorithm.captum import CaptumHeteroModel
from torch_geometric.nn import to_captum_model

In [139]:
captum_model = to_captum_model(model)

In [142]:
explainer.algorithm.attribution_method

<captum.attr._core.integrated_gradients.IntegratedGradients at 0x7f00e998d5a0>

In [128]:
explainer.algorithm.attribution_method()

TypeError: 'IntegratedGradients' object is not callable

In [113]:
explainer.algorithm.supports()

True

In [144]:
explainer(x,adj_t,edge_index)

TypeError: Explainer.__call__() takes 3 positional arguments but 4 were given

In [94]:
explainer(x,edge_index,edge_label_index=edge_label_index)

AssertionError: Tensor target dimension torch.Size([1, 26884]) is not valid. torch.Size([1, 26884])

In [12]:
model_config = ModelConfig(
    mode='binary_classification',
    task_level='edge',
    return_type='raw',
)

explainer = Explainer(
    model,  # It is assumed that model outputs a single tensor.
    algorithm=CaptumExplainer('IntegratedGradients'),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config = model_config
)

hetero_explanation = explainer(
    train_data.x_dict,
    train_data.edge_index_dict,
    index=torch.tensor([1, 3]),
)
print(hetero_explanation.edge_mask_dict)
print(hetero_explanation.node_mask_dict)


AttributeError: 'dict' object has no attribute 'x_dict'