In [None]:
import sys
import os
import logging
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import monai
from monai.data import ImageDataset
from monai.networks.layers.factories import Conv, Dropout, Pool, Norm
from monai.transforms import Compose, AddChannel, ScaleIntensity, EnsureType
from collections import OrderedDict
from typing import Callable, Sequence
import math
from functools import partial
import matplotlib.pyplot as plt

pin_memory = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class _DenseLayer(nn.Sequential):
    def __init__(
        self, spatial_dims: int, in_channels: int, growth_rate: int, bn_size: int, dropout_prob: float
    ) -> None:
        super(_DenseLayer, self).__init__()

        out_channels = bn_size * growth_rate
        conv_type: Callable = Conv[Conv.CONV, spatial_dims]
        norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
        dropout_type: Callable = Dropout[Dropout.DROPOUT, spatial_dims]

        self.add_module("norm1", norm_type(in_channels))
        self.add_module("relu1", nn.ReLU(inplace=True))
        self.add_module("conv1", conv_type(in_channels, out_channels, kernel_size=1, bias=False))

        self.add_module("norm2", norm_type(out_channels))
        self.add_module("relu2", nn.ReLU(inplace=True))
        self.add_module("conv2", conv_type(out_channels, growth_rate, kernel_size=3, padding=1, bias=False))

        if dropout_prob > 0:
            self.add_module("dropout", dropout_type(dropout_prob))

    def forward(self, x):
        new_features = super(_DenseLayer, self).forward(x)
        return torch.cat([x, new_features], 1)


class _DenseBlock(nn.Sequential):
    def __init__(
        self, spatial_dims: int, layers: int, in_channels: int, bn_size: int, growth_rate: int, dropout_prob: float
    ) -> None:
        super(_DenseBlock, self).__init__()
        for i in range(layers):
            layer = _DenseLayer(spatial_dims, in_channels, growth_rate, bn_size, dropout_prob)
            in_channels += growth_rate
            self.add_module("denselayer%d" % (i + 1), layer)


class _Transition(nn.Sequential):
    def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> None:
        super(_Transition, self).__init__()

        conv_type: Callable = Conv[Conv.CONV, spatial_dims]
        norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
        pool_type: Callable = Pool[Pool.AVG, spatial_dims]

        self.add_module("norm", norm_type(in_channels))
        self.add_module("relu", nn.ReLU(inplace=True))
        self.add_module("conv", conv_type(in_channels, out_channels, kernel_size=1, bias=False))
        self.add_module("pool", pool_type(kernel_size=2, stride=2))


