In [None]:
!nvidia-smi

In [None]:
import sys
sys.path.append('../')

In [None]:
import os
import sys
import json
import time
import rich
import numpy as np
import pickle
import wandb
import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from configs.finetune import FinetuneConfig
from tasks.classification import Classification

from models.backbone.base import calculate_out_features
from models.backbone.densenet import DenseNetBackbone
from models.backbone.resnet import build_resnet_backbone
from models.head.projector import MLPHead
from models.head.classifier import LinearClassifier

from datasets.brain import BrainProcessor, Brain, BrainMoCo
from datasets.transforms import make_transforms, compute_statistics

from utils.logging import get_rich_logger
from utils.gpu import set_gpu

from easydict import EasyDict as edict
from torch.utils.data import DataLoader, Subset

import matplotlib.pyplot as plt
from matplotlib import colors
import seaborn as sns

import nibabel as nib
from skimage.transform import resize

from copy import deepcopy

In [None]:
hashs =[("2022-07-02_08-00-31", "2022-07-03_13-41-32"),
        ("2022-07-02_08-00-57", "2022-07-03_13-37-29"),
        ("2022-07-02_09-38-52", "2022-07-03_13-33-23"),
        ("2022-07-02_09-40-42", "2022-07-03_13-29-10"),
        ("2022-07-02_11-17-38", "2022-07-03_13-25-05"),
        ("2022-07-02_11-20-21", "2022-07-03_13-21-00"),
        ("2022-07-02_17-15-14", "2022-07-03_13-16-54"),
        ("2022-07-02_17-15-34", "2022-07-03_13-12-44"),
        ("2022-07-02_18-53-46", "2022-07-03_13-08-35"),
        ("2022-07-02_18-54-27", "2022-07-03_13-04-32")]

In [None]:
gpus = ['3']
server = 'dgx'

In [None]:
import glob
hash = hashs[0]

In [None]:
glob.glob(f'./gcam/layer1/{hash[0]}-{hash[1]}/*.pkl')

