In [None]:
# plot sum of intensity for each slice VS slice number

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.insert(1, '/home/quahb/caipi_denoising/src')

from preparation.gen_data import get_train_data

In [None]:
X, y = get_train_data()

In [None]:
X.shape

# Reconstruction Artifact

## 1. Using `sklearn`

In [None]:
from preparation.preprocessing_pipeline import pad_square

from sklearn.feature_extraction.image import extract_patches_2d, reconstruct_from_patches_2d

from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

In [None]:
slice_i = 150

slc = X[slice_i]
slc = np.expand_dims(slc, 0)
slc = pad_square(slc)

slc = slc[0,:,:,0]
patches = extract_patches_2d(slc, (256, 256))
print(patches.shape)

recon = reconstruct_from_patches_2d(patches, (384, 384))

In [None]:
print(f'PSNR: {psnr(slc, recon)}')
print(f'SSIM: {ssim(slc, recon)}')

# figure, axis = plt.subplots(1, 2, figsize=(20,45))

# axis[0].imshow(slc, cmap='gray')
# axis[0].set(xlabel='min: {:.3f}, max: {:.3f}, mean: {:.3f}, std: {:.3f}'.format(np.min(slc), np.max(slc), np.mean(slc), np.std(slc)))

# axis[1].imshow(recon, cmap='gray')
# axis[1].set(xlabel='min: {:.3f}, max: {:.3f}, mean: {:.3f}, std: {:.3f}'.format(np.min(recon), np.max(recon), np.mean(recon), np.std(recon)))

figure, axis = plt.subplots(2, 2, figsize=(18,18))

im = slc
stats = [np.min(im), np.max(im), np.mean(im), np.std(im)]
axis[0, 0].imshow(im, cmap='gray')
axis[0, 0].set_title(f'Slice {slice_i}')
axis[0, 0].set(xlabel=f'min: {stats[0]}, max: {stats[1]}, mean: {stats[2]}, std: {stats[3]}')

im= recon
stats = [np.min(im), np.max(im), np.mean(im), np.std(im)]
axis[0, 1].imshow(im, cmap='gray')
axis[0, 1].set_title(f'Reconstructed')
axis[0, 1].set(xlabel=f'min: {stats[0]}, max: {stats[1]}, mean: {stats[2]}, std: {stats[3]}')

im=slc + recon
stats = [np.min(im), np.max(im), np.mean(im), np.std(im)]
axis[1, 0].imshow(im, cmap='gray')
axis[1, 0].set_title(f'Add Reconstructed')
axis[1, 0].set(xlabel=f'min: {stats[0]}, max: {stats[1]}, mean: {stats[2]}, std: {stats[3]}')

im=slc - recon
stats = [np.min(im), np.max(im), np.mean(im), np.std(im)]
axis[1, 1].imshow(im, cmap='gray')
axis[1, 1].set_title(f'Subtract Reconstructed')
axis[1, 1].set(xlabel=f'min: {stats[0]}, max: {stats[1]}, mean: {stats[2]}, std: {stats[3]}')


plt.show()

## 2. Using `patchify`

In [None]:
from patchify import patchify, unpatchify

In [None]:
slice_i = 150

slc = X[slice_i]
slc = np.expand_dims(slc, 0)
slc = pad_square(slc)

slc = slc[0,:,:,0]
print(slc.shape)
patches = patchify(slc, (256, 256), step=(128, 128))
print(patches.shape)

recon = unpatchify(patches, (384, 384))

In [None]:
figure, axis = plt.subplots(2, 2, figsize=(18,18))
axis[0,0].imshow(patches[0, 0], cmap='gray')
axis[0,1].imshow(patches[0, 1], cmap='gray')
axis[1,0].imshow(patches[1, 0], cmap='gray')
axis[1,1].imshow(patches[1, 1], cmap='gray')

In [None]:
print(f'PSNR: {psnr(slc, recon)}')
print(f'SSIM: {ssim(slc, recon)}')

# figure, axis = plt.subplots(1, 2, figsize=(20,45))

# axis[0].imshow(slc, cmap='gray')
# axis[0].set(xlabel='min: {:.3f}, max: {:.3f}, mean: {:.3f}, std: {:.3f}'.format(np.min(slc), np.max(slc), np.mean(slc), np.std(slc)))

