In [None]:
import argparse
from datetime import datetime
import glob
import json
import os
import random
import sys
import cv2
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
import torch
from torch import nn
import torch.nn.functional as F
from torch import optim
from torch.optim.lr_scheduler import StepLR
from torch.utils import data
from torchvision import models
from vgg import VGG8

In [None]:
import nibabel
import numpy as np
import pandas as pd
from scipy.ndimage import convolve1d
from torch.utils import data

from utils import get_lds_kernel_window

def prepare_weights(df, reweight, lds, lds_kernel, lds_ks, lds_sigma):
    if reweight == 'none':
        return None

    bin_counts = df['iqbin'].value_counts()
    # num_per_label[i] = the number of subjects in the age bin of the ith subject in the dataset
    if reweight == 'inv':
        num_per_label = [bin_counts[bin] for bin in df['iqbin']]
    elif reweight == 'sqrt_inv':
        num_per_label = [np.sqrt(bin_counts[bin]) for bin in df['iqbin']]

    if lds:
        lds_kernel_window = get_lds_kernel_window(lds_kernel, lds_ks, lds_sigma)
        smoothed_value = pd.Series(
            convolve1d(bin_counts.values, weights=lds_kernel_window, mode='constant'),
            index=bin_counts.index)
        num_per_label = [smoothed_value[bin] for bin in df['iqbin']]

    weights = [1. / x for x in num_per_label]
    scaling = len(weights) / np.sum(weights)
    weights = [scaling * x for x in weights]
    return weights

