In [7]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
%matplotlib inline
from PIL import Image
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import Compose, ToTensor
import torch.nn.functional as F
import torch.nn as nn


from captum.attr import IntegratedGradients, DeepLiftShap, GuidedGradCam



from captum.attr import DeepLiftShap
from captum.attr import visualization as viz

from torch.utils.data import Dataset, DataLoader
import warnings
warnings.simplefilter("ignore", Warning)
import torch.nn as nn 
from matplotlib.pyplot import figure
pd.set_option('display.max_rows', 5000)
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.pad_inches'] = 0
px = 1/plt.rcParams['figure.dpi']  # pixel in inches
figure(figsize=(512*px, 512*px), dpi=300)

import ast
import matplotlib
matplotlib.use('Agg')
from pathlib import Path
# import hgs_to_pixel_converter as hgs_to_pix
pd.set_option('display.max_colwidth', None)
from matplotlib.colors import ListedColormap

In [8]:
class VGG16(nn.Module):
    def __init__(self, num_classes=2):
        super(VGG16, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=False),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=False),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=False),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=False),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=False),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=False),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=False),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=False),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=False),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=False),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=False),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=False),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=False),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(inplace=False),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=False),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


In [9]:
class MyJP2Dataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.annotations = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __getitem__(self, index):
        img_path = os.path.join(self.root_dir, self.annotations.iloc[index, 0])
        hmi = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(hmi)
            
        y_prob = round(float((self.annotations.iloc[index, 1])), 2)
        y_label = str(self.annotations.iloc[index, 2])
        
        return (image, y_prob, y_label)

    def __len__(self):
        return len(self.annotations)

In [10]:
device = torch.device('cuda')
model_PATH1 = '../create_models/trained_models/fold1/fold1.pth'
weights1 = torch.load(model_PATH1)
test_model = VGG16().to(device)
test_model.load_state_dict(weights1['model_state_dict'])
test_model.eval()

VGG16(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU()
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU()
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU()
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU()
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU()
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), p

In [11]:
csv_file = 'x_class.csv'
data_path = '/scratch/cpandey1/hmi_jpgs_512/'
data_transforms = Compose([ToTensor()])
dataset1 = MyJP2Dataset(csv_file = csv_file , 
                             root_dir = data_path,
                             transform = data_transforms)
batch_size = 164
loader1 = DataLoader(dataset=dataset1, batch_size=batch_size, shuffle = False, num_workers=8)
df = pd.read_csv(csv_file)
df.head(1)

Unnamed: 0,timestamp,flare_prob,goes_class,fl_location,flare_start
0,2014/02/24/HMI.m2014.02.24_01.00.00.jpg,0.807306,X4.9,"(-82, -12)",2014-02-25 00:39:00


In [12]:
def imshow(img, transpose = True):
#     img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

dataiter = iter(loader1)
images, probs, labels = next(dataiter)
probs[3]

tensor(0.8600, dtype=torch.float64)

In [16]:
def attribute_image_features(algorithm, input, target, **kwargs):
    test_model.zero_grad()
    tensor_attributions = algorithm.attribute(input,
                                              target=target,
                                              **kwargs
                               )
    return tensor_attributions

def plot_attributions_guided_grad_cam(img, filename_grad, target):
    dpi=300
    plt.rcParams['figure.dpi'] = dpi
    plt.rcParams['savefig.pad_inches'] = 0
    px = 1/plt.rcParams['figure.dpi']  # pixel in inches
    figure(dpi=dpi)
    filename = filename_grad
    inp = img.unsqueeze(0)
    inp.requires_grad = True
    guided_gc = GuidedGradCam(test_model, test_model.features[28])
    grads = guided_gc.attribute(inp.to(device), target=target)
    grads = np.transpose(grads.squeeze(0).cpu().detach().numpy(), (1, 2, 0))
    original_image = np.transpose((img.cpu().detach().numpy() / 2) + 0.5, (1, 2, 0))
    
    fig, ax = viz.visualize_image_attr(grads, method="heat_map",sign="absolute_value",alpha_overlay=1.0,
                              show_colorbar=True)
    # Define custom colormap
    # cmap=plt.cm.Blues
    
    # fig, ax = viz.visualize_image_attr(grads, original_image, method="blended_heat_map", sign="absolute_value",alpha_overlay=0.8,
    #                           show_colorbar=True)
    
    # cbar = ax.figure.colorbar(ax.images[1])
    # cbar.ax.set_position([0.85, 0.15, 0.05, 0.7])
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    fig.tight_layout(pad=0.05)
    fig.savefig(f'{filename}.svg', dpi=dpi)
    # fig.savefig(f'{filename}.png', dpi=dpi)
    
    
