In [1]:
import glob
import os

import cv2
import numpy as np
from matplotlib import pyplot as plt

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.io import ImageReadMode, read_image

from xai.model import nets, methods
from mmn_xai.methods import sidu
from mmn_xai.data import dataset as xai_data

In [2]:
# Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception]
MODEL_NAME = "resnet"
NUM_CLASSES = 2

BATCH_SIZE = 8
EPOCHS = 15
LEARNING_RATE = 0.0005
MOMENTUM = 0.9
# Flag for feature extracting. When False, we finetune the whole model, when True 
# we only update the reshaped layer params
REGRESSION = True
FEATURE_EXTRACT = False
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = torch.device('cuda:0')

In [3]:
net, input_size = nets.initialize_model("resnet", NUM_CLASSES, FEATURE_EXTRACT, use_pretrained=True, device=DEVICE)
net.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,bias=False)
net = net.to(DEVICE)

In [4]:
state = torch.load("./out/model_3.pt") #val_acc ~ 0.8
# state = { ".".join(k.split(".")[1:]):v for k, v in state.items()}
# net = nn.DataParallel(net)
net.load_state_dict(state)
net = net.eval()

In [5]:
class Gray2RGB(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, x):
        return x.repeat(1, 3, 1, 1)    

# Data

In [6]:
import os

PATHS = []

def llegir_image(path, mode):
    global PATHS
    _, name = os.path.split(path)
    
    PATHS.append(name)
    
    return read_image(path, mode)
    
dataset = xai_data.ImageDataset(glob.glob("./in/COVIDGR/val/**/*.jpg"),  (lambda x: llegir_image(x, ImageReadMode.GRAY)))
data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

In [7]:
explainers = methods.get_xai_methods(net, net.layer4[1].conv2, DEVICE)

Generating filters: 100%|██████████| 6000/6000 [00:08<00:00, 695.70it/s]


In [8]:
subfolders = ["imgs", "xai", "xai_10", "xai_25", "xai_50", "xai_75"]
iguals = 0
diffs = 0

for i, (img, gt) in enumerate(data_loader):
    pred = net(img[:, 0:1, :, :].float().to(DEVICE)).cpu().detach().numpy()
    expl = explainers["grad_cam"](img.type(torch.float32))
    for ib in range(gt.shape[0]):
        gt_str = dataset.map(int(np.argmax(gt[ib])))
        pred_str = dataset.map(int(np.argmax(pred[ib])))
        
        if gt_str == pred_str:
            iguals += 1
        else:
            diffs += 1
        
        img_3d = cv2.cvtColor(img[ib, 0, :, :].cpu().detach().numpy(), cv2.COLOR_GRAY2BGR )
        
        for part, img_2_save in {"img": img_3d, "xai": (expl[ib] * 255)}.items(): 
            folder_path = os.path.join(".", "out", "res", part, gt_str, pred_str)
            os.makedirs(folder_path, exist_ok=True)

            cv2.imwrite(os.path.join(folder_path, PATHS[ib]), img_2_save)
        
        for i, (ths, name_part) in enumerate(zip([0.9, 0.75, 0.5, 0.25], subfolders[2:])):
            mask = (expl[ib] > ths).astype(np.uint8)
            contours, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
            img_border = cv2.drawContours(np.copy(img_3d), contours, -1, (0,255,255), 2)
            
            folder_path = os.path.join(".", "out", "res", name_part, gt_str, pred_str)
            os.makedirs(folder_path, exist_ok=True)
            
            
            cv2.imwrite(os.path.join(folder_path, PATHS[ib]), img_border)
    PATHS = []        
print(iguals, diffs)

  input = module(input)


230 52
