In [None]:
import torch
import torch.nn as nn
import torchio as tio
import sys
import os
import matplotlib.pyplot as plt
import numpy as np
import math

parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.insert(0, parent_dir)
from dataset.Dataset import OneSliceDataset, TranformedMaskedDataset
from encoder.encoder_decoder_model import Encoder, Autoencoder

In [None]:
# load sample from dataset
modalities = "t2w+adc+pet+mask"
transform = tio.RescaleIntensity(out_min_max=(0, 1), percentiles=(0, 99.5))
dataset = OneSliceDataset(root_dir="../../data", modality_transform=transform)
img_id = 3
img = dataset[img_id]["image"]

In [None]:
# Get reconstructed image
model = Autoencoder()
model.load_state_dict(torch.load("../encoder/checkpoints/encoder.pth"))
model.eval()

data = torch.tensor(img)
data = data.unsqueeze(0)
data = data.squeeze(dim=2).float()
print("d:", data.size())
output = model(data)

In [None]:
# Visualize before and after
def plot_slice(image):

    num_channels = image.shape[0]

    plt.figure(figsize=(15, 5))

    for channel in range(num_channels):
        slice_img = image[channel,0,:,:]
        plt.subplot(1, num_channels, channel + 1)  # Rows, columns, index
        plt.imshow(slice_img, cmap=plt.cm.Greys_r)

        plt.title(f'{modalities.split("+")[channel].upper()}')

    plt.show()

print("Original Image")
plot_slice(img)
print("Reconstructed Image")
output_img = output.detach().numpy()
output_img = np.expand_dims(output_img, axis=2)
output_img = np.squeeze(output_img, axis=0)
plot_slice(output_img)

In [None]:
# Get Encoded Image

encoded_images = model.encoder(data)

print(encoded_images.size())

def plot_encoded(image):

    num_channels = image.shape[1]

    plt.figure(figsize=(15, 4))

    for channel in range(num_channels):
        slice_img = image[0,channel,:,:]
        plt.subplot(math.ceil(num_channels / 10), 10, channel + 1)  # Rows, columns, index
        plt.imshow(slice_img, cmap=plt.cm.Greys_r)

    plt.show()

plot_encoded(encoded_images.detach().numpy())