# UIB XSDS

![SegmentLocal](office.gif "segment")

**Objectius**:

- [x] Crear una xarxa simple de regressió que permeti predir quants d'objects hi ha de cada tipus a una imatge **Sense overfitting**.
- [x] Aplicar LIME per obtenir explicacions de les prediccions i comprovar que són coherents.
- [x] Modificar la sortida de la xarxa perquè enlloc de prediure la quantitat d'objectes calculi les funcions de *P. Cortez et al.*
- [x] Calcular mètriques per saber si les explicacions corresponen a les esperades.

In [None]:
from typing import List, Callable, Tuple
import time
import copy
import glob

from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from matplotlib import pyplot as plt
import torch.nn.functional as F
from torch import nn

from skimage.segmentation import mark_boundaries
from prettytable import PrettyTable
from IPython import display
from lime import lime_image
from tqdm.auto import tqdm 
import pandas as pd
import numpy as np
import pylab as pl
import torch
import cv2


# Data

In [None]:
DATASET_PATH = "./xsds_shape_count_v2/train"
VAL_PATH = "./xsds_shape_count_v2/val"

In [None]:
class ImageDataset(Dataset):
    """General pytorch dataset.

    The data should be build with the following structure:
        /class_1
            img1.png
            img2.png
            ...
        /class_2
            img1.png
            img2.png
            ...
    """

    def __init__(
        self,
        file_names: List[str],
        get_img_fn: Callable,
        gt_info: List[List[int]],
        fn_agg
    ):
        self.__get_img_fn = get_img_fn

        self.__file_names = file_names
        self.__labels = gt_info
        self.__fn_agg = fn_agg

    def __getitem__(self, index: int) -> Tuple[np.array, torch.Tensor]:
        img_path = self.__file_names[index]
        image = self.__get_img_fn(img_path)

        label = np.array(self.__labels[index], dtype=np.float32)
    
        y = self.__fn_agg(label)

        return image, torch.from_numpy(y)

    def __len__(self) -> int:
        return len(self.__file_names)

In [None]:
def normalize_pandas(df):
    for key in df.columns:
        df[key] = (df[key]-df[key].min())/(df[key].max()-df[key].min())
    
    return df

def ssim(label):
    first_term = 1/2 * np.sin(np.pi * (label[0] / 2))
    second_term = 1/4 * np.sin(np.pi * (label[1] / 2))
    third_term = 1/6 * np.sin(np.pi * (label[2] / 2))
        
    y = [first_term +  second_term + third_term]
    
    return np.array(y)


def suum(label):
    first_term = 1/2 * label[0]
    second_term = 1/4 * label[1]
    third_term = 1/6 * label[2]
        
    y = [first_term +  second_term + third_term]
    
    return np.array(y)


In [None]:
df = pd.read_csv(f"{DATASET_PATH}/dades.csv", sep=";")
df_val = pd.read_csv(f"{VAL_PATH}/dades.csv", sep=";")

df = df.drop('Unnamed: 0', axis=1)
df_val = df_val.drop('Unnamed: 0', axis=1)

df = normalize_pandas(df)
df_val = normalize_pandas(df_val)  

gt = df.values.tolist()
gt_val = df_val.values.tolist();

In [None]:
dat_train = ImageDataset(
    sorted(glob.glob(f"{DATASET_PATH}/*.png")), 
    lambda x: cv2.imread(x, cv2.IMREAD_GRAYSCALE).reshape(1, 512, 512), 
    gt, 
    suum
)
dat_valid = ImageDataset(
    sorted(glob.glob(f"{VAL_PATH}/*.png")), 
    lambda x: cv2.imread(x, cv2.IMREAD_GRAYSCALE).reshape(1, 512, 512), 
    gt_val,
    suum
)


data_loader_train = DataLoader(dat_train, batch_size=128, shuffle=False)
data_loader_val = DataLoader(dat_valid, batch_size=128, shuffle=False)

# Instance NN and train

