In [None]:
import pandas as pd
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import sys
import cv2

In [None]:
class Interpretation:

    def __init__(self, feature_map):
        self.feature_map = feature_map

    @staticmethod
    def __relu(array):
        array[array < 0] = 0
        return array

    @staticmethod
    def __prepend_zero(arr):
        return np.concatenate((np.array([0]), arr[:-1]))

    @staticmethod
    def __alpha(mbr, max_prob):
        return np.gradient(mbr + sys.float_info.epsilon, 1 / max_prob[0])

    def __get_beta(self, mbr, cfr, max_prob):
        alpha = self.__relu(self.__alpha(mbr, max_prob))
        sorted_index = np.argsort(alpha)[::-1]
        result = None

        for i in range(self.feature_map):
            result = (result + cfr[sorted_index[i]]) if result is not None else cfr[sorted_index[i]]

        return self.__relu(result)

    def final_beta(self, mbr1, mbr2, mbr3, cfr1, cfr2, cfr3, max_prob):
        result1 = self.__get_beta(mbr1, cfr1, max_prob)
        result2 = self.__get_beta(mbr2, cfr2, max_prob)
        result3 = self.__get_beta(mbr3, cfr3, max_prob)

        result2 = np.append(result2, np.zeros(len(result1) - len(result2)))
        result3 = np.append(result3, np.zeros(len(result1) - len(result3)))

        return result1 + result2 + result3

    def individual_beta(self, mbr, cfr, max_prob, channel):
        beta = self.__get_beta(mbr, cfr, max_prob)
        result = None

        for i in range(channel):
            conv_beta = self.__prepend_zero(beta)
            result = (result + conv_beta) if result is not None else conv_beta
            beta = conv_beta

        return result

class Model(pl.LightningModule):

    def __init__(self, dimension):
        super().__init__()
        
        self.norm1 = nn.BatchNorm1d(128)
        self.norm2 = nn.BatchNorm1d(128)
        self.norm3 = nn.BatchNorm1d(128)
        self.norm4 = nn.BatchNorm1d(512)
        self.norm5 = nn.BatchNorm1d(512)
        
        self.conv1 = nn.Conv1d(dimension, 128, 4)
        self.conv2 = nn.Conv1d(dimension, 128, 8)
        self.conv3 = nn.Conv1d(dimension, 128, 16)
        
        self.max1 = nn.MaxPool1d(997)
        self.max2 = nn.MaxPool1d(993)
        self.max3 = nn.MaxPool1d(985)
        
        self.dense1 = nn.Linear(128*3, 512)
        self.dense2 = nn.Linear(512, 512)
        self.dense3 = nn.Linear(512, 11)
        
    def forward(self, x):
        x1 = self.max1(F.relu(self.conv1(x)))
        x2 = self.max2(F.relu(self.conv2(x)))
        x3 = self.max3(F.relu(self.conv3(x)))
        x1 = torch.flatten(x1, 1)
        x2 = torch.flatten(x2, 1)
        x3 = torch.flatten(x3, 1)
        x1 = self.norm1(x1)
        x2 = self.norm2(x2)
        x3 = self.norm3(x3)
        x = torch.cat((x1, x2, x3), dim=1)
        x = F.relu(self.norm4(self.dense1(x)))
        x = F.relu(self.norm5(self.dense2(x)))
        x = self.dense3(x)
        return x
    
    def cnn_1_forward_hook(self, _, input, output):
        self.cnn_1_forward_result = torch.squeeze(output)

    def cnn_2_forward_hook(self, _, input, output):
        self.cnn_2_forward_result = torch.squeeze(output)

    def cnn_3_forward_hook(self, _, input, output):
        self.cnn_3_forward_result = torch.squeeze(output)

    def max_pool_1_backward_hook(self, _, grad_input, grad_output):
        self.max_pool_1_backward_result = torch.squeeze(grad_output[0])

    def max_pool_2_backward_hook(self, _, grad_input, grad_output):
        self.max_pool_2_backward_result = torch.squeeze(grad_output[0])

    def max_pool_3_backward_hook(self, _, grad_input, grad_output):
        self.max_pool_3_backward_result = torch.squeeze(grad_output[0])

    def register_hook(self):
        self.conv1.register_forward_hook(self.cnn_1_forward_hook)
        self.conv2.register_forward_hook(self.cnn_2_forward_hook)
        self.conv3.register_forward_hook(self.cnn_3_forward_hook)

        self.max1.register_backward_hook(self.max_pool_1_backward_hook)
        self.max2.register_backward_hook(self.max_pool_2_backward_hook)
        self.max3.register_backward_hook(self.max_pool_3_backward_hook)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=3e-4)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.transpose(1, 2).contiguous()
        y_hat = self(x)
        loss = nn.BCEWithLogitsLoss()(y_hat, y)  
        return loss
    
    def validation_step(self, batch, batch_idx):
        pass
     
    def test_step(self, batch, batch_idx):
        pass