class DenseNet(nn.Module):
    """
    Densenet based on: "Densely Connected Convolutional Networks" https://arxiv.org/pdf/1608.06993.pdf
    Adapted from PyTorch Hub 2D version:
    https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py

    Args:
        spatial_dims: number of spatial dimensions of the input image.
        in_channels: number of the input channel.
        out_channels: number of the output classes.
        init_features: number of filters in the first convolution layer.
        growth_rate: how many filters to add each layer (k in paper).
        block_config: how many layers in each pooling block.
        bn_size: multiplicative factor for number of bottle neck layers.
                      (i.e. bn_size * k features in the bottleneck layer)
        dropout_prob: dropout rate after each dense layer.
    """

    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        init_features: int = 64,
        growth_rate: int = 32,
        block_config: Sequence[int] = (6, 12, 24, 16),
        bn_size: int = 4,
        dropout_prob: float = 0.0,
    ) -> None:

        super(DenseNet, self).__init__()

        conv_type: Callable = Conv[Conv.CONV, spatial_dims]
        norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
        pool_type: Callable = Pool[Pool.MAX, spatial_dims]
        avg_pool_type: Callable = Pool[Pool.ADAPTIVEAVG, spatial_dims]

        self.features = nn.Sequential(
            OrderedDict(
                [
                    ("conv0", conv_type(in_channels, init_features, kernel_size=7, stride=2, padding=3, bias=False)),
                    ("norm0", norm_type(init_features)),
                    ("relu0", nn.ReLU(inplace=True)),
                    ("pool0", pool_type(kernel_size=3, stride=2, padding=1)),
                ]
            )
        )

        in_channels = init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(
                spatial_dims=spatial_dims,
                layers=num_layers,
                in_channels=in_channels,
                bn_size=bn_size,
                growth_rate=growth_rate,
                dropout_prob=dropout_prob,
            )
            self.features.add_module("denseblock%d" % (i + 1), block)
            in_channels += num_layers * growth_rate
            if i == len(block_config) - 1:
                self.features.add_module("norm5", norm_type(in_channels))
            else:
                _out_channels = in_channels // 2
                trans = _Transition(spatial_dims, in_channels=in_channels, out_channels=_out_channels)
                self.features.add_module("transition%d" % (i + 1), trans)
                in_channels = _out_channels

        # pooling and classification
        self.class_layers = nn.Sequential(
            OrderedDict(
                [
                    ("relu", nn.ReLU(inplace=True)),
                    ("norm", avg_pool_type(1)),
                    ("flatten", nn.Flatten(1)),
                    ("class", nn.Linear(in_channels, out_channels)),
                ]
            )
        )
        '''
        self.relu = nn.ReLU(inplace=True)
        self.fc1 = nn.Linear(1024*4*4*4, 1024)
        self.fc2 = nn.Linear(1024, out_channels)

        '''
        # Avoid Built-in function isinstance was called with the wrong arguments warning
        # pytype: disable=wrong-arg-types
        for m in self.modules():
            if isinstance(m, conv_type):  # type: ignore
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, norm_type):  # type: ignore
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)
        # pytype: enable=wrong-arg-types

    def forward(self, x):
        x = self.features(x)
        x = self.class_layers(x) # remove if fc1/fc2
        #x = torch.flatten(x, 1)
        #x = self.fc1(x)
        #x = self.fc2(x)
        return x


def densenet121(**kwargs) -> DenseNet:
    model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs)
    return model


def densenet169(**kwargs) -> DenseNet:
    model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs)
    return model


def densenet201(**kwargs) -> DenseNet:
    model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs)
    return model


def densenet264(**kwargs) -> DenseNet:
    model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 64, 48), **kwargs)
    return model

In [None]:
def get_inplanes():
    return [64, 128, 256, 512]


def conv3x3x3(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes,
                     out_planes,
                     kernel_size=3,
                     stride=stride,
                     padding=1,
                     bias=False)


