# First Break Picking Prediction Demo

This notebook shows how to generate predictions for a new dataset using a pretrained model. Before
following it, we strongly suggest following the training notebook (`fbp_train_with_api.ipynb`)
located in the same folder. The cells below will use the model trained by that notebook by default.

Note that we do not retrain the provided model at all here, and assume that the new dataset is
"inside the distribution" of the training data that was previously used. If this is not the case,
the predictions may be completely useless.

Finally, note that this example will once again not use an external configuration file. Instead, it
will directly invoke the functions and class constructors that would use the content of such a
configuration file. This will help clarify the link between the content of these files, the role
of each parameter, and the step where they are involved.

In [None]:
# all imports are centralized here (helps identify environment issues before doing anything else!)

# these packages are part of the 'standard' library, and should be available in all environments
import functools
import glob
import os

# these packages are 3rd-party dependencies, and must be installed via pip or anaconda
import matplotlib.pyplot as plt
import torch.utils.data
import tqdm

# these packages are part of our API and must be manually installed (see the top-level README.md)
import hardpicks
import hardpicks.data.fbp.data_module as fbp_data_module
import hardpicks.data.fbp.gather_transforms as fbp_data_transforms
import hardpicks.metrics.fbp.utils as metrics_utils
import hardpicks.models.fbp.utils as model_utils
import hardpicks.models.fbp.unet as fbp_unet

## Model reloading

Instantiating a pretrained model is usually done in two steps: first, we need to use the
configuration of the model to recreate an identical copy of the network architecture that
was trained. Second, we need to ask PyTorch to reload the model's weights (i.e. the parameters
that were fitted during training) into that new copy of the model.

Here, since we used PyTorch-Lightning for training, this is simplified to one step. Under the hood,
PyTorch-Lightning manages a copy of the model's configuration directly inside the checkpoint. Thus,
our job is really simplified:

In [None]:
pretrained_model_path = "output/notebook_train_example/"  # this path is the one from the 1st demo!

print(f"Parsing pretrained model artifacts from: {pretrained_model_path}")
assert os.path.isdir(pretrained_model_path), \
    f"invalid pretrained model directory path: {pretrained_model_path}"
# the pretrained model weights file (or "checkpoint") has a name that varies a bit, we'll glob it
model_ckpt_path_pattern = os.path.join(pretrained_model_path, "best*.ckpt")
model_ckpt_paths = glob.glob(model_ckpt_path_pattern)
assert len(model_ckpt_paths) >= 1, \
    f"could not locate at least one 'best' checkpoint using: {model_ckpt_path_pattern}"
model_ckpt_path = model_ckpt_paths[-1]  # arbitrary: we'll keep the last if there are many
assert os.path.isfile(model_ckpt_path), \
    f"invalid checkpoint path: {model_ckpt_path}"

# alright, time to reinstantiate the model!
model = fbp_unet.FBPUNet.load_from_checkpoint(model_ckpt_path)

param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Reinstantiated a pretrained U-Net model with {(param_count / 1000000):2.1f}M parameters.")

## Data Loading

In this case, we will load the Denare Beach data from scratch, and generate predictions for it.
Luckily, since the Denare Beach HDF5 archives are structured the same way as the HDF5 archives
we previously used, we can once again rely on our gather parser implementation. The only new thing
we need to define is a "site info" dictionary that will provide basic info such as where the data
is located.

In [None]:
dataset_root_path = hardpicks.FBP_DATA_DIR  # change this path if your data is located elsewhere!
denare_dir_path = os.path.join(dataset_root_path, "Denare_2D")
denare_hdf5_path = os.path.join(denare_dir_path, "Denare_beach_dynamite_geom_2s_for_Mila.hdf")
assert os.path.isfile(denare_hdf5_path), \
    f"could not locate the Denare Beach HDF5 file at: {denare_hdf5_path}"
print(f"Will parse HDF5 data from: {denare_hdf5_path}")

# as long as the HDF5 structure is the same way as the previous, we can use the existing parser!
# ... we just need to define the proper site info dictionary and config and pass them to the API
test_site_info = {
    "site_name": "Denare",  # the name used to identify this site in logs/tables/plots/etc.
    "raw_hdf5_path": denare_hdf5_path,  # the path where the raw hdf5 file can be found
    "processed_hdf5_path": denare_hdf5_path,  # we won't be preprocessing the raw data, so same as above
    "receiver_id_digit_count": 3,  # this is used to decompose pegs into unique receiver identifiers
    "first_break_field_name": "SPARE1",  # specified in case the dataset provides multiple picks per trace
    "raw_md5_checksum": "332874c28971ab8029ac52bb9480fe0f",  # to make sure we're not using corrupted data
    "processed_md5_checksum": "332874c28971ab8029ac52bb9480fe0f",  # same as above!
}

