# Create Guided Back Propogation Saliency Maps.

In [None]:
import os
import sys
import mne
import pywt
import copy
import numpy as np
import pandas as pd
from PIL import Image, ImageFilter
import matplotlib.cm as mpl_color_map
from torch import nn
from scipy import signal
import copy
import torch.nn.functional as F
import torch
from torch.autograd import Variable
from torchvision import models, transforms

In [None]:
# EEG Channel Names
chan_names = ['Fp1',
 'Fp2',
 'F7',
 'F3',
 'Fz',
 'F4',
 'F8',
 'FC5',
 'FC1',
 'FC2',
 'FC6',
 'T7',
 'C3',
 'Cz',
 'C4',
 'T8',
 'CP5',
 'CP1',
 'CP2',
 'CP6',
 'AFz',
 'P7',
 'P3',
 'Pz',
 'P4',
 'P8',
 'PO9',
 'O1',
 'Oz',
 'O2',
 'PO10',
 'AF7',
 'AF3',
 'AF4',
 'AF8',
 'F5',
 'F1',
 'F2',
 'F6',
 'FT7',
 'FC3',
 'FC4',
 'FT8',
 'C5',
 'C1',
 'C2',
 'C6',
 'TP7',
 'CP3',
 'CPz',
 'CP4',
 'TP8',
 'P5',
 'P1',
 'P2',
 'P6',
 'PO7',
 'PO3',
 'POz',
 'PO4',
 'PO8']

# 1.) Load Model

In [None]:
class WLCNN(nn.Module):
    """
    conv2d_by_Leads seperates the batchxleadx32x250 batches into batch*leadx32x250 batches
    that can be fed into subConv2d which generates an output for each lead in the batch. the
    batch is then cast back into shape batchxleadx1x1 where the output for each lead can be 
    fed into the linear layer for making inference without combining lead data.
    """
    def __init__(self):
        super(WLCNN,self).__init__()

        self.conv1a = nn.Conv2d(61,122, kernel_size=(4,4), stride=(1,3)) 
        self.a1 = nn.ReLU()
        self.conv1a_bn = nn.BatchNorm2d(122)


        self.conv1b = nn.Conv2d(122,122, kernel_size=(4,2), stride=(1,3))
        self.a2 = nn.ReLU()
        self.conv1b_bn = nn.BatchNorm2d(122)
        
        self.maxPool1 = nn.MaxPool2d((2,2))

        self.conv2a = nn.Conv2d(122,244, kernel_size=(3,4), stride=(2,2))
        self.a3 = nn.ReLU()
        self.conv2a_bn = nn.BatchNorm2d(244)

        self.conv2b = nn.Conv2d(244,244, kernel_size=(4,4), stride=(2,2))
        self.a4 = nn.ReLU()
        self.conv2b_bn = nn.BatchNorm2d(244)

        self.maxPool2=nn.MaxPool2d((2,2))
        
        """
        self.conv3a = nn.Conv2d(244,488, kernel_size=(3,2), stride=(1,1))
        self.conv3a_bn = nn.BatchNorm2d(488)
        self.a5 = nn.ReLU()

        self.conv3b = nn.Conv2d(488,488, kernel_size=(2,2), stride=(1,1))
        self.a6 = nn.ReLU()
        self.conv3b_bn = nn.BatchNorm2d(488)

        self.maxPool3=nn.MaxPool2d((2,2))
        """
        
        self.fc1 = nn.Linear(244,3) 



    def forward(self,x):
 
        #convolve over channels only
        x = self.conv1a(x)
        x = self.a1(x)
        x = self.conv1a_bn(x)

      
        x = self.conv1b(x)
        x = self.a2(x)
        x = self.conv1b_bn(x)
        
        x = self.maxPool1(x)

  
        x = self.conv2a(x)
        x = self.a3(x)
        x = self.conv2a_bn(x)


        x = self.conv2b(x)
        x = self.a4(x)
        x = self.conv2b_bn(x)

        x = self.maxPool2(x)
        """
        x = self.conv3a(x)
        x = self.a5(x)
        x = self.conv3a_bn(x)
        
        

        x = self.conv3b(x)
        x = self.a6(x)
        x = self.conv3b_bn(x)

        x = self.maxPool3(x)
        """
        
        x = x.view(-1,244)

        x = self.fc1(x) #a (1x1) output for each lead
       # print(x.shape)

        
        return x

#set to train w/ GPU if available else cpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model = WLCNN().to(device)

