In [11]:
import numpy as np
import logging
import os 
from tqdm import tqdm
from torch import optim
import torch.nn as nn
import pandas as pd
import torch_geometric.transforms as T
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GNNExplainer
from torch_geometric.loader import DataLoader
from utils.model_explainable import explainer,Config
from utils.graph_dataset import SMILESDataset, smiles_string, smiles_to_one_hot
from sklearn.metrics import matthews_corrcoef, f1_score, cohen_kappa_score, accuracy_score, auc, roc_auc_score, average_precision_score, precision_score
from utils.edgeshaper import edgeshaper
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx
from rdkit_heatmaps import mapvalues2mol
from rdkit_heatmaps.utils import transform2png
from rdkit import Chem
from rdkit.Chem import Draw

In [2]:
config = Config.from_dict({
    "transformer_layer":2,
    "num_attention_heads":8,
    "hidden_size":256,
    "cnn_input_dim": 256,
    "embedding_dim": 65,
    "cnn_dropout": 0.4,
    "cnn_output_dim": 128,
    "gnn_input_dim": 33,
    "gnn_head": 8,
    "gnn_hidden_dim":64,
    "gnn_output_dim":128,
    "pool_input_dim":128,
    "gnn_dropout": 0.4,
    "num_class": 1,
    })

In [3]:
model = explainer(config)

In [4]:
model2 = torch.load('/home/zengxin/fpk/pycharm_project/GNN-DDAS/save_model/DL/one_label_gat/best_one_label_gat_merge_lr0.001_wdeacy0.01.pt')

In [5]:
new_model_state_dict = model.state_dict()

In [None]:
for key in new_model_state_dict.keys():
    if key in model2.keys():
        try:
            new_model_state_dict[key].copy_(model2[key])
        except:
            pass
new_model_state_dict['fc_layer.0.weight'].copy_(model2['fc_layer.1.weight'])
new_model_state_dict['fc_layer.0.bias'].copy_(model2['fc_layer.1.bias'])
new_model_state_dict['fc_layer.1.weight'].copy_(model2['fc_layer.2.weight'])
new_model_state_dict['fc_layer.1.bias'].copy_(model2['fc_layer.2.bias'])
new_model_state_dict['fc_layer.3.weight'].copy_(model2['fc_layer.4.weight'])
new_model_state_dict['fc_layer.3.bias'].copy_(model2['fc_layer.4.bias'])

In [7]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(42)


In [8]:
# 数据文件路径
test_root = '/home/zengxin/fpk/pycharm_project/GNN-DDAS/data/merge/merge_data/test_data'
test_set = SMILESDataset(root=test_root,raw_dataset='test_data.csv',processed_data='test.pt')

In [12]:
def visualize_explanations(test_cpd, phi_edges, SAVE_PATH=None, data=None):
    
    edge_index = test_cpd.edge_index.to("cpu")

    test_mol = Chem.MolFromSmiles(test_cpd.smiles)
    test_mol = Draw.PrepareMolForDrawing(test_mol)

    num_bonds = len(test_mol.GetBonds())

    rdkit_bonds = {}

    for i in range(num_bonds):
        init_atom = test_mol.GetBondWithIdx(i).GetBeginAtomIdx()
        end_atom = test_mol.GetBondWithIdx(i).GetEndAtomIdx()
        
        rdkit_bonds[(init_atom, end_atom)] = i

    rdkit_bonds_phi = [0]*num_bonds
    for i in range(len(phi_edges)):
        phi_value = phi_edges[i]
        init_atom = edge_index[0][i].item()
        end_atom = edge_index[1][i].item()
        
        if (init_atom, end_atom) in rdkit_bonds:
            bond_index = rdkit_bonds[(init_atom, end_atom)]
            rdkit_bonds_phi[bond_index] += phi_value
        if (end_atom, init_atom) in rdkit_bonds:
            bond_index = rdkit_bonds[(end_atom, init_atom)]
            rdkit_bonds_phi[bond_index] += phi_value

    plt.clf()
    canvas = mapvalues2mol(test_mol, None, rdkit_bonds_phi, atom_width=0.2, bond_length=0.5, bond_width=0.5) #TBD: only one direction for edges? bonds weights is wrt rdkit bonds order?
    img = transform2png(canvas.GetDrawingText())

    
    if SAVE_PATH is not None:

        img.save(SAVE_PATH + "/" + data+"_gnnexplanier_edgeshaper.png", dpi = (1200,1200))
    
    return img

In [27]:
explainer = GNNExplainer(model, epochs=400,return_type='log_prob')

In [None]:
for idx,data in enumerate(test_set):
    if data.y == 1:
        node_feat_mask, edge_mask = explainer.explain_graph(data.x, data.edge_index)
        ax, G = explainer.visualize_subgraph(-1,data.edge_index.to('cpu'), edge_mask.to('cpu'), data.y.to('cpu'),node_size=100)
        plt.savefig(f'/home/zengxin/fpk/pycharm_project/GNN-DDAS/data/explanier/smiles_{idx}.png', dpi=1200, bbox_inches='tight')
        edge_mask = np.array(edge_mask)
        visualize_explanations(data,phi_edges=edge_mask,SAVE_PATH='/home/zengxin/fpk/pycharm_project/GNN-DDAS/data/explanier',data=f'{idx}')