![SegmentLocal](https://imgs.xkcd.com/comics/machine_learning.png "segment")

### Quadern de bitàcora
**Canvi actual**: 

Val Acc : 0.97

- Retorn a l'arquitectura simple.
- Funció $ssim(x)$ original, ara dividim els valors de $x_n$ per 2.

**Old**


- Dropout 0.4  &#8594; 0.2
- ReLU -> Leaky ReLU

***

- Dropout 0.2  &#8594; 0.4

***

Val ACC : 0.887

- `1/3 * np.sin(np.pi * (label[0])) if label[0] > 0 else 0` &#8594; `1/2 * np.sin(np.pi * (1 - label[0])) if label[0] > 0 else 0` 
- Tercera capa densa.

***

- `1/2 * np.sin(np.pi * (1 - label[0])) if label[0] > 0 else 0` &#8594; `1/3 * np.sin(np.pi * (label[0])) if label[0] > 0 else 0`

***

- *Learning rate*. De 0.005 a 0.001
- *Dropout*. De 0.5 a 0.2

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class Net(nn.Module):
    def __init__(self, numChannels, classes):
        # call the parent constructor
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=numChannels, out_channels=25,
            kernel_size=(3, 3), padding="same")
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        
        self.conv2 = nn.Conv2d(in_channels=25, out_channels=50,
            kernel_size=(3, 3), padding="same")
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

        self.conv3 = nn.Conv2d(in_channels=50, out_channels=50,
            kernel_size=(3, 3), padding="same")
        self.relu3 = nn.ReLU()
        self.maxpool3 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        
        self.fc1 = nn.Linear(in_features=64*64*50, out_features=50)
        self.relu5 = nn.ReLU()
        
        self.fc2 = nn.Linear(in_features=50, out_features=classes)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)
        
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)
        
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.maxpool3(x)
        
        
        x = torch.flatten(x, 1)
        
        x = self.fc1(x)
        x = self.relu5(x)
        
        x = self.fc2(x)
        
        return x

In [None]:
%matplotlib inline

def train_model(
    model, 
    dataloaders, 
    criterion, 
    optimizer, 
    num_epochs=25, 
    is_inception=False, 
    do_validation=True, 
    regression=False,
    plot_acc=False,
):
    since = time.time()

    history = {}
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    best_acc_std = 0.0
    best_loss = 0.0
    best_loss_std = 0.0
    
    phases = ['train']
    
    if do_validation:
        phases.append('val')
        
    epoch_loss = 0
    epoch_acc = 0
    try:
        for epoch in range(num_epochs):
            # Each epoch has a training and validation phase
            for phase in phases:
                if phase == 'train':
                    model.train()  # Set model to training mode
                else:
                    model.eval()   # Set model to evaluate mode

                running_loss = []
                running_acc = []

                pbar = tqdm(dataloaders[phase], desc='Time, he\'s waiting in the wings')
                # Iterate over data.
                for inputs, labels in pbar:
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    inputs = inputs.type(torch.float32)

                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):
                        # Get model outputs and calculate loss
                        # Special case for inception because in training it has an auxiliary output. In train
                        #   mode we calculate the loss by summing the final output and the auxiliary output
                        #   but in testing we only consider the final output.
                        outputs = model(inputs)

                        loss = criterion(outputs.reshape(1, -1), labels.reshape(1, -1))

                        _, preds = torch.max(outputs, 1)

                        # backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    # statistics
                    running_loss.append((loss.item() * inputs.size(0)) / len(outputs))
                    if not regression:
                        aux = (preds == torch.max(labels, 1)[1])
                        running_acc.append((torch.sum(aux).double().cpu().detach().numpy() / len(outputs)))
                    else:
                        running_acc.append(torch.sum(torch.abs(outputs - labels)).double().cpu().detach().numpy() / len(outputs))

                    pbar.set_description('Epoch {}/{} - {} - ACC: {:.4f} LOSS: {:.4f}'.format(epoch, num_epochs - 1, phase.capitalize(), 1 - min(1, np.mean(running_acc)), np.mean(running_loss)))

                epoch_loss = np.mean(running_loss)
                epoch_acc = np.mean(running_acc)


                if regression:
                    epoch_acc = 1 - min(1, epoch_acc)

                # deep copy the model
                if (phase == 'val' or not 'val' in phases) and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_loss = epoch_loss

                    best_loss_std = np.std(running_loss)
                    best_acc_std = np.std(running_acc)

                    best_model_wts = copy.deepcopy(model.state_dict())

                if phase not in history:
                    history[phase] = []

                history[phase].append(epoch_acc)

            if plot_acc:
                plt.figure(figsize=(15,15))

                pl.title('Best Acc: [{:4f}-{:4f}]'.format(best_acc - (best_acc_std / 2), best_acc + (best_acc_std / 2)))

                for phase in phases:
                    pl.plot(history[phase], label=phase.capitalize())

                pl.legend()
                display.clear_output(wait=True) 
                display.display(pl.gcf())
                plt.close()                    

    except KeyboardInterrupt:
        print('Interrupted')
    except Exception as e:
        print(e)
    finally:
        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        print('Best val Acc: [{:4f}-{:4f}]'.format(best_acc - best_acc_std/2, best_acc + best_acc_std/2))
        
        # load best model weights
        model.load_state_dict(best_model_wts)
        return model, history