In [None]:
class W3DCNN(nn.Module):
    """
    conv2d_by_Leads seperates the batchxleadx32x250 batches into batch*leadx32x250 batches
    that can be fed into subConv2d which generates an output for each lead in the batch. the
    batch is then cast back into shape batchxleadx1x1 where the output for each lead can be 
    fed into the linear layer for making inference without combining lead data.
    """
    def __init__(self):
        super(W3DCNN,self).__init__()

        self.conv1a = nn.Conv3d(1,4, kernel_size=(3,4,4), stride=(2,1,3)) 
        self.a1 = nn.ReLU()
        self.conv1a_bn = nn.BatchNorm3d(4)


        self.conv1b = nn.Conv3d(4,8, kernel_size=(4,4,2), stride=(2,1,3))
        self.a2 = nn.ReLU()
        self.conv1b_bn = nn.BatchNorm3d(8)
        
        self.maxPool1 = nn.MaxPool3d((2,2,2))

        self.conv2a = nn.Conv3d(8,16, kernel_size=(3,3,4), stride=(2,2,2))
        self.a3 = nn.ReLU()
        self.conv2a_bn = nn.BatchNorm3d(16)

        self.conv2b = nn.Conv3d(16,32, kernel_size=(2,4,4), stride=(1,2,2))
        self.a4 = nn.ReLU()
        self.conv2b_bn = nn.BatchNorm3d(32)

        self.maxPool2=nn.MaxPool3d((2,2,2))
        
        """
        self.conv3a = nn.Conv2d(244,488, kernel_size=(3,2), stride=(1,1))
        self.conv3a_bn = nn.BatchNorm2d(488)
        self.a5 = nn.ReLU()

        self.conv3b = nn.Conv2d(488,488, kernel_size=(2,2), stride=(1,1))
        self.a6 = nn.ReLU()
        self.conv3b_bn = nn.BatchNorm2d(488)

        self.maxPool3=nn.MaxPool2d((2,2))
        """
        
        self.fc1 = nn.Linear(32,3) 



    def forward(self,x):

        bs = len(x[:,0,0,0])
        chans = len(x[0,:,0,0])
        x = x.view(bs,1,chans,32,250)

        x = self.conv1a(x)
        x = self.a1(x)
        x = self.conv1a_bn(x)

      
        x = self.conv1b(x)
        x = self.a2(x)
        x = self.conv1b_bn(x)
        
        x = self.maxPool1(x)

        x = self.conv2a(x)
        x = self.a3(x)
        x = self.conv2a_bn(x)


        x = self.conv2b(x)
        x = self.a4(x)
        x = self.conv2b_bn(x)

        x = self.maxPool2(x)
        """
        x = self.conv3a(x)
        x = self.a5(x)
        x = self.conv3a_bn(x)
        
        

        x = self.conv3b(x)
        x = self.a6(x)
        x = self.conv3b_bn(x)

        x = self.maxPool3(x)
        """
        
        x = x.view(-1,32)

        x = self.fc1(x) #a (1x1) output for each lead
       # print(x.shape)

        
        return x
    
#set to train w/ GPU if available else cpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model = W3DCNN().to(device)

# 2.) Pre-Processing

### Optional: Run GBP on all mild moderate and severe examples
gather all milds, moderate, and severe examples in respective file and run the following to create example lists of files.

In [None]:
ls1= []
for fi in os.listdir('mild_examp'):
    tmp = ('mild_examp/'+fi, 0)
    ls1.append(tmp)
"""
ls1= []
for fi in os.listdir('moderate_examp'):
    tmp = ('moderate_examp/'+fi, 1)
    ls1.append(tmp)
ls1= []
for fi in os.listdir('severe_examp'):
    tmp = ('severe_examp/'+fi, 2)
    ls1.append(tmp)
"""