class AgePredictionDataset(data.Dataset):
    def __init__(self, df, iq, iq_type, n_slices, im_type, reweight='none', lds=False, lds_kernel='gaussian', lds_ks=9, lds_sigma=1, labeled=True):
        self.df = df
        self.iq = iq
        self.iq_type = iq_type
        self.n_slices = n_slices # my edits
        self.im_type = im_type # my edits
        self.weights = prepare_weights(df, reweight, lds, lds_kernel, lds_ks, lds_sigma)
        self.labeled = labeled

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        if self.im_type == "int_rav": # my edits
            image = nibabel.load(row['path']).get_fdata() # my edits
            row['path'] = row['path'].replace('/reg_', '/ravens_') # my edits
            ravens = nibabel.load(row['path']).get_fdata() # my edits
            image = image + ravens # my edits
        else:
            image = nibabel.load(row['path']).get_fdata()

        #image = image[54:184, 25:195, 12:132] # Crop out zeroes --> original
        #st = 120 - (self.n_slices//2)
        #st = 150 - (self.n_slices//2)
        #ed = st + self.n_slices
        #image = image[st:ed, 25:195, 12:132] # # my edits; center 120th slice
        #image = image[50:190, 25:195, st:ed] #
        #image = np.transpose(image, (2, 0, 1))

        st = 75 - (self.n_slices//2) #new edit
        ed = st + self.n_slices #new edit
        image = image[50:190, 25:195, st:ed] #new edit
        image = np.transpose(image, (2, 0, 1)) #new edit

        image /= np.percentile(image, 95) # Normalize intensity

        if self.labeled:
            if self.iq_type == 'absolute':
                if self.iq == 'all':
                    iq = np.array([row['fiq'], row['viq'], row['piq']]) # my edits
                elif self.iq == 'fiq':
                    iq = np.array([row['fiq']])
                elif self.iq == 'piq':
                    iq = np.array([row['piq']])
                elif self.iq == 'viq':
                    iq = np.array([row['viq']])
            else:
                if self.iq == 'all':
                    iq = np.array([row['fiq_r'], row['viq_r'], row['piq_r']]) # my edits
                elif self.iq == 'fiq':
                    iq = np.array([row['fiq_r']])
                elif self.iq == 'piq':
                    iq = np.array([row['piq_r']])
                elif self.iq == 'viq':
                    iq = np.array([row['viq_r']])
            weight = self.weights[idx] if self.weights is not None else 1.
            return (image, iq, weight)
        else:
            return image

    def __len__(self):
        return len(self.df)


In [None]:

arch = 'vgg8' #'resnet18' # 'vgg8'
im_type = 'int' #, 'rav', 'int_rav'
n_slices = 100 #130 for resnet18
iq = 'fiq' # 'all' #, 'fiq', 'viq', 'piq'])
iq_type = 'absolute' #, 'residual'])
'''

arch = 'resnet18'
im_type = 'int'
n_slices = 130
iq = 'fiq'
iq_type = 'absolute'
'''

"\n\narch = 'resnet18'\nim_type = 'int' \nn_slices = 130\niq = 'fiq'\niq_type = 'absolute'\n"

In [None]:
def setup_model(arch, no_of_slices, n_classes, device):
    if arch == 'resnet18':
        model = models.resnet18(num_classes=n_classes)
        # Set the number of input channels to 130
        #model.conv1 = nn.Conv2d(130, 64, kernel_size=7, stride=2, padding=3, bias=False)
        model.conv1 = nn.Conv2d(no_of_slices, 64, kernel_size=7, stride=2, padding=3, bias=False)
    elif arch == 'vgg8':
        model = VGG8(in_channels=no_of_slices, num_classes=n_classes)
    else:
        raise Exception(f"Invalid arch: {arch}")
    model.double()
    model.to(device)
    return model

In [None]:
def validate(model, val_loader, device):
    model.eval()

    losses = []
    all_preds = []

    with torch.no_grad():
        for (images, ages, _) in val_loader:
            # When batch_size=1 DataLoader doesn't convert the data to Tensors
            if not torch.is_tensor(images):
                images = torch.tensor(images).unsqueeze(0)
            if not torch.is_tensor(ages):
                ages = torch.tensor(ages).unsqueeze(0)
            images, ages = images.to(device), ages.to(device)
            age_preds = model(images).view(-1)
            loss = F.l1_loss(age_preds, ages, reduction='mean')

            losses.append(loss.item())
            all_preds.extend(age_preds)

    return np.mean(losses), torch.stack(all_preds).cpu().numpy()

In [None]:
def correct_path(path):
    path = path.replace('/ABIDE/', '/ABIDE_I/')
    path = path.replace('/NIH-PD/', '/NIH_PD/')
    if im_type == "rav":
        path = path.replace('/reg_', '/ravens_')
    return path

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

checkpoint_dir = '/home/ch225256/iq_prediction/IQ_prediction/checkpoints/2729048'

if iq == 'all':
    n_classes = 3
else:
    n_classes = 1

data_list = '/home/ch225256/iq_prediction/IQ_prediction/folderlist2/all_iq_data_with_residuals_allok.list'
schema = {'id': str, 'fiq': float, 'fiq_r': float, 'piq': float, 'piq_r': float, 'viq': float, 'viq_r': float, 'sex': int, 'dg': int, 'age': float, 'path': str, 'site': int}
df = pd.read_csv(data_list, sep=' ', header=None, names=['id', 'fiq', 'fiq_r', 'piq', 'piq_r', 'viq', 'viq_r', 'sex', 'dg', 'age', 'path', 'site'], dtype=schema)
df['path'] = df['path'].apply(correct_path)

kf5 = KFold(n_splits = 5, random_state = 1, shuffle = True)

In [None]:
# For producing validation results only
i = 1
for fold in kf5.split(df):
    validation = df.iloc[fold[1]]
    val_dataset = AgePredictionDataset(validation, iq=iq, iq_type=iq_type, n_slices=n_slices, im_type=im_type)
    val_loader = data.DataLoader(val_dataset, batch_size=1)

    device = torch.device('cuda')
    model = setup_model(arch, n_slices, n_classes, device)

    ## Training and evaluation
    if arch == 'resnet18':
        prefix = '2drn'
    else:
        prefix = '2dvgg'

    model.load_state_dict(torch.load(f"{checkpoint_dir}/{im_type}_{prefix}_{iq}_{n_slices}sl_{iq_type}_f{i}.pth"))
    model.eval()

    losses = []
    all_preds = []

    with torch.no_grad():
        for (images, score, _) in val_loader:
            # When batch_size=1 DataLoader doesn't convert the data to Tensors
            if not torch.is_tensor(images):
                images = torch.tensor(images).unsqueeze(0)
            if not torch.is_tensor(score):
                score = torch.tensor(score).squeeze(-1)
            images, score = images.to(device), score.to(device)
            iq_preds = model(images).view(-1)
            loss = F.l1_loss(iq_preds, score, reduction='mean')

            losses.append(loss.item())
            all_preds.extend(iq_preds)

    print(f"Loss in fold {i} is: {np.mean(losses)}")
    i = i + 1
    break

In [None]:
# For selecting the layer for GradCAM
device = torch.device('cuda')
model = setup_model(arch, n_slices, n_classes, device)
## Training and evaluation
if arch == 'resnet18':
    prefix = '2drn'
else:
    prefix = '2dvgg'

model.load_state_dict(torch.load(f"{checkpoint_dir}/{im_type}_{prefix}_{iq}_{n_slices}sl_{iq_type}_f{i}.pth"))

if arch == 'resnet18':
    print(model.layer4[1].conv2)
else:
    print(model.conv41.conv1)

In [None]:
class InfoHolder():

    def __init__(self, heatmap_layer):
        self.gradient = None
        self.activation = None
        self.heatmap_layer = heatmap_layer

    def get_gradient(self, grad):
        self.gradient = grad

    def hook(self, model, input, output):
        output.register_hook(self.get_gradient)
        self.activation = output.detach()

def generate_heatmap(weighted_activation):
    raw_heatmap = torch.mean(weighted_activation, 0)
    heatmap = np.maximum(raw_heatmap.detach().cpu(), 0)
    heatmap /= torch.max(heatmap) + 1e-10
    return heatmap.numpy()

def superimpose(input_img, heatmap):
    img = cv2.cvtColor(input_img,cv2.COLOR_BGR2RGB)
    heatmap = cv2.resize(heatmap, (img.shape[0], img.shape[1]))
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    heatmap = np.transpose(heatmap, (1, 0, 2))

    superimposed_img = np.uint8(heatmap * 0.6 + img * 0.4)
    pil_img = cv2.cvtColor(superimposed_img,cv2.COLOR_BGR2RGB)
    return pil_img

def to_RGB(tensor):
    tensor = (tensor - tensor.min())
    tensor = tensor/(tensor.max() + 1e-10)
    image_binary = np.transpose(tensor.cpu().detach().numpy(), (1, 2, 0))
    image = np.uint8(255 * image_binary)
    return image

#image_binary = np.transpose(tensor.numpy(), (1, 2, 0))

def grad_cam(model, input_tensor, heatmap_layer, truelabel=None):
    info = InfoHolder(heatmap_layer)
    heatmap_layer.register_forward_hook(info.hook)

    output = model(input_tensor.unsqueeze(0))[0]
    truelabel = truelabel if truelabel else torch.argmax(output)

    output[truelabel].backward()

    weights = torch.mean(info.gradient, [0, 2, 3])
    activation = info.activation.squeeze(0)

    weighted_activation = torch.zeros(activation.shape)
    for idx, (weight, activation) in enumerate(zip(weights, activation)):
        weighted_activation[idx] = weight * activation

    heatmap = generate_heatmap(weighted_activation)
    input_image = to_RGB(input_tensor)
    #print(input_image.shape)
    input_image = input_image[:, :, 74:77]
    #print(input_image.shape)
    #input_image = np.transpose(input_image, (1, 2, 0))
    #print(input_image.shape)
    pil_img = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)

    return pil_img, superimpose(input_image, heatmap)