In [None]:
for hash in hashs:
    print(hash)
    #######################
    config = edict()
    config.server = server
    config.gpus = gpus
    local_rank = 0

    config.finetune_file = os.path.join(f'../checkpoints/pet-supmoco/resnet/{hash[0]}/finetune/{hash[1]}/ckpt.last.pth.tar')
    finetune_config = os.path.join(f'../checkpoints/pet-supmoco/resnet/{hash[0]}/finetune/{hash[1]}/configs.json')
    with open(finetune_config, 'rb') as fb:
        finetune_config = json.load(fb)

    finetune_config_names = [
        # data_parser
        'data_type', 'root', 'data_info', 'mci_only', 'n_splits', 'n_cv',
        'image_size', 'small_kernel', 'random_state',
        'intensity', 'crop', 'crop_size', 'rotate', 'flip', 'affine', 'blur', 'blur_std', 'prob',
        # model_parser
        'backbone_type', 'init_features', 'growth_rate', 'block_config', 'bn_size', 'dropout_rate',
        'arch', 'no_max_pool',
        # train
        'batch_size',
        # moco / supmoco
        'alphas',
        # others
        'task', 'projector_dim'
    ]

    for name in finetune_config_names:
        if name in finetune_config.keys():
            setattr(config, name, finetune_config[name])
            
    #########################################
    set_gpu(config)
    np.random.seed(config.random_state)
    torch.manual_seed(config.random_state)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.allow_tf32 = True
    torch.cuda.set_device(local_rank)

    # Networks
    if config.backbone_type == 'densenet':
        backbone = DenseNetBackbone(in_channels=1,
                                    init_features=config.init_features,
                                    growth_rate=config.growth_rate,
                                    block_config=config.block_config,
                                    bn_size=config.bn_size,
                                    dropout_rate=config.dropout_rate,
                                    semi=False)
        activation = True
    elif config.backbone_type == 'resnet':
        backbone = build_resnet_backbone(arch=config.arch,
                                         no_max_pool=config.no_max_pool,
                                         in_channels=1,
                                         semi=False)
        activation = False
    else:
        raise NotImplementedError

    if config.small_kernel:
        backbone._fix_first_conv()

    if config.crop:
        out_dim = calculate_out_features(backbone=backbone, in_channels=1, image_size=config.crop_size)
    else:
        out_dim = calculate_out_features(backbone=backbone, in_channels=1, image_size=config.image_size)
    classifier = LinearClassifier(in_channels=out_dim, num_classes=2, activation=activation)

    backbone.load_weights_from_checkpoint(path=config.finetune_file, key='backbone')
    classifier.load_weights_from_checkpoint(path=config.finetune_file, key='classifier')

    # load finetune data
    data_processor = BrainProcessor(root=config.root,
                                    data_info=config.data_info,
                                    data_type=config.data_type,
                                    mci_only=config.mci_only,
                                    random_state=config.random_state)
    datasets = data_processor.process(n_splits=config.n_splits, n_cv=config.n_cv)

    # intensity normalization
    assert config.intensity in [None, 'scale', 'minmax']
    mean_std, min_max = (None, None), (None, None)
    if config.intensity is None:
        pass
    elif config.intensity == 'scale':
        pass
    elif config.intensity == 'minmax':
        with open(os.path.join(config.root, 'labels/minmax.pkl'), 'rb') as fb:
            minmax_stats = pickle.load(fb)
            min_max = (minmax_stats[config.data_type]['min'], minmax_stats[config.data_type]['max'])
    else:
        raise NotImplementedError

    train_transform, test_transform = make_transforms(image_size=config.image_size,
                                                      intensity=config.intensity,
                                                      min_max=min_max,
                                                      crop_size=config.crop_size,
                                                      rotate=config.rotate,
                                                      flip=config.flip,
                                                      affine=config.affine,
                                                      blur_std=config.blur_std,
                                                      prob=config.prob)
    
    #########################################
    train_set = Brain(dataset=datasets['train'], data_type=config.data_type, transform=test_transform)
    test_set = Brain(dataset=datasets['test'], data_type=config.data_type, transform=test_transform)

    train_loader = DataLoader(dataset=train_set, batch_size=1, drop_last=False)
    test_loader = DataLoader(dataset=test_set, batch_size=1, drop_last=False)
    
    #######################################
    for layer in ['layer1', 'layer2', 'layer3', 'layer4']:
        model = ModelViz(backbone=backbone, classifier=classifier, local_rank=local_rank)
        gcam = GradCAMpp(model, f'backbone.{layer}')
        #########################################
        norm_cnt, ab_cnt = 0, 0
        avg_map_norm = np.zeros((64, 64, 64))
        avg_map_ab = np.zeros((64, 64, 64))

        for mode, loader in zip(['train', 'test'], [train_loader, test_loader]):

            for batch in tqdm.tqdm(loader):
                x = batch['x'].to(local_rank)

                logit = model(x)
                if batch['y'].item() == logit.argmax().item():
                    gcam_map = gcam(x)
                    gcam_map = gcam_map.cpu().numpy()[0][0]
                    gcam_map = np.abs(1 - gcam_map)        
                    if not np.isnan(gcam_map).any():    
                        if batch['y'].item() == 0:
                            avg_map_norm = avg_map_norm * norm_cnt + gcam_map
                            norm_cnt = norm_cnt + 1
                            avg_map_norm = avg_map_norm / norm_cnt
                        else:
                            avg_map_ab = avg_map_ab * ab_cnt + gcam_map
                            ab_cnt = ab_cnt + 1
                            avg_map_ab = avg_map_ab / ab_cnt

            #################################
            os.makedirs(f'gcam/{layer}/{hash[0]}-{hash[1]}', exist_ok=True)
            with open(f'gcam/{layer}/{hash[0]}-{hash[1]}/{mode}-normal.pkl', 'wb') as fb:
                pickle.dump(avg_map_norm, fb)
            with open(f'gcam/{layer}/{hash[0]}-{hash[1]}/{mode}-abnormal.pkl', 'wb') as fb:
                pickle.dump(avg_map_ab, fb)

## Filter only correct samples and save visualization maps

In [None]:
x = x[0, 0, :, :, :].cpu().numpy()

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(20, 5))
axs[0].imshow(x[32, :, :], cmap='binary')
axs[1].imshow(x[32, :, :], cmap='binary')
axs[1].imshow(avg_map_ab[32, :, :], cmap='jet', alpha=0.2,)
axs[2].imshow(x[:, 32, :], cmap='binary')
axs[2].imshow(avg_map_ab[:, 32, :], cmap='jet', alpha=0.2,)
axs[3].imshow(x[:, :, 32], cmap='binary')
axs[3].imshow(avg_map_ab[:, :, 32], cmap='jet', alpha=0.2,)
plt.show()

In [None]:

gcam_2 = GradCAMpp(model, ['layer2'])
gcam_3 = GradCAMpp(model, ['layer3'])
gcam_4 = GradCAMpp(model, ['layer4'])
guided = GuidedBackpropGrad(model)

In [None]:
with torch.no_grad():
    model.backbone.eval()
    model.classifier.eval()
    
    for batch in train_loader:
        logit = model(batch['x'].float().to(local_rank))
        if (batch['y'] == logit.argmax().cpu()).item():
            
        else:
            pass            

