In [3]:
#from deepautoqc.ae_architecture import Autoencoder, Encoder, Decoder
from deepautoqc.ae_arch2 import Encoder_AE, Decoder_AE, Autoencoder
import torch
import matplotlib.pyplot as plt
from deepautoqc.data_structures import load_from_pickle, BrainScan
import numpy as np
import torchio as tio

In [4]:
def load_to_tensor(img: np.ndarray) -> torch.Tensor:
    transform = tio.CropOrPad((3, 704, 800))
    img = img.transpose((2, 0, 1))
    img = torch.from_numpy(img)
    img = tio.ScalarImage(tensor=img[None])
    img = transform(img)
    img = img.data[0]

    return img.float()


In [5]:
def display_reconstruction(model, image):
    # Ensure the model is in evaluation mode
    model.eval()

    # Make sure the image is on the right device and has the expected dimensions
    print(image.shape)
    image = image.to(model.device).unsqueeze(0) # Add batch dimension if needed
    print(image.shape)
    # Pass the image through the autoencoder
    with torch.no_grad():
        reconstructed_image = model(image)

    # Convert to numpy and remove batch dimension
    original_image_np = image.squeeze(0).cpu().numpy()
    reconstructed_image_np = reconstructed_image.squeeze(0).cpu().numpy()

    # Assuming the images are single-channel, you might need to use squeeze to remove the channel dimension
    original_image_np = original_image_np.squeeze()
    reconstructed_image_np = reconstructed_image_np.squeeze()

    original_image_np = original_image_np.transpose((1, 2, 0))
    reconstructed_image_np = reconstructed_image_np.transpose((1, 2, 0))

    print("ORIG", original_image_np.min(), original_image_np.max())
    print("REC",reconstructed_image_np.min(), reconstructed_image_np.max())

    # Create a subplot to display the original and reconstructed images
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(original_image_np, cmap='gray')
    axes[0].set_title('Original Image')
    axes[0].axis('off')

    axes[1].imshow(reconstructed_image_np, cmap='gray')
    axes[1].set_title('Reconstructed Image')
    axes[1].axis('off')

    plt.show()

# Example usage:
# image = load_an_image() # Some function to load an image tensor
# display_reconstruction(model, image)


In [6]:
def is_anomaly(model, image, threshold):
    model.eval()
    
    image = image.to(model.device).unsqueeze(0)

    with torch.no_grad():
        reconstructed_image = model(image)

    error = torch.nn.functional.mse_loss(image, reconstructed_image)

    #classification = "normal" if error < threshold else "anomaly"
    return error


In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
latent_dim = 64
#model = Autoencoder(base_channel_size=32, latent_dim=latent_dim)

In [8]:
#ckpt_path = "/Users/Dominik/Charite/DeepAutoQC/src/deepautoqc/ckpts/AE_384-22-1886.ckpt"
ckpt_path = "/Users/Dominik/Charite/DeepAutoQC/src/deepautoqc/ckpts/autoencoder-epoch=19-val_loss=0.00.ckpt"

In [9]:
#model.load_from_checkpoint(ckpt_path_2, map_location=torch.device('cpu'))
model = Autoencoder.load_from_checkpoint(ckpt_path, map_location=device, data_module=None)

In [10]:
model

