## Apply Shifter to Dataset

In [None]:
import numpy as np
from common.image_preprocessing import data_all, data_temp, temp_from_original
from pathlib import Path, PurePath

sz = 128
cxr8_original_path = data_all / 'chest-nihcc' / 'by_class' / 'no_finding'
weights_cxr8_temp = temp_from_original(cxr8_original_path, 
                                        PurePath(f"{sz}x{sz}") / 'weights')

In [None]:
from cxr_projection.zebrastack_v0_model import create_autoencoder
encoder, decoder, autoencoder = create_autoencoder()

In [None]:
# load the most recent rgc weights
encode_fn = sorted(weights_cxr8_temp.glob("*_rgc_encoder.h5"))[-1]
print("Loading ", encode_fn)
encoder.load_weights(encode_fn)

decode_fn = sorted(weights_cxr8_temp.glob("*_rgc_decoder.h5"))[-1]
print("Loading ", decode_fn)
decoder.load_weights(decode_fn)

autoencoder_fn = sorted(weights_cxr8_temp.glob("*_rgc_autoencoder.h5"))[-1]
print("Loading ", autoencoder_fn)
autoencoder.load_weights(autoencoder_fn)

In [None]:
import time
cxr8_temp = temp_from_original(cxr8_original_path, PurePath(f"{sz}x{sz}") / 'clahe_processed')

processed_imgs = {}

start_time = time.time()
for npy_filepath in cxr8_temp.glob('*.npy'):
    img = np.load(npy_filepath)
    img = np.reshape(img, (img.shape[0],img.shape[1],1))
    processed_imgs[npy_filepath.stem] = img
    if len(processed_imgs) % 100 == 0:
        print(npy_filepath.stem, img.shape, end='\r')
end_time = time.time()

print(f"Loaded {len(processed_imgs)} npy in {end_time - start_time} seconds")

In [None]:
from common.shifter_ops import shift

processed_imgs_shifted = {}
for name in processed_imgs:
    img = processed_imgs[name]
    if name not in processed_imgs_shifted:
        processed_imgs_shifted[name] = shift(img, encoder)
        print()
        print(name, end='\r')