# Imports

In [None]:
import time, os, warnings, torch, multiprocessing, skimage, scipy
import numpy as np
import pandas as pd
import torch.nn as nn
from PIL import Image
from timm import create_model
from prodigyopt import Prodigy
import pytorch_lightning as pl
from torchvision import transforms
from model_extractors import resnet50_img_extractor
from masking_network import resnet50_trained_extractor
from sklearn.model_selection import train_test_split
from captum.attr import GradientShap
from captum.attr import IntegratedGradients
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from torchray.attribution.extremal_perturbation import extremal_perturbation, contrastive_reward
from quantus.metrics.faithfulness.faithfulness_estimate import FaithfulnessEstimate


# Data

In [2]:

class lung_data(torch.utils.data.Dataset):
    def __init__(self, dataframe, transform=None, grayscale=False):
        self.dataframe = dataframe
        self.transform = transform
        self.grayscale = grayscale
        self.num_classes = 2
    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        row = self.dataframe.iloc[index]
        if self.grayscale:
           img = skimage.io.imread(row["img_path"])
           img = np.expand_dims(img, axis=0)
           img = self.transform(img)
           return torch.from_numpy(img), np.eye(self.num_classes)[row["target"]]
        
        img    = Image.open(row["img_path"])
        img = img.convert("RGB")
         
        img = self.transform (img) if self.transform !=None else img 
        return img, np.eye(self.num_classes)[row["target"]],  row["img_path"]
    


In [None]:

data_path = "data/RSNA_DATASET/"
df = pd.read_csv('data/RSNA_DATASET/stage_2_train_labels.csv')
df = df.drop_duplicates(subset=['patientId',"Target"], keep='first')
df["img_path"] = df.Target.apply(lambda x: "PNEUMONIA" if x == 1 else "NORMAL")
df["img_path"] = df.apply(lambda row: data_path +row["img_path"] + "/" + str(row['patientId']) + ".png", axis=1)
df = df[['img_path', 'Target',"x","y","width","height"]]
df.columns = ['img_path', 'target',"x","y","width","height"]

df = df.fillna(0)
df["x"] = (df["x"]/1024 * 224).astype(int)
df["y"] = (df["y"]/1024 * 224).astype(int)
df["width"] = (df["width"]/1024 * 224).astype(int)
df["height"] = (df["height"]/1024 * 224).astype(int)
train_df, test_df = train_test_split(df, test_size=0.2, stratify=df['target'], random_state=42)
df

## Model Training data

In [4]:

TRANS  =  transforms.Compose([
        transforms.ToTensor(),  # Convert image to tensor
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],),  # Normalize using precomputed mean and std
        transforms.Resize((224, 224),antialias=True),
    ])

data_set_model_train = lung_data(train_df, transform=TRANS)
data_set_model_test = lung_data(test_df, transform=TRANS)


In [5]:
weights =  len(train_df)/np.array(train_df.target.value_counts())
weights = weights/np.sum(weights)
weights

array([0.2253244, 0.7746756])

## Evaluation data with bounding boxes 

In [6]:

TRANS  =  transforms.Compose([
        transforms.ToTensor(),  # Convert image to tensor
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],),  # Normalize using precomputed mean and std
        transforms.Resize((224, 224),antialias=True),
    ])
only_bounding_box =  test_df[test_df.target==1]
only_bounding_box = only_bounding_box.sample(1000,replace=False,random_state=42)
only_bounding_box
data_set_train = lung_data(train_df, transform=TRANS)
data_set_test = lung_data(only_bounding_box, transform=TRANS)

# Model


## Architecture

In [7]:

class classifier_model(pl.LightningModule):
    def __init__(self, model_string = "resnet50",):
        super().__init__()
        self.model = create_model(model_string, pretrained=True, num_classes=2,in_chans= 3)
        self.criterion = nn.CrossEntropyLoss(weight=torch.tensor(weights))

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y, _ = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y, _ = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        self.log('val_loss', loss)

    def configure_optimizers(self):
        optimizer = Prodigy(self.parameters(), lr=1, weight_decay=1e-4, )
        return optimizer
model = classifier_model("resnet50")

## Model Training

In [None]:

BATCH = 32
NUM_WORKERS = multiprocessing.cpu_count() - 2
EPOCHS = 10
MIXED_PRECISION = False
DETERMINISTIC = False


train_loader = torch.utils.data.DataLoader(data_set_model_train, batch_size=BATCH, shuffle=True, num_workers=NUM_WORKERS)
test_loader = torch.utils.data.DataLoader(data_set_model_test, batch_size=BATCH, shuffle=False, num_workers=NUM_WORKERS)