In [None]:
def get_example_params(example_index, model, path_wts=None, examples=None):
    """
        Gets used variables for almost all visualizations, like the image, model etc.
    Args:
        example_index (int): When examples = None, image id to use from examples
        model: model: 1: WLCNN or 2: W3DCNN
        path_wts: Path to trained model weights
        examples: one of the example lists above (ls1)
    returns:
        original_image (numpy arr): Original image read from the file
        prep_img (numpy_arr): Processed image
        target_class (int): Target class for the image
        file_name_to_export (string): File name to export the visualizations
        pretrained_model(Pytorch model): Model to use for the operations
    """
    
    # Pick one of the examples
    if examples == None:
        example_list = (
                        ('input_images/sub-032320_EC_mild_1_raw.fif', 0),
                        ('input_images/sub-032464_EC_moderate_75_raw.fif', 1),
                        ('input_images/sub-032455_EC_severe_101_raw.fif', 2)
                    )
    else:   
        example_list = examples

    img_path = example_list[example_index][0]
    target_class = example_list[example_index][1]
    file_name_to_export = img_path[img_path.rfind('/')+1:img_path.rfind('.')]
    
    # Process image
    prep_img = preprocess_image(img_path)
    original_image = prep_img
    
    # Define model
    if model == 1:
        pretrained_model = WLCNN()
    elif model == 2:
        pretrained_model = W3DCNN()
    print("loading from: "+path_wts)

    checkpoint = torch.load(path_wts, map_location=torch.device('cpu'))
    pretrained_model.load_state_dict(checkpoint['model_state_dict'])

    #print("acc from prev:{:.4f}".format(checkpoint['epoch_acc'])) 

    return (original_image,
            prep_img,
            target_class,
            file_name_to_export,
            pretrained_model)

In [None]:
def convert_to_grayscale(im_as_arr):
    """
        Converts 3d image to grayscale
    Args:
        im_as_arr (numpy arr): RGB image with shape (D,W,H)
    returns:
        grayscale_im (numpy_arr): Grayscale image with shape (1,W,D)
    """
    grayscale_im = np.sum(np.abs(im_as_arr), axis=0)
    im_max = np.percentile(grayscale_im, 99)
    im_min = np.min(grayscale_im)
    grayscale_im = (np.clip((grayscale_im - im_min) / (im_max - im_min), 0, 1))
    grayscale_im = np.expand_dims(grayscale_im, axis=0)
    return grayscale_im

def save_gradient_images(gradient, file_name):
    """
        Exports the original gradient image
    Args:
        gradient (np arr): Numpy array of the gradient with shape (3, 224, 224)
        file_name (str): File name to be exported
    """
    #print(os.path.exists('./results'))
    if not os.path.exists('./results'):
        os.makedirs('./results')
        print('dir made')
    # Normalize
    gradient = gradient - gradient.min()
    gradient /= gradient.max()
    # Save image
    path_to_file = os.path.join('./results', file_name + '.jpg')
    save_image(gradient, path_to_file)

In [None]:
class HiddenPrints:
    """
    Helper class to suppress mne print statements
    """
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdou

In [None]:
def Wavelet(path:str):
    """
    create data cubes using data from .csv file at path. of the form <
    (Lead)x(frequencies)x(time) Power wavelete data types.

    Args:
        path to .csv file (should be passed by TesseractData and dataset)
        ex: ~/SMNI_TRAINTEST_DATA/train/alcoholic/a_S3_377_069.csv

    Returns: tess 3D numpy array containing time frequency transform data
             of shape (Lead)x(frequency)x(time).
    """
    channels = 61
    scale = 32 #scale param for morle wavelet
    length = 250

    
    raw = mne.io.read_raw_fif(path)
    raw = raw.get_data(picks=raw.ch_names, start=0)

    data = pd.DataFrame(data=raw)

    y_voltage = np.empty([channels, length])
    for i in range(channels):
        y_voltage[i] = data.iloc[i]


    # Generate a 3d array of channelXfreqXtime
    waves_mag = np.empty([61, scale, 250 ], dtype=float) 
    # store the abs(coef) into waves_mag instead of computing waves_cmp
    # then taking the absolute value of waves_cmp. I hope this speeds things up

    # compute and store complex morlet transform for each lead
    # THIS VERSION PUTS Lead First
    for i in range(channels):
        coef, freqs=pywt.cwt(y_voltage[i],np.arange(1,scale+1),'cmor0.4-1.0',sampling_period=1)
        waves_mag[i,:,:] = copy.deepcopy(abs(coef))

    return waves_mag

