In [None]:
# Import necessary modules of astroCAST
from astroCAST.denoising import SubFrameGenerator, Network
from astroCAST.analysis import Video
from pathlib import Path

%load_ext memory_profiler

In [None]:
folder = Path("/media/janrei1/data/method_testing/")

input_file = folder.joinpath("22A7x9-1.h5") # your input file name here. 'h5' or 'tdb' is recommended for parallel processing 
assert input_file.is_file()

model_path = folder.joinpath("model")

infer_path = input_file.with_suffix(".denoised.h5")
out_loc = "inf/ch0"

In [None]:
param = dict(paths=input_file, loc="data/ch0", input_size=(256, 256), pre_post_frame=5, gap_frames=0,
                 normalize="global", in_memory=True)

# Train model

In [None]:
if not model_path.is_dir() or len(list(model_path.glob("*.h5")))<1:
    
    train_gen = SubFrameGenerator(padding=None, batch_size=64,max_per_file=640,
                               allowed_rotation=[1, 2, 3], allowed_flip=[0, 1], shuffle=True, **param)

    val_gen = SubFrameGenerator(padding=None, batch_size=16, max_per_file=16, cache_results=True,
                                       allowed_rotation=[0], allowed_flip=[-1], shuffle=True, **param)


    net = Network(train_generator=train_gen, val_generator=val_gen, n_stacks=3, kernel=32, batchNormalize=False, use_cpu=True)
    net.run(batch_size=1, 
            num_epochs=50,
            patience=2, min_delta=0.01, 
            save_model=model_path, load_weights=True)


# Denoise

In [None]:
def test_infer():
    
    if infer_path.is_file():
        infer_path.unlink()
    
    inf_gen = SubFrameGenerator(padding="edge", batch_size=2000, allowed_rotation=[0], allowed_flip=[-1], z_select=(0, 25),
                                    shuffle=False, max_per_file=None, **param)

    # inf_gen.infer(model=model_path, output=infer_path.as_posix(), out_loc=out_loc, rescale=True)
    inf_gen.infer(model=model_path, output=infer_path.with_suffix(".tiff").as_posix(), out_loc=out_loc, rescale=True, dtype=float)

    
test_infer()

In [None]:
# vid = Video(infer_path, h5_loc=out_loc)
# vid.show()