trainer = pl.Trainer(
        max_epochs=EPOCHS,
        devices="auto",
        accelerator="gpu" if torch.cuda.is_available() else "cpu", 
        precision=16 if MIXED_PRECISION else 32,
        default_root_dir="classifier_pneunomia_resnet50_logs",
        accumulate_grad_batches=int(32/BATCH,),
        deterministic=DETERMINISTIC)


In [9]:
trainer.fit(model, train_loader, test_loader)
model.eval()
model.cuda()
accs = []
for batch in test_loader:
    x, y, _ = batch
    logits = (model(x.cuda())).detach().cpu()
    acc = (logits.argmax(1) == y.argmax(1)).float().mean()
    accs.append(acc.item())
model.cpu()
print("Accuracy: ", sum(accs) / len(accs))

# Methodology

## NEM

### Hyperparameters

In [10]:
BATCH = 32
NUM_WORKERS = multiprocessing.cpu_count() - 2
EPOCHS = 5
MIXED_PRECISION = False
DETERMINISTIC = False
CONTRASTIVE = False
NOISE_MASK = False
INVERSE = True

### Model

In [11]:

model = classifier_model("resnet50").load_from_checkpoint("classifier_pneunomia_resnet50_logs/lightning_logs/version_1/checkpoints/epoch=9-step=6680.ckpt",strict=False)
extractor =  resnet50_img_extractor(model.model)
masking_model = resnet50_trained_extractor(extractor, EPOCHS, batch_size=BATCH, lr = 1, 
                                           center = False, partition = 1, 
                                           noise_mask = NOISE_MASK, constrastive = CONTRASTIVE, inverse=INVERSE)

constrastive False
 The encoder channels are (1024, 512, 256, 64, 3)
 The decoder channels are (256, 128, 64, 32, 16)
 The buttom layer is 2048
 Using image space is True


### Training

In [None]:

train_loader = torch.utils.data.DataLoader(data_set_train, batch_size=BATCH, shuffle=True)
test_loader = torch.utils.data.DataLoader(data_set_test, batch_size=BATCH, shuffle=False)
trainer = pl.Trainer(
        max_epochs=EPOCHS,
        devices="auto",
        accelerator="gpu" if torch.cuda.is_available() else "cpu", 
        precision=16 if MIXED_PRECISION else 32,
        default_root_dir="masking_supervised_pneumonia_logs",
        accumulate_grad_batches=int(128/BATCH,),
        deterministic=DETERMINISTIC
)

trainer.fit(masking_model, train_loader)

### Prediction

In [None]:

masking_model.eval()
masking_model = masking_model.to("cuda")
nem_path = "result/supervised/pneunomia/nem_inv/"
os.makedirs(nem_path,exist_ok=True)
execution_times = []
for (img, y, name) in data_set_test:
    name = name.split("/")[-1]
    img = img.unsqueeze(0).to("cuda")
    start_time = time.time()
    attr, _  = masking_model(img)
    execution_times.append(time.time()-start_time)
    x = img.squeeze().permute(1,2,0).cpu().detach().numpy()
    x = ((x - x.min()) / (x.max() - x.min()))*255
    attr = attr.cpu().squeeze().detach().numpy()
    attr = (((attr - attr.min()) / (attr.max() - attr.min()))*255).astype(np.uint8)
    if INVERSE:
        attr = 255 - attr



    Image.fromarray(attr).save(f"{nem_path}{name}")

print("Average execution time: ", np.array(execution_times).sum()/len(execution_times))

# Comparisons

## GradCAM

In [None]:
method = GradCAMPlusPlus(model, [model.model.layer4[-1]])   
path = "result/supervised/pneunomia/gradcam/"
os.makedirs(path, exist_ok=True)


execution_times = []
for (img,target, name) in data_set_test:
    name = name.split("/")[-1]
    start_time = time.time()
    attr = method(input_tensor=img.unsqueeze(0).cuda(), targets=[ClassifierOutputTarget(target.argmax())]).squeeze()
    execution_times.append(time.time()-start_time)
    x = img.permute(1,2,0).cpu().detach().numpy()
    x = (((x - x.min()) / (x.max() - x.min()))*255).astype(np.uint8)
    
    attr = (((attr - attr.min()) / (attr.max() - attr.min()))*255).astype(np.uint8)
    Image.fromarray(attr).save(f"{path}{name}")
print("Average execution time: ", np.array(execution_times).sum()/len(execution_times))

# GradientShap