In [None]:
# ResNet18
import matplotlib.pyplot as plt
f = 0
for fold in kf5.split(df):
    f += 1
    validation = df.iloc[fold[1]]
    val_dataset = AgePredictionDataset(validation, iq=iq, iq_type=iq_type, n_slices=n_slices, im_type=im_type)
    val_loader = data.DataLoader(val_dataset, batch_size=1)

    device = torch.device('cuda')
    model = setup_model(arch, n_slices, n_classes, device)

    ## Training and evaluation
    if arch == 'resnet18':
        prefix = '2drn'
    else:
        prefix = '2dvgg'

    model.load_state_dict(torch.load(f"{checkpoint_dir}/{im_type}_{prefix}_{iq}_{n_slices}sl_{iq_type}_f{f}.pth"))
    model.eval()

    losses = []
    all_preds = []

    for (images, score, _) in val_loader:
        images, score = images.to(device), score.to(device)
        images = torch.squeeze(images)

        heatmap_layer = model.layer4[1].conv2
        image, cam_image = grad_cam(model, images, heatmap_layer)

        '''
        plt.figure
        plt.subplot(1, 2, 1)
        plt.imshow(image, cmap = 'gray')
        plt.subplot(1, 2, 2)
        plt.imshow(cam_combined)
        plt.show()


        plt.imshow(image)
        plt.savefig(f"/home/ch225256/iq_prediction/IQ_prediction/age_prediction/images/im_{prefix}_{im_type}_{n_slices}_{iq}_{iq_type}_f{i}_{j}.png")

        plt.imshow(cam_image)
        plt.savefig(f"/home/ch225256/iq_prediction/IQ_prediction/age_prediction/images/gradcam_{prefix}_{im_type}_{n_slices}_{iq}_{iq_type}_f{i}_{j}.png")
        '''
        break

    plt.figure(figsize=(15, 7))
    plt.subplot(1, 2, 1)
    plt.imshow(image, cmap = 'gray')
    plt.savefig(f"/home/ch225256/iq_prediction/IQ_prediction/age_prediction/images/resnet18_image_{f}.png")
    plt.subplot(1, 2, 2)
    plt.imshow(cam_image)
    plt.savefig(f"/home/ch225256/iq_prediction/IQ_prediction/age_prediction/images/resnet18_cam_{f}.png")
    plt.show()

