# First Break Picking Model Training Demo

This notebook shows how to locally train a model to predict first breaks on one of the already-known
survey sites. The AMLRT at Mila does not use notebooks very often, and this is why the code is
mostly designed to be executed via scripts in IDEs such as PyCharm. Nevertheless, this notebook
should give a good overview of the main steps required to train a model from scratch using the API.

This example is also kept fairly minimal: it does not support the exploration of hyperparameters
across multiple training sessions, and it does not support resuming an interrupted session, as that
would be impractical in a notebook. We will however still highlight some key concepts surrounding
hyperparameter exploration that are fundamental in the definition of experiments.

Since we are training a model from scratch here for a significant number of epochs, it is strongly
recommended to have a GPU on the machine where this notebook is executed. It will be detected and
used by PyTorch below.

Finally, note that this example will not use an external configuration file. Instead, it will
directly invoke the functions and class constructors that would use the content of such a
configuration file. This will help clarify the link between the content of these files, the role
of each parameter, and the step where they are involved. An example of a valid configuration file
(`unet-mini.yaml`) is also given in this folder, but it is meant to be executed in conjunction with
the `hardpicks/main.py` script.

In [None]:
# all imports are centralized here (helps identify environment issues before doing anything else!)

# these packages are part of the 'standard' library, and should be available in all environments
import functools
import logging
import os

# these packages are 3rd-party dependencies, and must be installed via pip or anaconda
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning.loggers
import torch.utils.data

# these packages are part of our API and must be manually installed (see the top-level README.md)
import hardpicks
import hardpicks.data.fbp.data_module as fbp_data_module
import hardpicks.data.fbp.site_info as fbp_site_info
import hardpicks.models.fbp.utils as model_utils
import hardpicks.models.fbp.unet as fbp_unet

In [None]:
# first, let's define the path to the folder where we'll be dumping all the training results/logs
output_root_path = "output/notebook_train_example/"  # this path is relative to the notebook itself
# make sure it doesn't exist already! (change the folder name otherwise)
output_root_path = os.path.abspath(output_root_path)
os.makedirs(output_root_path, exist_ok=False)
print(f"Experiment results will be in: {output_root_path}")

# next, we'll setup the high-level tensorboard stuff so that metrics/losses are logged somewhere
tensorboard_output_path = os.path.join(output_root_path, "tensorboard")
os.makedirs(tensorboard_output_path, exist_ok=True)
tbx_logger = pytorch_lightning.loggers.TensorBoardLogger(
    save_dir=tensorboard_output_path,
    name="default",
    default_hp_metric=False,
)
logging.getLogger().setLevel(logging.INFO)
# the output folder and logging is now properly set up, we'll focus on the data next...

## Data Loading

The `main.py` script essentially prepares the output directory structure and loggers just like we
did above, and then calls the `create_data_module` function of the
`hardpicks/data/data_loader.py` module. Since we are trying to train a model
using first break picking data, this function creates an `FBPDataModule` object that will parse a
provided configuration dictionary and return a "data loader factory". The goal of using such a
factory is to allow PyTorch Lightning to easily create new data loaders and use them across
different processes while avoiding unnecessarily costly raw-data copies in memory. Here, since we
will be training from scratch and invoking the creation of objects directly, we will avoid using
the factory object and instead create the data loaders (and parsers) directly.

A data loader in the context of PyTorch-based projects is essentially an object that loads and
combines data samples into tensors so that a model can ingest them. This means that all the
preprocessing and augmentation operations that must be applied to the raw data will be a
responsibility of the data loader object.

The data loader interface that is universally used for PyTorch-based project is
[defined by PyTorch itself](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader).
Most of the arguments it expects are related to how examples should be combined ('batched')
together, how memory should be managed, or how many processes should be used in parallel. The key
argument it expects is the `dataset` object it will be requesting examples from. That `dataset`
object (which we call in practice a 'dataset parser', or simply a 'parser') is the object that
holds all the logic required to read the raw data from the disk. It is also responsible for the
cleaning and preprocessing of the raw data. In our case, the parser is based on the
`ShotLineGatherDataset` class of the `hardpicks/data/fbp/gather_parser.py`
module, and is further wrapped into cleaning and preprocessing classes. We show how to instantiate
this parser below using existing API methods for a training and testing set:

In [None]:
# first, we need to identify which site we will be training on...
# ... we will use the predefined site information arrays from the API to select Lalor for this demo
lalor_site_info = fbp_site_info.get_site_info_by_name("Lalor")
assert os.path.isfile(lalor_site_info["raw_hdf5_path"]), \
    f"could not locate Lalor site raw data at: {lalor_site_info['raw_hdf5_path']}"
assert os.path.isfile(lalor_site_info["processed_hdf5_path"]), \
    f"could not locate Lalor site preprocessed data at: {lalor_site_info['processed_hdf5_path']}"