In [None]:
model.eval()
method = GradientShap(model)
baseline = torch.zeros((1, 3, 224, 224)).cuda()
path = "result/supervised/pneunomia/grad_shap/"
os.makedirs(path, exist_ok=True)
execution_times = []
for (img,target, name) in data_set_test:
    name = name.split("/")[-1]
    start_time = time.time()
    attr = method.attribute(img.unsqueeze(0).cuda(),baselines=baseline,target= torch.tensor(target.argmax())).squeeze()
    execution_times.append(time.time()-start_time)
    x = img.permute(1,2,0).cpu().detach().numpy()
    x = (((x - x.min()) / (x.max() - x.min()))*255).astype(np.uint8)
    attr = torch.abs(attr).sum(0).cpu().numpy()
    attr = (((attr - attr.min()) / (attr.max() - attr.min()))*255).astype(np.uint8)
    Image.fromarray(attr).save(f"{path}{name}")
print("Average execution time: ", np.array(execution_times).sum()/len(execution_times))

# Integrated Gradients

In [None]:
model.eval()
method = IntegratedGradients(model)
baseline = torch.zeros((1, 3, 224, 224)).cuda()
path = "result/supervised/pneunomia/integrated_gradients/"
os.makedirs(path, exist_ok=True)
execution_times = []
for (img,target, name) in data_set_test:
    name = name.split("/")[-1]
    start_time = time.time()
    attr = method.attribute(img.unsqueeze(0).cuda(),baselines=baseline,target= torch.tensor(target.argmax())).squeeze()
    execution_times.append(time.time()-start_time)
    x = img.permute(1,2,0).cpu().detach().numpy()
    x = (((x - x.min()) / (x.max() - x.min()))*255).astype(np.uint8)
    attr = torch.abs(attr).sum(0).cpu().numpy()
    attr = (((attr - attr.min()) / (attr.max() - attr.min()))*255).astype(np.uint8)
    Image.fromarray(attr).save(f"{path}{name}")
print("Average execution time: ", np.array(execution_times).sum()/len(execution_times))

Average execution time:  0.04876405501365662


# Smooth Pixel mask

In [None]:
warnings.filterwarnings("ignore", category=UserWarning)

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")   
softmax = torch.nn.Softmax(dim=-1)

class SmoothMask:
    def __init__(self, area, model):
        self.area = area 
        self.model = model
    def __call__(self, x, pred):
        mask, _ = extremal_perturbation(
            self.model, x, pred,
            reward_func=contrastive_reward,
            debug=False,
            areas=[self.area]
        )
        mask 
        return mask

In [None]:
method = SmoothMask(model =model, area=0.05)
path = "result/supervised/pneunomia/smooth_mask/"
os.makedirs(path, exist_ok=True)
execution_times = []
for (img,target, name) in data_set_test:
    name = name.split("/")[-1]
    start_time = time.time()
    attr = method(img.unsqueeze(0).cuda(), 1)
    execution_times.append(time.time()-start_time)

    x = img.permute(1,2,0).cpu().detach().numpy()
    x = (((x - x.min()) / (x.max() - x.min()))*255).astype(np.uint8)
    attr = attr.squeeze().cpu().detach().numpy()
    attr = (((attr - attr.min()) / (attr.max() - attr.min()))*255).astype(np.uint8)
    Image.fromarray(attr).save(f"{path}{name}")
print("Average execution time: ", np.array(execution_times).sum()/len(execution_times))

Average execution time:  3.66284157204628


# Results

# masks

In [14]:
def gen_bounding_box(x,y,w,h, size=(224,224)):
    "generate a boundingbox mask using numpy and the input coordinates"
    mask = np.zeros(size)
    mask[y:y+h, x:x+w] = 1
    return mask.astype(np.uint8)
mask_path = "result/masks/pneunomia/"
os.makedirs(mask_path, exist_ok=True)
for sample in only_bounding_box.iterrows():
   mask =  gen_bounding_box(sample[1]["x"],sample[1]["y"],sample[1]["width"],sample[1]["height"])
   Image.fromarray(mask).save(f"{mask_path}{sample[1]['img_path'].split('/')[-1]}")

## Experimental method

In [19]:
def run_exp(metrics, metric_names,data):
    res_path = "result/supervised/pneunomia/"
    mask_path = "result/masks/pneunomia/"
    resnet_50_smoothmask_path = f"{res_path}smooth_mask/"
    resnet_50_gradcam_path = f"{res_path}gradcam/"
    resnet_50_gradshape_path = f"{res_path}grad_shap/"
    resnet_50_integrated_gradients_path = f"{res_path}integrated_gradients/"
    resnet_50_nem_path = f"{res_path}nem_inv/"

    for metric_func,metric_name in zip(metrics,metric_names):    
        resnet_50_smoothmask_res = metric_func(mask_path=mask_path, explanation_path=resnet_50_smoothmask_path, samples=data)
        resnet_50_gradcam_res = metric_func(mask_path=mask_path, explanation_path=resnet_50_gradcam_path, samples=data)
        resnet_50_gradshape_res = metric_func(mask_path=mask_path, explanation_path=resnet_50_gradshape_path, samples=data)
        resnet_50_integrated_gradients_res = metric_func(mask_path=mask_path, explanation_path=resnet_50_integrated_gradients_path, samples=data)
        resnet_50_nem_res = metric_func(mask_path=mask_path, explanation_path=resnet_50_nem_path, samples=data)


        print(f"""
        {metric_name}:
        resnet_50_smoothmask:     {resnet_50_smoothmask_res}
        resnet_50_gradcam:        {resnet_50_gradcam_res}
        resnet_50_gradshape:      {resnet_50_gradshape_res}
        resnet_50_integrated_gradients: {resnet_50_integrated_gradients_res}
        resnet_50_nem:            {resnet_50_nem_res}
        """)




