In [1]:
!nvidia-smi

Sun Feb 19 14:38:39 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-DGXS...  On   | 00000000:07:00.0  On |                    0 |
| N/A   42C    P0    94W / 300W |   8716MiB / 32505MiB |     98%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-DGXS...  On   | 00000000:08:00.0 Off |                    0 |
| N/A   43C    P0    98W / 300W |   6501MiB / 32508MiB |     80%      Default |
|       

In [2]:
import sys
sys.path.append('../')

In [3]:
import os
import sys
import json
import time
import rich
import numpy as np
import pickle
import wandb
import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from configs.finetune import FinetuneConfig
from tasks.classification import Classification

from models.backbone.base import calculate_out_features
from models.backbone.densenet import DenseNetBackbone
from models.backbone.resnet import build_resnet_backbone
from models.head.projector import MLPHead
from models.head.classifier import LinearClassifier

from datasets.brain import BrainProcessor, Brain, BrainMoCo
from datasets.transforms import make_transforms, compute_statistics

from utils.logging import get_rich_logger
from utils.gpu import set_gpu

from easydict import EasyDict as edict
from torch.utils.data import DataLoader, Subset

import matplotlib.pyplot as plt
from matplotlib import colors
import seaborn as sns

import nibabel as nib
from skimage.transform import resize

from copy import deepcopy

In [4]:
hashs =[("2022-07-02_08-00-31", "2022-07-03_13-41-32"),
        ("2022-07-02_08-00-57", "2022-07-03_13-37-29"),
        ("2022-07-02_09-38-52", "2022-07-03_13-33-23"),
        ("2022-07-02_09-40-42", "2022-07-03_13-29-10"),
        ("2022-07-02_11-17-38", "2022-07-03_13-25-05"),
        ("2022-07-02_11-20-21", "2022-07-03_13-21-00"),
        ("2022-07-02_17-15-14", "2022-07-03_13-16-54"),
        ("2022-07-02_17-15-34", "2022-07-03_13-12-44"),
        ("2022-07-02_18-53-46", "2022-07-03_13-08-35"),
        ("2022-07-02_18-54-27", "2022-07-03_13-04-32")]

In [5]:
gpus = ['3']
server = 'dgx'

In [6]:
from monai.visualize import (
    GradCAMpp,
    OcclusionSensitivity,
    SmoothGrad,
    GuidedBackpropGrad,
    GuidedBackpropSmoothGrad,
)

In [7]:
class ModelViz(nn.Module):
    def __init__(self, backbone, classifier, local_rank):
        super(ModelViz, self).__init__()
        self.local_rank = local_rank
        self.backbone = backbone
        self.classifier = classifier
        self._build_model(self.backbone, self.classifier)
    
    def _build_model(self, backbone, classifier):
    
        self.backbone = backbone
        self.classifier = classifier
        
        self.backbone.to(self.local_rank)
        self.classifier.to(self.local_rank)
        
        self.backbone.eval()
        self.classifier.eval()
        
    def forward(self, x):
        logits = self.classifier(self.backbone(x))
        return logits