In [None]:
net = Net(1, 3)

model = nn.Sequential(net, nn.Sequential(
                            nn.Linear(in_features=3, out_features=10),
                            nn.ReLU(),
                            nn.Linear(in_features=10, out_features=10),
                            nn.ReLU(),
                            nn.Linear(in_features=10, out_features=1)
                        )
                     )

model = nn.DataParallel(model)
model = model.to(device)

In [None]:
criterion = nn.L1Loss()
optimizer = optimizer_ft = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.7)

# Train and evaluate
net, hist = train_model(
    model, 
    {"train": data_loader_train, "val": data_loader_val}, 
    criterion, 
    optimizer_ft, 
    num_epochs=500, 
    do_validation=True,
    regression=True,
    plot_acc=True
)

In [None]:
max(hist['val'])

### Resultats

![SegmentLocal](./mem.jpg "segment")

In [None]:
# model.load_state_dict(torch.load('ssim_xsds_values.pt'));
# torch.save(model.state_dict(), 'ssim_xsds_values.pt')

In [None]:
for img, gt in data_loader_val:
    break
model = model.eval()
pred = model(img.to(device).type(torch.float32))
img = img.to(device).type(torch.float32)

res_val = []
diff = np.abs(pred.detach().cpu().numpy() - gt.cpu().numpy())

print(np.mean(diff), np.std(diff))

In [None]:
pred = pred.detach().cpu().numpy()

In [None]:
x = PrettyTable()
x.field_names = ["Prediction", "GT", "Diff"]

for i, (p, gt_img) in enumerate(zip(pred, gt.numpy())):
    p = np.abs(p)[0]
    gt_img = np.abs(gt_img[0])
    
    x.add_row([f"{p:.2f}", f"{gt_img:.2f}", f"{(p - gt_img):.2f}"])
    
    if i > 10:
        break

print(x)

# XAI

In [None]:
def gt_exp_gen(paths):
    for path in paths:
        img = cv2.imread(path)
        yield img

def calc_metric_norm(gt_exp, explanations, dist_fn, full_data = False, plot = False):
    count = 0

    diff = []
    ii = 1
    for gt_img, expl in zip(gt_exp, explanations):
        mask = np.sum(gt_img, axis=-1)

        conts, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

        gt_img = gt_img.astype(np.float64)
        gt_img *= ([0.55, 0.27 * 2, 0.18 * 3])
        max_img = gt_img.max() if gt_img.max() > 0 else 1
        
        gt_img /= max_img
        
        gt_img_resumed = np.sum(gt_img, axis=2)

        expl = expl[:,:,0].astype(np.float64)
        expl /= expl.max() if expl.max() > 0 else 1

        local_diff = []
        for c_val in [0/max_img, (1 * 0.55)/max_img, (2 * 0.27)/max_img, (3 * 0.18)/max_img]:
            c_gt_img = gt_img_resumed[gt_img_resumed == c_val]
            c_expl = expl[gt_img_resumed == c_val]
            
            aux_diff = 1 - dist_fn(c_expl, c_gt_img)
            
            local_diff.append(aux_diff)

        if local_diff == local_diff:
            diff.append(np.array(local_diff))
        
        if (ii < 17) and plot:       
            plt.subplot(4, 4, ii)
            plt.axis('off')
            plt.imshow(expl)

            ii += 1
    
    return diff


def mae(img1, img2):
    return (np.sum(np.abs(img1 - img2)) / img2.size)

def mse(img1, img2):
    return np.sum((img1 - img2) ** 2) / img2.size