## Locality

In [None]:
def relevance_rank( mask_path, explanation_path, samples):
    rank_accuracy = 0
    samps = 0
    for i, sample in samples.iterrows():
        img_name = sample["img_path"].split("/")[-1]
        s = (np.array(Image.open(mask_path + img_name)) > 0).astype(np.uint8)
        a =  np.array(Image.open(explanation_path +  img_name))/255
        # Prepare shapes.
        a = a.flatten()
        s = np.where(s.flatten().astype(bool))[0]
        # Size of the ground truth mask.
        k = len(s)
        # Sort in descending order.
        a_sorted = np.argsort(a)[-int(k) :]
        # Calculate hits.
        hits = len(np.intersect1d(s, a_sorted))
        if hits != 0:
            rank_accuracy += hits / float(k)
        else:
            rank_accuracy += 0.0
        samps +=1 



    return rank_accuracy/ samps

def relevancy_mass(mask_path, explanation_path, samples):
    mass_accuracy_total = 0
    samps = 0
    for i, sample in samples.iterrows():
        img_name = sample["img_path"].split("/")[-1]
        s = (np.array(Image.open(mask_path + img_name)) > 0).astype(np.uint8)
        a =  np.array(Image.open(explanation_path +  img_name))/255
        # 
        a = a.flatten()
        s = s.flatten().astype(bool)
        # Compute inside/outside ratio.
        r_within = np.sum(a[s])
        r_total = np.sum(a)
        # Calculate mass accuracy.
        mass_accuracy = r_within / r_total
        mass_accuracy_total += mass_accuracy
        samps +=1

    return mass_accuracy_total/ samps
     

run_exp([
    
     relevance_rank, relevancy_mass],
        [
     "Relevance Rank", "Relevancy Mass"],
    only_bounding_box)

## Complexity

In [None]:

def complexity( mask_path, explanation_path, samples):
    complexity = 0
    
    for i, sample in samples.iterrows():
        img_name = sample["img_path"].split("/")[-1]
        a =  np.array(Image.open(explanation_path +  img_name))/255
        # Prepare shapes.
        newshape = np.prod(a.shape)
        a = np.array(np.reshape(a, newshape), dtype=np.float64) / np.sum(np.abs(a))
        complexity += scipy.stats.entropy(pk=a)   


    return complexity/ len(samples)


def sparseness( mask_path, explanation_path, samples):
    complexity = 0
    
    for i, sample in samples.iterrows():
        img_name = sample["img_path"].split("/")[-1]
        a =  np.array(Image.open(explanation_path +  img_name))/255
        # Prepare shapes.
        newshape = np.prod(a.shape)
        a = np.array(np.reshape(a, newshape), dtype=np.float64)
        a += 0.0000001
        a = np.sort(a)
        complexity += (np.sum((2 * np.arange(1, a.shape[0] + 1) - a.shape[0] - 1) * a)) / (
            a.shape[0] * np.sum(a)
        )

    return complexity/ len(samples)


run_exp([
    complexity, 
    sparseness
    ],
        [
    "Complexity", 
    "Sparseness"
    ],
    only_bounding_box)


## Faithfullness

In [None]:
def faithfullness(mask_path, explanation_path, samples):
    metric = FaithfulnessEstimate(features_in_step=224 * 4)
    values = []
    i = 0
    for X,Y,img_name in samples:
        i += 1
        img_name = img_name.split("/")[-1]
        a =  np.array(Image.open(explanation_path +  img_name))/255
        Y = Y.argmax()
        values +=   metric(model=model.cuda().eval(),
                            x_batch=X.unsqueeze(0).numpy(), y_batch=np.expand_dims(Y.astype(np.uint8), axis=0),
                              a_batch=np.expand_dims(a, axis=0),device="cuda")
    return np.nanmean(values)

    
run_exp([faithfullness],
        ["Faithfullness"],
    data_set_test)