In [None]:
for hash in hashs:
    for layer in['layer1', 'layer2']:

        # Individual Heatmap    
        config = edict()
        config.server = server
        config.gpus = gpus
        local_rank = 0

        config.finetune_file = os.path.join(f'../checkpoints/pet-supmoco/resnet/{hash[0]}/finetune/{hash[1]}/ckpt.last.pth.tar')
        finetune_config = os.path.join(f'../checkpoints/pet-supmoco/resnet/{hash[0]}/finetune/{hash[1]}/configs.json')
        with open(finetune_config, 'rb') as fb:
            finetune_config = json.load(fb)

        finetune_config_names = [
            # data_parser
            'data_type', 'root', 'data_info', 'mci_only', 'n_splits', 'n_cv',
            'image_size', 'small_kernel', 'random_state',
            'intensity', 'crop', 'crop_size', 'rotate', 'flip', 'affine', 'blur', 'blur_std', 'prob',
            # model_parser
            'backbone_type', 'init_features', 'growth_rate', 'block_config', 'bn_size', 'dropout_rate',
            'arch', 'no_max_pool',
            # train
            'batch_size',
            # moco / supmoco
            'alphas',
            # others
            'task', 'projector_dim'
        ]

        for name in finetune_config_names:
            if name in finetune_config.keys():
                setattr(config, name, finetune_config[name])

        #########################################
        set_gpu(config)
        np.random.seed(config.random_state)
        torch.manual_seed(config.random_state)
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.allow_tf32 = True
        torch.cuda.set_device(local_rank)

        # Networks
        if config.backbone_type == 'densenet':
            backbone = DenseNetBackbone(in_channels=1,
                                        init_features=config.init_features,
                                        growth_rate=config.growth_rate,
                                        block_config=config.block_config,
                                        bn_size=config.bn_size,
                                        dropout_rate=config.dropout_rate,
                                        semi=False)
            activation = True
        elif config.backbone_type == 'resnet':
            backbone = build_resnet_backbone(arch=config.arch,
                                             no_max_pool=config.no_max_pool,
                                             in_channels=1,
                                             semi=False)
            activation = False
        else:
            raise NotImplementedError

        if config.small_kernel:
            backbone._fix_first_conv()

        if config.crop:
            out_dim = calculate_out_features(backbone=backbone, in_channels=1, image_size=config.crop_size)
        else:
            out_dim = calculate_out_features(backbone=backbone, in_channels=1, image_size=config.image_size)
        classifier = LinearClassifier(in_channels=out_dim, num_classes=2, activation=activation)

        backbone.load_weights_from_checkpoint(path=config.finetune_file, key='backbone')
        classifier.load_weights_from_checkpoint(path=config.finetune_file, key='classifier')

        # load finetune data
        data_processor = BrainProcessor(root=config.root,
                                        data_info=config.data_info,
                                        data_type=config.data_type,
                                        mci_only=config.mci_only,
                                        random_state=config.random_state)
        datasets = data_processor.process(n_splits=config.n_splits, n_cv=config.n_cv)

        # intensity normalization
        assert config.intensity in [None, 'scale', 'minmax']
        mean_std, min_max = (None, None), (None, None)
        if config.intensity is None:
            pass
        elif config.intensity == 'scale':
            pass
        elif config.intensity == 'minmax':
            with open(os.path.join(config.root, 'labels/minmax.pkl'), 'rb') as fb:
                minmax_stats = pickle.load(fb)
                min_max = (minmax_stats[config.data_type]['min'], minmax_stats[config.data_type]['max'])
        else:
            raise NotImplementedError

        train_transform, test_transform = make_transforms(image_size=config.image_size,
                                                          intensity=config.intensity,
                                                          min_max=min_max,
                                                          crop_size=config.crop_size,
                                                          rotate=config.rotate,
                                                          flip=config.flip,
                                                          affine=config.affine,
                                                          blur_std=config.blur_std,
                                                          prob=config.prob)

        #########################################
        train_set = Brain(dataset=datasets['train'], data_type=config.data_type, transform=test_transform)
        test_set = Brain(dataset=datasets['test'], data_type=config.data_type, transform=test_transform)

        train_loader = DataLoader(dataset=train_set, batch_size=1, drop_last=False)
        test_loader = DataLoader(dataset=test_set, batch_size=1, drop_last=False)


        ###############
        model = ModelViz(backbone=backbone, classifier=classifier, local_rank=local_rank)
        gcam = GradCAMpp(model, f'backbone.{layer}')

        import torch.optim as optim
        optimizer = optim.AdamW(model.parameters())
        
        ##############
        # save individual
        for mode, dset, loader in zip(['train', 'test'], [train_set, test_set], [train_loader, test_loader]):

            path = f'gcam/{layer}/{hash[0]}-{hash[1]}/{mode}'
            os.makedirs(path + '-converter', exist_ok=True)
            os.makedirs(path + '-nonconverter', exist_ok=True)

            for batch in tqdm.tqdm(loader):

                x = batch['x'].to(local_rank)
                idx = batch['idx'].item()
                                
                logit = model(x)
                logit = logit.detach()
                confidence = "{:.3f}".format(logit.softmax(dim=1)[0, batch['y'].item()].item())
                
                # correctly classified
                if batch['y'].item() == logit.argmax().item():
                    
                    for reverse in [True, False]:
                        
                        optimizer.zero_grad()
                        
                        gcam_map = gcam(x)
                        gcam_map = gcam_map.cpu().numpy()[0][0]
                        if reverse:
                            gcam_map = np.abs(1 - gcam_map)
                            
                        if not np.isnan(gcam_map).any():
                            # status
                            if batch['y'].item() == 0:
                                status = 'nonconverter'
                            else:
                                status = 'converter'
                            # heatmap
                            pet_file = dset.pet[idx]                        
                            pet_id = pet_file.split('/')[-1].replace('.pkl', '')
                            with open(pet_file, 'rb') as fb:
                                pet = pickle.load(fb)

                            mask = pet <= 0

                            gcam_map = resize(gcam_map, [145, 145, 145])
                            gcam_map[mask] = np.nan
                            
                            fig, axs = plt.subplots(3, 2, figsize=(10, 15))
                            axs[0, 0].imshow(pet[72, :, :], cmap='binary')
                            axs[0, 1].imshow(gcam_map[72, :, :], cmap='jet')

                            axs[1, 0].imshow(pet[:, 72, :], cmap='binary')
                            axs[1, 1].imshow(gcam_map[:, 72, :], cmap='jet')

                            axs[2, 0].imshow(pet[:, :, 90], cmap='binary')
                            axs[2, 1].imshow(gcam_map[:, :, 90], cmap='jet')
                            plt.savefig(
                                path + '-' + status + f'/{pet_id}-{confidence}-{reverse}.png',
                                dpi=300,
                                bbox_inches='tight'
                            )
                            plt.close()

