In [2]:
%load_ext autoreload
%autoreload 2

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import pandas as pd
import numpy as np

from PIL import Image

from brainiac.loader import ADNIClassificationDataset
from brainiac.utils import load_model
from brainiac.models import SimpleCNN
from brainiac.grad_cam import GradCAM

In [15]:
df = pd.read_csv('/home/basimova_nf/ADNI-processed/data.csv')
dataset = ADNIClassificationDataset(df)

images = dataset[0][0].unsqueeze(0)

In [35]:
def from_3d_img_to_frames(img):
    img = img.squeeze(0).squeeze(0)
    frames = [Image.fromarray(np.uint8(img.numpy()[i] * 255)) for i in range(img.shape[0])]
    frames = [fr.convert('P', palette = Image.ADAPTIVE) for fr in frames]
    return frames

In [6]:
def save_gif(images, path, duration=0.1):
    images[0].save(path, save_all=True, append_images=images[1:],
                   optimize=False, duration=duration, loop=0)

In [7]:
model = SimpleCNN()
model, _ = load_model(model, 'trained_model/CNN/CN-MCI-AD_Adam_10_4_0.0001_1e-05/model_epoch8.pth')

In [37]:
def make_regions_from(model, images, target_layers):
    
    target_layers = ['conv', 'pool']
    target_class = 0

    gcam = GradCAM(model, target_layers, mode='3D')
    
    _ = gcam.forward(images)
    
    ids_ = torch.LongTensor([[target_class]] * len(images))
    gcam.backward(ids=ids_)
    
    regions = gcam.generate(target_layer=target_layers[-1])
    
    return regions

In [38]:
regions = make_regions_from(model, images, ['conv', 'pool'])

In [39]:
gcam = (regions + images) / 2

In [40]:
gcam.shape

torch.Size([1, 1, 128, 96, 96])

In [41]:
frames = from_3d_img_to_frames(gcam)
save_gif(frames, 'out.gif')