def calc_metric(gt_exp, explanations, dist_fn, full_data = False, plot = False):
    count = 0

    diff = []
    ii = 1
    for gt_img, expl in zip(gt_exp, explanations):
        gt_img = gt_img.astype(np.float64)
        gt_img *= [0.55, 0.27, 0.18]

        gt_img /= gt_img.max() if gt_img.max() > 0 else 1
        gt_img_resumed = np.sum(gt_img, axis=2)

        expl = expl[:,:,0].astype(np.float64)

        expl /= expl.max() if expl.max() > 0 else 1
        
        expl_i = expl[(gt_img_resumed > 0.05) | (expl > 0.05)]
        gt_img_resumed_i = gt_img_resumed[(gt_img_resumed > 0.05) | (expl > 0.05)]
        

        local_diff = 1 - dist_fn(expl_i, gt_img_resumed_i)
        
        if local_diff == local_diff:
            diff.append(local_diff)    


        if (ii < 17) and plot:       
            plt.subplot(4, 4, ii)
            plt.axis('off')
            plt.imshow(expl)
            plt.title(f"Metric {diff[-1]:.2f}")

            ii += 1
    
    if not full_data:
        return np.mean(diff), np.std(diff)
    else:
        return np.mean(diff), np.std(diff), diff

In [None]:
def resume_metrics()

## LIME

In [None]:
def segment_circles(image):
    return image[:, :, 0]