def plot_attributions_deepshap(img, filename_shap, i, target):
    dpi=300
    plt.rcParams['figure.dpi'] = dpi
    plt.rcParams['savefig.pad_inches'] = 0
    px = 1/plt.rcParams['figure.dpi']  # pixel in inches
    figure(dpi=dpi)
    filename = filename_shap
    inp = img.unsqueeze(0)
    inp.requires_grad = True
    saliency = DeepLiftShap(test_model)
    if i<10:
        grads = saliency.attribute(inp.to(device), baselines= images[i:i+9].to(device), target=target)
    else:
        grads = saliency.attribute(inp.to(device), baselines= images[i-5:i+5].to(device), target=target)
    grads = np.transpose(grads.squeeze(0).cpu().detach().numpy(), (1, 2, 0))
    original_image = np.transpose((img.cpu().detach().numpy() / 2) + 0.5, (1, 2, 0))
    fig, ax = viz.visualize_image_attr(grads, method="heat_map",sign="absolute_value",alpha_overlay=1.0,
                              show_colorbar=True)
    # fig, ax = viz.visualize_image_attr(grads, original_image, method="blended_heat_map", sign="absolute_value",alpha_overlay=0.8,
    #                           show_colorbar=True)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    fig.tight_layout(pad=0.05)
    fig.savefig(f'{filename}.svg', dpi=dpi)
    
    
def plot_attributions_intgrad(img, filename_intgrad, target):
    dpi=300
    plt.rcParams['figure.dpi'] = dpi
    plt.rcParams['savefig.pad_inches'] = 0
    px = 1/plt.rcParams['figure.dpi']  # pixel in inches
    figure(dpi=dpi)
    filename = filename_intgrad
    inp = img.unsqueeze(0)
    inp.requires_grad = True
    inp=inp.to(device)
    ig = IntegratedGradients(test_model)
    attr_ig, delta = attribute_image_features(ig, inp, target, baselines=inp * 0 + 0.5, return_convergence_delta=True)
    attr_ig = np.transpose(attr_ig.squeeze(0).cpu().detach().numpy(), (1, 2, 0))
    original_image = np.transpose((img.cpu().detach().numpy() / 2) + 0.5, (1, 2, 0))
    fig, ax = viz.visualize_image_attr(attr_ig, method="heat_map",sign="absolute_value",alpha_overlay=1.0,
                              show_colorbar=True)
#     print('Approximation delta: ', abs(delta))
    # fig, ax = viz.visualize_image_attr(attr_ig, original_image, method="blended_heat_map",sign="absolute_value",alpha_overlay=0.8,
    #                           show_colorbar=True)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    fig.tight_layout(pad=0.05)
    fig.savefig(f'{filename}.svg', dpi=dpi)

In [17]:
for i in range(len(df)): 
    pth_grad = 'gradcam/' + df.loc[i]['timestamp'][0:11]
    Path(pth_grad).mkdir(parents=True, exist_ok=True)
    img = images[i]
    filename_grad = 'gradcam/' + df.loc[i]['timestamp'][:-4]
    plot_attributions_guided_grad_cam(img, filename_grad, 1)

In [18]:

for i in range(len(df)): 
    pth_shap = 'shap/' + df.loc[i]['timestamp'][0:11]
    Path(pth_shap).mkdir(parents=True, exist_ok=True)
    img = images[i]
    filename_shap = 'shap/' + df.loc[i]['timestamp'][:-4]
    plot_attributions_deepshap(img, filename_shap, i, 1)

In [21]:

for i in range(len(df)): 
    pth_ig = 'intgradtest/' + df.loc[i]['timestamp'][0:11]
    Path(pth_ig).mkdir(parents=True, exist_ok=True)
    img = images[i]
    filename_intgrad = 'intgradtest/' + df.loc[i]['timestamp'][:-4]
    plot_attributions_intgrad(img, filename_intgrad, 1)