In [1]:
import os
import time
import torch
import warnings
import numpy as np
import nibabel as nib
from time import strftime
from functools import partial
from copy import deepcopy
from torch.utils.data import DataLoader
import torch.nn.functional as F
from captum.attr import IntegratedGradients, InputXGradient, DeepLift
from captum.attr import Deconvolution, Occlusion
from captum.attr import FeatureAblation, FeaturePermutation
from captum.attr import GuidedBackprop, GuidedGradCam, LayerGradCam
from captum.attr import DeepLiftShap, GradientShap, KernelShap
from models import SimpleResNet
from datasets import ShapesDataset, LocationDataset, RotationDataset, ScaleDataset
from datasets import ContrastDataset
from utils import color_codes, time_to_string

warnings.simplefilter("ignore")

master_seed = 42
data_folder = '/home/mariano/data/Saliency/'

def attribution(x, y, attr_m, *args, **kwargs):
    attr = attr_m.attribute(
        x, target=0, *args, **kwargs
    )
    if x.shape != attr.shape:
        attr = F.interpolate(attr, size=x.size()[2:], mode='trilinear')
    attr_map = attr.squeeze().detach().cpu().numpy()
    if y != 1:
        attr_map = - attr_map
    return attr_map
    
def run_experiments(
    master_seed, property_name, dataset, network, weight_path,
    test_samples=100, train_samples=10000, val_samples=1000,
    im_size=(64, 64, 64), n_seeds=30
):
    np.random.seed(master_seed)
    seeds = np.random.randint(0, 100000, n_seeds)
    c = color_codes()
    saliency_path = os.path.join(weight_path, property_name.capitalize())
    if not os.path.exists(saliency_path):
        os.mkdir(saliency_path)
    print(
        '{:}[{:}] {:}Creating {:}testing{:} dataset '
        'for {:} classification{:}'.format(
            c['clr'] + c['c'], strftime("%m/%d/%Y - %H:%M:%S"), c['g'], c['nc'] + c['y'],
            c['nc'] + c['g'], c['b'] + property_name + c['nc'] + c['g'], c['nc']
        )
    )
    testing_set = dataset(im_size=im_size, n_samples=test_samples, seed=master_seed)
    train_batch = 20
    test_batch = 50
    epochs = 2
    patience = 2
    for test_n, seed in enumerate(seeds):
        acc = 0
        print(
            '{:}[{:}] {:}Starting experiment {:}(seed {:05d}){:} [{:02d}/{:02d}] '
            '{:}for {:} classification{:}'.format(
                c['clr'] + c['c'], strftime("%m/%d/%Y - %H:%M:%S"), c['g'],
                c['nc'] + c['y'], seed, c['nc'] + c['c'], test_n + 1, len(seeds),
                c['nc'] + c['g'], c['b'] + property_name + c['nc'] + c['g'], c['nc']
            )
        )
        np.random.seed(seed)
        torch.manual_seed(seed)
        net = network(n_images=1)
        
        methods = [
            # 'Input'-based
            (IntegratedGradients, 'IntegratedGradients'),
            (InputXGradient, 'InputXGradient'),
            (DeepLift, 'DeepLift'),
            # Inverse
            (Deconvolution, 'Deconvolution'),
            # Perturbation-based
            # (partial(Occlusion, sliding_window_shapes=(8, 8, 8)), 'Occlusion'),
            # (FeaturePermutation, 'FeaturePermutation'),
            # GradCAM-related
            (partial(LayerGradCam, layer=net.extractor.down[1].conv), 'LayerGradCam'),
            (partial(GuidedGradCam, layer=net.extractor.down[1].conv), 'GuidedGradCam'),
            (GuidedBackprop, 'GuidedBackprop'),
        ]
               
        net.init = False
        n_param = sum(
            p.numel() for p in net.parameters() if p.requires_grad
        )
        training_set = dataset(im_size=im_size, n_samples=train_samples)
        validation_set = dataset(im_size=im_size, n_samples=val_samples)
        training_loader = DataLoader(
            training_set, train_batch, True
        )
        validation_loader = DataLoader(
            validation_set, test_batch
        )
        model_path = os.path.join(
            weight_path, '{:}_s{:05d}.pt'.format(property_name, seed)
        )
        try:
            net.load_model(model_path)
        except IOError:
            net.fit(training_loader, validation_loader, epochs=epochs, patience=patience)
            net.save_model(model_path)
            
        print(
            '{:}[{:}] {:}Testing {:}(seed {:05d}){:} [{:02d}/{:02d}] '
            '{:}for {:} classification <{:04d}/{:04d} samples>{:}'.format(
                c['clr'] + c['c'], strftime("%m/%d/%Y - %H:%M:%S"), c['g'],
                c['nc'] + c['y'], seed, c['clr'] + c['c'], test_n + 1, len(seeds),
                c['nc'] + c['g'], c['b'] + property_name + c['nc'] + c['g'],
                len(training_set), len(validation_set), c['nc']
            )
        )

        for i, (x, (y, mask)) in enumerate(DataLoader(testing_set, test_batch)):
            pred_y = net.inference(x.numpy(), False)
            pred_y = (pred_y > 0).squeeze().astype(np.uint8)
            y = y.numpy().astype(np.uint8)
            acc += np.sum(y == pred_y) / len(testing_set)
        print(
            '{:}[{:}] {:}Accuracy {:}(seed {:05d}){:} [{:02d}/{:02d}] {:}'
            '{:5.3f}{:}'.format(
                c['clr'] + c['c'], strftime("%m/%d/%Y - %H:%M:%S"), c['g'],
                c['nc'] + c['y'], seed, c['clr'] + c['c'], test_n + 1, len(seeds),
                c['nc'] + c['b'], acc, c['nc']
            )
        )
    print(
        '{:}[{:}] {:}Saliency maps{:}'.format(
            c['clr'] + c['c'], strftime("%m/%d/%Y - %H:%M:%S"), c['g'], c['nc']
        )
    )
    
    init_start = time.time()
    for i, (x, (y, mask)) in enumerate(DataLoader(testing_set, 1)):
        y = y.numpy().astype(np.uint8)
        x_cuda = x.to(net.device)
        x_cuda.requires_grad_()
        # Image
        filepath = os.path.join(
            saliency_path, '{:}-{:04d}_s{:05d}_c{:d}_image.nii.gz'.format(
                property_name, i, master_seed, y.tolist()[0]
            )
        )
        nii = nib.Nifti1Image(np.squeeze(x.numpy()), np.eye(4))
        nii.to_filename(filepath)
        filepath = os.path.join(
            saliency_path, '{:}-{:04d}_s{:05d}_c{:d}_mask.nii.gz'.format(
                property_name, i, master_seed, y.tolist()[0]
            )
        )
        nii = nib.Nifti1Image(np.squeeze(mask.numpy()).astype(np.uint8), np.eye(4))
        nii.to_filename(filepath)
        for attr_m, attr_name in methods:
            map_path = os.path.join(
                saliency_path, '{:}-{:04d}_{:}_c{:d}.nii.gz'.format(
                    property_name, i, attr_name, y.tolist()[0]
                )
            )
            try:
                nib.load(map_path).get_fdata()
            except IOError:
                maps = []
                predictions = []
                for test_n, seed in enumerate(seeds):
                    time_elapsed = time.time() - init_start
                    eta = (len(testing_set) - (i + 1)) * time_elapsed / (i + 1)
                    model_path = os.path.join(
                        weight_path, '{:}_s{:05d}.pt'.format(property_name, seed)
                    )
                    net.load_model(model_path)
                    net.eval()
                    print(' '.join([' '] * 300), end='\r')
                    print(
                        '\033[KGenerating {:} map'
                        '(case {:d}/{:d} | seed {:05d} [{:02d}/{:02d}]) {:} ETA {:}'.format(
                            attr_name, i + 1, len(testing_set),
                            seed, test_n, len(seeds),
                            time_to_string(time_elapsed),
                            time_to_string(eta),
                        ), end='\r'
                    )
                    maps.append(
                        attribution(x_cuda, y, attr_m(net))
                    )
                    with torch.no_grad():
                        predictions.append(
                            int(net(x_cuda).squeeze().cpu().detach().numpy().tolist() >= 0)
                        )
                hdr = deepcopy(nii.header)
                hdr['descrip'] = str.encode('Predictions: {:}'.format(
                    ' '.join([
                        str(p) for p in predictions
                    ])
                ))
                out_nii = nib.Nifti1Image(np.stack(maps, axis=-1), None, hdr)
                out_nii.to_filename(map_path)
            

    print(
        '{:}[{:}] {:}Experiments for {:} classification finished{:}'.format(
            c['clr'] + c['c'], strftime("%m/%d/%Y - %H:%M:%S"), c['r'], c['nc'] + c['y'],
            c['nc'] + c['r'], c['b'] + property_name + c['nc'] + c['r'], c['nc']
        )
    )