In [None]:
def get_map(data, target_layer, model, local_rank):
    
    model.backbone.eval()
    model.classifier.eval()
    
    gcam = GradCAMpp(model, target_layer)
    guided = GuidedBackpropGrad(model)
    
    # predict
    x = torch.as_tensor(data['x'][None].to(local_rank))        
    pred_logits = model(x)
    model.zero_grad()
    
    pred_label = pred_logits.argmax(dim=1).item()
    pred_prob = torch.nn.functional.softmax(pred_logits, dim=1)[0, pred_label].item() * 100

    title = f"Pred: {pred_label} ({pred_prob:.2f}%) | True: {data['y']}"

    # image
    img = torch.moveaxis(x, 1, -1)

    # gradcam
    gcam_map = gcam(x=x, class_idx=pred_label)[0]
    model.zero_grad()

    # guided
    guided_map = guided(x)
    model.zero_grad()
    guided_map = torch.sum(guided_map ** 2, dim=1) ** 0.5
    
    # shape
    img = img.cpu().detach().numpy()
    img = np.squeeze(img)
    
    gcam_map = gcam_map.cpu().detach().numpy()
    gcam_map = np.squeeze(gcam_map)
    
    guided_map = guided_map.cpu().detach().numpy()
    guided_map = np.squeeze(guided_map)
    
    model.zero_grad()
    
    return img, gcam_map, guided_map, title

In [None]:
for layer in ['layer1', 'layer2', 'layer3', 'layer4']:
    print(layer)
    os.makedirs(f'../cam/{layer}', exist_ok=True)
    for i in pmci_idx:

        # load data
        stripped_pet_file = test_set.pet[i]
        stripped_mri_file = test_set.mri[i]

        pet_id = stripped_pet_file.split('/')[-1].replace('.pkl', '')
        mri_id = stripped_mri_file.split('/')[-1].replace('.pkl', '')

        with open(stripped_mri_file, 'rb') as fb:
            stripped_mri = pickle.load(fb)

        with open(stripped_pet_file, 'rb') as fb:
            stripped_pet = pickle.load(fb)

        nonstripped_pet_file = '/raidWorkspace/mingu/Data/ADNI/PUP_FBP/{}/pet_proc/w_{}_SUVR.nii.gz'.format(pet_id, pet_id)
        nonstripped_pet = nib.load(nonstripped_pet_file).get_fdata()
        nonstripped_pet = np.pad(nonstripped_pet, ((12, 12), (0, 0), (12, 12)), 'constant')

        # get activation map
        temp_set = np.array(Subset(test_set, [i]))
        d = temp_set[0]
        img, gcamp_map, guided_map, title = get_map(d, f'backbone.{layer}', model, local_rank)

        gcamp_map_t = np.abs(1 - gcamp_map)
        tr = resize(gcamp_map_t, [145, 145, 145])
        tr_m = tr.copy()
        bmask = stripped_pet <= 0
        tr_m[bmask] = np.nan

        m = tr_m < hyparam.vmin
        tr_m[m] = np.nan

        # show map
        fig, axs = plt.subplots(3, 4, figsize=(20, 15))
        plt.suptitle(title + f' | {mri_id} | {pet_id}')

        axs[0, 0].imshow(stripped_pet[hyparam.loc1, :, :], cmap='binary')
        axs[0, 1].imshow(tr_m[hyparam.loc1, :, :], cmap='jet', alpha=0.2,)
        axs[0, 2].imshow(nonstripped_pet[hyparam.loc1, :, :], cmap='binary')
        axs[0, 2].imshow(tr_m[hyparam.loc1, :, :], cmap='jet', alpha=0.2,)
        axs[0, 3].imshow(stripped_mri[hyparam.loc1, :, :], cmap='binary')
        axs[0, 3].imshow(tr_m[hyparam.loc1, :, :], cmap='jet', alpha=0.2,)

        axs[1, 0].imshow(stripped_pet[:, hyparam.loc2, :], cmap='binary')
        axs[1, 1].imshow(tr_m[:, hyparam.loc2, :], cmap='jet', alpha=0.2,)
        axs[1, 2].imshow(nonstripped_pet[:, hyparam.loc2, :], cmap='binary')
        axs[1, 2].imshow(tr_m[:, hyparam.loc2, :], cmap='jet', alpha=0.2,)
        axs[1, 3].imshow(stripped_mri[:, hyparam.loc2, :], cmap='binary')
        axs[1, 3].imshow(tr_m[:, hyparam.loc2, :], cmap='jet', alpha=0.2,)

        axs[2, 0].imshow(stripped_pet[:, :, hyparam.loc3], cmap='binary')
        axs[2, 1].imshow(tr_m[:, :, hyparam.loc3], cmap='jet', alpha=0.2,)
        axs[2, 2].imshow(nonstripped_pet[:, :, hyparam.loc3], cmap='binary')
        axs[2, 2].imshow(tr_m[:, :, hyparam.loc3], cmap='jet', alpha=0.2,)
        axs[2, 3].imshow(stripped_mri[:, :, hyparam.loc3], cmap='binary')
        axs[2, 3].imshow(tr_m[:, :, hyparam.loc3], cmap='jet', alpha=0.2,)

        plt.savefig(f'../cam/{layer}/{mri_id}_{pet_id}.png', dpi=400,
                    bbox_inches='tight')
        plt.close()
        model.zero_grad()