In [None]:
def preprocess_image(path, resize_im=True):
    """
        Processes image for CNNs
    Args:
        PIL_img (PIL_img): PIL Image or numpy array to process
        resize_im (bool): Resize to 224 or not
    returns:
        im_as_var (torch variable): Variable that contains processed float tensor
    """
    
    mean = [
    6.057787360077358e-06,
    6.087420497235737e-06,
    7.3990380621281874e-06,
    4.26021991865975e-06,
    2.539638463888475e-06,
    4.114347200376418e-06,
    7.82875798194875e-06,
    6.485739319775771e-06,
    1.9192647558707106e-06,
    2.473070259677167e-06,
    7.191401187038389e-06,
    9.746760885852078e-06,
    6.5584541997872675e-06,
    3.800515075142793e-06,
    7.667242056921622e-06,
    1.098459778463806e-05,
    1.0834701802504087e-05,
    8.28091010997476e-06,
    9.337240549820418e-06,
    1.2834776165816766e-05,
    4.27243188727822e-06,
    1.5791480155150454e-05,
    1.4001914284608656e-05,
    1.3751925491812879e-05,
    1.5291240753861295e-05,
    1.947148781941967e-05,
    1.7938964069618997e-05,
    1.987472439042673e-05,
    1.845479014263127e-05,
    1.968688715451437e-05,
    1.8165726731432112e-05,
    6.777176750486097e-06,
    4.754154756755766e-06,
    4.764475652547024e-06,
    6.87751954700681e-06,
    5.729166948061963e-06,
    2.850098232526941e-06,
    3.228867711013545e-06,
    5.226466231830351e-06,
    8.474320071937972e-06,
    4.148349543128007e-06,
    4.737125846372745e-06,
    8.874316209515818e-06,
    8.175465098051129e-06,
    4.443676008117992e-06,
    5.4173081548627475e-06,
    9.241705597005122e-06,
    1.237329241107373e-05,
    9.616053051768428e-06,
    8.303972253779376e-06,
    1.1170897624137052e-05,
    1.5039500493486429e-05,
    1.4841113242660599e-05,
    1.3065708563015742e-05,
    1.3679599337093769e-05,
    1.7859899457787005e-05,
    1.9265918713678717e-05,
    1.8492493007637822e-05,
    1.7724493279678855e-05,
    1.9086136839310716e-05,
    2.2065214391782833e-05      
    ]

    dev = [
    6.226787724404643e-06,
    6.334461101252833e-06,
    7.532137981857288e-06,
    4.239707086526406e-06,
    2.6669268673174235e-06,
    4.085261767155552e-06,
    8.121753271608523e-06,
    6.6090573909851414e-06,
    1.8582565106933258e-06,
    2.516306726592908e-06,
    7.5948400778550435e-06,
    1.049799509243296e-05,
    7.064685104230226e-06,
    4.4963659700270465e-06,
    8.31816463510664e-06,
    1.2149702294967242e-05,
    1.2092646101663928e-05,
    9.530515704417179e-06,
    1.1090785547727506e-05,
    1.4668592146355854e-05,
    4.48958615696688e-06,
    1.9380885397891452e-05,
    1.676695544359015e-05,
    1.707348620669288e-05,
    1.8257697630001285e-05,
    2.3635050137116193e-05,
    2.212982535067034e-05,
    2.4296133802117022e-05,
    2.227763445353628e-05,
    2.3789918993075088e-05,
    2.2138089933686428e-05,
    6.8588371249472165e-06,
    4.9794779785127445e-06,
    4.944180658889942e-06,
    7.057348321808729e-06,
    5.829374656656345e-06,
    2.916004499001558e-06,
    3.5322337253732193e-06,
    5.240591107098919e-06,
    8.796968997040724e-06,
    4.125534189419075e-06,
    4.855986216498394e-06,
    9.435575567982905e-06,
    8.593776081275078e-06,
    5.074781659449896e-06,
    6.242808802312015e-06,
    9.991918583732963e-06,
    1.409254057377304e-05,
    1.0894342352575892e-05,
    9.875674795134744e-06,
    1.2910368478758826e-05,
    1.781295720914918e-05,
    1.7783326314145297e-05,
    1.5700134151438275e-05,
    1.6390968593499346e-05,
    2.1365723152930735e-05,
    2.4059720280776322e-05,
    2.262522612710923e-05,
    2.1502597301469412e-05,
    2.3205764264590672e-05,
    2.7072069650931403e-05     
    ]
    
    data_transforms = transforms.Compose([
            transforms.Normalize(mean, dev)
            ])
    ten = torch.FloatTensor(Wavelet(path))
    ten = data_transforms(ten)
    ten.requires_grad=True
    return ten.unsqueeze(0)

# 3.) Guided Back Prop Class

In [None]:
"""
Created on Thu Oct 26 11:23:47 2017
@author: Utku Ozbulak - github.com/utkuozbulak
"""
import torch
from torch.nn import ReLU