In [None]:
dimension = 21
model_path = './models/epoch=34.ckpt'
model = Model(dimension)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['state_dict'])
model.eval()
model.register_hook()
num_of_feature_map = 50

In [None]:
from PIL import Image, ImageDraw, ImageFont

def draw_grad_cam_heatmap_on_text(sequence, hm, file_name, source, max_text_width=100000):
    d2coding_font = ImageFont.truetype("./D2Coding.ttf", size=20, encoding="unic")

    margin_top = 3
    margin_bottom = 3

    sequence_length = len(sequence)
    sequence_lines = int(sequence_length / max_text_width) + (1 if sequence_length % max_text_width != 0 else 0)

    text_width, text_height = d2coding_font.getsize(sequence)
    text_width = int(text_width / sequence_length)

    line_space = 7
    text_height = 14

    if text_height == 17:
        line_space = 7
        text_height = 14

    canvas_width = text_width * min(max_text_width, sequence_length)
    canvas_height = margin_top + (text_height * sequence_lines) + (line_space * (sequence_lines - 1)) + margin_bottom
    canvas = Image.new('RGBA', (canvas_width, canvas_height), (255, 255, 255, 255))

    draw = ImageDraw.Draw(canvas)

    # Draw Grad-CAM
    for char_index in range(sequence_length):
        char_index += 1

        cur_line = int(char_index / max_text_width) + (1 if char_index % max_text_width != 0 else 0)
        cur_index = char_index - ((cur_line - 1) * max_text_width)

        x0 = text_width * (cur_index - 1)
        y0 = text_height * (cur_line - 1) + ((cur_line - 1) * line_space) + margin_top

        x1 = text_width * cur_index
        y1 = text_height * cur_line + ((cur_line - 1) * line_space) + margin_top

        # Heatmap Index
        hmi = char_index - 1
        # draw.rectangle([(x0, y0), (x1, y1)], fill=(hm[hmi][0][2], hm[hmi][0][1], hm[hmi][0][0]))
        if source == 'train' :
            draw.rectangle([(x0, y0), (x1, y1)], fill=(255, int(255 - 255*hm[hmi]), int(255 - 255*hm[hmi])))
        else : 
            draw.rectangle([(x0, y0), (x1, y1)], fill=(int(255 - 255*hm[hmi]), int(255 - 255*hm[hmi]), 255))

    # Split Sequence
    if sequence_lines > 1:
        temp_sequence = ""
        for i in range(sequence_lines):
            temp_sequence += (sequence[max_text_width * i: max_text_width * (i + 1)] + "\n")
        sequence = temp_sequence[:-1]

    # Draw Sequence
    draw.multiline_text((0, 0), sequence, font=d2coding_font, fill="#000000")

    canvas.save(f"{file_name}", "PNG")

In [None]:
def plot(y_pred, color):
    img = np.zeros((30, 30, 3))

    img[0:30, 0:30] = color

    return img


def plotting_alpha(value, text, color):
    img = plot(value, color)
    font = cv2.FONT_HERSHEY_SIMPLEX

    text = "%s" % text

    cv2.putText(img, text, (0 + 1, len(img) - 2), font, 1, (10, 10, 10), 1)

    return img