In [None]:
# ResNet18 - GradCam Averaging
import matplotlib.pyplot as plt
f = 0
count = 0
for fold in kf5.split(df):
    f += 1
    validation = df.iloc[fold[1]]
    val_dataset = AgePredictionDataset(validation, iq=iq, iq_type=iq_type, n_slices=n_slices, im_type=im_type)
    val_loader = data.DataLoader(val_dataset, batch_size=1)

    device = torch.device('cuda')
    model = setup_model(arch, n_slices, n_classes, device)

    ## Training and evaluation
    if arch == 'resnet18':
        prefix = '2drn'
    else:
        prefix = '2dvgg'

    model.load_state_dict(torch.load(f"{checkpoint_dir}/{im_type}_{prefix}_{iq}_{n_slices}sl_{iq_type}_f{f}.pth"))
    model.eval()

    losses = []
    all_preds = []

    for (images, score, _) in val_loader:
        images, score = images.to(device), score.to(device)
        images = torch.squeeze(images)

        heatmap_layer = model.layer4[1].conv2
        image, cam_image = grad_cam(model, images, heatmap_layer)

        if count == 0:
            cam_combined = cam_image
        else:
            cam_combined = cam_combined + cam_image

        count += 1

cam_combined2 = cam_combined/count
image[image > 0] = 1
cam_combined2 = np.multiply(cam_combined2, image)

plt.figure
plt.subplot(1, 2, 1)
plt.imshow(image, cmap = 'gray')
plt.subplot(1, 2, 2)
plt.imshow(cam_combined2)
plt.show()

'''
plt.figure
plt.subplot(1, 2, 1)
plt.imshow(image, cmap = 'gray')
plt.subplot(1, 2, 2)
plt.imshow(cam_combined)
plt.show()


plt.imshow(image)
plt.savefig(f"/home/ch225256/iq_prediction/IQ_prediction/age_prediction/images/im_{prefix}_{im_type}_{n_slices}_{iq}_{iq_type}_f{i}_{j}.png")

plt.imshow(cam_image)
plt.savefig(f"/home/ch225256/iq_prediction/IQ_prediction/age_prediction/images/gradcam_{prefix}_{im_type}_{n_slices}_{iq}_{iq_type}_f{i}_{j}.png")
        '''

