# Imports

In [None]:
import multiprocessing, time, torch, os, scipy

import numpy as np
import pandas as pd
from PIL import Image
from relax import RELAX
import pytorch_lightning as pl
from torchvision import transforms
from medclip.dataset import constants
from masking_network import medclip_masking_net
from medclip import  MedCLIPVisionModel, constants
from sklearn.model_selection import train_test_split







# Data

In [2]:

class lung_data(torch.utils.data.Dataset):
    def __init__(self, dataframe, transform):
        self.dataframe = dataframe
        self.transform = transform
    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        row = self.dataframe.iloc[index]
        
        img    = Image.open(row["img_path"])
        img = img.convert("RGB")
        
        
        return self.transform(img) , row["img_path"]
    


In [None]:

data_path = "data/RSNA_DATASET/"
df = pd.read_csv('data/RSNA_DATASET/stage_2_train_labels.csv')
df = df.drop_duplicates(subset=['patientId',"Target"], keep='first')
df["img_path"] = df.Target.apply(lambda x: "PNEUMONIA" if x == 1 else "NORMAL")
df["img_path"] = df.apply(lambda row: data_path +row["img_path"] + "/" + str(row['patientId']) + ".png", axis=1)
df = df[['img_path', 'Target',"x","y","width","height"]]
df.columns = ['img_path', 'target',"x","y","width","height"]

df = df.fillna(0)
df["x"] = (df["x"]/1024 * 224).astype(int)
df["y"] = (df["y"]/1024 * 224).astype(int)
df["width"] = (df["width"]/1024 * 224).astype(int)
df["height"] = (df["height"]/1024 * 224).astype(int)
train_df, test_df = train_test_split(df, test_size=0.2, stratify=df['target'], random_state=42)
df

## Only take test data with bounding boxes

In [4]:

TRANS =  transforms.Compose([
        transforms.ToTensor(),  # Convert image to tensor
        transforms.Normalize(
            mean=constants.IMG_MEAN,
            std=constants.IMG_STD,),  # Normalize using precomputed mean and std
        transforms.Resize((224, 224),antialias=True),
    ])
only_bounding_box =  test_df[test_df.target==1]
only_bounding_box = only_bounding_box.sample(1000,replace=False,random_state=42)
only_bounding_box
data_set_train = lung_data(train_df, transform=TRANS)
data_set_test = lung_data(only_bounding_box, transform=TRANS)


# Model
## TODO AUTO DOWNLOAD THE VISION MODEL weights

In [None]:

model = MedCLIPVisionModel().to("cuda")
model.load_from_medclip("./pretrained/medclip-resnet")
model.eval()

# Methodology

# NEM-U

### Hyperparameters

In [7]:

BATCH = 32
NUM_WORKERS = multiprocessing.cpu_count() - 2
EPOCHS = 10
MIXED_PRECISION = False
DETERMINISTIC = False
CONTRASTIVE = True
NOISE_MASK = False


### Model

In [None]:
model = medclip_masking_net(EPOCHS,BATCH,constrastive=CONTRASTIVE,noise_mask=NOISE_MASK)

# Training

In [None]:

train_loader = torch.utils.data.DataLoader(data_set_train, batch_size=BATCH, shuffle=True)
test_loader = torch.utils.data.DataLoader(data_set_test, batch_size=BATCH, shuffle=False)
trainer = pl.Trainer(
        max_epochs=EPOCHS,
        devices="auto",
        accelerator="gpu" if torch.cuda.is_available() else "cpu", 
        precision=16 if MIXED_PRECISION else 32,
        default_root_dir="masking_unsupervised_pneumonia_logs",
        accumulate_grad_batches=int(128/BATCH,),
        deterministic=DETERMINISTIC
)

trainer.fit(model, train_loader)


# Prediction

In [None]:

model.eval()
model = model.to("cuda")
nem_path = "result/unsupervised/pneunomia/nem/"
os.makedirs(nem_path,exist_ok=True)
execution_times = []
for (img, name) in data_set_test:
    name = name.split("/")[-1]
    img = img.unsqueeze(0).to("cuda")
    start_time = time.time()
    attr, _  = model(img)
    execution_times.append(time.time()-start_time)
    x = img.squeeze().permute(1,2,0).cpu().detach().numpy()
    x = ((x - x.min()) / (x.max() - x.min()))*255
    attr = attr.cpu().squeeze().detach().numpy()
    attr = (((attr - attr.min()) / (attr.max() - attr.min()))*255).astype(np.uint8)



    Image.fromarray(attr).save(f"{nem_path}{name}")

print("Average execution time: ", np.array(execution_times).sum()/len(execution_times))

# Comparisons

## RELAX

In [None]:

model = MedCLIPVisionModel().to("cuda")
model.load_from_medclip("pretrained/medclip-resnet")
model.eval()
relax_path = "result/unsupervised/pneunomia/relax/"
U_relax_path  = "result/unsupervised/pneunomia/u_relax/"
execution_times = []
for (img, name) in data_set_test:
    name = name.split("/")[-1]
    relax = RELAX(img.unsqueeze(0).cuda(),model)
    start_time = time.time()
    relax()
    execution_times.append(time.time()-start_time)
    x = img.permute(1,2,0).cpu().detach().numpy()
    x = ((x - x.min()) / (x.max() - x.min()))*255
    attr = relax.importance().cpu().squeeze().detach().numpy()
    attr = (((attr - attr.min()) / (attr.max() - attr.min()))*255).astype(np.uint8)
    uncertainty = (relax.U_RELAX().cpu().squeeze().detach().numpy())
    uncertainty = ((uncertainty - uncertainty.min()) / (uncertainty.max()- uncertainty.min())*255).astype(np.uint8)


    Image.fromarray(attr).save(f"{relax_path}{name}")
    Image.fromarray(uncertainty).save(f"{U_relax_path}{name}")