Autoencoder(
  (encoder): Encoder_AE(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (decoder): Decoder_AE(
    (deconv1): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (deconv2): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (deconv3): ConvTranspose2d(32, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tr

In [11]:
bad_subject_data =  load_from_pickle("/Volumes/PortableSSD/data/skullstrip_rpt_processed_unusable/ds-hcp_sub-122620_skull_strip_report_ds-hcp_sub-122620_report-skull.pkl")
good_subject_data = load_from_pickle("/Volumes/PortableSSD/data/skullstrip_rpt_processed_usable/ds-pnc_chunk-9_reports_sub-607733289129_report-skull.pkl")

In [26]:
for subject in bad_subject_data:
    img = subject.img
    img_tensor = load_to_tensor(img=img)

    #display_reconstruction(model=model, image=img_tensor)
    error = is_anomaly(model, image=img_tensor, threshold=0)
    print(error.item())

0.004826251417398453
0.005108018405735493
0.005246909335255623
0.004962102044373751
0.004873546306043863
0.004788018297404051
0.004441829398274422
0.004180642776191235
0.004379771649837494
0.004226073622703552
0.0033179151359945536
0.004206518642604351
0.004475178197026253
0.00424937903881073
0.003019413910806179
0.0039051014464348555
0.0038943735416978598
0.003423169255256653
0.0033993409015238285
0.0028970001731067896
0.0021845384035259485


In [27]:
for subject in good_subject_data:
    img = subject.img
    img_tensor = load_to_tensor(img=img)

    #display_reconstruction(model=model, image=img_tensor)
    error = is_anomaly(model, image=img_tensor, threshold=0)
    print(error.item())

0.0069738030433654785
0.00784993451088667
0.008321424946188927
0.00819751899689436
0.00839664600789547
0.008240921422839165
0.006975465454161167
0.006753566209226847
0.008451194502413273
0.008425210602581501
0.006614774465560913
0.008393190801143646
0.008549758233129978
0.006958923768252134
0.006926208268851042
0.009643963538110256
0.010753598995506763
0.0090930862352252
0.008254941552877426
0.0068747722543776035
0.00451842462643981


In [None]:
for subject in bad_subject_data:
    img = subject.img
    img_tensor = load_to_tensor(img=img)

    display_reconstruction(model=model, image=img_tensor)

In [None]:
for subject in good_subject_data:
    img = subject.img
    img_tensor = load_to_tensor(img=img)

    display_reconstruction(model=model, image=img_tensor)

In [12]:
import umap

  @numba.jit()
  @numba.jit()
  @numba.jit()
  @numba.jit()


In [19]:
encoder_outputs = []
encoder_outputs_good = []
encoder_outputs_bad = []

In [20]:
for subject in bad_subject_data:
    img = subject.img
    img_tensor = load_to_tensor(img=img)
    #print(img_tensor.shape)
    img_tensor = img_tensor.to(model.device).unsqueeze(0)
    with torch.no_grad():
        encoder_output = model.encoder(img_tensor)
        #print(encoder_output.shape)
        encoder_outputs_bad.append(encoder_output.cpu().numpy())
    

In [21]:
for subject in good_subject_data:
    img = subject.img
    img_tensor = load_to_tensor(img=img)
    #print(img_tensor.shape)
    img_tensor = img_tensor.to(model.device).unsqueeze(0)
    with torch.no_grad():
        encoder_output = model.encoder(img_tensor)
        encoder_outputs_good.append(encoder_output.cpu().numpy())

In [22]:
#encoder_outputs = np.concatenate(encoder_outputs, axis=0)
#encoder_outputs_bad = np.array(encoder_outputs_bad).reshape(len(encoder_outputs_bad), -1)
encoder_outputs_bad = np.array(encoder_outputs_bad)
encoder_outputs_good = np.array(encoder_outputs_good)
print(encoder_outputs_good.shape)
print(encoder_outputs_bad.shape)


(21, 1, 128, 176, 200)
(21, 1, 128, 176, 200)


In [23]:
encoder_outputs = np.vstack([encoder_outputs_bad, encoder_outputs_good])
encoder_outputs.shape

(42, 1, 128, 176, 200)

In [25]:
encoder_outputs_avg = encoder_outputs.mean(axis=(3,4))
encoder_outputs_avg = encoder_outputs_avg.squeeze(1)
encoder_outputs_avg.shape

(42, 128)

In [29]:
reducer = umap.UMAP()
embedding = reducer.fit_transform(encoder_outputs_avg)

OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


: 

In [None]:
# Plot the result
plt.scatter(embedding[:, 0], embedding[:, 1])
plt.gca().set_aspect('equal', 'datalim')
plt.title('UMAP projection of the data', fontsize=24)
plt.show()