### 1. Import dependencies

In [None]:
from pathlib import Path

import tifffile
from tqdm.auto import tqdm
import numpy as np
from matplotlib import pyplot as pl

from n2v.models import N2VConfig, N2V
from n2v.utils.n2v_utils import manipulate_val_data
from n2v.internals.N2V_DataGenerator import N2V_DataGenerator

### 2. Define main parameters

In [None]:
# Folder containing noisy images
image_dir = './train_videos/'

# Name of the resulting model. If you already have the model with this name, it will be overwritten
model_name = 'wound_healing_n2v'

# Path where the model will be saved once trained
model_path = './models/'

# How many epochs (rounds) the network will be trained. 
# Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. 
number_of_epochs = 50

# N2V divides image into patches during training, this value is the size of the resulting patches
patch_shape=(64, 64)

# Input the percentage of your training dataset you want to use to validate the network during the training
percentage_validation = 10

# Limit the max number of generated patches in case of lower RAM
max_patches = 10000

### 3. Create data generator

In [None]:
datagen = N2V_DataGenerator()
images = datagen.load_imgs_from_directory(directory=image_dir, filter='*.tif', dims='TCYX')
patches = datagen.generate_patches_from_list(images, shape=patch_shape)
patches = patches[:max_patches]

total_number_of_patches = len(patches)
num_training_patches = total_number_of_patches - int(total_number_of_patches * (percentage_validation / 100))
X = patches[:num_training_patches]
X_val = patches[num_training_patches:]

### 4. Define model

In [None]:
config = N2VConfig(
    X, 
    blurpool=True,
    skip_skipone=True,
    unet_residual=False,
    n2v_manipulator='median', 
    unet_kern_size=3, 
    unet_n_first=64,      
    unet_n_depth=3, 
    train_steps_per_epoch=int(X.shape[0]/128), 
    train_epochs=number_of_epochs, 
    train_loss='mse', 
    batch_norm=True, 
    train_batch_size=128, 
    n2v_perc_pix=0.198, 
    n2v_patch_shape=patch_shape, 
    n2v_neighborhood_radius=5, 
    single_net_per_channel=False
)

model = N2V(config=config, name=model_name, basedir=model_path)

### 5. Train and save model

In [None]:
history = model.train(X, X_val)

model.export_TF(
    name='N2V', 
    description="", 
    authors=[""],
    test_img=X[0], axes='YXC',
    patch_shape=patch_shape
)

### 6. Plot training and validation loss

In [None]:
import matplotlib.pyplot as plt
from csbdeep.utils import plot_history

plt.figure(figsize=(16, 5))
plot_history(history, ['loss', 'val_loss']);

### 7. Prediction

In [None]:
data_folder = "./all_videos/" # path to the data you want to predict

data_filter = "*.tif" 

output_postfix = "denoised" # results will be saved in the same folder, with this postfix added to the name

# If you want to predict with a different model,
# uncomment this lines and add a path to the model:

# model_name = "custom model name"
# model_path = "custom model path"

In [None]:
model = N2V(
    config=None, 
    name=model_name, 
    basedir=model_path
)

samples = sorted(Path(data_folder).glob(data_filter))

if len(samples) == 0:
    print(f"No data found in folder: {data_folder}")

for sample in tqdm(samples):
    image = tifffile.imread(sample)
    image = np.swapaxes(image, 1, -1)  # CYX -> YXC
    
    result = []
    for timepoint in image:
        pred = model.predict(timepoint, axes='YXC')
        result.append(pred)
        
    result = np.stack(result)
    result = np.swapaxes(result, -1, 1)

    name = sample.stem
    result_path = sample.parent / f"{sample.stem}_{output_postfix}.tif"
    tifffile.imwrite(result_path, result, imagej=True, metadata={'axes': 'TCYX', 'mode': 'rgb'})