The rest of the data loading from now on is similar to what was done in the training demo notebook.

**NOTE**: it is important to remember that if the pretrained model used a particular set of
preprocessing operations to prepare its input data, we should be using the SAME operations here.
Diverting from these would take the input data "out-of-distribution", and the predictions would
suffer!

In [None]:
generic_site_params = dict(
    convert_to_fp16=True,  # convert trace sample data to 16-bit floats (saves memory!)
    convert_to_int16=True,  # same as above, but for identifiers, picks, and other integer data
    preload_trace_data=True,  # we'll put everything in memory right away (should be <5GB)
    cache_trace_metadata=True,
    provide_offset_dists=True,  # finally, we'll generate new offset distance arrays/maps
)

test_data_parser = fbp_data_module.FBPDataModule.create_parser(
    site_info=test_site_info,
    site_params={
        "use_cache": False,
        "normalize_samples": True,  # this was used during training, we need to reuse it here too!
    },
    prefix="test",
    dataset_hyper_params=generic_site_params,
    segm_class_count=model.segm_class_count,
)
print(f"Test dataset parser ready with {len(test_data_parser)} gathers!")

test_data_loader = torch.utils.data.DataLoader(
    dataset=test_data_parser,
    batch_size=6,
    shuffle=False,
    num_workers=2,
    collate_fn=functools.partial(
        fbp_data_module.fbp_batch_collate,
        pad_to_nearest_pow2=True,
    ),
)
print(f"Test data loader ready with {len(test_data_loader)} minibatches!")

In [None]:
# once again, let's display an actual minibatch of random line gathers as a grid of 2D images
fig, axes = plt.subplots(6, figsize=(12, 12))
minibatch = next(iter(test_data_loader))
for gather_idx in range(6):
    # we'll use some utility functions that are already-written to convert gathers into images
    gather_image = model_utils.generate_pred_image(
        # note: the provided first breaks picks will be shown in green
        batch=minibatch,
        raw_preds=None,
        batch_gather_idx=gather_idx,
        segm_class_count=model.segm_class_count,
        segm_first_break_prob_threshold=0.,
        draw_prior=False,
    )
    ax = axes[gather_idx]
    ax.imshow(
        gather_image,
        interpolation="none",
        aspect="auto",
    )
fig.tight_layout()
plt.show()

## Prediction

To generate predictions, we could delegate everything to PyTorch-Lightning (and deal with how
it expects to receive/return the data), or call the underlying prediction functions directly
(PyTorch-style). Here, we will do the latter, but using the input tensor preparation function
that's already implemented with the model, and that PyTorch-Lightning would also rely on.

In [None]:
if torch.cuda.device_count():
    print("Will predict on GPU.")
    model = model.cuda()
else:
    print("Will predict on CPU.")

model.test_evaluator.reset()  # each time this cell is executed, we'll reset the evaluator...

loader_wrapper = tqdm.tqdm(test_data_loader, total=len(test_data_loader))
for batch_idx, batch in enumerate(loader_wrapper):  # loops over all minibatches in the test data loader
    with torch.no_grad():  # since we don't want to do backpropagation and track gradients like in training
        input_tensor = model_utils.prepare_input_features(
            batch,
            use_dist_offsets=model.use_dist_offsets,
            use_first_break_prior=model.use_first_break_prior,
        ).to(model.device).float()
        predictions = model(input_tensor)  # calls the forward pass of the model
        # NOTE: since our pre-trained model is a segmentation model (i.e. a U-Net encoder-decoder),
        # the "predictions" are actually a stack of class score maps, one for each of the gathers in
        # the provided minibatch. The shape of the 'predictions' tensor is thus:
        #    predictions.shape = (BATCH_SIZE, CLASS_COUNT, TRACE_COUNT, SAMPLE_COUNT)
        # ... to get actual first break pick predictions from this map, we need to search for the
        # sample in each trace (row) that maximizes the score of the "first break" class. We provide
        # a function to do this:
        predicted_picks = metrics_utils.get_regr_preds_from_raw_preds(
            raw_preds=predictions,
            segm_class_count=model.segm_class_count,
            prob_threshold=0.01,  # this sets the "minimum bar" for the confidence in a first break!
        )
        # ... our metrics evaluator actually calls that function under the hood!
        metrics = model.test_evaluator.ingest(batch, batch_idx, predictions.detach())
        # ... if we wanted to do something else with the predictions, we would do it here!

A useful part of the already-implemented model objects is that they contain a test set evaluator
with the metrics we were already using during training. This evaluator has been provided with the
model predictions at every step of the loop above, meaning we can just print how well the model did!