print("Training site info:")
for key, val in lalor_site_info.items():
    print(f"\t{key}: {val}")

# we need to define right away what kind of segmentation class setup we want; we'll do binary!
segmentation_class_count = 1  # one class of interest = first break, this is defined internally

# next, we'll specify a few site-level hyperparameters that will influence how the data is processed
# ...note that most of these hyperparameters can be omitted and they will default to proper values
lalor_site_params = {
    # this path points to a file that indicates bad gathers (by receiver/shot ids) to be avoided
    "rejected_gather_yaml_path":
        os.path.join(hardpicks.FBP_BAD_GATHERS_DIR, "bad-gather-ids_combined.yaml"),
    # if we were to re-instantiate the parser often, we could cache its metadata, but no need here
    "use_cache": False,
    # all trace samples should be normalized in order to improve model training behavior
    "normalize_samples": True,  # by default, the normalization will be independent for each trace
    # do not use a 'buffer' zone around the annotated first break picks to make predictions easier
    "segm_first_break_buffer": 0,
}

# the dataset parsers also need a couple hyperparameters to specify non-site-related settings
generic_site_params = dict(
    convert_to_fp16=True,  # convert trace sample data to 16-bit floats (saves memory!)
    convert_to_int16=True,  # same as above, but for identifiers, picks, and other integer data
    preload_trace_data=False,  # we cannot preload all the dataset into memory (it's too big!)
    cache_trace_metadata=True,  # we can however cache its metadata into memory... (faster!)
    provide_offset_dists=True,  # finally, we'll generate new offset distance arrays/maps
)

# during training, it's always good to use a couple of augmentation operations, so we'll define
# them here (this is how the `ShotLineGatherPreprocessor` class expects to see them defined)
lalor_site_augmentations = [
    {
        "type": "crop",
        "params": {
            "low_sample_count": 512,
            "high_sample_count": 1024,
            "max_crop_fraction": 0.333,
        },
    },
    {"type": "flip"},
]

# all the data-related hyperparameters are ready, now let's create our training/validation parsers!

lalor_eval_split_ratio = 0.15  # will split using 15% of shots for validation and 85% for training
lalor_train_parser = fbp_data_module.FBPDataModule.create_parser(
    site_info=lalor_site_info,
    site_params={
        **lalor_site_params,
        "augmentations": lalor_site_augmentations,  # augmentations are only used for training!
        "subset": {"eval_ratio": lalor_eval_split_ratio, "use_eval_split": False},
    },
    prefix="train",
    dataset_hyper_params=generic_site_params,
    segm_class_count=segmentation_class_count,
)
print(f"Training dataset parser ready with {len(lalor_train_parser)} gathers!")

lalor_valid_parser = fbp_data_module.FBPDataModule.create_parser(
    site_info=lalor_site_info,
    site_params={
        **lalor_site_params,
        "subset": {"eval_ratio": lalor_eval_split_ratio, "use_eval_split": True},
    },
    prefix="valid",
    dataset_hyper_params=generic_site_params,
    segm_class_count=segmentation_class_count,
)
print(f"Validation dataset parser ready with {len(lalor_valid_parser)} gathers!")

In [None]:
# once we have our parsers, it's time to create the data loaders to get batches of gathers!
# ...we need to use a special collate function to combine gathers of different sizes
collate_fn = functools.partial(
    fbp_data_module.fbp_batch_collate,
    pad_to_nearest_pow2=True,
)

# with that collate function defined, we can create the two data loaders directly...
train_data_loader = torch.utils.data.DataLoader(
    dataset=lalor_train_parser,
    batch_size=6,
    shuffle=True,  # always shuffle training data!
    num_workers=2,
    collate_fn=collate_fn,
)
print(f"Training data loader ready with {len(train_data_loader)} minibatches!")

valid_data_loader = torch.utils.data.DataLoader(
    dataset=lalor_valid_parser,
    batch_size=6,
    shuffle=False,  # never shuffle validation/test data!
    num_workers=2,
    collate_fn=collate_fn,
)
print(f"Validation data loader ready with {len(valid_data_loader)} minibatches!")


Now that we have a data loader for both the training and validation splits of the Lalor data, we
can display a couple examples for each, and prove that these two do not overlap based on the
intersection of the loaded gather IDs.