In [None]:
class SeqEncoder:
    def __init__(self):
        self.categories = np.array(list('ACDEFGHIKLMNPQRSTVWYX'))

    def char_to_one_hot_encoding(self, c):
        X_int = np.zeros(len(self.categories), dtype=np.int8)
        index = np.where(self.categories == c)[0]

        if len(index) == 0: return X_int

        X_int[index] = 1
        return X_int

    def seq_to_one_hot_encoding(self, seq):
        return np.array([self.char_to_one_hot_encoding(c) for c in seq])

    def one_hot_encoding_to_seq(self, one_hot_encoding):
        X = np.array(one_hot_encoding).reshape(-1, 21)
        last_index = np.where(np.sum(X, axis=1) == 0)[0]
        last_index = 1000 if len(last_index) == 0 else last_index[0]
        return ''.join(self.categories[X.argmax(axis=1)][:last_index])

In [None]:
sensitivity = 1.3

sequence = 'AAAAA'

result_int_1 = None
result_int_2 = None
result_int_3 = None

seq = sequence
sequence = sequence + ' ' * (1000 - len(sequence))
sequence = SeqEncoder().seq_to_one_hot_encoding(sequence)
sequence = torch.Tensor([sequence]).transpose(1, 2).contiguous().float()
outputs_p = model(sequence)
outs = outputs_p.squeeze()
index_class = np.array(torch.argmax(outputs_p, dim=1, keepdim=True).tolist()).reshape(-1)[0]
outs[index_class].backward(retain_graph=True)

mbr_1 = model.max_pool_1_backward_result.detach().cpu().numpy()
mbr_2 = model.max_pool_2_backward_result.detach().cpu().numpy()
mbr_3 = model.max_pool_3_backward_result.detach().cpu().numpy()

cfr_1 = model.cnn_1_forward_result.detach().cpu().numpy()
cfr_2 = model.cnn_2_forward_result.detach().cpu().numpy()
cfr_3 = model.cnn_3_forward_result.detach().cpu().numpy()

max_prob = torch.max(outputs_p, dim=1, keepdim=True)[0].detach().cpu().numpy().reshape(sequence.size()[0])
interpretation = Interpretation(num_of_feature_map)

# Level 1 Interpretation
final_beta = interpretation.final_beta(mbr_1, mbr_2, mbr_3, cfr_1, cfr_2, cfr_3, max_prob)
result_int_1 = (result_int_1 + final_beta) if result_int_1 is not None else final_beta

# Level 2 Interpretation
individual_beta_c1 = interpretation.individual_beta(mbr_1, cfr_1, max_prob, 4)
individual_beta_c2 = interpretation.individual_beta(mbr_2, cfr_2, max_prob, 8)
individual_beta_c2 = np.append(individual_beta_c2, np.zeros(len(cfr_1[0]) - len(cfr_2[0])))
individual_beta_c3 = interpretation.individual_beta(mbr_3, cfr_3, max_prob, 16)
individual_beta_c3 = np.append(individual_beta_c3, np.zeros(len(cfr_1[0]) - len(cfr_3[0])))
individual_beta = individual_beta_c1 + individual_beta_c2 + individual_beta_c3
result_int_2 = (result_int_2 + individual_beta) if result_int_2 is not None else individual_beta

# Level 3 Interpretation
result_int_3 = (result_int_3 + individual_beta) if result_int_3 is not None else individual_beta
result = result_int_3.reshape(1,-1)[:,:len(sequence.nonzero())]
result = sensitivity ** result
result = result[0]

kernel_size = 7
kernel = np.ones(kernel_size) / kernel_size
result = np.expand_dims(np.convolve(result, kernel, mode='same'), 0)
result = result[0]

r_img = np.zeros((30, 30 * len(seq), 3))

cmap = plt.get_cmap('Reds') 
norm = mcolors.Normalize(vmin=min(result), vmax=max(result)) 
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)

for e in range(len(seq)):

    color = sm.to_rgba(np.array(result)[e])[2:3] +  sm.to_rgba(np.array(result)[e])[1:2] + sm.to_rgba(np.array(result)[e])[0:1]
    img = plotting_alpha(np.array(result)[e], seq[e], np.array(color) * 255)
    r_img[:, e * 30:(e + 1) * 30] = img

cv2.imwrite(f'PEPIC_interpretation.png', r_img)