In [None]:
import os

os.chdir("../")

print(os.getcwd())

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from logging import getLogger

from rdkit import Chem
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem import rdDepictor
from chainer_chemistry.saliency.visualizer.base_visualizer import BaseVisualizer
from chainer_chemistry.saliency.visualizer.visualizer_utils import red_blue_cmap, abs_max_scaler

from src.GraphWisconsin.generate_graph_dataset import graph_dataset
from src.GraphWisconsin.model_GNN import GCNReg

In [None]:


class MolVisualizer(BaseVisualizer):
    """Saliency visualizer for mol data."""

    def __init__(self, logger=None):
        self.logger = logger or getLogger(__name__)

    @staticmethod
    def _is_visible(begin, end, bond_color):
        if bond_color:
            return (begin + end) * 0.5 if begin * end > 0 else 0
        return 0

    def visualize(self, saliency, mol, save_filepath=None,
                  visualize_ratio=1.0, color_fn=red_blue_cmap,
                  scaler=abs_max_scaler, legend='', 
                  bond_color=True, global_scaler_coeff=0):
        """Visualize or save `saliency` with molecule."""
        rdDepictor.Compute2DCoords(mol)
        Chem.SanitizeMol(mol)
        Chem.Kekulize(mol)
        num_atoms = mol.GetNumAtoms()

        saliency = saliency[:num_atoms]

        if global_scaler_coeff > 0:
            saliency = saliency / global_scaler_coeff
            scaler_coeff = global_scaler_coeff
        else:
            scaler_coeff = np.max(np.abs(saliency)) if scaler else 1
            saliency = scaler(saliency) if scaler else saliency

        abs_saliency = np.abs(saliency)
        threshold = np.percentile(abs_saliency, (1 - visualize_ratio) * 100)
        highlight_atoms = [idx for idx, value in enumerate(abs_saliency) if value >= threshold]
        atom_colors = {i: color_fn(e) for i, e in enumerate(saliency)}

        bond_colors = {
            bond.GetIdx(): color_fn(self._is_visible(
                saliency[bond.GetBeginAtomIdx()],
                saliency[bond.GetEndAtomIdx()],
                bond_color))
            for bond in mol.GetBonds()
        }

        drawer = rdMolDraw2D.MolDraw2DSVG(500, 375)
        drawer.DrawMolecule(
            mol,
            highlightAtoms=highlight_atoms,
            highlightAtomColors=atom_colors,
            highlightBonds=list(bond_colors.keys()),
            highlightBondColors=bond_colors,
            legend=legend
        )
        drawer.FinishDrawing()
        svg = drawer.GetDrawingText()

        if save_filepath:
            ext = save_filepath.split('.')[-1]
            if ext == 'svg':
                with open(save_filepath, 'w') as f:
                    f.write(svg)
            elif ext == 'png':
                try:
                    import cairosvg
                    cairosvg.svg2png(bytestring=svg, write_to=save_filepath)
                except ImportError:
                    self.logger.error('Please install cairosvg to save as png.')
        return svg, scaler_coeff


class SmilesVisualizer(MolVisualizer):
    """Saliency visualizer for SMILES data."""

    def visualize(self, saliency, smiles, save_filepath=None,
                  visualize_ratio=1.0, color_fn=red_blue_cmap,
                  scaler=abs_max_scaler, legend='', add_Hs=False,
                  use_canonical_smiles=True, bond_color=True, global_scaler_coeff=0):
        """Visualize or save `saliency` with SMILES."""
        mol = Chem.MolFromSmiles(smiles)
        if use_canonical_smiles:
            mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol, canonical=True))
        if add_Hs:
            mol = Chem.AddHs(mol)
        return super().visualize(
            saliency, mol, save_filepath=save_filepath,
            visualize_ratio=visualize_ratio, color_fn=color_fn, scaler=scaler,
            legend=legend, bond_color=bond_color, global_scaler_coeff=global_scaler_coeff)


In [None]:
class SaliencyVisualizer:
    
    def __init__(self, model_path, model_name, data_path):
        self.model = self._load_model(model_path, model_name)
        self.data_path = data_path

    @staticmethod
    def _load_model(model_path, model_name):
        checkpoint = torch.load(f"{model_path}/{model_name}")
        model = GCNReg(74, 256, 1, True)
        model.load_state_dict(checkpoint['model_state_dict'])
        return model

    def visualize_saliency(self, file_path='./reports/saliency_R2'):
        dat = pd.read_csv(self.data_path, header=None)
        sml_exp = dat[0].to_list()
        logCMC = [0] * len(sml_exp)

        g_exp = graph_dataset(sml_exp, logCMC)

        for test_id in range(len(sml_exp)):
            sml = sml_exp[test_id]
            cmc = logCMC[test_id]
            test_g = g_exp[test_id][0]
            n_feat = test_g.ndata['h'].numpy()
            pred, grad = self.model(test_g)
            pred = pred.cpu().detach().numpy().flatten()[0]
            n_sal = grad.cpu().detach().numpy()
            n_sal_sum_atom = np.sum(n_sal[:, 0:43] * n_feat[:, 0:43], axis=1)

            visualizer = SmilesVisualizer()
            scaler = abs_max_scaler
            svg, _ = visualizer.visualize(
                n_sal_sum_atom, sml,
                save_filepath=f"{file_path}/actualdata_row_{str(test_id + 1).zfill(3)}.png",
                visualize_ratio=1, bond_color=False, scaler=scaler,
                legend=f"{sml}, pred:{pred:.2f}"
            )
            # Uncomment the next line if you want to display the SVG inline (e.g., in Jupyter)
            #display(SVG(svg.replace('svg:', '')))
            #save_filepath=f"{file_path}/actualdata_row_{str(test_id + 1).zfill(3)}.png"


if __name__ == "__main__":
    model_path = ".//models"
    model_name = "GCN_early_stop//ep1000bs5lr0.005kf11hu256cvid5es.pth.tar"
    data_path = ".//data//test_more_dataNoinionc.csv                                                                         "
    
    visualizer = SaliencyVisualizer(model_path, model_name, data_path)
    visualizer.visualize_saliency(file_path='.')


In [None]:
import matplotlib.pyplot as plt

def plot_colorbar(cmap_name='bwr'):
    fig, ax = plt.subplots(figsize=(6, 1))
    fig.subplots_adjust(bottom=0.5)

    # Create a colormap
    cmap = plt.get_cmap(cmap_name)

    # Create a colorbar with custom settings
    norm = plt.Normalize(-1, 1)  # Assuming your saliency values range from -1 to 1
    cbar = fig.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap),
                        cax=ax, orientation='horizontal')
    cbar.set_label('Saliency Value')
    plt.show()

plot_colorbar('bwr')


In [None]:
print(os.getcwd())