In [2]:
import numpy as np
import imageio
import glob
import cv2
import os
import argparse

import PIL.Image as Image
import imageio
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.pyplot import MultipleLocator

import torch
import torchvision.transforms as transforms

from models import SwinTransformer

# Pytorch grad cam
from pytorch_grad_cam import GradCAM, \
                             ScoreCAM, \
                             GradCAMPlusPlus, \
                             AblationCAM, \
                             XGradCAM, \
                             EigenCAM, \
                             EigenGradCAM

from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import show_cam_on_image, \
                                         deprocess_image, \
                                         preprocess_image

## Load Model

In [3]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

model = SwinTransformer(img_size=224,
                        patch_size=4,
                        in_chans=3,
                        num_classes=1,
                        embed_dim=96,
                        depths=[2, 2, 6, 2],
                        num_heads=[3, 6, 12, 24],
                        window_size=7,
                        mlp_ratio=4.0,
                        qkv_bias=True,
                        qk_scale=None,
                        drop_rate=0.0,
                        drop_path_rate=0.2,
                        ape=False,
                        patch_norm=True,
                        use_checkpoint=False,
                        device=device)

model.load_state_dict(torch.load('mse_weight/w_mse_epoch_70_f1_0.9172.pt'), strict=False)
model.to(device)

SwinTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0): BasicLayer(
      dim=96, input_resolution=(56, 56), depth=2
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          dim=96, input_resolution=(56, 56), num_heads=3, window_size=7, shift_size=0, mlp_ratio=4.0
          (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            dim=96, window_size=(7, 7), num_heads=3
            (qkv): Linear(in_features=96, out_features=288, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=96, out_features=96, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): Identity()
          (norm2): LayerNo

## Preprocessing

In [4]:
m = 0.39221061670618984
s = 0.11469786773730418
t = transforms.Compose([transforms.ToPILImage(),
                        transforms.Resize((224,224)),
                        transforms.ToTensor(),
                        transforms.Normalize((m, m, m), (s, s, s))])

## GIF function

In [None]:
def fig2data(fig):

    # draw the renderer
    fig.canvas.draw()
 
    # Get the RGBA buffer from the figure
    w, h = fig.canvas.get_width_height()
    buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8)
    buf.shape = (w, h, 4)
 
    # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode
    buf = np.roll(buf, 3, axis=2)
    image = Image.frombytes("RGBA", (w, h), buf.tostring())
    image = np.asarray(image)
    return image

def plot_fig(img, pop, ct_len):
    
    plt.ioff()
    plt.title('Positive CT Scan')
    fig = plt.figure(figsize=(15, 6))
    
    ax = plt.subplot(121)
    plt.imshow(img)
    ax.axis('off')
    
    
    ax = plt.subplot(122)
    ax.autoscale(enable=False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_position('zero')
    ax.spines['bottom'].set_position('zero')
    ax.set_xticks([])
    
    plt.xlim(0, ct_len)
    plt.ylim(-1.5, 1.5)    
    plt.plot(pop)

    image = fig2data(fig)
    
    return image

def compose_gif(path, save_path):
    img_list = os.listdir(path)
    sort_index = sorted(range(len(img_list)), key=lambda k: int(img_list[k].split('.')[0]))
    ct_len = len(sort_index)
    
    start_idx = int(round(ct_len / 10 * 3, 0))
    end_idx = int(round(ct_len / 10 * 7, 0)) + 1
    
    gif_images = []
    pop = []
    for i in range(start_idx, end_idx):
        img_path = os.path.join(path, img_list[sort_index[i]])
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        input_img = t(img).to(device).unsqueeze(0)
        cam_image = get_cam_image(model, img, input_img)
        cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)
        output = model(input_img)
        pop.append(output.item())
        image = plot_fig(cam_image, pop, end_idx - start_idx)
        gif_images.append(image)
        
    imageio.mimsave(save_path, gif_images, fps=min(5,len(sort_index) // 5))

## Grad Camp function

In [15]:
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--use-cuda', action='store_true', default=False,
                        help='Use NVIDIA GPU acceleration')
    parser.add_argument('--image-path', type=str, default='./examples/both.png',
                        help='Input image path')
    parser.add_argument('--aug_smooth', action='store_true',
                        help='Apply test time augmentation to smooth the CAM')
    parser.add_argument('--eigen_smooth', action='store_true',
                        help='Reduce noise by taking the first principle componenet'
                        'of cam_weights*activations')

    parser.add_argument('--method', type=str, default='gradcam',
                        help='Can be gradcam/gradcam++/scorecam/xgradcam/ablationcam')

    args = parser.parse_args([])
    args.use_cuda = args.use_cuda and torch.cuda.is_available()

    return args

def reshape_transform(tensor, height=7, width=7):
    result = tensor.reshape(tensor.size(0), 
        height, width, tensor.size(2))

    # Bring the channels to the first dimension,
    # like in CNNs.
    result = result.transpose(2, 3).transpose(1, 2)
    return result

def get_cam_image(model, rgb_image, input_tensor):
    args = get_args()
    
    methods = \
    {"gradcam": GradCAM, 
     "scorecam": ScoreCAM, 
     "gradcam++": GradCAMPlusPlus,
     "ablationcam": AblationCAM,
     "xgradcam": XGradCAM,
     "eigencam": EigenCAM,
     "eigengradcam": EigenGradCAM}
    
    target_layer = model.layers[-1].blocks[-2].norm1

    cam = methods[args.method](model=model, 
                               target_layer=target_layer,
                               use_cuda=args.use_cuda,
                               reshape_transform=reshape_transform)

    target_category = None
    cam.batch_size = 32
    grayscale_cam = cam(input_tensor=input_tensor,
                        target_category=target_category,
                        eigen_smooth=args.eigen_smooth,
                        aug_smooth=args.aug_smooth)

    # Here grayscale_cam has only one image in the batch
    grayscale_cam = grayscale_cam[0, :]
    rgb_image = cv2.resize(rgb_image, (224, 224))
    rgb_image = np.float32(rgb_image) / 255
    cam_image = show_cam_on_image(rgb_image, grayscale_cam)
    
    return cam_image

## Main

In [17]:
# Input CT scan path, and save path
compose_gif('/covid/data/val/neg/ct_scan_33/', 'neg.gif')

  # This is added back by InteractiveShellApp.init_path()
  app.launch_new_instance()
