In [2]:
# imports
from pathlib import Path
from astrocast.denoising import Network, PyTorchNetwork, SubFrameDataset, SubFrameGenerator

# General considerations

This training is generally a one-time requirement; once completed, the model can be reapplied to subsequent datasets without the need for retraining. The training phase can range from 1 to 12 hours, depending on the complexity of the data and computational resources. The architecture used in astroCAST is modeled on the suggested network published in [Lecoq et. al. 2021](https://doi.org/10.1038/s41592-021-01285-2). 

### Table 1: Parameters for denoiser training

| Parameter               | Default value         | Comment                                                                                                         |
|-------------------------|-----------------------|-----------------------------------------------------------------------------------------------------------------|
| **Input Data**          |                       |                                                                                                                 |
| `training_files`        | `path/to/training/files` | Path to your training files. Accepts individual file paths or multiple files (e.g., `./train_*.h5`).           |
| `validation_files`      |                       | Path to your validation files, similar to training files.                                                       |
| `loc`                   | `dataset_name`        | Specifies the location inside `.h5` files.                                                                      |
| `batch_size` / `batch_size_val` | `32`          | Number of data samples processed at once. Can be reduced if memory is a constraint.                             |
| `max_per_file` / `max_per_val_file` | `16`      | Limits the number of training samples per video to avoid overtraining. Aim for one to few minutes per training epoch. |
| **Data Transformation** |                       |                                                                                                                 |
| `train_rotation`        | `(1, 2, 3)`           | Allows rotating training images to increase data variety. Rotation in steps of 90°.                             |
| `train_flip`            | `(0, 1)`              | Allows flipping training images, again for variety.                                                             |
| `normalize`             | `global`              | Controls scaling of input data. Global normalization is recommended.                                            |
| `batch_normalize`       | `False`               | Optional feature to further scale data during training.                                                         |
| **Denoiser Architecture** |                     |                                                                                                                 |
| `input_size`            | `(256, 256)`          | Size of the area denoised at once. Should ideally match your frame size.                                        |
| `n_stacks`              | `2`                   | Layers in the encoder and decoder. More layers might improve quality but also increase complexity.              |
| `kernel`                | `32`                  | Complexity of the initial layer. Higher values may improve results but increase risk of overfitting.            |
| `pre_post_frames`       | `5`                   | Frames around the target frame used for denoising. More frames might improve quality but increase requirements.  |
| `gap_frames`            | `0`                   | Skips frames immediately before and after the target frame for better denoising in some cases.                  |
| **Training Parameters** |                       |                                                                                                                 |
| `epochs`                | `10`                  | Maximum training cycles. Can be high if using early stopping.                                                   |
| `loss`                  | `annealed_loss`       | Loss function used to assess reconstruction quality.                                                           |
| `learning_rate`         | `0.001`               | Speed of model learning. Too high values can make the model unstable.                                           |
| `decay_rate`, `decay_steps` | `0.99`, `250`     | Gradually reduces the learning rate for stability.                                                              |
| `patience`              | `3`                   | Stops training if no improvement after this many cycles.                                                        |
| `min_delta`             | `0.001`               | What's considered an "improvement" in model performance.                                                        |
| `pretrained_weights`    | `None`                | Use weights from a previous model to speed up training or for transfer learning.                                |
| `use_cpu`               | `True`                | Use CPU or GPU for training. GPU is faster for training.                                                        |
| `in_memory`             | `False`               | Toggle if training data is loaded into memory. Not recommended due to high memory usage.                        |
| **Output Parameters**   |                       |                                                                                                                 |
| `save_path`             | `None`                | Where to save your trained model.                                                                               |

### **CRITICAL** considerations

#### Architecture
The settings of the *Denoiser Architecture* cannot be changed without training a new model from scratch.

#### Overtraining
Overtraining is a significant concern when using the denoiser, as it employs a neural network internally. To mitigate this risk, the denoiser incorporates an automatic stopping mechanism that halts training if no further improvement is observed on the validation dataset. Nevertheless, caution must be exercised when selecting a high value for n_stacks or using a limited training dataset, as these factors can still increase the likelihood of overtraining. If overtraining does occur, it may result in high denoising quality for the training data but poor generalization for new data. 


# Set up training files
## Adjust settings

In [None]:
use_pytorch = True
use_cpu = True  # set to false if cuda compatible GPU is available

my_data_directory = Path("data/")  # change to your data directory
model_path = Path("models/")  # change to your output directory

## Set up folders for training and saving models

In [None]:
if not my_data_directory.exists():
    my_data_directory.mkdir(parents=True)

print(f"loading data from: {my_data_directory}")

# Optional: download training files or provide your own
# !astrocast download_datasets {{my_data_directory}}

model_path = my_data_directory.joinpath(model_path)
if not model_path.is_dir():
    model_path.mkdir(parents=True)
print(f"saving models to: {model_path}")

# Define training parameters

In [None]:
if use_pytorch:
    save_model_path = model_path.joinpath(f"trained_model.pth")
else:
    save_model_path = model_path.joinpath(f"trained_model.h5")

# training data parameters
input_size = (256, 256)
pre_post_frames = 5
gap_frames = 0
train_rotation = [1, 2, 3]
loc = "data/"  # dataset name in .h5 files

# architecture parameters
n_stacks = 2
kernel = 32
epochs = 50
patience = 5
min_delta = float(1e-4)
max_per_file = 16

# Training files
train_str = "private_data/train*.h5"  # change to your training data
if "*" in train_str:
    train_paths = list(my_data_directory.glob(train_str))
else:
    train_paths = my_data_directory.joinpath(train_str)

# Validation files
val_str = "private_data/test*.h5"  # change to your validation data
if "*" in val_str:
    val_paths = list(my_data_directory.glob(val_str))
else:
    val_paths = my_data_directory.joinpath(val_str)


### pytorch implementation

In [ ]:
if use_pytorch:
    
    # create training dataset
    train_dataset = SubFrameDataset(paths=train_paths, input_size=input_size, loc=loc,
                                    pre_post_frames=pre_post_frames, max_per_file=max_per_file,
                                    gap_frames=gap_frames, allowed_rotation=train_rotation, padding=None,
                                    normalize="global", in_memory=False, allowed_flip=[0, 1], shuffle=True)
    
    # create validation dataset
    val_dataset = None
    if val_paths is not None:
        val_dataset = SubFrameDataset(paths=val_paths, input_size=input_size, loc=loc,
                                      pre_post_frames=pre_post_frames, max_per_file=3,
                                      gap_frames=gap_frames, allowed_rotation=0, padding=None,
                                      normalize="global", in_memory=False, allowed_flip=-1, shuffle=True)
    
    # create network
    net = PyTorchNetwork(train_dataset, val_dataset=val_dataset, batch_size=16,
                         shuffle=True, num_workers=4,
                         learning_rate=0.001, momentum=0.9, decay_rate=0.1, decay_steps=30,
                         n_stacks=n_stacks, kernels=kernel, kernel_size=3,
                         batch_normalize=False, use_cpu=use_cpu)
    
    # train
    net.run(num_epochs=epochs,
            save_model=save_model_path,
            patience=patience,
            min_delta=min_delta)


### tensorflow implementation

In [ ]:
if not use_pytorch:
    
    train_gen = SubFrameGenerator(paths=train_paths, max_per_file=max_per_file, loc=loc,
                                  input_size=input_size,
                                  pre_post_frames=pre_post_frames, gap_frames=gap_frames,
                                  allowed_rotation=train_rotation,
                                  padding=None, batch_size=8, normalize="global", in_memory=False,
                                  allowed_flip=[0, 1], shuffle=True)
    
    # Validator
    if val_paths is not None:
        
        val_gen = SubFrameGenerator(
                paths=val_paths, max_per_file=3, loc=loc, input_size=input_size,
                pre_post_frames=pre_post_frames, gap_frames=gap_frames, allowed_rotation=[0],
                padding=None, batch_size=16, normalize="global", in_memory=False,
                cache_results=True,
                allowed_flip=[-1], shuffle=True)
    
    else:
        val_gen = None
    
    # Network
    net = Network(train_generator=train_gen, val_generator=val_gen, learning_rate=0.001, decay_rate=0.99,
                  pretrained_weights=None,
                  n_stacks=n_stacks, kernel=kernel,
                  batchNormalize=False, use_cpu=use_cpu)
    
    net.run(batch_size=1, num_epochs=epochs, patience=patience, min_delta=min_delta,
            save_model=save_model_path)