def own_segment(image):
    aux_img_seg = np.zeros_like(image.astype(np.uint8)[:, :, 0], dtype=np.uint8)
    cont, _ = cv2.findContours(image.astype(np.uint8)[:, :, 0], cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    for i, c in enumerate(cont):
        cv2.drawContours(aux_img_seg, [c], -1, i+1, -1)
    
    return aux_img_seg

In [None]:
def batch_predict(images):
    img = np.transpose(images, (0, 3, 1, 2))
    img_t = torch.from_numpy(img)
    img_t = img_t.to(device)
    
    logits = model(img_t[:, 1:2, :, :])
    
    return logits.detach().cpu().numpy()[:, :]

explainer = lime_image.LimeImageExplainer()
lime_res = []

for i in tqdm(range(64)):
    explanation = explainer.explain_instance(img.repeat(1, 3, 1, 1)[i,:,:,:].type(torch.float32).cpu().permute(1, 2, 0).numpy(), 
                                             batch_predict, # classification function
                                             top_labels=3, 
                                             hide_color=0,
                                             num_samples=200,
                                             segmentation_fn=segment_circles,
                                             progress_bar=False,
                                             random_seed=42)

    mask = np.zeros((explanation.segments.shape[0], explanation.segments.shape[1], 3), dtype=np.float64)
    for k in explanation.local_exp.keys():
        for key, val in explanation.local_exp[k]:
            mask[:, :, k][explanation.segments == key] = np.abs(val)
    lime_res.append(mask)


In [None]:
plt.figure(figsize=(25, 25))

expl_gen = gt_exp_gen(sorted(glob.glob(VAL_PATH + "/gt/*.png")))
expl_metric = calc_metric_norm(expl_gen, lime_res, mse, full_data=True, plot=True)

In [None]:
np.nanmean(np.vstack(expl_metric), axis=0)

In [None]:
expl_metric

In [None]:
df_val.iloc[0]

In [None]:
(1/2 * np.sin(np.pi * (0.0/2))) + (1/4 * np.sin(np.pi * (0.175/2))) + (1/6 * np.sin(np.pi * (0.461/2)))

In [None]:
aux = np.copy(img.cpu())
aux[aux == 1] = 0

plt.imshow(aux[0,0,:,:])
plt.colorbar()

In [None]:
model(torch.from_numpy(aux[0:0+1, :, :, :]).to(device))

In [None]:
from uib_xai.metrics import faithfullness

table = PrettyTable()
table.field_names = ["Real metric", "Faithfulness", "DIFF"]
diff_s = []
for i, (sal, x_met) in enumerate(zip(lime_res, expl_metric)):
    img = img.to(device).type(torch.float32)
    
    faith = faithfullness.faithfullness(img[i:i+1,:,:,:], sal, lambda x: [torch.squeeze(model(x)).cpu().detach().numpy()], (50, 50), 0)
    faith = np.abs(faith)
    
    diff = np.abs(x_met - faith)
    diff_s.append(diff)
    table.add_row([f"{(x_met):.2f}", f"{faith:.2f}", f"{diff:.2f}"])

print(table)

In [None]:
plt.imshow(lime_res[-1])

# RISE

In [None]:
from rise import RISE

In [None]:
wrapped_model = RISE(model, input_size=(512, 512), n_masks=5000).to(device)

In [None]:
rise_img = []
for i in tqdm(range(64)):
    torch.cuda.empty_cache()
    with torch.no_grad():
        saliency = wrapped_model(img.to(device).type(torch.float32)[i:i+1, :, :, :])
        
        rise_img.append(saliency.cpu().numpy().transpose((1,2,0)))

In [None]:
plt.figure(figsize=(20,20))

expl_gen = gt_exp_gen(sorted(glob.glob("./xsds_shape_count_v2/val/gt/*.png")))
mean, std, expl_metric = calc_metric_norm(expl_gen, rise_img, mse, full_data=True, plot=True)

In [None]:
print(mean, std)

In [None]:
from uib_xai.metrics import faithfullness

table = PrettyTable()
table.field_names = ["Metric", "Faithfulness"]
for i, (sal, x_met) in enumerate(zip(rise_img, expl_metric)):
    img = img.to(device).type(torch.float32)
    
    faith = faithfullness.faithfullness(img[i:i+1,:,:,:], sal, lambda x: [torch.squeeze(model(x)).cpu().detach().numpy()], (50, 50), 0)
    faith = np.abs(faith)
    
    table.add_row([f"{(x_met):.2f}" , f"{faith:.2f}"])
    
    if i > 10:
        break

print(table)

In [None]:
plt.imshow(rise_img[2])
plt.title(f"{rise_img[2].shape}");

In [None]:
cv2.imwrite("prova.png", rise_img[2] * 255)

In [None]:
plt.imshow((img[2,:,:,:].cpu().numpy()[0, :, :] /255) + rise_img[2][:, :, 0])

In [None]:
img[i,:,:,:].cpu().numpy().shape

In [None]:
saliency.numpy().transpose((1,2,0)).shape

In [None]:
plt.figure(figsize=(15, 15))

for i, sal in enumerate(rise_img):
    plt.subplot(8,8, i+1)
    aux = img.to(device).type(torch.float32)[i:i+1, :, :, :].cpu().numpy()[0][0] 
    aux = aux / aux.max()

    plt.imshow(sal) 

In [None]:
saliency.shape

# CAM

In [None]:
print(model)

In [None]:
from pytorch_grad_cam import GradCAMPlusPlus, ScoreCAM, GradCAM

target_layers = [model.module[0].maxpool2]
interpreter = ScoreCAM(model=net, target_layers=target_layers, use_cuda=True)

imgs = []
interpretations = []

In [None]:
plt.figure(figsize=(25, 25))

for i in range(16):
    ax1 = plt.subplot(4, 4, i + 1)
    res_interpret = interpreter(input_tensor = img[i: i+ 1,:,:,:].type(torch.float32))
    grayscale_cam = res_interpret[0, :]
    
    imgs.append(img[i: i+ 1,:,:,:])
    interpretations.append(grayscale_cam)
        
    plt.imshow(img[i, 0, :, :].cpu().numpy(), cmap='gray');
    plt.imshow(grayscale_cam[: ,: ], alpha=.5);
#     ax1.title.set_text(f"GT {np.argmax(labels[i])} Predicció {np.argmax(preds[i])}")

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

plt.figure(figsize=(15,15))

for i in range(16):
    gt_res = gt[i, :].cpu().detach().numpy()
    pred_res = pred[i, :]
    
    res = f"GT {str(round(gt_res[0], 2))} | PRED {str(round(pred_res[0], 2))}"

    cam_plus_plus = interpreter(input_tensor = img[i: i+ 1,:,:,:].type(torch.float32))
    grayscale_cam = cam_plus_plus[0, :]
    grayscale_cam[grayscale_cam < 0.5] = 0

    plt.subplot( 4, 4, i + 1)
    plt.title(res)
    plt.imshow((grayscale_cam * 255).astype(np.uint8))
    plt.colorbar()
    
    cv2.imwrite(f"./res/{str(i).zfill(3)}.png", grayscale_cam * 255)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

plt.figure(figsize=(15,15))

for i in range(16):
    plt.subplot( 4, 4, i + 1)
    plt.imshow(img[i, 0, :, :].cpu().numpy())
    plt.colorbar()    

In [None]:
plt.figure()
plt.imshow(img[4, 0, :, :].cpu().numpy())
plt.colorbar()

In [None]:
np.unique(grayscale_cam)

In [None]:
counts, bins = np.histogram(grayscale_cam[grayscale_cam > 0.01])
plt.hist(bins[:-1], bins, weights=counts)

![SegmentLocal](https://i-download.imgflip.com/ighss.gif "segment")