In [None]:
import torch
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import h5py
import patchify

from utils.lodopab_dataset import LodopabDataset

In [None]:
file_path = '../datasets/LoDoPaB/ground_truth_train/ground_truth_train_000.hdf5'
f = h5py.File(file_path, 'r')
dataset = f['data']
dataset.shape

In [None]:
fig, ax = plt.subplots(1, 2)
ax[0].imshow(dataset[0, :64, :8], )
ax[0].grid(False)
ax[0].set_title('Slice')
ax[1].imshow(dataset[0])
ax[1].grid(False)
ax[1].set_title('Full example')
plt.suptitle('First example')
plt.show()

In [None]:
data_range = (np.min(dataset), np.max(dataset))
print(f'data range: {data_range}')
print(f'num examples: {dataset.shape[0]}')
print(f'image size: {dataset.shape[1:]}')

In [None]:
N = dataset.shape[-1]
P = 8
patches = patchify.patchify(dataset[0], patch_size=(P, P), step=1)

In [None]:
fig, ax = plt.subplots(8, 8)
for i in range(8):
    for j in range(8):
        a = ax[i, j]
        a.imshow(patches[8 * i + j, 0], )
        a.get_xaxis().set_ticks([])
        a.get_yaxis().set_ticks([])
plt.suptitle('First 64 patches')
plt.show()

In [None]:
flat_patches = patches.reshape(-1, P, P)
n_patches_per_image = flat_patches.shape[0]
n_patches_calc = (N - P + 1) ** 2

print(f'calculated number of patches per image: {n_patches_calc}')
print(f'actual number of patches per image: {n_patches_per_image}')
print(f'total number of patches: {128 * n_patches_per_image}')

In [None]:
lodo_dataset = LodopabDataset(file_path=file_path, patch_size=8, print_output=True)
lodo_loader = DataLoader(lodo_dataset, batch_size=126025, shuffle=False)

In [None]:
loaded_patches = next(iter(lodo_loader)).reshape(355, 355, 1, 8, 8).squeeze().numpy()
loaded_patches.shape

In [None]:
reconstructed_image = patchify.unpatchify(loaded_patches, (362, 362))

plt.imshow(reconstructed_image)
plt.grid(False)
plt.suptitle('Reconstruction')
plt.show()