In [2]:
run_experiments(
    master_seed, 'shape', ShapesDataset, SimpleResNet, data_folder
)

[K[36m[04/21/2023 - 09:39:45] [32mCreating [0m[33mtesting[0m[32m dataset for [1mshape[0m[32m classification[0m
[K[36m[04/21/2023 - 09:39:53] [32mStarting experiment [0m[33m(seed 15795)[0m[36m [01/30] [0m[32mfor [1mshape[0m[32m classification[0m                                                                                                                                                                                                                                                                                                                                                                                                                                                                     
[K[36m[04/21/2023 - 09:39:57] [32mTesting [0m[33m(seed 15795)[K[36m [01/30] [0m[32mfor [1mshape[0m[32m classification <20000/2000 samples>[0m
[K[36m[04/21/2023 - 09:39:58] [32mAccuracy [0m[33m(seed 15795)[K[36m [01/30] [0m[1m1.000[0m
[K[36m[04/21/2023 -

[K[36m[04/21/2023 - 09:40:18] [32mAccuracy [0m[33m(seed 05311)[K[36m [20/30] [0m[1m1.000[0m
[K[36m[04/21/2023 - 09:40:18] [32mStarting experiment [0m[33m(seed 83104)[0m[36m [21/30] [0m[32mfor [1mshape[0m[32m classification[0m
[K[36m[04/21/2023 - 09:40:18] [32mTesting [0m[33m(seed 83104)[K[36m [21/30] [0m[32mfor [1mshape[0m[32m classification <20000/2000 samples>[0m
[K[36m[04/21/2023 - 09:40:19] [32mAccuracy [0m[33m(seed 83104)[K[36m [21/30] [0m[1m1.000[0m
[K[36m[04/21/2023 - 09:40:19] [32mStarting experiment [0m[33m(seed 53707)[0m[36m [22/30] [0m[32mfor [1mshape[0m[32m classification[0m
[K[36m[04/21/2023 - 09:40:19] [32mTesting [0m[33m(seed 53707)[K[36m [22/30] [0m[32mfor [1mshape[0m[32m classification <20000/2000 samples>[0m
[K[36m[04/21/2023 - 09:40:20] [32mAccuracy [0m[33m(seed 53707)[K[36m [22/30] [0m[1m1.000[0m
[K[36m[04/21/2023 - 09:40:20] [32mStarting experiment [0m[33m(seed 85305)[0m[36m [2

In [3]:
run_experiments(
    master_seed, 'location', LocationDataset, SimpleResNet, data_folder
)

[K[36m[04/21/2023 - 09:41:37] [32mCreating [0m[33mtesting[0m[32m dataset for [1mlocation[0m[32m classification[0m
[K[36m[04/21/2023 - 09:41:45] [32mStarting experiment [0m[33m(seed 15795)[0m[36m [01/30] [0m[32mfor [1mlocation[0m[32m classification[0m                                                                                                                                                                                                                                                                                                                                                                                                                                                                  
[K[36m[04/21/2023 - 09:41:45] [32mTesting [0m[33m(seed 15795)[K[36m [01/30] [0m[32mfor [1mlocation[0m[32m classification <20000/2000 samples>[0m
[K[36m[04/21/2023 - 09:41:46] [32mAccuracy [0m[33m(seed 15795)[K[36m [01/30] [0m[1m1.000[0m
[K[36m[04/21/

[K[36m[04/21/2023 - 09:42:05] [32mAccuracy [0m[33m(seed 67969)[K[36m [19/30] [0m[1m1.000[0m
[K[36m[04/21/2023 - 09:42:05] [32mStarting experiment [0m[33m(seed 05311)[0m[36m [20/30] [0m[32mfor [1mlocation[0m[32m classification[0m
[K[36m[04/21/2023 - 09:42:05] [32mTesting [0m[33m(seed 05311)[K[36m [20/30] [0m[32mfor [1mlocation[0m[32m classification <20000/2000 samples>[0m
[K[36m[04/21/2023 - 09:42:06] [32mAccuracy [0m[33m(seed 05311)[K[36m [20/30] [0m[1m1.000[0m
[K[36m[04/21/2023 - 09:42:06] [32mStarting experiment [0m[33m(seed 83104)[0m[36m [21/30] [0m[32mfor [1mlocation[0m[32m classification[0m
[K[36m[04/21/2023 - 09:42:06] [32mTesting [0m[33m(seed 83104)[K[36m [21/30] [0m[32mfor [1mlocation[0m[32m classification <20000/2000 samples>[0m
[K[36m[04/21/2023 - 09:42:07] [32mAccuracy [0m[33m(seed 83104)[K[36m [21/30] [0m[1m1.000[0m
[K[36m[04/21/2023 - 09:42:07] [32mStarting experiment [0m[33m(seed 53707)

In [4]:
run_experiments(
    master_seed, 'rotation', RotationDataset, SimpleResNet, data_folder
)

[K[36m[04/21/2023 - 09:43:39] [32mCreating [0m[33mtesting[0m[32m dataset for [1mrotation[0m[32m classification[0m
[K[36m[04/21/2023 - 09:43:48] [32mStarting experiment [0m[33m(seed 15795)[0m[36m [01/30] [0m[32mfor [1mrotation[0m[32m classification[0m                                                                                                                                                                                                                                                                                                                                                                                                                                                                  
[K[36m[04/21/2023 - 09:43:48] [32mTesting [0m[33m(seed 15795)[K[36m [01/30] [0m[32mfor [1mrotation[0m[32m classification <20000/2000 samples>[0m
[K[36m[04/21/2023 - 09:43:49] [32mAccuracy [0m[33m(seed 15795)[K[36m [01/30] [0m[1m0.975[0m
[K[36m[04/21/

[K[36m[04/21/2023 - 09:44:09] [32mAccuracy [0m[33m(seed 67969)[K[36m [19/30] [0m[1m0.970[0m
[K[36m[04/21/2023 - 09:44:09] [32mStarting experiment [0m[33m(seed 05311)[0m[36m [20/30] [0m[32mfor [1mrotation[0m[32m classification[0m
[K[36m[04/21/2023 - 09:44:09] [32mTesting [0m[33m(seed 05311)[K[36m [20/30] [0m[32mfor [1mrotation[0m[32m classification <20000/2000 samples>[0m
[K[36m[04/21/2023 - 09:44:10] [32mAccuracy [0m[33m(seed 05311)[K[36m [20/30] [0m[1m0.985[0m
[K[36m[04/21/2023 - 09:44:10] [32mStarting experiment [0m[33m(seed 83104)[0m[36m [21/30] [0m[32mfor [1mrotation[0m[32m classification[0m
[K[36m[04/21/2023 - 09:44:10] [32mTesting [0m[33m(seed 83104)[K[36m [21/30] [0m[32mfor [1mrotation[0m[32m classification <20000/2000 samples>[0m
[K[36m[04/21/2023 - 09:44:11] [32mAccuracy [0m[33m(seed 83104)[K[36m [21/30] [0m[1m0.980[0m
[K[36m[04/21/2023 - 09:44:11] [32mStarting experiment [0m[33m(seed 53707)

In [None]:
run_experiments(
    master_seed, 'scale', ScaleDataset, SimpleResNet, data_folder
)

[K[36m[04/21/2023 - 09:45:34] [32mCreating [0m[33mtesting[0m[32m dataset for [1mscale[0m[32m classification[0m
[K[36m[04/21/2023 - 09:45:43] [32mStarting experiment [0m[33m(seed 15795)[0m[36m [01/30] [0m[32mfor [1mscale[0m[32m classification[0m                                                                                                                                                                                                                                                                                                                                                                                                                                                                     
[K[36m[04/21/2023 - 09:45:43] [32mTesting [0m[33m(seed 15795)[K[36m [01/30] [0m[32mfor [1mscale[0m[32m classification <20000/2000 samples>[0m
[K[36m[04/21/2023 - 09:45:44] [32mAccuracy [0m[33m(seed 15795)[K[36m [01/30] [0m[1m1.000[0m
[K[36m[04/21/2023 -

[K           Epoch 001 | [32m 0.0305[0m |  0.0000 | [36m  0.0178[0m |   0.0000 |   0.0000 | [36m  0.0035[0m | 0.000 | 18m 56s                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   
Training finished in 2 epochs (37m 43s) with minimum loss = 0.000000 (epoch 0)
[K[36m[04/21/2023 - 10:23:44] [32mTesting [0m[33m(seed 62955)[K[36m [17/30] [0m[32mfor [1mscale[0m[32m classification <20000/2000 samples>[0m
[K[36m[04/21/2023 - 10:23:45] [32mAccuracy [0m[33m(seed 62955)[K[36m [17/30] [0m[1m0.990[0m
[K[36m[04/21/2023 - 10:23:45] [32mStarting experiment [0m[3

[K[36m[04/21/2023 - 12:51:47] [32mAccuracy [0m[33m(seed 83104)[K[36m [21/30] [0m[1m0.995[0m
[K[36m[04/21/2023 - 12:51:47] [32mStarting experiment [0m[33m(seed 53707)[0m[36m [22/30] [0m[32mfor [1mscale[0m[32m classification[0m
[K           Epoch num |  train  |   val   |   xent   |    fn    |    fp    |   acc    |  drp  |
           ----------|---------|---------|----------|----------|----------|----------|-------|
[K           [32mEpoch 000[0m | [32m 0.0500[0m | [32m 0.0000[0m | [36m  0.0220[0m | [36m  0.0000[0m | [36m  0.0000[0m | [36m  0.0060[0m | 0.000 | 18m 25s                                                                                                                                                                                                                                                                                                                                                                                                      

[K           [32mEpoch 000[0m | [32m 0.0431[0m | [32m 0.0000[0m | [36m  0.0512[0m | [36m  0.0000[0m | [36m  0.0000[0m | [36m  0.0105[0m | 0.000 | 18m 31s                                                                                                                                                                                                                                                                                                                                                                                                                                               
[K[0mEpoch 001 (942/1000 - 94.20%) [███████████████████████  ] train_loss 0.092040 (0.024870) 16m 16s / ETA 1m 0s[0m                                                                                                                                                                                                                                                                                         

In [None]:
run_experiments(
    master_seed, 'contrast', ContrastDataset, SimpleResNet, data_folder
)