**NOTE**: this is based on the assumption that the groundtruth data was also packaged and passed as
a component of the minibatches by the data loader, and that it can be used by the "ingest" function.

In [None]:
test_results_map = model.test_evaluator.summarize()
print(f"Test results for {test_site_info['site_name']}:")
for key, val in test_results_map.items():
    print(f"\t{key}: {val}")

Finally, let's display predictions on top of the images we were showing earlier...

In [None]:
fig, axes = plt.subplots(6, figsize=(12, 24))
minibatch = next(iter(test_data_loader))
for gather_idx in range(6):
    with torch.no_grad():
        input_tensor = model_utils.prepare_input_features(
            minibatch,
            use_dist_offsets=model.use_dist_offsets,
            use_first_break_prior=model.use_first_break_prior,
        ).to(model.device).float()
        predictions = model(input_tensor)  # calls the forward pass of the model
    gather_image = model_utils.generate_pred_image(
        # note: the provided first breaks picks will be shown in green, predictions in red
        batch=minibatch,
        raw_preds=predictions,
        batch_gather_idx=gather_idx,
        segm_class_count=model.segm_class_count,
        segm_first_break_prob_threshold=0.,
        draw_prior=False,
        draw_prob_heatmap=False,
    )
    ax = axes[gather_idx]
    ax.imshow(
        gather_image,
        interpolation="none",
        aspect="auto",
    )
fig.tight_layout()
plt.show()


**Bonus**: as mentioned in our report, there are tons of ways to improve the quality of first break
picking predictions. Most of those are related to the training of the model, so we cannot do that
here. However, creating an ensemble is a test-time improvement that we can do. However, since
we only have a single model, we cannot create a conventional "model ensemble". Instead, we will
create an ensemble based on the augmentation of input gathers.

To keep things simple, we will only apply a single augmentation operations to our input gathers,
namely a trace-wise flip. This will create a pair of inputs whose predictions will be averaged,
resulting in (hopefully) a small boost in performance.

In [None]:
# we will create a 2nd data loader with a slightly modified collate function...
def flip_then_collate(list_of_gathers):
    for gather in list_of_gathers:
        # we flip gathers using an already implemented function that also flips metadata as needed!
        fbp_data_transforms.flip(gather)
    return fbp_data_module.fbp_batch_collate(
        list_of_gathers,
        pad_to_nearest_pow2=True,
    )

tta_data_loader = torch.utils.data.DataLoader(
    dataset=test_data_parser,
    batch_size=6,
    shuffle=False,
    num_workers=2,
    collate_fn=flip_then_collate,
)

# important note: we add padding AFTER the flip, so we'll need to be careful when unflipping below!

model.test_evaluator.reset()  # restart with a fresh evaluator
loader_wrapper = tqdm.tqdm(
    zip(test_data_loader, tta_data_loader),
    total=len(test_data_loader),
)
for batch_idx, (batch, batch_flip) in enumerate(loader_wrapper):
    with torch.no_grad():
        # the disadvantage to test-time-augmentation: we need to call the model multiple times...
        input_tensor = model_utils.prepare_input_features(
            batch,
            use_dist_offsets=model.use_dist_offsets,
            use_first_break_prior=model.use_first_break_prior,
        ).to(model.device).float()
        predictions = model(input_tensor).detach()

        # once again for the flipped batches...
        input_tensor_flip = model_utils.prepare_input_features(
            batch_flip,
            use_dist_offsets=model.use_dist_offsets,
            use_first_break_prior=model.use_first_break_prior,
        ).to(model.device).float()
        predictions_flip = model(input_tensor_flip).detach()

        assert predictions_flip.shape == predictions.shape

        # now we only need to combine the two prediction maps into a single one!
        # (combining classification scores can be easily done by averaging!)

        predictions_unflip = torch.flip(predictions_flip, [2])  # dim#2 = trace axis
        for gather_idx in range(len(predictions_flip)):
            # note: we need to unflip the right traces without touching the padding...
            expected_trace_count = batch_flip["trace_count"][gather_idx].item()
            expected_padding_count = predictions_flip.shape[2] - expected_trace_count
            predictions[gather_idx, :, 0:expected_trace_count, :] = torch.mean(
                torch.stack([
                    predictions[gather_idx, :, 0:expected_trace_count, :],
                    predictions_unflip[gather_idx, :, expected_padding_count:, :]
                ]),
                dim=0,
            )

        metrics = model.test_evaluator.ingest(batch, batch_idx, predictions)

# finally, let's see if the results improved!

tta_results_map = model.test_evaluator.summarize()
print(f"Test results (with TTA) for {test_site_info['site_name']}:")
for key, val in tta_results_map.items():
    diff = val - test_results_map[key]
    print(f"\t{key}: {val}  ({diff:+1.4f} difference with original)")