In [1]:
import torch
import collections
import matplotlib.pyplot as plt
import numpy as np
import random
import torch.nn.functional as F

In [2]:
import sys
sys.path.insert(1, '/home/buehlern/Documents/Masterarbeit/models')
from src.data.mri_datamodule import MRIDataModule
from src.models.swin_mae_module import SWINTransformerMAE
from src.models.components.mask_generator import MaskGenerator

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
mae = SWINTransformerMAE(image_size = 3072, patch_size = 8, encoder_stride = 64)

In [4]:
mri_datamodule = MRIDataModule(image_size = 3072, square = True, output_channels = 1, cache = False, fix_inverted = True)

initializing MRIDatasetBase ...
reading /home/buehlern/Documents/Masterarbeit/data/clean_df_slim_frac.pkl file ...
PATH /home/buehlern/Documents/Masterarbeit/data/BodyPartExamined_mappings_mergemore.json
/home/buehlern/Documents/Masterarbeit/data/cache-full/df_labelcomparison.pkl does not exit --> no items excluded by it
MRIDatasetBase(len=639877) initialized

initializing MRIDataset(mode=train) ...
MRIDataset(mode=train, len=516402) initialized

initializing MRIDataset(mode=val) ...
MRIDataset(mode=val, len=27518) initialized

initializing MRIDataset(mode=test) ...
WARN: including test data
MRIDataset(mode=test, len=95957) initialized


In [5]:
def show_image(image, title=''):
    # image is [H, W, 1]
    #print(image.shape)
    assert image.shape[2] == 1
    plt.imshow(image, cmap=plt.cm.bone)
    plt.title(title, fontsize=16)
    plt.axis('off')
    return

In [6]:
mask_generator = MaskGenerator(
    input_size = mae.image_size,
    mask_patch_size = mae.patch_size,
    model_patch_size = mae.patch_size,
    mask_ratio = mae.mask_ratio
)

In [10]:
def visualize(pixel_values, model, imgname=None):
    # simulate forward pass
    inputs = pixel_values # torch.stack([pixel_values])
    # Generate batch of masks
    bool_masked_pos = torch.stack([mask_generator() for item in inputs])
    outputs = mae.net(inputs, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=False)
    y = outputs.reconstruction
    y = torch.einsum('nchw->nhwc', y).detach().cpu()
    
    # visualize the mask
    mask = bool_masked_pos
    #print("mask.shape", mask.shape) # (1, 147456)
    num_patches = model.image_size // model.patch_size
    mask = mask.view(-1, num_patches, num_patches)
    #print("mask.shape", mask.shape) # (1, 384, 384)
    mask = F.interpolate(mask.unsqueeze(1).float(),
                         size=(model.image_size, model.image_size),
                         mode="nearest")
    #print("mask.shape", mask.shape) # (1, 1, 3072, 3072)
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
    #print("mask.shape", mask.shape) # (1, 3072, 3072, 1)
    
    x = torch.einsum('nchw->nhwc', pixel_values)

    # masked image
    im_masked = x * (1 - mask)

    # MAE reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask

    # make the plt figure larger
    plt.rcParams['figure.figsize'] = [24, 10]

    plt.subplot(1, 4, 1)
    show_image(x[0], "original")

    plt.subplot(1, 4, 2)
    show_image(im_masked[0], "masked")

    plt.subplot(1, 4, 3)
    show_image(y[0], f"reconstruction (loss: {outputs.loss.item():.4f})")

    plt.subplot(1, 4, 4)
    show_image(im_paste[0], "reconstruction + visible")

    if imgname is not None:
        plt.savefig('/home/buehlern/Documents/Masterarbeit/notebooks/Data Exploration Graphics/Model Eval/SWIN MAE Untrained/' + str(imgname) + '.png')
    plt.show()

In [8]:
dl_iter = iter(mri_datamodule.data_val)

In [None]:
item = next(dl_iter)
image = item[0]
batch = image.unsqueeze(0)
visualize(batch, mae, imgname="0")