class GuidedBackprop():
    """
       Produces gradients generated with guided back propagation from the given image
    """
    def __init__(self, model):
        self.model = model
        self.gradients = None
        self.forward_relu_outputs = []
        # Put model in evaluation mode
        self.model.eval()
        self.update_relus()
        self.hook_layers()

    def hook_layers(self):
        def hook_function(module, grad_in, grad_out):
            self.gradients = grad_in[0]
            #print(grad_in[0])
        # Register hook to the first layer
        #print(list(self.model._modules.items())[0][1])
        first_layer = list(self.model._modules.items())[0][1]
        first_layer.register_backward_hook(hook_function)

    def update_relus(self):
        """
            Updates relu activation functions so that
                1- stores output in forward pass
                2- imputes zero for gradient values that are less than zero
        """
        def relu_backward_hook_function(module, grad_in, grad_out):
            """
            If there is a negative gradient, change it to zero
            """
            # Get last forward output
            corresponding_forward_output = self.forward_relu_outputs[-1]
            corresponding_forward_output[corresponding_forward_output > 0] = 1
            modified_grad_out = corresponding_forward_output * torch.clamp(grad_in[0], min=0.0)
            del self.forward_relu_outputs[-1]  # Remove last forward output
            #print(module, grad_in[0].shape)
            return (modified_grad_out,)

        def relu_forward_hook_function(module, ten_in, ten_out):
            """
            Store results of forward pass
            """
            self.forward_relu_outputs.append(ten_out)

        # Loop through layers, hook up ReLUs
        for pos, module in self.model._modules.items():
            if isinstance(module, ReLU):
                module.register_backward_hook(relu_backward_hook_function)
                module.register_forward_hook(relu_forward_hook_function)

    def generate_gradients(self, input_image, target_class):
        #import pdb; pdb.set_trace()
        # Forward pass
        model_output = self.model(input_image)
        # Zero gradients
        self.model.zero_grad()
        # Target for backprop
        one_hot_output = torch.FloatTensor(1, model_output.size()[-1]).zero_()
        one_hot_output[0][target_class] = 1
        # Backward pass
        model_output.backward(gradient=one_hot_output)
        # Convert Pytorch variable to numpy array
        gradients_as_arr = self.gradients.data.numpy()[0]
        return gradients_as_arr

# 4.) Run Guided Back Prop

In [None]:
fs = 250 
scale = 32
file_name_to_export = "guided_back_prop"

#path_wts = ""

#create frequency scales
frequencies = pywt.scale2frequency('cmor0.4-1.0', np.arange(1,scale+1)) / (1.0/fs)
frequencies = frequencies.astype(int)

### Run GBP on one example

In [None]:
########## Get Params ############
file_name_to_export = "guided_back_prop"
target_example = 0  # alcoholic
(original_image, prep_img, target_class, file_name_to_export, pretrained_model) = get_example_params(example_index=0, model=1, path_wts=path_wts, examples=None)

########## Guided backprop ############
GBP = GuidedBackprop(pretrained_model)
# Get gradients
guided_grads = GBP.generate_gradients(prep_img, target_class)

print('Guided backprop completed')

### Run GBP on example list

In [None]:
x1 = np.zeros([61,32,250])
x2 = np.zeros([1,61,32,250])
examples = mild_examples
path_wts = path_wts3
print("Guided Back Prop...")
print(str(len(examples)) + " examples...")
for i in range(len(examples)):
    target_example = i
    original_image, prep_img, target_class, file_name_to_export, pretrained_model = get_example_params(target_example, model = 2, path_wts = path_wts,examples=examples)

    ########## Guided backprop ############
    GBP = GuidedBackprop(pretrained_model)
    # Get gradients
    guided_grads = GBP.generate_gradients(prep_img, target_class)
    x1 += guided_grads[0]
    x2 += prep_img.detach().numpy()
x1 = x1/len(examples)
x2 = x2/len(examples)
print('Guided backprop completed')

### Create Mean activations:
create mean activations one at a time

In [None]:
x1 = x1.mean(2)
mi_act = np.sqrt(x1**2).T

#uncomment for 3d 
x2 = x2[0]
mi_img = x2[0].mean(0)
print(mi_img.shape)

"""
x1 = x1.mean(2)
mo_act = np.sqrt(x1**2).T

#uncomment for 3d 
x2 = x2[0]
mo_img = x2[0].mean(0)
print(mo_img.shape)

x1 = x1.mean(2)
se_act = np.sqrt(x1**2).T

#uncomment for 3d 
x2 = x2[0]
se_img = x2[0].mean(0)
print(se_img.shape)
"""