Note that each instance of a 'data sample' that can be loaded is a dictionary that contain
a line gather (i.e. an array of traces recorded along a single continuous receiver line). This data
comes with metadata, and everything is described using the following fields:
 - `Origin`: the name of the site that the data comes from (useful when we're mixing them!);
 - `shot_id`: the unique identifier of the shot that corresponds to the loaded gather;
 - `rec_line_id`: the unique identifier of the receiver line that corresponds to the loaded gather;
 - `rec_ids`: the unique identifier of the receivers that recorded the traces in the loaded gather;
 - `gather_id`: the unique identifier of the loaded gather (a 0-based int created by the parser);
 - `gather_trace_ids`: the unique identifier of the traces in the loaded sample;
 - `first_break_labels`: the array of first break pixel index labels (one for each trace);
 - `first_break_timestamps`: the array of first break timestamps (one for each trace);
 - `bad_first_breaks_mask`: a mask array that indicates which traces have an invalid first break label;
 - `rec_coords`: the array of receiver ground coordinates (one for each trace);
 - `rec_coords`: the array of shot ground coordinates (one for each trace);
 - `trace_count`: the number of traces in the loaded gather;
 - `sample_count`: the number of recorded seismic samples in each trace of the loaded gather;
 - `filled_first_breaks_mask`: a mask array that indicates which traces have interpolated labels;
 - `dead_rec_mask`: a mask array that indicates which traces have dead receivers (no amplitudes);
 - `samples`: the 2D array of recorded seismic samples (one vector per trace);
 - `segmentation_mask`: (if activated) the 2D segmentation class map used for training.

In [None]:
# let's display an actual minibatch of random line gathers as a grid of 2D images
fig, axes = plt.subplots(6, figsize=(9, 15))
minibatch = next(iter(train_data_loader))
for gather_idx in range(6):
    # we'll use some utility functions that are already-written to convert gathers into images
    gather_image = model_utils.generate_pred_image(
        # note: the provided first breaks picks will be shown in green
        batch=minibatch,
        raw_preds=None,
        batch_gather_idx=gather_idx,
        segm_class_count=segmentation_class_count,
        segm_first_break_prob_threshold=0.,
        draw_prior=False,
        draw_prob_heatmap=False,
    )
    ax = axes[gather_idx]
    ax.set_xlabel("time (ms)")
    ax.set_xticks(np.linspace(0, gather_image.shape[1], num=9, dtype=np.int32))
    real_sample_count = minibatch["samples"][gather_idx].shape[1]
    real_sample_rate = minibatch["sample_rate_ms"][gather_idx]
    ax.set_xticklabels(np.linspace(0, real_sample_count * real_sample_rate, num=9, dtype=np.int32))
    ax.set_ylabel("receiver (index)")
    ax.set_yticks(np.linspace(0, gather_image.shape[0], num=5, dtype=np.int32))
    real_trace_count = minibatch["samples"][gather_idx].shape[0]
    ax.set_yticklabels(np.linspace(0, real_trace_count, num=5, dtype=np.int32))
    ax.imshow(
        gather_image,
        interpolation="none",
        aspect="auto",
    )
fig.tight_layout()
plt.show()

# note that black borders in the displayed images are normal, these correspond to necessary padding!

In [None]:
# we'll do a quick intersection test to show that the metadata of the train/valid loaders is unique
train_gather_ids = [
    lalor_train_parser.get_meta_gather(idx)["gather_id"]
    for idx in range(len(lalor_train_parser))
]
valid_gather_ids = [
    lalor_valid_parser.get_meta_gather(idx)["gather_id"]
    for idx in range(len(lalor_valid_parser))
]
assert len(np.intersect1d(train_gather_ids, valid_gather_ids)) == 0, \
    "there is some overlap between the unique gather IDs of the two dataset parsers? oh-oh..."
print("All good, no intersection found!")


## Model Creation

Once the data is loaded and ready to be used, we can create a model and prepare it for training. In
theory, any kind of CNN with a semantic segmentation (encoder-decoder) setup should do the trick.
We will use a fairly well-known and flexible U-Net architecture here.

The important thing to note here is that since we're using PyTorch Lightning, the model has to be
derived from PyTorch Lightning's [LightningModule class](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html).
This interface gives the model the responsibility of implementing what to do during both the
training and validation steps. This means that it should manage its own loss function and define
which optimizer and scheduler combo to use. Therefore, the hyperparameters for all these settings
have to be provided to the model class's constructor, as shown below.

In [None]:
model_config = {
    # hyperparameters that define the architecture of the model:
    "unet_encoder_type": "resnet18",  # our CNN encoder will be based on the ResNet18 architecture
    "unet_decoder_type": "vanilla",  # our CNN decoder will be based on the original U-Net architecture
    "encoder_block_count": 5,  # the number of blocks that are defined by the encoder architecture
    "mid_block_channels": 0,  # the number of (extra) blocks to use between the encoder and decoder
    "decoder_block_channels": "[256, 128, 64, 32, 16]",  # the depth of the 1st conv layer in each decoder block
    "decoder_attention_type": None,  # the type of attention layer to use in the decoder (experimental!)
    "segm_class_count": segmentation_class_count,  # the segmentation class setup to use (1, 2, or 3)

    # hyperparameters that define what data to use as input from the loaded minibatches:
    "use_dist_offsets": True,  # toggles whether to use offset distances as extra input channels
    "use_first_break_prior": False,  # toggles whether to use fbp priors as an extra input channel (experimental!)
    "coordconv": False,  # toggles whether to use tensor space coordinates as extra input channels (experimental!)

    # hyperparameters that define the optimizer/scheduler/loss setups:
    "optimizer_type": "Adam",  # will use the Adam optimizer defined by PyTorch
    "optimizer_params": {  # specifies the arguments to be passed to the optimizer's constructor
        "lr": 0.002136,
        "weight_decay": 0.000001,
    },
    "scheduler_type": "StepLR",  # will use the StepLR scheduler defined by PyTorch
    "scheduler_params": {  # specifies the arguments to be passed to the scheduler's constructor
        "step_size": 10,
        "gamma": 0.1,
    },
    "update_scheduler_at_epochs": True,  # specifies that the scheduler should be updated each epoch
    "loss_type": "crossentropy",  # will use the BinaryCrossEntropy loss defined by PyTorch
    "loss_params": {},  # specifies the arguments to be passed to the loss function's constructor

    # hyperparameters that define the evaluator configuration (for metrics):
    "use_full_metrics_during_training": False,  # toggles whether to skip metric evaluation during training
    "eval_type": "FBPEvaluator",  # type of the evaluator that will be instantiated
    "segm_first_break_prob_threshold": 0.,  # minimum sensitivity threshold for first break predictions
    "eval_metrics": [  # list of metrics that should be evaluated during validation (we'll use only two)
        {"metric_type": "HitRate", "metric_params": {"buffer_size_px": 1}},
        {"metric_type": "MeanBiasError"},
    ],

    # other hyperparameters:
    "gathers_to_display": 10,  # specifies the number of gathers to render for tensorboard each epoch
    "use_checkpointing": False,  # specifies whether to use gradient checkpointing to lower GPU memory usage
    "max_epochs": 5,  # maximum number of epochs that training should run for
}

# let's instantiate the model using that config now!
model = fbp_unet.FBPUNet(model_config)

# last step: we'll give a reference to the tensorboard logger we created earlier to the model...
setattr(model, "_tbx_logger", tbx_logger)

param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Instantiated a U-Net model from scratch with {(param_count / 1000000):2.1f}M parameters.")



## Training

For the training itself, we also delegate the loop logic to PyTorch Lightning. The [trainer
object](https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html) receives the data
loaders and model, and it takes care of implementing the iterative forward/backward steps required
to update model parameters and evaluate its performance. Below, we show the minimal code required
to start this process; in the API, this is done in the `hardpicks/train.py`
module.

In [None]:
# if a GPU is available in the current environment, we'll use it
if torch.cuda.device_count():
    print("Will train on GPU.")
else:
    print("Will train on CPU.")

# if tensorboard is installed in your environment, you can launch it with the printed command!
print(f"Tensorboard can now be launched with:\n\t tensorboard --logdir {tensorboard_output_path}")

# this will save the 'best' model based on the hit rate evaluated by the model already
checkpoint_callback = pytorch_lightning.callbacks.ModelCheckpoint(
    dirpath=output_root_path,
    filename="best-{epoch:03d}-{step:06d}",
    monitor="valid/HitRate1px",
    mode="max",
)

# finally, we can create the trainer object...
trainer = pytorch_lightning.Trainer(
    logger=tbx_logger,
    callbacks=[checkpoint_callback],
    gpus=int(bool(torch.cuda.device_count())),
    max_epochs=15,
)
# ... and start training!
trainer.fit(
    model=model,
    train_dataloader=train_data_loader,
    val_dataloaders=valid_data_loader,
)

# if we get here, training is done! (the 'fit' call above will block)
best_model_path = os.path.abspath(checkpoint_callback.best_model_path)
print(f"Best model is saved at: {best_model_path}")

# to restore the best model and use it to generate predictions, we would use:
# model = model.load_from_checkpoint(best_model_path)

...now that you've got a trained model on a particular site, the next step would be to test it
on another site! See the `fbp_predict_with_api.ipynb` notebook in the same folder for an example.

There are two things we need to highlight about the kind of experiment we just ran here:
 - **We trained a model with a set of arbitrary hyperparameters that are likely suboptimal.** Doing
   a proper hyperparameter search (using e.g. Orion) is highly recommended when trying to obtain
   the maximum performance on a particular dataset. Using an off-the-shelf recipe (like we did here)
   can offer a good starting point, but every experiment deserves its own fine-tuning.
 - **We trained and validated on the same survey site, and did not test the final 'best' model.**
   There is no guarantee that the performance we see reported in tensorboard is not inflated due to
   overfitting on some patterns found across the shots and receiver lines of Lalor. This is why it
   is best to train on one site, validate on another, and finally test the 'best' models on another.