# Training
Train the network using `dss.train`. The function will:
- load train/val/test data
- initialize the network
- save all parameters for reproducibility
- train the network and save the best network to disk
- run inference and evaluate the network using the test data.

In the process, four files will be created, all starting with the date_time the process was started (`YYYYMMDD_HHMMSS` as in `20192310_091032`):
- `*_params.yaml` - training parameters etc
- `*_arch.yaml` - network architecture
- `*_model.h5` -  model architecture and weights
- `*_results.h5` - evaluation results

Training can be invoked from within a script/notebook (see below) or from the command line (command line interface created by [defopt](https://defopt.readthedocs.io/en/stable/index.html)):
```shell
python -m dss.train --data-dir dat/dmel_single_raw.npy --save-dir res --model-name tcn --kernel-size 16 --nb-filters 16 --nb-hist 512 --nb-epoch 20 -i
```

The above example uses the dataset created by [1_prepare_data.ipynb](1_prepare_data.ipynb). Note that the training data set and the model parameters are chosen for demonstration purposes and do not yield state-of-the-art performance.
However, training for 10 epochs should yield a reasonably good model and takes around 15 minutes on a GPU. Training will typically converge after ~80 epochs.

In [1]:
import dss.train

In [2]:
help(dss.train.train)

Help on function train in module dss.train:

train(*, data_dir: str, y_suffix: str = '', save_dir: str = './', save_prefix: str = None, model_name: str = 'tcn', nb_filters: int = 16, kernel_size: int = 16, nb_conv: int = 3, use_separable: List[bool] = False, nb_hist: int = 1024, ignore_boundaries: bool = True, batch_norm: bool = True, nb_pre_conv: int = 0, pre_kernel_size: int = 3, pre_nb_filters: int = 16, pre_nb_conv: int = 2, verbose: int = 2, batch_size: int = 32, nb_epoch: int = 400, learning_rate: float = None, reduce_lr: bool = False, reduce_lr_patience: int = 5, fraction_data: float = None, seed: int = None, batch_level_subsampling: bool = False, tensorboard: bool = False, log_messages: bool = False, nb_stacks: int = 2, with_y_hist: bool = True, x_suffix: str = '')
    Train a DeepSS network.
    
    Args:
        data_dir (str): Path to the directory or file with the dataset for training.
                        Accepts npy-dirs (recommended), h5 files or zarr files.
        

Training can be invoked from within a script/notebook (see below) or from the command line (command line interface created by [defopt](https://defopt.readthedocs.io/en/stable/index.html)). For instance:
```shell
python -m dss.train --data-dir dat/dmel_single_raw.npy --save-dir res --model-name tcn --kernel-size 16 --nb-filters 16 --nb-hist 512 --nb-epoch 20 -i
```


In [3]:

dss.train.train(model_name='tcn',  # see `dss.models` for valid model_names
                data_dir='dat/dmel_single_raw.npy', 
                save_dir='res',
                nb_hist=512,
                kernel_size=16,
                nb_filters=16,
                ignore_boundaries=True,
                verbose=1,
                nb_epoch=2,
                log_messages=True)

_______________________________________________
activation_1 (Activation)       (None, 512, 16)      0           conv1d_3[0][0]                   
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 512, 16)      0           activation_1[0][0]               
__________________________________________________________________________________________________
spatial_dropout1d_1 (SpatialDro (None, 512, 16)      0           lambda_1[0][0]                   
__________________________________________________________________________________________________
conv1d_4 (Conv1D)               (None, 512, 16)      272         spatial_dropout1d_1[0][0]        
__________________________________________________________________________________________________
add_1 (Add)                     (None, 512, 16)      0           add[0][0]                        
                                                             

KeyboardInterrupt: 