### Plot Scalogram

In [None]:
#milx
import matplotlib.pyplot as plt
%matplotlib inline
#guided_grads = guided_grads.reshape(32,250)
plt.figure(figsize=(15,15))
#plt.title('GBP Saliency Map')
#plt.ylabel('Frequency (Hz)')
#plt.xlabel('Time (sec.)')
plt.xticks([])
plt.xticks([])

#x = np.sqrt(guided_grads[0]**2)
x = prep_img[0][60].detach().numpy()

plt.yticks([])
plt.imshow(x,extent=[0, 61, 0, 66],aspect='auto',
            vmax=x.max(), vmin=x.min(), cmap = 'inferno')

### Plot Saliency Map

In [None]:
#milx
import matplotlib.pyplot as plt
%matplotlib inline
#guided_grads = guided_grads.reshape(32,250)
plt.figure(figsize=(15,15))
plt.title('GBP Saliency Map')
plt.ylabel('Frequency (Hz)')
plt.xlabel('Time (sec.)')
plt.xticks(np.arange(0,25.1,1),labels=np.round(np.arange(0,1.01,1./25),3),rotation=60, fontsize=18)
plt.xticks(np.arange(1,62,1),labels=chan_names,rotation=60)

#x = np.sqrt(guided_grads[0]**2)
x = np.sqrt(guided_grads[0]**2)

plt.yticks(np.arange(0,31.1,1),labels=np.flip(frequencies),rotation=45)
plt.imshow(x,extent=[0, 61, 0, 31],aspect='auto',
            vmax=x.max(), vmin=x.min(), cmap = 'plasma')

# 5.) Create Saliency and Two Channel Plot for Ea. example

In [None]:
plt_fr = np.array([frequencies[i] for i in range(0,len(frequencies),2)]) #Create new channels for label

In [None]:
act_max = max(se_act.max(), mo_act.max(), mi_act.max())
import matplotlib.pyplot as plt
%matplotlib inline
#create Big subplot
fig = plt.figure(figsize=(24,10))  
ax = fig.add_subplot(1,1,1)    # The big subplot
plt.rcParams.update({'font.size': 24})
# Set common labels  
ax.set_xlabel('Time (Sec.)',fontsize=24)
ax.set_ylabel('Frequency (Hz)',fontsize=24)

#Plot Saliency Map   
ax1 = fig.add_subplot(1,3,1)
ax1.set_title('Mild',fontsize=24)

ax1.set_xticks(np.arange(1,62,4), minor = False)
ax1.set_xticklabels(names, rotation = 70)
ax1.set_yticks(np.arange(0,61,4), minor = False)
ax1.set_yticklabels(np.flip(plt_fr))
ax1.imshow(mi_act,extent=[0, 61, 0, 61], aspect='equal',cmap = 'plasma',vmax=mi_act.max(), vmin=0)

#Plot Saliency Map   
ax2 = fig.add_subplot(1,3,2)
ax2.set_title('Moderate',fontsize=24)

ax2.set_xticks(np.arange(1,62,4), minor = False)
ax2.set_xticklabels(names, rotation = 70)
ax2.set_yticks(np.arange(0,31.1,1), minor = False)
ax2.set_yticklabels([])
ax2.imshow(mo_act,extent=[0, 61, 0, 61], aspect='equal',cmap = 'plasma',vmax=mo_act.max(), vmin=0)

#Plot Saliency Map   
ax3 = fig.add_subplot(1,3,3)
ax3.set_title('severe',fontsize=24)

ax3.set_xticks(np.arange(1,62,4), minor = False)
ax3.set_xticklabels(names, rotation = 70)
ax3.set_yticks(np.arange(0,31.1,1), minor = False)
ax3.set_yticklabels([])
ax3.imshow(se_act,extent=[0, 61, 0, 61], aspect='equal',cmap = 'plasma',vmax=se_act.max(), vmin=0)

fig.subplots_adjust(wspace=0.06, hspace=0.02)


# Turn off axis lines and ticks of the big subplot
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['left'].set_color('none')
ax.spines['right'].set_color('none')
ax.tick_params(labelcolor='w', top=False, bottom=False, left=False, right=False)



In [None]:
import matplotlib.gridspec as gridspec
#create Big subplot
fig = plt.figure(figsize=(16,10), constrained_layout=False)  