print("Average execution time: ", np.array(execution_times).sum()/len(execution_times))

# Results

## Generate Masks

In [None]:

def gen_bounding_box(x,y,w,h, size=(224,224)):
    "generate a boundingbox mask using numpy and the input coordinates"
    mask = np.zeros(size)
    mask[y:y+h, x:x+w] = 1
    return mask.astype(np.uint8)
mask_path = "result/masks/pneunomia/"
os.makedirs(mask_path, exist_ok=True)
for sample in only_bounding_box.iterrows():
   mask =  gen_bounding_box(sample[1]["x"],sample[1]["y"],sample[1]["width"],sample[1]["height"])
   Image.fromarray(mask).save(f"{mask_path}{sample[1]['img_path'].split('/')[-1]}")

##  Experimental method

In [None]:
def run_exp(metrics, metric_names):
    data= only_bounding_box
    res_path = "result/unsupervised/pneunomia/"
    mask_path = "result/masks/pneunomia/"
    MEDCLIP_relax_path = f"{res_path}relax/"
    MEDCLIP_u_relax_path = f"{res_path}u_relax/"
    MEDCLIP_nem_path = f"{res_path}nem/"

    for metric_func,metric_name in zip(metrics,metric_names):
        MEDCLIP_relax_res = metric_func(mask_path=mask_path, explanation_path=MEDCLIP_relax_path, samples=data)
        MEDCLIP_u_relax_res = metric_func(mask_path=mask_path, explanation_path=MEDCLIP_u_relax_path, samples=data)
        MEDCLIP_nem_res = metric_func(mask_path=mask_path, explanation_path=MEDCLIP_nem_path, samples=data)

        print(f"""
        {metric_name}:
        MEDCLIP_relax:     {MEDCLIP_relax_res}
        MEDCLIP_U_relax:   {MEDCLIP_u_relax_res}
        MEDCLIP_NEM:       {MEDCLIP_nem_res}
        """)



## Locality

In [None]:
def relevance_rank( mask_path, explanation_path, samples, uncertainty=False):
    rank_accuracy = 0
    samps = 0
    for i, sample in samples.iterrows():
        img_name = sample["img_path"].split("/")[-1]
        s = (np.array(Image.open(mask_path + img_name)) > 0).astype(np.uint8)
        a =  np.array(Image.open(explanation_path +  img_name))/255
        # Prepare shapes.
        a = a.flatten()
        s = np.where(s.flatten().astype(bool))[0]
        # Size of the ground truth mask.
        k = len(s)
        # Sort in descending order.
        a_sorted = np.argsort(a)[-int(k) :]
        # Calculate hits.
        hits = len(np.intersect1d(s, a_sorted))
        if hits != 0:
            rank_accuracy += hits / float(k)
        else:
            rank_accuracy += 0.0
        samps +=1 



    return rank_accuracy/ samps

def relevancy_mass(mask_path, explanation_path, samples, uncertainty=False):
    mass_accuracy_total = 0
    samps = 0
    for i, sample in samples.iterrows():
        img_name = sample["img_path"].split("/")[-1]
        s = (np.array(Image.open(mask_path + img_name)) > 0).astype(np.uint8)
        a =  np.array(Image.open(explanation_path +  img_name))/255
        # 
        a = a.flatten()
        s = s.flatten().astype(bool)
        # Compute inside/outside ratio.
        r_within = np.sum(a[s])
        r_total = np.sum(a)
        # Calculate mass accuracy.
        mass_accuracy = r_within / r_total
        mass_accuracy_total += mass_accuracy
        samps +=1

    return mass_accuracy_total/ samps
     

run_exp([
     relevance_rank, relevancy_mass],
        [
 "Relevance Rank", "Relevancy Mass"],)

# Complexity

In [None]:

def complexity( mask_path, explanation_path, samples,uncertainty=False):
    complexity = 0
    
    for i, sample in samples.iterrows():
        img_name = sample["img_path"].split("/")[-1]
        a =  np.array(Image.open(explanation_path +  img_name))/255
        # Prepare shapes.
        newshape = np.prod(a.shape)
        a = np.array(np.reshape(a, newshape), dtype=np.float64) / np.sum(np.abs(a))
        complexity += scipy.stats.entropy(pk=a)   


    return complexity/ len(samples)


def sparseness( mask_path, explanation_path, samples,uncertainty=False):
    complexity = 0
    
    for i, sample in samples.iterrows():
        img_name = sample["img_path"].split("/")[-1]
        a =  np.array(Image.open(explanation_path +  img_name))/255
        # Prepare shapes.
        newshape = np.prod(a.shape)
        a = np.array(np.reshape(a, newshape), dtype=np.float64)
        a += 0.0000001
        a = np.sort(a)
        complexity += (np.sum((2 * np.arange(1, a.shape[0] + 1) - a.shape[0] - 1) * a)) / (
            a.shape[0] * np.sum(a)
        )

    return complexity/ len(samples)


run_exp([
    complexity, 
    sparseness
    ],
        [
    "Complexity", 
    "Sparseness"
    ])