In [None]:
# VGG8
import matplotlib.pyplot as plt
f = 0
for fold in kf5.split(df):
    f += 1
    if f != 5:
        continue
    validation = df.iloc[fold[1]]
    val_dataset = AgePredictionDataset(validation, iq=iq, iq_type=iq_type, n_slices=n_slices, im_type=im_type)
    val_loader = data.DataLoader(val_dataset, batch_size=1)

    device = torch.device('cuda')
    model = setup_model(arch, n_slices, n_classes, device)

    ## Training and evaluation
    if arch == 'resnet18':
        prefix = '2drn'
    else:
        prefix = '2dvgg'

    model.load_state_dict(torch.load(f"{checkpoint_dir}/{im_type}_{prefix}_{iq}_{n_slices}sl_{iq_type}_f{f}.pth"))
    model.eval()

    losses = []
    all_preds = []

    j = 1
    for (images, score, _) in val_loader:
        images, score = images.to(device), score.to(device)
        images = torch.squeeze(images)

        heatmap_layer = model.conv41.conv1
        image, cam_image = grad_cam(model, images, heatmap_layer)


        plt.figure(figsize=(15, 7))
        plt.subplot(1, 2, 1)
        plt.imshow(image, cmap = 'gray')
        plt.savefig(f"/home/ch225256/iq_prediction/IQ_prediction/age_prediction/images/vgg8_image_{j}.png")
        plt.subplot(1, 2, 2)
        plt.imshow(cam_image)
        plt.savefig(f"/home/ch225256/iq_prediction/IQ_prediction/age_prediction/images/vgg8_cam_{j}.png")
        plt.show()

        j += 1

        '''
        plt.imshow(image)
        plt.savefig(f"/home/ch225256/iq_prediction/IQ_prediction/age_prediction/images/im_{prefix}_{im_type}_{n_slices}_{iq}_{iq_type}_f{i}_{j}.png")

        plt.imshow(cam_image)
        plt.savefig(f"/home/ch225256/iq_prediction/IQ_prediction/age_prediction/images/gradcam_{prefix}_{im_type}_{n_slices}_{iq}_{iq_type}_f{i}_{j}.png")
        '''
        #break

    '''
    plt.figure(figsize=(15, 7))
    plt.subplot(1, 2, 1)
    plt.imshow(image, cmap = 'gray')
    plt.savefig(f"/home/ch225256/iq_prediction/IQ_prediction/age_prediction/images/vgg8_image_{f}.png")
    plt.subplot(1, 2, 2)
    plt.imshow(cam_image)
    plt.savefig(f"/home/ch225256/iq_prediction/IQ_prediction/age_prediction/images/vgg8_cam_{f}.png")
    plt.show()
    '''

In [None]:
# VGG8 - Gradcam average
import matplotlib.pyplot as plt
f = 0
count = 0
for fold in kf5.split(df):
    f += 1
    if f != 5:
        continue
    validation = df.iloc[fold[1]]
    val_dataset = AgePredictionDataset(validation, iq=iq, iq_type=iq_type, n_slices=n_slices, im_type=im_type)
    val_loader = data.DataLoader(val_dataset, batch_size=1)

    device = torch.device('cuda')
    model = setup_model(arch, n_slices, n_classes, device)

    ## Training and evaluation
    if arch == 'resnet18':
        prefix = '2drn'
    else:
        prefix = '2dvgg'

    model.load_state_dict(torch.load(f"{checkpoint_dir}/{im_type}_{prefix}_{iq}_{n_slices}sl_{iq_type}_f{f}.pth"))
    model.eval()

    losses = []
    all_preds = []

    for (images, score, _) in val_loader:
        images, score = images.to(device), score.to(device)
        images = torch.squeeze(images)

        heatmap_layer = model.conv41.conv1
        image, cam_image = grad_cam(model, images, heatmap_layer)

        if count == 0:
            cam_combined = cam_image
        else:
            cam_combined = cam_combined + cam_image

        count += 1

cam_combined2 = cam_combined/count
image[image > 0] = 1
cam_combined2 = np.multiply(cam_combined2, image)

plt.figure
plt.subplot(1, 2, 1)
plt.imshow(image, cmap = 'gray')
plt.subplot(1, 2, 2)
plt.imshow(cam_combined2)
plt.show()