fig.add_gridspec(2, 4, wspace=0.0, hspace=0.001)
#gs = gridspec.GridSpec(2, 4)
#gs.update(wspace=0.5)
ax1 = plt.subplot(gs[0, :2], )


ax1.set_title('Mild',fontsize=24)
ax1.set_xticks(np.arange(1,62,4), minor = False)
ax1.set_xticklabels(names, rotation = 70)
ax1.set_yticks(np.arange(0,61,4), minor = False)
ax1.set_yticklabels(np.flip(plt_fr))
ax1.imshow(mi_act,extent=[0, 61, 0, 61], aspect='equal',cmap = 'plasma',vmax=mi_act.max(), vmin=0)

ax2 = plt.subplot(gs[0, 2:])
ax2.set_title('Moderate',fontsize=24)
ax2.set_xticks(np.arange(1,62,4), minor = False)
ax2.set_xticklabels(names, rotation = 70)
ax2.set_yticks(np.arange(0,31.1,1), minor = False)
ax2.set_yticklabels([])
ax2.imshow(mo_act,extent=[0, 61, 0, 61], aspect='equal',cmap = 'plasma',vmax=mo_act.max(), vmin=0)

ax3 = plt.subplot(gs[1, 1:3])
ax3.set_title('severe',fontsize=24)
ax3.set_xticks(np.arange(1,62,4), minor = False)
ax3.set_xticklabels(names, rotation = 70)
ax3.set_yticks(np.arange(0,31.1,1), minor = False)
ax3.set_yticklabels([])
ax3.imshow(se_act,extent=[0, 61, 0, 61], aspect='equal',cmap = 'plasma',vmax=se_act.max(), vmin=0)

plt.show()

In [None]:
import matplotlib.gridspec as gridspec
#create Big subplot
fig = plt.figure(figsize=(20,23), constrained_layout=False)  

#fig.add_gridspec(2, 4, wspace=0.0, hspace=0.001)
gs = gridspec.GridSpec(2, 4,wspace=0.15, hspace=0.15)
#gs.update(wspace=0.5)
ax1 = fig.add_subplot(gs[0, :2], )


ax1.set_title('Mild',fontsize=40)
ax1.set_xticks(np.arange(1,62,4), minor = False)
ax1.set_xticklabels(names, rotation = 70,fontsize=30)
ax1.set_yticks(np.arange(0,61,4), minor = False)
ax1.set_yticklabels(np.flip(plt_fr),fontsize=30)
ax1.imshow(mi_act,extent=[0, 61, 0, 61], aspect='equal',cmap = 'plasma',vmax=mi_act.max(), vmin=0)

ax2 = fig.add_subplot(gs[0, 2:])
ax2.set_title('Moderate',fontsize=40)
ax2.set_xticks(np.arange(1,62,4), minor = False)
ax2.set_xticklabels(names, rotation = 70,fontsize=30)
ax2.set_yticks(np.arange(0,61,4), minor = False)
ax2.set_yticklabels([])
#ax2.set_yticklabels(np.flip(plt_fr),fontsize=30)
ax2.imshow(mo_act,extent=[0, 61, 0, 61], aspect='equal',cmap = 'plasma',vmax=mo_act.max(), vmin=0)

ax3 = fig.add_subplot(gs[1, 1:3])
ax3.set_title('Severe',fontsize=40)
ax3.set_xticks(np.arange(1,62,4), minor = False)
ax3.set_xticklabels(names, rotation = 70,fontsize=30)
ax3.set_yticks(np.arange(0,61,4), minor = False)
ax3.set_yticklabels(np.flip(plt_fr),fontsize=30)
ax3.imshow(se_act,extent=[0, 61, 0, 61], aspect='equal',cmap = 'plasma',vmax=se_act.max(), vmin=0)

#fig.subplots_adjust(wspace=-2, hspace=2)
plt.show()

# 5.) Create Saliency and One Channel Plot for Ea. example

In [None]:
plt_fr = np.array([frequencies[i] for i in range(0,len(frequencies),6)]) #Create new channels for label

In [None]:
#create Big subplot
fig = plt.figure(figsize=(17,17))   
ax = fig.add_subplot(1,1,1)    # The big subplot

# Set common labels  
ax.set_xlabel('Time (Sec.)',fontsize=26)
ax.set_ylabel('Frequency (Hz)',fontsize=26)