# axis[1].imshow(recon, cmap='gray')
# axis[1].set(xlabel='min: {:.3f}, max: {:.3f}, mean: {:.3f}, std: {:.3f}'.format(np.min(recon), np.max(recon), np.mean(recon), np.std(recon)))

figure, axis = plt.subplots(2, 2, figsize=(18,18))

im = slc
stats = [np.min(im), np.max(im), np.mean(im), np.std(im)]
axis[0, 0].imshow(im, cmap='gray')
axis[0, 0].set_title(f'Slice {slice_i}')
axis[0, 0].set(xlabel=f'min: {stats[0]}, max: {stats[1]}, mean: {stats[2]}, std: {stats[3]}')

im= recon
stats = [np.min(im), np.max(im), np.mean(im), np.std(im)]
axis[0, 1].imshow(im, cmap='gray')
axis[0, 1].set_title(f'Reconstructed')
axis[0, 1].set(xlabel=f'min: {stats[0]}, max: {stats[1]}, mean: {stats[2]}, std: {stats[3]}')

im=slc + recon
stats = [np.min(im), np.max(im), np.mean(im), np.std(im)]
axis[1, 0].imshow(im, cmap='gray')
axis[1, 0].set_title(f'Add Reconstructed')
axis[1, 0].set(xlabel=f'min: {stats[0]}, max: {stats[1]}, mean: {stats[2]}, std: {stats[3]}')

im=slc - recon
stats = [np.min(im), np.max(im), np.mean(im), np.std(im)]
axis[1, 1].imshow(im, cmap='gray')
axis[1, 1].set_title(f'Subtract Reconstructed')
axis[1, 1].set(xlabel=f'min: {stats[0]}, max: {stats[1]}, mean: {stats[2]}, std: {stats[3]}')


plt.show()

# Raw intensities

In [None]:
intensities = [np.sum(slc) for slc in X]

In [None]:
plt.figure(figsize=(25,10))
plt.plot(range(len(X)), intensities, range(len(X)), [2900000 for i in range(len(X))], range(len(X)), [5500000 for i in range(len(X))])
plt.xlabel('Slice Number')
plt.ylabel('Sum of slice intensity')

In [None]:
slc_i = 220
subj_i = 30
print(np.sum(X[256 * subj_i + slc_i]))
plt.imshow(X[256 * subj_i + slc_i], cmap='gray')

In [None]:
rows = 5
cols = 5
img_i = 0

figure, axis = plt.subplots(rows, cols, figsize=(20,20))
for i in range(rows):
    for j in range(cols):
        axis[i, j].plot(range(256), intensities[256 * img_i: 256 * img_i + 256])
        img_i += 1

plt.show()

# Normalized Intensities

In [None]:
mean, std = np.mean(X), np.std(X)
norm_X = (X - mean) / std
norm_intensities = [np.sum(slc) for slc in norm_X]

In [None]:
plt.figure(figsize=(25,10)) 
plt.title('Sum of normalized slice intensities')
plt.plot(range(len(X)), norm_intensities, [-50000 for i in range(len(X))])

In [None]:
rows = 5
cols = 5
img_i = 0

figure, axis = plt.subplots(rows, cols, figsize=(20,20))
for i in range(rows):
    for j in range(cols):
        axis[i, j].plot(range(256), norm_intensities[256 * img_i: 256 * img_i + 256])
        img_i += 1

plt.show()

# Standardize Intensities by subject

In [None]:
std_X = np.zeros(X.shape)
for i in range(int(len(X) / 256)):
    subj_vol = X[i * 256: i * 256 + 256]
    min_val, max_val = np.min(subj_vol), np.max(subj_vol)
    num = subj_vol - min_val
    den = max_val - min_val
    std_X[i * 256: i * 256 + 256] = num / den

std_intensities = [np.sum(slc) for slc in std_X]
min(std_intensities)

In [None]:
count = 0
for i in std_intensities: 
    if i >= 5000: 
        count += 1
count

In [None]:
plt.figure(figsize=(25,10))
plt.plot(range(len(X)), std_intensities, [5000 for i in range(len(X))])
plt.xlabel('Slice Number')
plt.ylabel('Sum of slice intensity')

In [None]:
rows = 5
cols = 5
img_i = 0

figure, axis = plt.subplots(rows, cols, figsize=(20,20))
for i in range(rows):
    for j in range(cols):
        axis[i, j].plot(range(256), std_intensities[256 * img_i: 256 * img_i + 256])
        img_i += 1

plt.show()