In [None]:

os.makedirs(f'../guided/', exist_ok=True)
for i in pmci_idx:

    # load data
    stripped_pet_file = test_set.pet[i]
    stripped_mri_file = test_set.mri[i]

    pet_id = stripped_pet_file.split('/')[-1].replace('.pkl', '')
    mri_id = stripped_mri_file.split('/')[-1].replace('.pkl', '')

    with open(stripped_mri_file, 'rb') as fb:
        stripped_mri = pickle.load(fb)

    with open(stripped_pet_file, 'rb') as fb:
        stripped_pet = pickle.load(fb)

    nonstripped_pet_file = '/raidWorkspace/mingu/Data/ADNI/PUP_FBP/{}/pet_proc/w_{}_SUVR.nii.gz'.format(pet_id, pet_id)
    nonstripped_pet = nib.load(nonstripped_pet_file).get_fdata()
    nonstripped_pet = np.pad(nonstripped_pet, ((12, 12), (0, 0), (12, 12)), 'constant')

    # get activation map
    temp_set = np.array(Subset(test_set, [i]))
    d = temp_set[0]
    img, gcamp_map, guided_map, title = get_map(d, 'backbone.layer1', model, local_rank)

    gcamp_map_t = np.abs(1 - gcamp_map)
    tr = resize(gcamp_map_t, [145, 145, 145])
    tr_m = tr.copy()
    bmask = stripped_pet <= 0
    tr_m[bmask] = np.nan

    m = tr_m < hyparam.vmin
    tr_m[m] = np.nan

    guided_map_t = deepcopy(guided_map)
    g_tr = resize(guided_map_t, [145, 145, 145])
    g_tr_m = g_tr.copy()
    bmask = stripped_pet <= 0
    g_tr_m[bmask] = np.nan

    m = g_tr_m < 0.2
    g_tr_m[m] = np.nan

    # show map
    fig, axs = plt.subplots(3, 4, figsize=(20, 15))
    plt.suptitle(title + f' | {mri_id} | {pet_id}')

    axs[0, 0].imshow(stripped_pet[hyparam.loc1, :, :], cmap='binary')
    axs[0, 1].imshow(g_tr_m[hyparam.loc1, :, :], cmap='jet', alpha=0.2,)
    axs[0, 2].imshow(nonstripped_pet[hyparam.loc1, :, :], cmap='binary')
    axs[0, 2].imshow(g_tr_m[hyparam.loc1, :, :], cmap='jet', alpha=0.2,)
    axs[0, 3].imshow(stripped_mri[hyparam.loc1, :, :], cmap='binary')
    axs[0, 3].imshow(g_tr_m[hyparam.loc1, :, :], cmap='jet', alpha=0.2,)

    axs[1, 0].imshow(stripped_pet[:, hyparam.loc2, :], cmap='binary')
    axs[1, 1].imshow(g_tr_m[:, hyparam.loc2, :], cmap='jet', alpha=0.2,)
    axs[1, 2].imshow(nonstripped_pet[:, hyparam.loc2, :], cmap='binary')
    axs[1, 2].imshow(g_tr_m[:, hyparam.loc2, :], cmap='jet', alpha=0.2,)
    axs[1, 3].imshow(stripped_mri[:, hyparam.loc2, :], cmap='binary')
    axs[1, 3].imshow(g_tr_m[:, hyparam.loc2, :], cmap='jet', alpha=0.2,)

    axs[2, 0].imshow(stripped_pet[:, :, hyparam.loc3], cmap='binary')
    axs[2, 1].imshow(g_tr_m[:, :, hyparam.loc3], cmap='jet', alpha=0.2,)
    axs[2, 2].imshow(nonstripped_pet[:, :, hyparam.loc3], cmap='binary')
    axs[2, 2].imshow(g_tr_m[:, :, hyparam.loc3], cmap='jet', alpha=0.2,)
    axs[2, 3].imshow(stripped_mri[:, :, hyparam.loc3], cmap='binary')
    axs[2, 3].imshow(g_tr_m[:, :, hyparam.loc3], cmap='jet', alpha=0.2,)
    
    plt.savefig(f'../guided/{mri_id}_{pet_id}.png', dpi=400,
                bbox_inches='tight')
    plt.close()
    model.zero_grad()