#Compute GBP Map
file_name_to_export = "guided_back_prop"
target_example = 1 # 
(original_image, prep_img, target_class, file_name_to_export, pretrained_model) = get_example_params(example_index=target_example,path_wts=path_wts1,model=0)
GBP = GuidedBackprop(pretrained_model)
guided_grads = GBP.generate_gradients(prep_img, target_class)
guided_grads = guided_grads.reshape(32,250)
x = guided_grads

#Compute GBP Map
file_name_to_export = "guided_back_prop"
target_example = 1  # 
(original_image, prep_img, target_class, file_name_to_export, pretrained_model) = get_example_params(example_index=target_example,path_wts=path_wts2, model=1)
GBP = GuidedBackprop(pretrained_model)
guided_grads = GBP.generate_gradients(prep_img, target_class)
#guided_grads = guided_grads.reshape(32,250)
#print(guided_grads.shape)
x = guided_grads[16]

#Plot Saliency Map     
ax4 = fig.add_subplot(2,2,1)
ax4.set_title('WLCNN',fontsize=24)
ax4.set_xticks(np.arange(0,61.1,12.2), minor = False)
ax4.set_xticklabels([])
ax4.set_yticks(np.arange(0,61,4), minor = False)
ax4.set_yticklabels(np.flip(plt_fr))
ax4.imshow(np.sqrt(x**2), extent=[0, 61, 0, 61],aspect='equal',
            vmax=np.sqrt(x**2).max(), vmin=0, cmap = 'plasma')

#Plot CP5 Scalogram
ax6 = fig.add_subplot(2,2,3)
ax6.set_xticks(np.arange(0,25.1,5), minor = False)
ax6.set_xticklabels(np.round(np.arange(0,1.1,0.2),2))
ax6.set_xticks(np.arange(0,61.1,12.2), minor = False)
ax6.set_xticklabels(np.round(np.arange(0,1.1,0.2),2))
ax6.set_yticks(np.arange(0,61,4), minor = False)
ax6.set_yticklabels(np.flip(plt_fr))
for tick in ax6.xaxis.get_major_ticks():
    tick.set_pad(8)
ax6.imshow(prep_img.detach().numpy()[0][16], extent=[0, 61, 0, 61],aspect='equal',
            vmax=0.75*abs(prep_img.detach()).max(), vmin=prep_img.detach().min(),cmap ='inferno')

#Compute GBP Map
file_name_to_export = "guided_back_prop"
target_example = 1  #
(original_image, prep_img, target_class, file_name_to_export, pretrained_model) = get_example_params(example_index=target_example,path_wts=path_wts3 ,model = 2)
GBP = GuidedBackprop(pretrained_model)
guided_grads = GBP.generate_gradients(prep_img, target_class)
#guided_grads = guided_grads.reshape(32,250)
x = guided_grads[0][16]

#Plot Saliency Map    
ax7 = fig.add_subplot(2,2,2)
ax7.set_title('W3DCNN',fontsize=24)
ax7.set_xticks(np.arange(0,25.1,5), minor = False)
ax7.set_xticklabels([])
ax7.set_yticks(np.arange(0,25.1,5), minor = False)
ax7.set_yticklabels([])
ax7.imshow(np.sqrt(x**2), extent=[-1, 25, -1, 25],aspect='equal',
            vmax=np.sqrt(x**2).max(), vmin=0, cmap = 'plasma')

#Plot CP5 Scalogram
ax9 = fig.add_subplot(2,2,4)
ax9.set_xticks(np.arange(0,25.1,5), minor = False)
ax9.set_xticklabels(np.round(np.arange(0,1.1,0.2),2))
ax9.set_yticks(np.arange(0,25.1,5), minor = False)
ax9.set_yticklabels([])
for tick in ax9.xaxis.get_major_ticks():
    tick.set_pad(8)
ax9.imshow(prep_img.detach().numpy()[0][16], extent=[-1, 25, -1, 25],aspect='equal',
            vmax=0.75*abs(prep_img.detach()).max(), vmin=prep_img.detach().min(),cmap ='inferno')

ax7.yaxis.set_label_position("right")
ax7.set_ylabel('Class Activatoin Map', fontsize=24)

ax9.yaxis.set_label_position("right")
ax9.set_ylabel('Scalogram', fontsize=24)

# Turn off axis lines and ticks of the big subplot
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['left'].set_color('none')
ax.spines['right'].set_color('none')
ax.tick_params(labelcolor='w', top=False, bottom=False, left=False, right=False)

fig.subplots_adjust(wspace=0.01, hspace=0.05)


fig.savefig('CP5_Scalo_Conv_Gbp.jpg')