100%|█████████████████████████████████████████████████████████████████████████████████| 565/565 [45:16<00:00,  4.81s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 56/56 [03:58<00:00,  4.25s/it]
100%|█████████████████████████████████████████████████████████████████████████████████| 565/565 [45:42<00:00,  4.85s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 56/56 [04:19<00:00,  4.63s/it]
100%|█████████████████████████████████████████████████████████████████████████████████| 551/551 [43:51<00:00,  4.78s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 70/70 [05:20<00:00,  4.58s/it]
100%|█████████████████████████████████████████████████████████████████████████████████| 551/551 [44:15<00:00,  4.82s/it]
100%|███████████████████████████████████████████████████████████████████████████████████| 70/70 [04:44<00:00,  4.07s/it]
100%|███████████████████████████

In [None]:
# save average-train
converter = []
nonconverter = []

for batch in tqdm.tqdm(train_loader):

    x = batch['x'].to(local_rank)
    logit = model(x)    
    idx = batch['idx'].item()

    # correctly classified
    if batch['y'].item() == logit.argmax().item():
        gcam_map = gcam(x)
        gcam_map = gcam_map.cpu().numpy()[0][0]
        # gcam_map = np.abs(1 - gcam_map)
        # gcam_map = np.log(1 + gcam_map)

        if not np.isnan(gcam_map).any():

            pet_file = train_set.pet[idx]
            pet_id = pet_file.split('/')[-1].replace('.pkl', '')
            with open(pet_file, 'rb') as fb:
                pet = pickle.load(fb)

            mask = pet <= 0

            gcam_map = resize(gcam_map, [145, 145, 145])
            # gcam_map[mask] = np.nan
            
            # status
            if batch['y'].item() == 0:
                nonconverter.append(gcam_map)
                nonconverter_pet = pet
                nonconverter_mask = mask
            else:
                converter.append(gcam_map)
                converter_pet = pet
                converter_mask = mask
    optimizer.zero_grad()

In [None]:
a = np.array(converter)
a = np.mean(a, axis=0)
a[converter_mask] = np.nan

fig, axs = plt.subplots(3, 2, figsize=(10, 15))
axs[0, 0].imshow(converter_pet[72, :, :], cmap='binary')
axs[0, 1].imshow(a[72, :, :], cmap='jet')

axs[1, 0].imshow(converter_pet[:, 72, :], cmap='binary')
axs[1, 1].imshow(a[:, 72, :], cmap='jet')

axs[2, 0].imshow(converter_pet[:, :, 90], cmap='binary')
axs[2, 1].imshow(a[:, :, 90], cmap='jet')
plt.show()

In [None]:
a = np.array(nonconverter)
a = np.mean(a, axis=0)
a[nonconverter_mask] = np.nan

fig, axs = plt.subplots(3, 2, figsize=(10, 15))
axs[0, 0].imshow(nonconverter_pet[72, :, :], cmap='binary')
axs[0, 1].imshow(a[72, :, :], cmap='jet')

axs[1, 0].imshow(nonconverter_pet[:, 72, :], cmap='binary')
axs[1, 1].imshow(a[:, 72, :], cmap='jet')

axs[2, 0].imshow(nonconverter_pet[:, :, 90], cmap='binary')
axs[2, 1].imshow(a[:, :, 90], cmap='jet')
plt.show()

In [None]:
a = np.array(converter)
b = np.array(nonconverter)
c = np.concatenate([a, b], axis=0)

c = np.mean(c, axis=0)
c[nonconverter_mask] = np.nan

fig, axs = plt.subplots(3, 2, figsize=(10, 15))
axs[0, 0].imshow(nonconverter_pet[72, :, :], cmap='binary')
axs[0, 1].imshow(c[72, :, :], cmap='jet')

axs[1, 0].imshow(nonconverter_pet[:, 72, :], cmap='binary')
axs[1, 1].imshow(c[:, 72, :], cmap='jet')

axs[2, 0].imshow(nonconverter_pet[:, :, 90], cmap='binary')
axs[2, 1].imshow(c[:, :, 90], cmap='jet')
plt.show()

In [None]:
# save average-test
converter = []
nonconverter = []

for batch in tqdm.tqdm(test_loader):

    x = batch['x'].to(local_rank)
    logit = model(x)    
    idx = batch['idx'].item()

    # correctly classified
    if batch['y'].item() == logit.argmax().item():
        gcam_map = gcam(x)
        gcam_map = gcam_map.cpu().numpy()[0][0]
        # gcam_map = np.abs(1 - gcam_map)
        # gcam_map = np.log(1 + gcam_map)

        if not np.isnan(gcam_map).any():

            pet_file = test_set.pet[idx]
            pet_id = pet_file.split('/')[-1].replace('.pkl', '')
            with open(pet_file, 'rb') as fb:
                pet = pickle.load(fb)

            mask = pet <= 0

            gcam_map = resize(gcam_map, [145, 145, 145])
            # gcam_map[mask] = np.nan
            
            # status
            if batch['y'].item() == 0:
                nonconverter.append(gcam_map)
                nonconverter_pet = pet
                nonconverter_mask = mask
            else:
                converter.append(gcam_map)
                converter_pet = pet
                converter_mask = mask
    optimizer.zero_grad()

In [None]:
a = np.array(converter)
a = np.mean(a, axis=0)
a[converter_mask] = np.nan

fig, axs = plt.subplots(3, 2, figsize=(10, 15))
axs[0, 0].imshow(converter_pet[72, :, :], cmap='binary')
axs[0, 1].imshow(a[72, :, :], cmap='jet')

axs[1, 0].imshow(converter_pet[:, 72, :], cmap='binary')
axs[1, 1].imshow(a[:, 72, :], cmap='jet')

axs[2, 0].imshow(converter_pet[:, :, 90], cmap='binary')
axs[2, 1].imshow(a[:, :, 90], cmap='jet')
plt.show()

In [None]:
a = np.array(nonconverter)
a = np.mean(a, axis=0)
a[nonconverter_mask] = np.nan

fig, axs = plt.subplots(3, 2, figsize=(10, 15))
axs[0, 0].imshow(nonconverter_pet[72, :, :], cmap='binary')
axs[0, 1].imshow(a[72, :, :], cmap='jet')

axs[1, 0].imshow(nonconverter_pet[:, 72, :], cmap='binary')
axs[1, 1].imshow(a[:, 72, :], cmap='jet')

axs[2, 0].imshow(nonconverter_pet[:, :, 90], cmap='binary')
axs[2, 1].imshow(a[:, :, 90], cmap='jet')
plt.show()

In [None]:
a = np.array(converter)
b = np.array(nonconverter)
c = np.concatenate([a, b], axis=0)

c = np.mean(c, axis=0)
c[nonconverter_mask] = np.nan

fig, axs = plt.subplots(3, 2, figsize=(10, 15))
axs[0, 0].imshow(nonconverter_pet[72, :, :], cmap='binary')
axs[0, 1].imshow(c[72, :, :], cmap='jet')

axs[1, 0].imshow(nonconverter_pet[:, 72, :], cmap='binary')
axs[1, 1].imshow(c[:, 72, :], cmap='jet')

axs[2, 0].imshow(nonconverter_pet[:, :, 90], cmap='binary')
axs[2, 1].imshow(c[:, :, 90], cmap='jet')
plt.show()