In [8]:
from deepautoqc.ae_architecture import Autoencoder, Encoder, Decoder
import torch
import matplotlib.pyplot as plt
from deepautoqc.data_structures import load_from_pickle, BrainScan
import numpy as np
import torchio as tio

In [9]:
def load_to_tensor(img: np.ndarray) -> torch.Tensor:
    transform = tio.CropOrPad((3, 704, 800))
    img = img.transpose((2, 0, 1))
    img = torch.from_numpy(img)
    img = tio.ScalarImage(tensor=img[None])
    img = transform(img)
    img = img.data[0]

    return img.float()


In [10]:
def display_reconstruction(model, image):
    # Ensure the model is in evaluation mode
    model.eval()

    # Make sure the image is on the right device and has the expected dimensions
    print(image.shape)
    image = image.to(model.device).unsqueeze(0) # Add batch dimension if needed
    print(image.shape)
    # Pass the image through the autoencoder
    with torch.no_grad():
        reconstructed_image = model(image)

    # Convert to numpy and remove batch dimension
    original_image_np = image.squeeze(0).cpu().numpy()
    reconstructed_image_np = reconstructed_image.squeeze(0).cpu().numpy()

    # Assuming the images are single-channel, you might need to use squeeze to remove the channel dimension
    original_image_np = original_image_np.squeeze()
    reconstructed_image_np = reconstructed_image_np.squeeze()

    original_image_np = original_image_np.transpose((1, 2, 0))
    reconstructed_image_np = reconstructed_image_np.transpose((1, 2, 0))

    print("ORIG", original_image_np.min(), original_image_np.max())
    print("REC",reconstructed_image_np.min(), reconstructed_image_np.max())

    # Create a subplot to display the original and reconstructed images
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(original_image_np, cmap='gray')
    axes[0].set_title('Original Image')
    axes[0].axis('off')

    axes[1].imshow(reconstructed_image_np, cmap='gray')
    axes[1].set_title('Reconstructed Image')
    axes[1].axis('off')

    plt.show()

# Example usage:
# image = load_an_image() # Some function to load an image tensor
# display_reconstruction(model, image)


In [11]:
def is_anomaly(model, image, threshold):
    model.eval()
    
    image = image.to(model.device).unsqueeze(0)

    with torch.no_grad():
        reconstructed_image = model(image)

    error = torch.nn.functional.mse_loss(image, reconstructed_image)

    #classification = "normal" if error < threshold else "anomaly"
    return error


In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
latent_dim = 64
#model = Autoencoder(base_channel_size=32, latent_dim=latent_dim)

In [13]:
ckpt_path = "/Users/Dominik/Charite/DeepAutoQC/src/deepautoqc/ckpts/AE_384-22-1886.ckpt"
ckpt_path_2 = "/Users/Dominik/Charite/DeepAutoQC/src/deepautoqc/ckpts/AE_64-24-2050.ckpt"

In [14]:
#model.load_from_checkpoint(ckpt_path_2, map_location=torch.device('cpu'))
model = Autoencoder.load_from_checkpoint(ckpt_path, map_location=device)

In [15]:
model

Autoencoder(
  (encoder): Encoder(
    (net): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): GELU(approximate='none')
      (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): GELU(approximate='none')
      (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (5): GELU(approximate='none')
      (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): GELU(approximate='none')
      (8): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (9): GELU(approximate='none')
      (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): GELU(approximate='none')
      (12): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (13): GELU(approximate='none')
      (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (15): GELU(approximate='none')
      (16): Conv2d(256, 51

In [16]:
bad_subject_data =  load_from_pickle("/Volumes/PortableSSD/data/skullstrip_rpt_processed_unusable/ds-hcp_sub-122620_skull_strip_report_ds-hcp_sub-122620_report-skull.pkl")
good_subject_data = load_from_pickle("/Volumes/PortableSSD/data/skullstrip_rpt_processed_usable/ds-pnc_chunk-9_reports_sub-607733289129_report-skull.pkl")

In [None]:
for subject in bad_subject_data:
    img = subject.img
    img_tensor = load_to_tensor(img=img)

    #display_reconstruction(model=model, image=img_tensor)
    error = is_anomaly(model, image=img_tensor, threshold=0)
    print(error.item())

In [None]:
for subject in good_subject_data:
    img = subject.img
    img_tensor = load_to_tensor(img=img)

    #display_reconstruction(model=model, image=img_tensor)
    error = is_anomaly(model, image=img_tensor, threshold=0)
    print(error.item())

In [None]:
for subject in bad_subject_data:
    img = subject.img
    img_tensor = load_to_tensor(img=img)

    display_reconstruction(model=model, image=img_tensor)

In [None]:
for subject in good_subject_data:
    img = subject.img
    img_tensor = load_to_tensor(img=img)

    display_reconstruction(model=model, image=img_tensor)

In [17]:
import umap

  @numba.jit()
  @numba.jit()
  @numba.jit()
  @numba.jit()


In [18]:
encoder_outputs = []

In [19]:
for subject in bad_subject_data:
    img = subject.img
    img_tensor = load_to_tensor(img=img)

    img_tensor = img_tensor.to(model.device).unsqueeze(0)
    with torch.no_grad():
        encoder_output = model.encoder(img_tensor)
        encoder_outputs.append(encoder_output.cpu().numpy())
    

In [20]:
for subject in good_subject_data:
    img = subject.img
    img_tensor = load_to_tensor(img=img)

    img_tensor = img_tensor.to(model.device).unsqueeze(0)
    with torch.no_grad():
        encoder_output = model.encoder(img_tensor)
        encoder_outputs.append(encoder_output.cpu().numpy())

In [21]:
encoder_outputs = np.concatenate(encoder_outputs, axis=0)

In [None]:
reducer = umap.UMAP()
embedding = reducer.fit_transform(encoder_outputs)

OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


In [None]:
# Plot the result
plt.scatter(embedding[:, 0], embedding[:, 1])
plt.gca().set_aspect('equal', 'datalim')
plt.title('UMAP projection of the data', fontsize=24)
plt.show()