def conv1x1x1(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes,
                     out_planes,
                     kernel_size=1,
                     stride=stride,
                     bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super().__init__()

        self.conv1 = conv3x3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3x3(planes, planes)
        self.bn2 = nn.BatchNorm3d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super().__init__()

        self.conv1 = conv1x1x1(in_planes, planes)
        self.bn1 = nn.BatchNorm3d(planes)
        self.conv2 = conv3x3x3(planes, planes, stride)
        self.bn2 = nn.BatchNorm3d(planes)
        self.conv3 = conv1x1x1(planes, planes * self.expansion)
        self.bn3 = nn.BatchNorm3d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self,
                 block,
                 layers,
                 block_inplanes,
                 n_input_channels=3,
                 conv1_t_size=7,
                 conv1_t_stride=1,
                 no_max_pool=False,
                 shortcut_type='B',
                 widen_factor=1.0,
                 n_classes=400):
        super().__init__()

        block_inplanes = [int(x * widen_factor) for x in block_inplanes]

        self.in_planes = block_inplanes[0]
        self.no_max_pool = no_max_pool

        self.conv1 = nn.Conv3d(n_input_channels,
                               self.in_planes,
                               kernel_size=(conv1_t_size, 7, 7),
                               stride=(conv1_t_stride, 2, 2),
                               padding=(conv1_t_size // 2, 3, 3),
                               bias=False)
        self.bn1 = nn.BatchNorm3d(self.in_planes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, block_inplanes[0], layers[0],
                                       shortcut_type)
        self.layer2 = self._make_layer(block,
                                       block_inplanes[1],
                                       layers[1],
                                       shortcut_type,
                                       stride=2)
        self.layer3 = self._make_layer(block,
                                       block_inplanes[2],
                                       layers[2],
                                       shortcut_type,
                                       stride=2)
        self.layer4 = self._make_layer(block,
                                       block_inplanes[3],
                                       layers[3],
                                       shortcut_type,
                                       stride=2)

        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Linear(block_inplanes[3] * block.expansion, n_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight,
                                        mode='fan_out',
                                        nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _downsample_basic_block(self, x, planes, stride):
        out = F.avg_pool3d(x, kernel_size=1, stride=stride)
        zero_pads = torch.zeros(out.size(0), planes - out.size(1), out.size(2),
                                out.size(3), out.size(4))
        if isinstance(out.data, torch.cuda.FloatTensor):
            zero_pads = zero_pads.cuda()

        out = torch.cat([out.data, zero_pads], dim=1)

        return out

    def _make_layer(self, block, planes, blocks, shortcut_type, stride=1):
        downsample = None
        if stride != 1 or self.in_planes != planes * block.expansion:
            if shortcut_type == 'A':
                downsample = partial(self._downsample_basic_block,
                                     planes=planes * block.expansion,
                                     stride=stride)
            else:
                downsample = nn.Sequential(
                    conv1x1x1(self.in_planes, planes * block.expansion, stride),
                    nn.BatchNorm3d(planes * block.expansion))

        layers = []
        layers.append(
            block(in_planes=self.in_planes,
                  planes=planes,
                  stride=stride,
                  downsample=downsample))
        self.in_planes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.in_planes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        if not self.no_max_pool:
            x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)

        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def generate_model(model_depth, **kwargs):
    assert model_depth in [10, 18, 34, 50, 101, 152, 200]

    if model_depth == 10:
        model = ResNet(BasicBlock, [1, 1, 1, 1], get_inplanes(), **kwargs)
    elif model_depth == 18:
        model = ResNet(BasicBlock, [2, 2, 2, 2], get_inplanes(), **kwargs)
    elif model_depth == 34:
        model = ResNet(BasicBlock, [3, 4, 6, 3], get_inplanes(), **kwargs)
    elif model_depth == 50:
        model = ResNet(Bottleneck, [3, 4, 6, 3], get_inplanes(), **kwargs)
    elif model_depth == 101:
        model = ResNet(Bottleneck, [3, 4, 23, 3], get_inplanes(), **kwargs)
    elif model_depth == 152:
        model = ResNet(Bottleneck, [3, 8, 36, 3], get_inplanes(), **kwargs)
    elif model_depth == 200:
        model = ResNet(Bottleneck, [3, 24, 36, 3], get_inplanes(), **kwargs)

    return model

In [None]:
data_dir = '/home/marafath/scratch/abide_iq_combined'

np.random.seed(42)
idx = np.random.permutation(850)

# options
im_type = 'int'  #'rav', 'int_rav'
fold = 0 # 0-4
iq = 'viq' # 'all', 'viq', 'fiq'
iq_type = 'absolute' # 'residual'
arch = 'resnet50' #'resnet', resnet50, densenet, densenet169
win_size = (140, 175, 130)


if im_type == 'int':
    datalist = os.listdir(os.path.join(data_dir, "int"))
    data_dir2 = os.path.join(data_dir, "int")
elif im_type == 'rav':
    datalist = os.listdir(os.path.join(data_dir, "rav"))
    data_dir2 = os.path.join(data_dir, "rav")
elif im_type == 'int_rav':
    datalist = os.listdir(os.path.join(data_dir, "int_rav"))
    data_dir2 = os.path.join(data_dir, "int_rav")


train_image = []
train_label = []
validation_image = []
validation_label = []

start = fold*170
end = (fold+1)*170

for i in range(len(datalist)):
    nifti_name = datalist[idx[i]]
    if iq == 'all':
        if iq_type == 'absolute':
            iq_label = [float(nifti_name.split('_')[2]), float(nifti_name.split('_')[6]), float(nifti_name.split('_')[10])]
        else:
            iq_label = [float(nifti_name.split('_')[4]), float(nifti_name.split('_')[8]), float(nifti_name.split('_')[12])]
        no_out_classes = 3
    elif iq == 'fiq':
        if iq_type == 'absolute':
            iq_label = float(nifti_name.split('_')[2])
        else:
            iq_label = float(nifti_name.split('_')[4])
        no_out_classes = 1
    elif iq == 'viq':
        if iq_type == 'absolute':
            iq_label = float(nifti_name.split('_')[6])
        else:
            iq_label = float(nifti_name.split('_')[8])
        no_out_classes = 1
    elif iq == 'piq':
        if iq_type == 'absolute':
            iq_label = float(nifti_name.split('_')[10])
        else:
            iq_label = float(nifti_name.split('_')[12])
        no_out_classes = 1

    if i >= start and i < end:
        validation_image.append(os.path.join(data_dir2, nifti_name))
        validation_label.append(iq_label)
    else:
        train_image.append(os.path.join(data_dir2, nifti_name))
        train_label.append(iq_label)

train_label = torch.tensor(train_label)
validation_label = torch.tensor(validation_label)

print(no_out_classes)
print(validation_label)
print(arch)

# Define transforms
train_transforms = Compose([ScaleIntensity(), AddChannel(), EnsureType()])
val_transforms = Compose([ScaleIntensity(), AddChannel(), EnsureType()])

# create a training data loader
train_ds = ImageDataset(image_files=train_image, labels=train_label, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=1, pin_memory=pin_memory)

# create a validation data loader
val_ds = ImageDataset(image_files=validation_image, labels=validation_label, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=pin_memory)

# Create DenseNet121/ResNet18, MSELoss and Adam optimizer
if arch == 'densenet169':
    model = densenet169(spatial_dims=3, in_channels=1, out_channels=no_out_classes).to(device)
if arch == 'densenet':
    model = densenet121(spatial_dims=3, in_channels=1, out_channels=no_out_classes).to(device)
elif arch == 'resnet50':
    model = generate_model(model_depth=50, n_input_channels=1, n_classes=no_out_classes, shortcut_type='B').to(device)
elif arch == 'resnet':
    model = generate_model(model_depth=18, n_input_channels=1, n_classes=no_out_classes, shortcut_type='B').to(device)


1
tensor([115.,  88.,  87.,  96., 100., 119., 113., 104., 127.,  98.,  92., 113.,
        113., 145., 107., 180., 111., 129., 111.,  88., 106., 119., 121., 106.,
        129., 129., 126., 126.,  94., 121., 135., 136.,  91., 107., 118.,  87.,
        108., 113.,  89., 149., 115., 126., 102., 111., 118.,  67.,  99., 119.,
        101.,  98., 111., 129.,  99., 111., 106., 125.,  99., 108., 102., 101.,
        113., 118., 113.,  78., 112.,  99., 124.,  84., 108., 117., 118., 121.,
        109., 127., 100.,  98., 129., 107., 131., 103., 121., 102., 121.,  93.,
         98., 109.,  83., 101., 116., 123., 121.,  97.,  94., 104.,  86., 125.,
        108., 106.,  97., 119., 107.,  88.,  98.,  81., 126., 106., 106.,  95.,
         72.,  96.,  91., 121.,  85.,  90., 133., 108., 119., 112., 136., 120.,
         98., 123., 126.,  93., 139., 115.,  98.,  98.,  95.,  99., 102.,  98.,
        115., 116., 115., 106., 113.,  90., 103., 117., 106.,  99.,  96.,  99.,
        108., 130.,  98., 116.,  87., 

In [None]:
model.load_state_dict(torch.load(f'/home/marafath/scratch/abide_saved_models/{im_type}_{arch}_{iq}_{iq_type}_f{fold}.pth'))
model.eval()
#cam = monai.visualize.GradCAM(nn_module=model, target_layers="class_layers.relu") # for densenet121 and densenet169
cam = monai.visualize.GradCAM(nn_module=model, target_layers="layer4.2") # for Resnet50
#cam = monai.visualize.GradCAM(nn_module=model, target_layers="layer4.1") # for Resnet18

In [None]:
for name, param in model.named_modules(): print(name, param)

In [None]:
from torchsummary import summary
summary(model.cuda(), (1, 140, 175, 130))

In [None]:
import nibabel as nib
count = 0
for val_data in val_loader:
    val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
    im = val_images.cpu().detach().numpy()
    cam_result = cam(x=val_images, class_idx=None)
    c = cam_result.cpu().detach().numpy()
    im1 = im.squeeze(0).squeeze(0)
    c1 = c.squeeze(0).squeeze(0)


    im1 = nib.Nifti1Image(im1, np.eye(4))
    nib.save(im1, f'/home/marafath/scratch/im_gradcam/{im_type}_{arch}_{iq}_{iq_type}_f{fold}_image_{count}.nii.gz')

    c1 = nib.Nifti1Image(c1, np.eye(4))
    nib.save(c1, f'/home/marafath/scratch/im_gradcam/{im_type}_{arch}_{iq}_{iq_type}_f{fold}_gradcam_{count}.nii.gz')

    count += 1

    '''
    plt.figure('check', (18, 6))
    plt.subplot(1, 4, 1)
    plt.title('CT-65')
    plt.imshow(im1[:, :, 65], cmap='gray')
    plt.subplot(1, 4, 2)
    plt.title('GradCAM')
    plt.imshow(c1[:, :, 65])
    plt.subplot(1, 4, 3)
    plt.title('CT-100')
    plt.imshow(im1[:, :, 100], cmap='gray')
    plt.subplot(1, 4, 4)
    plt.title('GradCAM')
    plt.imshow(c1[:, :, 100])
    plt.show()
    '''



In [None]:
import nibabel as nib
count = 0
for val_data in val_loader:
    val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
    im = val_images.cpu().detach().numpy()
    cam_result = cam(x=val_images, class_idx=None)
    c = cam_result.cpu().detach().numpy()
    im1 = im.squeeze(0).squeeze(0)
    c1 = c.squeeze(0).squeeze(0)

    im11 = im1
    im11[im1 > 0] = 1
    c11 = c1*im11


    im1 = nib.Nifti1Image(im1, np.eye(4))
    nib.save(im1, f'/home/marafath/scratch/im_gradcam/{im_type}_{arch}_{iq}_{iq_type}_f{fold}_image_{count}.nii.gz')

    c11 = nib.Nifti1Image(c11, np.eye(4))
    nib.save(c11, f'/home/marafath/scratch/im_gradcam/{im_type}_{arch}_{iq}_{iq_type}_f{fold}_gradcam_{count}.nii.gz')

    count += 1

    if count == 2:
        break

    '''
    plt.figure('check', (18, 6))
    plt.subplot(1, 4, 1)
    plt.title('CT-65')
    plt.imshow(im1[:, :, 65], cmap='gray')
    plt.subplot(1, 4, 2)
    plt.title('GradCAM')
    plt.imshow(c1[:, :, 65])
    plt.subplot(1, 4, 3)
    plt.title('brain mask')
    plt.imshow(im11[:, :, 65], cmap='gray')
    plt.subplot(1, 4, 4)
    plt.title('GradCAM masked')
    plt.imshow(c11[:, :, 65])
    plt.show()
    '''


In [None]:
model.load_state_dict(torch.load(f'/home/marafath/scratch/abide_saved_models/{im_type}_{arch}_{iq}_{iq_type}_f{fold}.pth'))
model.eval()
#for name, _ in model.named_modules(): print(name)
cam = monai.visualize.GradCAM(nn_module=model, target_layers="layer4.1")
# cam = monai.visualize.GradCAMpp(nn_module=model_3d, target_layers="class_layers.relu")
print("original feature shape",cam.feature_map_size([1, 1] + list(win_size), device),)
print("upsampled feature shape", [1, 1] + list(win_size))

original feature shape torch.Size([1, 1, 9, 6, 5])
upsampled feature shape [1, 1, 140, 175, 130]
